@@ -331,7 +331,7 @@ def process_structure_for_boltz(
331331 out_dir : str | Path | None = None ,
332332 num_workers : int = 8 ,
333333 ensemble_size : int = 1 ,
334- recycling_steps : int = 3 ,
334+ recycling_steps : int | None = 3 ,
335335) -> dict :
336336 """Annotate an Atomworks structure with Boltz-specific configuration.
337337
@@ -346,14 +346,20 @@ def process_structure_for_boltz(
346346 Number of parallel workers for preprocessing.
347347 ensemble_size : int
348348 Number of samples to generate (batch dimension of x_init).
349- recycling_steps : int
349+ recycling_steps : int | None
350350 Number of recycling steps to perform during featurization Pairformer pass.
351+ Will set to 3 if None.
351352
352353 Returns
353354 -------
354355 dict
355356 Structure dict with "_boltz_config" key added.
356357 """
358+ # Other models define a default deeper in their code,
359+ # but Boltz requires an integer value, so fix it here.
360+ if recycling_steps is None :
361+ recycling_steps = 3
362+
357363 config = BoltzConfig (
358364 out_dir = out_dir or structure .get ("metadata" , {}).get ("id" , "boltz_output" ),
359365 num_workers = num_workers ,
@@ -603,7 +609,7 @@ def _setup_data_module(
603609
604610 processed_dir = out_dir / "processed"
605611 processed = BoltzProcessedInput (
606- manifest = Manifest .load (processed_dir / "manifest.json" ), # type: ignore (Boltz repo doesn't have the right type hints?)
612+ manifest = Manifest .load (processed_dir / "manifest.json" ),
607613 targets_dir = processed_dir / "structures" ,
608614 msa_dir = processed_dir / "msa" ,
609615 constraints_dir = (processed_dir / "constraints" )
@@ -783,27 +789,25 @@ def _pairformer_pass(
783789
784790 if self .model .use_templates :
785791 if self .model .is_template_compiled :
786- template_module = (
787- self .model .template_module ._orig_mod # type: ignore (compiled torch module has this attribute, type checker doesn't know)
788- )
792+ template_module = self .model .template_module ._orig_mod
789793 else :
790794 template_module = self .model .template_module
791795
792- z = z + template_module (z , features , pair_mask , use_kernels = self .model .use_kernels ) # type: ignore (Object will be callable here)
796+ z = z + template_module (z , features , pair_mask , use_kernels = self .model .use_kernels )
793797
794798 if self .model .is_msa_compiled :
795- msa_module = self .model .msa_module ._orig_mod # type: ignore (compiled torch module has this attribute, type checker doesn't know)
799+ msa_module = self .model .msa_module ._orig_mod
796800 else :
797801 msa_module = self .model .msa_module
798802
799- z = z + msa_module (z , s_inputs , features , use_kernels = self .model .use_kernels ) # type: ignore (Object will be callable here)
803+ z = z + msa_module (z , s_inputs , features , use_kernels = self .model .use_kernels )
800804
801805 if self .model .is_pairformer_compiled :
802- pairformer_module = self .model .pairformer_module ._orig_mod # type: ignore (compiled torch module has this attribute, type checker doesn't know)
806+ pairformer_module = self .model .pairformer_module ._orig_mod
803807 else :
804808 pairformer_module = self .model .pairformer_module
805809
806- s , z = pairformer_module (s , z , mask = mask , pair_mask = pair_mask ) # type: ignore (Object will be callable here)
810+ s , z = pairformer_module (s , z , mask = mask , pair_mask = pair_mask )
807811
808812 q , c , to_keys , atom_enc_bias , atom_dec_bias , token_trans_bias = (
809813 self .model .diffusion_conditioning (
@@ -1068,7 +1072,7 @@ def _setup_data_module(
10681072
10691073 processed_dir = out_dir / "processed"
10701074 processed = BoltzProcessedInput (
1071- manifest = Manifest .load (processed_dir / "manifest.json" ), # type: ignore (Boltz repo doesn't have the right type hints?)
1075+ manifest = Manifest .load (processed_dir / "manifest.json" ),
10721076 targets_dir = processed_dir / "structures" ,
10731077 msa_dir = processed_dir / "msa" ,
10741078 constraints_dir = (processed_dir / "constraints" )
@@ -1357,7 +1361,7 @@ def _pairformer_pass(
13571361 )
13581362
13591363 if self .model .is_pairformer_compiled :
1360- pairformer_module = self .model .pairformer_module ._orig_mod # type: ignore (compiled torch module has this attribute, type checker doesn't know)
1364+ pairformer_module = self .model .pairformer_module ._orig_mod
13611365 else :
13621366 pairformer_module = self .model .pairformer_module
13631367
@@ -1367,7 +1371,7 @@ def _pairformer_pass(
13671371 mask = mask ,
13681372 pair_mask = pair_mask ,
13691373 use_kernels = self .model .use_kernels ,
1370- ) # type: ignore (Object will be callable here)
1374+ )
13711375
13721376 return {
13731377 "s" : s ,
0 commit comments