@@ -476,8 +476,8 @@ def test_vmap_output_shape(self, reward_function_1vme, test_coordinates_1vme, de
476476
477477 unique_combinations , inverse_indices = (
478478 reward_function_1vme .precompute_unique_combinations (
479- elements_batch [0 , 0 ], # ty: ignore[no-matching-overload, invalid-argument-type]
480- b_factors_batch [0 , 0 ], # ty: ignore[no-matching-overload, invalid-argument-type]
479+ elements_batch [0 , 0 ],
480+ b_factors_batch [0 , 0 ],
481481 )
482482 )
483483
@@ -496,7 +496,7 @@ def test_vmap_output_shape(self, reward_function_1vme, test_coordinates_1vme, de
496496 op = rf_partial ,
497497 )
498498
499- assert result .shape == torch .Size ([num_particles ]) # ty: ignore[unresolved-attribute]
499+ assert result .shape == torch .Size ([num_particles ])
500500
501501 def test_vmap_consistency (self , reward_function_1vme , test_coordinates_1vme , device ):
502502 """Test vmap results match sequential calls."""
@@ -520,8 +520,8 @@ def test_vmap_consistency(self, reward_function_1vme, test_coordinates_1vme, dev
520520 occupancies_batch = einx .rearrange ("n -> p e n" , occupancies , p = num_particles , e = 1 )
521521
522522 unique_combinations , inverse_indices = reward_function_1vme .precompute_unique_combinations (
523- elements_batch [0 , 0 ], # ty: ignore[no-matching-overload, invalid-argument-type]
524- b_factors_batch [0 , 0 ], # ty: ignore[no-matching-overload, invalid-argument-type]
523+ elements_batch [0 , 0 ],
524+ b_factors_batch [0 , 0 ],
525525 )
526526
527527 rf_partial = partial (
@@ -549,7 +549,7 @@ def test_vmap_consistency(self, reward_function_1vme, test_coordinates_1vme, dev
549549 )
550550 result_sequential .append (loss .item ())
551551
552- result_sequential = torch .tensor (result_sequential , device = result_vmap .device ) # ty: ignore[unresolved-attribute]
552+ result_sequential = torch .tensor (result_sequential , device = result_vmap .device )
553553
554554 # GPU vmap and sequential loops accumulate floating-point reductions in
555555 # different orders, yielding abs diffs up to ~1.3e-4 and rel diffs up to
0 commit comments