Skip to content

Commit 1c0bb41

Browse files
committed
PR 205 feedback fixes
1 parent 16bb824 commit 1c0bb41

5 files changed

Lines changed: 36 additions & 19 deletions

File tree

run_grid_search.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,15 @@ def parse_args() -> argparse.Namespace:
450450
"--recycling-steps",
451451
type=int,
452452
default=None,
453-
help="Number of recycling steps for model inference (if not specified, uses model default)",
453+
help="Number of recycling steps for model inference. If not specified, "
454+
"uses model default, which can be found in each model's wrapper.py file",
455+
)
456+
parser.add_argument(
457+
"--num-diffusion-steps",
458+
type=int,
459+
default=200,
460+
help="Number of diffusion steps for model inference. If not specified, "
461+
"uses model default, which can be found in each model's wrapper.py file",
454462
)
455463

456464
# Trajectory scaling arguments

src/sampleworks/models/boltz/wrapper.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def process_structure_for_boltz(
331331
out_dir: str | Path | None = None,
332332
num_workers: int = 8,
333333
ensemble_size: int = 1,
334-
recycling_steps: int = 3,
334+
recycling_steps: int | None = 3,
335335
) -> dict:
336336
"""Annotate an Atomworks structure with Boltz-specific configuration.
337337
@@ -346,14 +346,20 @@ def process_structure_for_boltz(
346346
Number of parallel workers for preprocessing.
347347
ensemble_size : int
348348
Number of samples to generate (batch dimension of x_init).
349-
recycling_steps : int
349+
recycling_steps : int | None
350350
Number of recycling steps to perform during featurization Pairformer pass.
351+
Will set to 3 if None.
351352
352353
Returns
353354
-------
354355
dict
355356
Structure dict with "_boltz_config" key added.
356357
"""
358+
# Other models define a default deeper in their code,
359+
# but Boltz requires an integer value, so fix it here.
360+
if recycling_steps is None:
361+
recycling_steps = 3
362+
357363
config = BoltzConfig(
358364
out_dir=out_dir or structure.get("metadata", {}).get("id", "boltz_output"),
359365
num_workers=num_workers,
@@ -603,7 +609,7 @@ def _setup_data_module(
603609

604610
processed_dir = out_dir / "processed"
605611
processed = BoltzProcessedInput(
606-
manifest=Manifest.load(processed_dir / "manifest.json"), # type: ignore (Boltz repo doesn't have the right type hints?)
612+
manifest=Manifest.load(processed_dir / "manifest.json"),
607613
targets_dir=processed_dir / "structures",
608614
msa_dir=processed_dir / "msa",
609615
constraints_dir=(processed_dir / "constraints")
@@ -784,26 +790,26 @@ def _pairformer_pass(
784790
if self.model.use_templates:
785791
if self.model.is_template_compiled:
786792
template_module = (
787-
self.model.template_module._orig_mod # type: ignore (compiled torch module has this attribute, type checker doesn't know)
793+
self.model.template_module._orig_mod
788794
)
789795
else:
790796
template_module = self.model.template_module
791797

792-
z = z + template_module(z, features, pair_mask, use_kernels=self.model.use_kernels) # type: ignore (Object will be callable here)
798+
z = z + template_module(z, features, pair_mask, use_kernels=self.model.use_kernels)
793799

794800
if self.model.is_msa_compiled:
795-
msa_module = self.model.msa_module._orig_mod # type: ignore (compiled torch module has this attribute, type checker doesn't know)
801+
msa_module = self.model.msa_module._orig_mod
796802
else:
797803
msa_module = self.model.msa_module
798804

799-
z = z + msa_module(z, s_inputs, features, use_kernels=self.model.use_kernels) # type: ignore (Object will be callable here)
805+
z = z + msa_module(z, s_inputs, features, use_kernels=self.model.use_kernels)
800806

801807
if self.model.is_pairformer_compiled:
802-
pairformer_module = self.model.pairformer_module._orig_mod # type: ignore (compiled torch module has this attribute, type checker doesn't know)
808+
pairformer_module = self.model.pairformer_module._orig_mod
803809
else:
804810
pairformer_module = self.model.pairformer_module
805811

806-
s, z = pairformer_module(s, z, mask=mask, pair_mask=pair_mask) # type: ignore (Object will be callable here)
812+
s, z = pairformer_module(s, z, mask=mask, pair_mask=pair_mask)
807813

808814
q, c, to_keys, atom_enc_bias, atom_dec_bias, token_trans_bias = (
809815
self.model.diffusion_conditioning(
@@ -1068,7 +1074,7 @@ def _setup_data_module(
10681074

10691075
processed_dir = out_dir / "processed"
10701076
processed = BoltzProcessedInput(
1071-
manifest=Manifest.load(processed_dir / "manifest.json"), # type: ignore (Boltz repo doesn't have the right type hints?)
1077+
manifest=Manifest.load(processed_dir / "manifest.json"),
10721078
targets_dir=processed_dir / "structures",
10731079
msa_dir=processed_dir / "msa",
10741080
constraints_dir=(processed_dir / "constraints")
@@ -1357,7 +1363,7 @@ def _pairformer_pass(
13571363
)
13581364

13591365
if self.model.is_pairformer_compiled:
1360-
pairformer_module = self.model.pairformer_module._orig_mod # type: ignore (compiled torch module has this attribute, type checker doesn't know)
1366+
pairformer_module = self.model.pairformer_module._orig_mod
13611367
else:
13621368
pairformer_module = self.model.pairformer_module
13631369

@@ -1367,7 +1373,7 @@ def _pairformer_pass(
13671373
mask=mask,
13681374
pair_mask=pair_mask,
13691375
use_kernels=self.model.use_kernels,
1370-
) # type: ignore (Object will be callable here)
1376+
)
13711377

13721378
return {
13731379
"s": s,

src/sampleworks/models/protenix/wrapper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -687,9 +687,9 @@ def step(
687687
s_inputs=s_inputs,
688688
s_trunk=s_trunk,
689689
z_trunk=z_trunk,
690-
pair_z=pair_z, # ty: ignore[invalid-argument-type]
691-
p_lm=p_lm, # ty: ignore[invalid-argument-type]
692-
c_l=c_l, # ty: ignore[invalid-argument-type]
690+
pair_z=pair_z,
691+
p_lm=p_lm,
692+
c_l=c_l,
693693
)
694694

695695
# TODO: is there a way to handle this more cleanly?

src/sampleworks/models/rf3/wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def featurize(self, structure: dict) -> GenerativeModelInput[RF3Conditioning]:
326326
) # since we're not batching, the loader returns a list of length 1
327327

328328
# (Hydra instantiation of pipeline means it is going to be hard to type check here)
329-
pipeline_output = self.inference_engine.pipeline(input_spec.to_pipeline_input()) # ty: ignore[call-non-callable]
329+
pipeline_output = self.inference_engine.pipeline(input_spec.to_pipeline_input())
330330
pipeline_output = trainer.fabric.to_device(pipeline_output)
331331

332332
features = trainer._assemble_network_inputs(pipeline_output)

src/sampleworks/utils/guidance_script_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from sampleworks.core.scalers.pure_guidance import PureGuidance
2828
from sampleworks.core.scalers.step_scalers import (
2929
DataSpaceDPSScaler,
30-
NoScalingScaler,
3130
NoiseSpaceDPSScaler,
3231
NoScalingScaler,
3332
)
@@ -232,7 +231,7 @@ def get_reward_function_and_structure(
232231
logger.debug(f"Loading structure from {structure_path}")
233232
safe_structure_path = resolve_mixed_hetatm_atom_altlocs(Path(structure_path))
234233
structure = parse(
235-
Path(safe_structure_path),
234+
safe_structure_path,
236235
hydrogen_policy="remove",
237236
add_missing_atoms=False,
238237
ccd_mirror_path=None,
@@ -428,14 +427,17 @@ def _run_guidance(
428427
is_boltz = "Boltz" in wrapper_class_name
429428

430429
# Annotate structure with model-specific configuration (including recycling_steps)
430+
# See https://github.com/diff-use/sampleworks/issues/192 for a plan to organize this better.
431431
recycling_steps = getattr(args, "recycling_steps", None)
432432
if "Protenix" in wrapper_class_name:
433433
from sampleworks.models.protenix.wrapper import annotate_structure_for_protenix
434+
434435
structure = annotate_structure_for_protenix(
435436
structure, ensemble_size=args.ensemble_size, recycling_steps=recycling_steps
436437
)
437438
elif "RF3" in wrapper_class_name:
438439
from sampleworks.models.rf3.wrapper import annotate_structure_for_rf3
440+
439441
structure = annotate_structure_for_rf3(
440442
structure,
441443
ensemble_size=args.ensemble_size,
@@ -446,6 +448,7 @@ def _run_guidance(
446448
)
447449
elif "Boltz" in wrapper_class_name:
448450
from sampleworks.models.boltz.wrapper import process_structure_for_boltz
451+
449452
structure = process_structure_for_boltz(
450453
structure, ensemble_size=args.ensemble_size, recycling_steps=recycling_steps
451454
)

0 commit comments

Comments
 (0)