@@ -331,7 +331,7 @@ def process_structure_for_boltz(
331331 out_dir : str | Path | None = None ,
332332 num_workers : int = 8 ,
333333 ensemble_size : int = 1 ,
334- recycling_steps : int = 3 ,
334+ recycling_steps : int | None = 3 ,
335335) -> dict :
336336 """Annotate an Atomworks structure with Boltz-specific configuration.
337337
@@ -346,14 +346,20 @@ def process_structure_for_boltz(
346346 Number of parallel workers for preprocessing.
347347 ensemble_size : int
348348 Number of samples to generate (batch dimension of x_init).
349- recycling_steps : int
349+ recycling_steps : int | None
350350 Number of recycling steps to perform during featurization Pairformer pass.
351+ Will set to 3 if None.
351352
352353 Returns
353354 -------
354355 dict
355356 Structure dict with "_boltz_config" key added.
356357 """
358+ # Other models define a default deeper in their code,
359+ # but Boltz requires an integer value, so fix it here.
360+ if recycling_steps is None :
361+ recycling_steps = 3
362+
357363 config = BoltzConfig (
358364 out_dir = out_dir or structure .get ("metadata" , {}).get ("id" , "boltz_output" ),
359365 num_workers = num_workers ,
@@ -603,7 +609,7 @@ def _setup_data_module(
603609
604610 processed_dir = out_dir / "processed"
605611 processed = BoltzProcessedInput (
606- manifest = Manifest .load (processed_dir / "manifest.json" ), # type: ignore (Boltz repo doesn't have the right type hints?)
612+ manifest = Manifest .load (processed_dir / "manifest.json" ),
607613 targets_dir = processed_dir / "structures" ,
608614 msa_dir = processed_dir / "msa" ,
609615 constraints_dir = (processed_dir / "constraints" )
@@ -784,26 +790,26 @@ def _pairformer_pass(
784790 if self .model .use_templates :
785791 if self .model .is_template_compiled :
786792 template_module = (
787- self .model .template_module ._orig_mod # type: ignore (compiled torch module has this attribute, type checker doesn't know)
793+ self .model .template_module ._orig_mod
788794 )
789795 else :
790796 template_module = self .model .template_module
791797
792- z = z + template_module (z , features , pair_mask , use_kernels = self .model .use_kernels ) # type: ignore (Object will be callable here)
798+ z = z + template_module (z , features , pair_mask , use_kernels = self .model .use_kernels )
793799
794800 if self .model .is_msa_compiled :
795- msa_module = self .model .msa_module ._orig_mod # type: ignore (compiled torch module has this attribute, type checker doesn't know)
801+ msa_module = self .model .msa_module ._orig_mod
796802 else :
797803 msa_module = self .model .msa_module
798804
799- z = z + msa_module (z , s_inputs , features , use_kernels = self .model .use_kernels ) # type: ignore (Object will be callable here)
805+ z = z + msa_module (z , s_inputs , features , use_kernels = self .model .use_kernels )
800806
801807 if self .model .is_pairformer_compiled :
802- pairformer_module = self .model .pairformer_module ._orig_mod # type: ignore (compiled torch module has this attribute, type checker doesn't know)
808+ pairformer_module = self .model .pairformer_module ._orig_mod
803809 else :
804810 pairformer_module = self .model .pairformer_module
805811
806- s , z = pairformer_module (s , z , mask = mask , pair_mask = pair_mask ) # type: ignore (Object will be callable here)
812+ s , z = pairformer_module (s , z , mask = mask , pair_mask = pair_mask )
807813
808814 q , c , to_keys , atom_enc_bias , atom_dec_bias , token_trans_bias = (
809815 self .model .diffusion_conditioning (
@@ -1068,7 +1074,7 @@ def _setup_data_module(
10681074
10691075 processed_dir = out_dir / "processed"
10701076 processed = BoltzProcessedInput (
1071- manifest = Manifest .load (processed_dir / "manifest.json" ), # type: ignore (Boltz repo doesn't have the right type hints?)
1077+ manifest = Manifest .load (processed_dir / "manifest.json" ),
10721078 targets_dir = processed_dir / "structures" ,
10731079 msa_dir = processed_dir / "msa" ,
10741080 constraints_dir = (processed_dir / "constraints" )
@@ -1357,7 +1363,7 @@ def _pairformer_pass(
13571363 )
13581364
13591365 if self .model .is_pairformer_compiled :
1360- pairformer_module = self .model .pairformer_module ._orig_mod # type: ignore (compiled torch module has this attribute, type checker doesn't know)
1366+ pairformer_module = self .model .pairformer_module ._orig_mod
13611367 else :
13621368 pairformer_module = self .model .pairformer_module
13631369
@@ -1367,7 +1373,7 @@ def _pairformer_pass(
13671373 mask = mask ,
13681374 pair_mask = pair_mask ,
13691375 use_kernels = self .model .use_kernels ,
1370- ) # type: ignore (Object will be callable here)
1376+ )
13711377
13721378 return {
13731379 "s" : s ,
0 commit comments