Skip to content

Commit

Permalink
fix: fix test case for adj estimator
Browse files Browse the repository at this point in the history
  • Loading branch information
TomeHirata committed Sep 5, 2024
1 parent 5a5f9b2 commit 48fdf66
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 3 deletions.
2 changes: 0 additions & 2 deletions dte_adj/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,6 @@ def _compute_cumulative_distribution(
cumulative_distribution = np.zeros(n_loc)
superset_prediction = np.zeros((n_records, n_loc))
treatment_mask = treatment_arms == target_treatment_arm
if self.is_multi_task:
confounding_in_arm = confoundings[treatment_mask]
n_records_in_arm = len(confounding_in_arm)
if self.is_multi_task:
Expand Down Expand Up @@ -521,7 +520,6 @@ def _compute_cumulative_distribution(
subset_train_mask = (~superset_mask) & treatment_mask
subset_test_mask_inner = superset_mask[treatment_mask]
confounding_train = confoundings[subset_train_mask]
confounding_subset_test = confoundings[subset_test_mask]
binominal_train = binominal[subset_train_mask]
if len(np.unique(binominal_train)) == 1:
subset_prediction[subset_test_mask_inner] = binominal_train[0]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_adjusted_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,5 @@ def test_compute_cumulative_distribution(self):
self.assertAlmostEqual(cumulative_distribution[i], (i + 1) / 10, places=2)

for i in range(20):
for j in range(1, 10):
for j in range(1, 8):
self.assertAlmostEqual(superset_prediction[i, j], 0.5, places=2)

0 comments on commit 48fdf66

Please sign in to comment.