Skip to content

Commit ec53661

Browse files
committed
PR 205 feedback fixes
1 parent 16bb824 commit ec53661

5 files changed

Lines changed: 41 additions & 24 deletions

File tree

run_grid_search.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,15 @@ def parse_args() -> argparse.Namespace:
450450
"--recycling-steps",
451451
type=int,
452452
default=None,
453-
help="Number of recycling steps for model inference (if not specified, uses model default)",
453+
help="Number of recycling steps for model inference. If not specified, "
454+
"uses model default, which can be found in each model's wrapper.py file",
455+
)
456+
parser.add_argument(
457+
"--num-diffusion-steps",
458+
type=int,
459+
default=200,
460+
help="Number of diffusion steps for model inference. If not specified, "
461+
"uses model default, which can be found in each model's wrapper.py file",
454462
)
455463

456464
# Trajectory scaling arguments
@@ -460,9 +468,6 @@ def parse_args() -> argparse.Namespace:
460468
parser.add_argument(
461469
"--ensemble-sizes", default="1 2 4 8", help="Space-separated ensemble sizes"
462470
)
463-
parser.add_argument(
464-
"--num-diffusion-steps", type=int, default=200, help="Number of diffusion steps"
465-
)
466471
parser.add_argument(
467472
"--gradient-weights",
468473
default="0.01 0.1 0.2",

src/sampleworks/models/boltz/wrapper.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def process_structure_for_boltz(
331331
out_dir: str | Path | None = None,
332332
num_workers: int = 8,
333333
ensemble_size: int = 1,
334-
recycling_steps: int = 3,
334+
recycling_steps: int | None = 3,
335335
) -> dict:
336336
"""Annotate an Atomworks structure with Boltz-specific configuration.
337337
@@ -346,14 +346,20 @@ def process_structure_for_boltz(
346346
Number of parallel workers for preprocessing.
347347
ensemble_size : int
348348
Number of samples to generate (batch dimension of x_init).
349-
recycling_steps : int
349+
recycling_steps : int | None
350350
Number of recycling steps to perform during featurization Pairformer pass.
351+
Will set to 3 if None.
351352
352353
Returns
353354
-------
354355
dict
355356
Structure dict with "_boltz_config" key added.
356357
"""
358+
# Other models define a default deeper in their code,
359+
# but Boltz requires an integer value, so fix it here.
360+
if recycling_steps is None:
361+
recycling_steps = 3
362+
357363
config = BoltzConfig(
358364
out_dir=out_dir or structure.get("metadata", {}).get("id", "boltz_output"),
359365
num_workers=num_workers,
@@ -603,7 +609,7 @@ def _setup_data_module(
603609

604610
processed_dir = out_dir / "processed"
605611
processed = BoltzProcessedInput(
606-
manifest=Manifest.load(processed_dir / "manifest.json"), # type: ignore (Boltz repo doesn't have the right type hints?)
612+
manifest=Manifest.load(processed_dir / "manifest.json"),
607613
targets_dir=processed_dir / "structures",
608614
msa_dir=processed_dir / "msa",
609615
constraints_dir=(processed_dir / "constraints")
@@ -783,27 +789,25 @@ def _pairformer_pass(
783789

784790
if self.model.use_templates:
785791
if self.model.is_template_compiled:
786-
template_module = (
787-
self.model.template_module._orig_mod # type: ignore (compiled torch module has this attribute, type checker doesn't know)
788-
)
792+
template_module = self.model.template_module._orig_mod
789793
else:
790794
template_module = self.model.template_module
791795

792-
z = z + template_module(z, features, pair_mask, use_kernels=self.model.use_kernels) # type: ignore (Object will be callable here)
796+
z = z + template_module(z, features, pair_mask, use_kernels=self.model.use_kernels)
793797

794798
if self.model.is_msa_compiled:
795-
msa_module = self.model.msa_module._orig_mod # type: ignore (compiled torch module has this attribute, type checker doesn't know)
799+
msa_module = self.model.msa_module._orig_mod
796800
else:
797801
msa_module = self.model.msa_module
798802

799-
z = z + msa_module(z, s_inputs, features, use_kernels=self.model.use_kernels) # type: ignore (Object will be callable here)
803+
z = z + msa_module(z, s_inputs, features, use_kernels=self.model.use_kernels)
800804

801805
if self.model.is_pairformer_compiled:
802-
pairformer_module = self.model.pairformer_module._orig_mod # type: ignore (compiled torch module has this attribute, type checker doesn't know)
806+
pairformer_module = self.model.pairformer_module._orig_mod
803807
else:
804808
pairformer_module = self.model.pairformer_module
805809

806-
s, z = pairformer_module(s, z, mask=mask, pair_mask=pair_mask) # type: ignore (Object will be callable here)
810+
s, z = pairformer_module(s, z, mask=mask, pair_mask=pair_mask)
807811

808812
q, c, to_keys, atom_enc_bias, atom_dec_bias, token_trans_bias = (
809813
self.model.diffusion_conditioning(
@@ -1068,7 +1072,7 @@ def _setup_data_module(
10681072

10691073
processed_dir = out_dir / "processed"
10701074
processed = BoltzProcessedInput(
1071-
manifest=Manifest.load(processed_dir / "manifest.json"), # type: ignore (Boltz repo doesn't have the right type hints?)
1075+
manifest=Manifest.load(processed_dir / "manifest.json"),
10721076
targets_dir=processed_dir / "structures",
10731077
msa_dir=processed_dir / "msa",
10741078
constraints_dir=(processed_dir / "constraints")
@@ -1357,7 +1361,7 @@ def _pairformer_pass(
13571361
)
13581362

13591363
if self.model.is_pairformer_compiled:
1360-
pairformer_module = self.model.pairformer_module._orig_mod # type: ignore (compiled torch module has this attribute, type checker doesn't know)
1364+
pairformer_module = self.model.pairformer_module._orig_mod
13611365
else:
13621366
pairformer_module = self.model.pairformer_module
13631367

@@ -1367,7 +1371,7 @@ def _pairformer_pass(
13671371
mask=mask,
13681372
pair_mask=pair_mask,
13691373
use_kernels=self.model.use_kernels,
1370-
) # type: ignore (Object will be callable here)
1374+
)
13711375

13721376
return {
13731377
"s": s,

src/sampleworks/models/protenix/wrapper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -687,9 +687,9 @@ def step(
687687
s_inputs=s_inputs,
688688
s_trunk=s_trunk,
689689
z_trunk=z_trunk,
690-
pair_z=pair_z, # ty: ignore[invalid-argument-type]
691-
p_lm=p_lm, # ty: ignore[invalid-argument-type]
692-
c_l=c_l, # ty: ignore[invalid-argument-type]
690+
pair_z=pair_z,
691+
p_lm=p_lm,
692+
c_l=c_l,
693693
)
694694

695695
# TODO: is there a way to handle this more cleanly?

src/sampleworks/models/rf3/wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def featurize(self, structure: dict) -> GenerativeModelInput[RF3Conditioning]:
326326
) # since we're not batching, the loader returns a list of length 1
327327

328328
# (Hydra instantiation of pipeline means it is going to be hard to type check here)
329-
pipeline_output = self.inference_engine.pipeline(input_spec.to_pipeline_input()) # ty: ignore[call-non-callable]
329+
pipeline_output = self.inference_engine.pipeline(input_spec.to_pipeline_input()) # type: ignore[call-non-callable]
330330
pipeline_output = trainer.fabric.to_device(pipeline_output)
331331

332332
features = trainer._assemble_network_inputs(pipeline_output)

src/sampleworks/utils/guidance_script_utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from sampleworks.core.scalers.pure_guidance import PureGuidance
2828
from sampleworks.core.scalers.step_scalers import (
2929
DataSpaceDPSScaler,
30-
NoScalingScaler,
3130
NoiseSpaceDPSScaler,
3231
NoScalingScaler,
3332
)
@@ -232,7 +231,7 @@ def get_reward_function_and_structure(
232231
logger.debug(f"Loading structure from {structure_path}")
233232
safe_structure_path = resolve_mixed_hetatm_atom_altlocs(Path(structure_path))
234233
structure = parse(
235-
Path(safe_structure_path),
234+
safe_structure_path,
236235
hydrogen_policy="remove",
237236
add_missing_atoms=False,
238237
ccd_mirror_path=None,
@@ -428,14 +427,22 @@ def _run_guidance(
428427
is_boltz = "Boltz" in wrapper_class_name
429428

430429
# Annotate structure with model-specific configuration (including recycling_steps)
430+
# See https://github.com/diff-use/sampleworks/issues/192 for a plan to organize this better.
431431
recycling_steps = getattr(args, "recycling_steps", None)
432+
if recycling_steps is not None and recycling_steps <= 0:
433+
raise ValueError("recycling_steps must be > 0")
434+
if args.num_diffusion_steps is not None and args.num_diffusion_steps <= 0:
435+
raise ValueError("num_diffusion_steps must be > 0")
436+
432437
if "Protenix" in wrapper_class_name:
433438
from sampleworks.models.protenix.wrapper import annotate_structure_for_protenix
439+
434440
structure = annotate_structure_for_protenix(
435441
structure, ensemble_size=args.ensemble_size, recycling_steps=recycling_steps
436442
)
437443
elif "RF3" in wrapper_class_name:
438444
from sampleworks.models.rf3.wrapper import annotate_structure_for_rf3
445+
439446
structure = annotate_structure_for_rf3(
440447
structure,
441448
ensemble_size=args.ensemble_size,
@@ -446,6 +453,7 @@ def _run_guidance(
446453
)
447454
elif "Boltz" in wrapper_class_name:
448455
from sampleworks.models.boltz.wrapper import process_structure_for_boltz
456+
449457
structure = process_structure_for_boltz(
450458
structure, ensemble_size=args.ensemble_size, recycling_steps=recycling_steps
451459
)

0 commit comments

Comments
 (0)