11"""Community-based Boltz models for complex structure prediction with ligands/dna/rna."""
22
3- import re
4- import string
3+ import warnings
4+ from logging import warning
55from typing import Any
66
77from pydantic import BaseModel , Field , TypeAdapter , model_validator
1313from openprotein .protein import Protein
1414
1515from . import api
16+ from .complex import id_generator
1617from .future import FoldComplexResultFuture
1718from .models import FoldModel
1819
19- valid_id_pattern = re .compile (r"^[A-Z]{1,5}$|^\d{1,5}$" )
20-
21-
22- def is_valid_id (id_str : str ) -> bool :
23- """
24- Check if the id_str matches the valid pattern for IDs (1-5 uppercase or 1-5 digits).
25- """
26- if not id_str or len (id_str ) > 5 :
27- return False
28- return bool (valid_id_pattern .fullmatch (id_str ))
29-
30-
31- def id_generator (used_ids : list [str ] | None = None , max_alpha_len = 5 , max_numeric = 99999 ):
32- """
33- Yields new chain IDs, skipping any in 'used_ids'.
34- First A..Z, AA..ZZ, … up to max_alpha_len, then '1','2',… up to max_numeric.
35- """
36- used = set (tuple (used_ids or []))
37- letters = list (string .ascii_uppercase )
38-
39- # --- Alphabetic IDs ---
40- curr_len = 1
41- curr_indices = [0 ] * curr_len # start at 'A'
42-
43- def bump_indices ():
44- # lexicographically increment curr_indices; return False on overflow
45- for i in reversed (range (len (curr_indices ))):
46- if curr_indices [i ] < len (letters ) - 1 :
47- curr_indices [i ] += 1
48- for j in range (i + 1 , len (curr_indices )):
49- curr_indices [j ] = 0
50- return True
51- return False
52-
53- while curr_len <= max_alpha_len :
54- candidate = "" .join (letters [i ] for i in curr_indices )
55- if candidate not in used :
56- used .add (candidate )
57- yield candidate
58- # bump
59- if not bump_indices ():
60- curr_len += 1
61- if curr_len > max_alpha_len :
62- break
63- curr_indices = [0 ] * curr_len
64-
65- # --- Numeric IDs ---
66- num = 1
67- while num <= max_numeric :
68- candidate = str (num )
69- num += 1
70- if candidate not in used :
71- used .add (candidate )
72- yield candidate
73-
74- # exhausted
75- raise RuntimeError ("exhausted all possible IDs" )
76-
7720
7821class BoltzModel (FoldModel ):
7922 """
@@ -97,8 +40,8 @@ def fold(
9740 rnas : list [RNA ] | None = None ,
9841 ligands : list [Ligand ] | None = None ,
9942 diffusion_samples : int = 1 ,
100- recycling_steps : int = 3 ,
101- sampling_steps : int = 200 ,
43+ num_recycles : int = 3 ,
44+ num_steps : int = 200 ,
10245 step_scale : float = 1.638 ,
10346 use_potentials : bool = False ,
10447 constraints : list [dict ] | None = None ,
@@ -119,9 +62,9 @@ def fold(
11962 List of ligands to include in folded output.
12063 diffusion_samples: int
12164 Number of diffusion samples to use
122- recycling_steps : int
65+ num_recycles : int
12366 Number of recycling steps to use
124- sampling_steps : int
67+ num_steps : int
12568 Number of sampling steps to use
12669 step_scale : float
12770 Scaling factor for diffusion steps.
@@ -133,6 +76,17 @@ def fold(
13376 FoldComplexResultFuture
13477 Future for the folding complex result.
13578 """
79+ # migrate old parameter
80+ if (recycling_steps := kwargs .get ("recycling_steps" )) is not None :
81+ num_recycles = recycling_steps
82+ warnings .warn (
83+ "`recycling_steps` has been updated to `num_recycles`. The parameter will be auto-corrected for now but raise an exception in the future."
84+ )
85+ if (sampling_steps := kwargs .get ("sampling_steps" )) is not None :
86+ num_steps = sampling_steps
87+ warnings .warn (
88+ "`sampling_steps` has been updated to `num_steps`. The parameter will be auto-corrected for now but raise an exception in the future."
89+ )
13690 # validate constraints
13791 if constraints is not None :
13892 TypeAdapter (list [BoltzConstraint ]).validate_python (constraints )
@@ -247,8 +201,8 @@ def fold(
247201 model_id = self .model_id ,
248202 sequences = sequences ,
249203 diffusion_samples = diffusion_samples ,
250- recycling_steps = recycling_steps ,
251- sampling_steps = sampling_steps ,
204+ num_recycles = num_recycles ,
205+ num_steps = num_steps ,
252206 step_scale = step_scale ,
253207 constraints = constraints ,
254208 use_potentials = use_potentials ,
@@ -276,8 +230,8 @@ def fold(
276230 rnas : list [RNA ] | None = None ,
277231 ligands : list [Ligand ] | None = None ,
278232 diffusion_samples : int = 1 ,
279- recycling_steps : int = 3 ,
280- sampling_steps : int = 200 ,
233+ num_recycles : int = 3 ,
234+ num_steps : int = 200 ,
281235 step_scale : float = 1.638 ,
282236 use_potentials : bool = False ,
283237 constraints : list [dict ] | None = None ,
@@ -300,9 +254,9 @@ def fold(
300254 List of ligands to include in folded output.
301255 diffusion_samples: int
302256 Number of diffusion samples to use
303- recycling_steps : int
257+ num_recycles : int
304258 Number of recycling steps to use
305- sampling_steps : int
259+ num_steps : int
306260 Number of sampling steps to use
307261 step_scale : float
308262 Scaling factor for diffusion steps.
@@ -360,8 +314,8 @@ def fold(
360314 rnas = rnas ,
361315 ligands = ligands ,
362316 diffusion_samples = diffusion_samples ,
363- recycling_steps = recycling_steps ,
364- sampling_steps = sampling_steps ,
317+ num_recycles = num_recycles ,
318+ num_steps = num_steps ,
365319 step_scale = step_scale ,
366320 use_potentials = use_potentials ,
367321 constraints = constraints ,
@@ -385,8 +339,8 @@ def fold(
385339 rnas : list [RNA ] | None = None ,
386340 ligands : list [Ligand ] | None = None ,
387341 diffusion_samples : int = 1 ,
388- recycling_steps : int = 3 ,
389- sampling_steps : int = 200 ,
342+ num_recycles : int = 3 ,
343+ num_steps : int = 200 ,
390344 step_scale : float = 1.638 ,
391345 constraints : list [dict ] | None = None ,
392346 ) -> FoldComplexResultFuture :
@@ -405,9 +359,9 @@ def fold(
405359 List of ligands to include in folded output.
406360 diffusion_samples: int
407361 Number of diffusion samples to use
408- recycling_steps : int
362+ num_recycles : int
409363 Number of recycling steps to use
410- sampling_steps : int
364+ num_steps : int
411365 Number of sampling steps to use
412366 step_scale : float
413367 Scaling factor for diffusion steps.
@@ -426,8 +380,8 @@ def fold(
426380 rnas = rnas ,
427381 ligands = ligands ,
428382 diffusion_samples = diffusion_samples ,
429- recycling_steps = recycling_steps ,
430- sampling_steps = sampling_steps ,
383+ num_recycles = num_recycles ,
384+ num_steps = num_steps ,
431385 step_scale = step_scale ,
432386 use_potentials = True ,
433387 constraints = constraints ,
@@ -448,8 +402,8 @@ def fold(
448402 rnas : list [RNA ] | None = None ,
449403 ligands : list [Ligand ] | None = None ,
450404 diffusion_samples : int = 1 ,
451- recycling_steps : int = 3 ,
452- sampling_steps : int = 200 ,
405+ num_recycles : int = 3 ,
406+ num_steps : int = 200 ,
453407 step_scale : float = 1.638 ,
454408 use_potentials : bool = False ,
455409 constraints : list [dict ] | None = None ,
@@ -469,9 +423,9 @@ def fold(
469423 List of ligands to include in folded output.
470424 diffusion_samples: int
471425 Number of diffusion samples to use
472- recycling_steps : int
426+ num_recycles : int
473427 Number of recycling steps to use
474- sampling_steps : int
428+ num_steps : int
475429 Number of sampling steps to use
476430 step_scale : float
477431 Scaling factor for diffusion steps.
@@ -492,8 +446,8 @@ def fold(
492446 rnas = rnas ,
493447 ligands = ligands ,
494448 diffusion_samples = diffusion_samples ,
495- recycling_steps = recycling_steps ,
496- sampling_steps = sampling_steps ,
449+ num_recycles = num_recycles ,
450+ num_steps = num_steps ,
497451 step_scale = step_scale ,
498452 use_potentials = use_potentials ,
499453 constraints = constraints ,
0 commit comments