Skip to content

Commit ba29c9e

Browse files
committed
chore(types): fix easy ty warnings shared across boltz/protenix/rf3 envs
1 parent 9ca8b1b commit ba29c9e

7 files changed

Lines changed: 18 additions & 12 deletions

File tree

scripts/eval/classify_altloc_regions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ def main(args: argparse.Namespace) -> None:
415415
)
416416
)
417417

418-
out_df = pd.DataFrame(all_rows, columns=OUTPUT_COLUMNS)
418+
out_df = pd.DataFrame(all_rows, columns=pd.Index(OUTPUT_COLUMNS))
419419
args.output_file.parent.mkdir(parents=True, exist_ok=True)
420420
out_df.to_csv(args.output_file, index=False)
421421
logger.info(f"Wrote {len(out_df)} classified spans to {args.output_file}")

src/sampleworks/metrics/metric.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@ def instantiate_metric_manager(
2626
A dictionary where keys are metric names and values are
2727
Hydra configurations for the metrics.
2828
"""
29-
metrics = {}
29+
metrics: dict[str, Metric] = {}
3030
for name, cfg in metrics_cfg.items():
31+
if not isinstance(name, str):
32+
raise TypeError(f"metrics_cfg key must be a str, got {type(name).__name__}: {name!r}")
3133
metric = hydra.utils.instantiate(cfg)
3234
if not isinstance(metric, Metric):
3335
raise TypeError(f"{name} must be a Metric instance")

src/sampleworks/utils/guidance_script_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def save_losses(losses, output_dir):
164164

165165
def get_model_and_device(
166166
device_str: str,
167-
model_checkpoint_path: str,
167+
model_checkpoint_path: str | None,
168168
model_type: str,
169169
method: str | None = None,
170170
model: Any = None,

src/sampleworks/utils/mmseqs2.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,10 @@ def download(ID, path):
297297
update_M = False
298298
if M not in a3m_lines:
299299
a3m_lines[M] = []
300+
if M is None:
301+
# Well-formed A3M files always begin with a `>` header that sets M,
302+
# so this only triggers on malformed input. Skip rather than KeyError.
303+
continue
300304
a3m_lines[M].append(line)
301305

302306
# we run all the sequences together in one string, so this is really just a list of file dumps

tests/cli/test_guidance_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_model_specific_args_rf3_msa(self):
6666
"/data/msa.a3m",
6767
] + COMMON_ARGS
6868
config = GuidanceConfig.from_cli(argv)
69-
assert config.msa_path == "/data/msa.a3m"
69+
assert config.msa_path == "/data/msa.a3m" # ty: ignore[unresolved-attribute]
7070

7171
def test_guidance_specific_args_fk(self):
7272
argv = [

tests/rewards/test_real_space_density_reward.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -476,8 +476,8 @@ def test_vmap_output_shape(self, reward_function_1vme, test_coordinates_1vme, de
476476

477477
unique_combinations, inverse_indices = (
478478
reward_function_1vme.precompute_unique_combinations(
479-
elements_batch[0, 0], # ty: ignore[no-matching-overload, invalid-argument-type]
480-
b_factors_batch[0, 0], # ty: ignore[no-matching-overload, invalid-argument-type]
479+
elements_batch[0, 0],
480+
b_factors_batch[0, 0],
481481
)
482482
)
483483

@@ -496,7 +496,7 @@ def test_vmap_output_shape(self, reward_function_1vme, test_coordinates_1vme, de
496496
op=rf_partial,
497497
)
498498

499-
assert result.shape == torch.Size([num_particles]) # ty: ignore[unresolved-attribute]
499+
assert result.shape == torch.Size([num_particles])
500500

501501
def test_vmap_consistency(self, reward_function_1vme, test_coordinates_1vme, device):
502502
"""Test vmap results match sequential calls."""
@@ -520,8 +520,8 @@ def test_vmap_consistency(self, reward_function_1vme, test_coordinates_1vme, dev
520520
occupancies_batch = einx.rearrange("n -> p e n", occupancies, p=num_particles, e=1)
521521

522522
unique_combinations, inverse_indices = reward_function_1vme.precompute_unique_combinations(
523-
elements_batch[0, 0], # ty: ignore[no-matching-overload, invalid-argument-type]
524-
b_factors_batch[0, 0], # ty: ignore[no-matching-overload, invalid-argument-type]
523+
elements_batch[0, 0],
524+
b_factors_batch[0, 0],
525525
)
526526

527527
rf_partial = partial(
@@ -549,7 +549,7 @@ def test_vmap_consistency(self, reward_function_1vme, test_coordinates_1vme, dev
549549
)
550550
result_sequential.append(loss.item())
551551

552-
result_sequential = torch.tensor(result_sequential, device=result_vmap.device) # ty: ignore[unresolved-attribute]
552+
result_sequential = torch.tensor(result_sequential, device=result_vmap.device)
553553

554554
# GPU vmap and sequential loops accumulate floating-point reductions in
555555
# different orders, yielding abs diffs up to ~1.3e-4 and rel diffs up to

tests/utils/test_framework_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ def test_raises_on_incompatible_batch_sizes(self):
527527
def test_raises_on_scalar_input(self):
528528
scalar = np.float64(1.0)
529529
with pytest.raises(ValueError, match="ndim >= 1"):
530-
match_batch(scalar, target_batch_size=2) # type: ignore[no-matching-overload]
530+
match_batch(scalar, target_batch_size=2) # ty: ignore[no-matching-overload]
531531

532532
def test_1d_array(self):
533533
array = np.array([42.0])
@@ -549,4 +549,4 @@ def test_preserves_dtype(self):
549549

550550
def test_raises_on_unsupported_type(self):
551551
with pytest.raises(TypeError, match="unsupported array type"):
552-
match_batch([1, 2, 3], target_batch_size=2) # type: ignore[no-matching-overload]
552+
match_batch([1, 2, 3], target_batch_size=2) # ty: ignore[no-matching-overload]

0 commit comments

Comments
 (0)