Skip to content

Commit 2039efa

Browse files
committed
Take out compute_emission_probability from njit functions
1 parent 5cd02d1 commit 2039efa

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

python/tests/beagle_numba.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -377,9 +377,9 @@ def compute_forward_matrix(
377377
# Get allele at genotyped position i on reference haplotype j.
378378
ref_a = ref_h[i, j]
379379
# Get emission probability.
380-
em_prob = compute_emission_probability(
381-
mismatch_probs[i], query_a == ref_a, num_alleles=num_alleles
382-
)
380+
em_prob = mismatch_probs[i]
381+
if query_a == ref_a:
382+
em_prob = 1.0 - (num_alleles - 1) * mismatch_probs[i]
383383
fwd_mat[i, j] = em_prob
384384
if i > 0:
385385
fwd_mat[i, j] *= scale * fwd_mat[i - 1, j] + shift
@@ -421,11 +421,9 @@ def compute_backward_matrix(
421421
query_a = query_h[iP1]
422422
for j in range(h):
423423
ref_a = ref_h[iP1, j]
424-
em_prob = compute_emission_probability(
425-
mismatch_probs[iP1],
426-
query_a == ref_a,
427-
num_alleles=num_alleles,
428-
)
424+
em_prob = mismatch_probs[iP1]
425+
if query_a == ref_a:
426+
em_prob = 1.0 - (num_alleles - 1) * mismatch_probs[iP1]
429427
bwd_mat[iP1, j] *= em_prob
430428
site_sum = np.sum(bwd_mat[iP1, :])
431429
scale = (1 - trans_probs[iP1]) / site_sum

0 commit comments

Comments
 (0)