Skip to content

Commit f2eb03a

Browse files
committed
Release v0.8.8
- fixes issue with FoldResultFuture - adds RosettaFold-3 (not incl API) - adds minifold (not incl API)
1 parent 1481b9c commit f2eb03a

File tree

10 files changed

+328
-100
lines changed

10 files changed

+328
-100
lines changed

openprotein/align/api.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -400,12 +400,6 @@ def prompt_post(
400400
"Either 'num_sequences' or 'num_residues' must be set, but not both."
401401
)
402402

403-
if num_sequences is not None and not (0 <= num_sequences < 100):
404-
raise InvalidParameterError("The 'num_sequences' must be between 0 and 100.")
405-
406-
if num_residues is not None and not (0 <= num_residues < 24577):
407-
raise InvalidParameterError("The 'num_residues' must be between 0 and 24577.")
408-
409403
if random_seed is None:
410404
random_seed = random.randrange(2**32)
411405

openprotein/embeddings/poet.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,11 @@ def generate(
295295
EmbeddingsGenerateFuture
296296
Future object representing the status and information about the generation job.
297297
"""
298-
prompt_id = prompt if isinstance(prompt, str) else prompt.id
298+
if prompt is not None:
299+
kwargs["prompt_id"] = prompt if isinstance(prompt, str) else prompt.id
300+
else:
301+
# NB: this is for handling PoET-2
302+
assert self.model_id != "poet"
299303
return EmbeddingsGenerateFuture.create(
300304
session=self.session,
301305
job=api.request_generate_post(
@@ -307,7 +311,6 @@ def generate(
307311
topp=topp,
308312
max_length=max_length,
309313
random_seed=seed,
310-
prompt_id=prompt_id,
311314
**kwargs,
312315
),
313316
)

openprotein/embeddings/poet2.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def single_site(
287287

288288
def generate(
289289
self,
290-
prompt: str | Prompt,
290+
prompt: str | Prompt | None,
291291
query: str | bytes | Protein | Query | None = None,
292292
use_query_structure_in_decoder: bool = True,
293293
num_samples: int = 100,
@@ -304,7 +304,7 @@ def generate(
304304
305305
Parameters
306306
----------
307-
prompt : str or Prompt
307+
prompt : str or Prompt or None, optional
308308
Prompt from an align workflow to condition PoET model.
309309
query : str or bytes or Protein or Query or None, optional
310310
Query to use with prompt.
@@ -351,7 +351,8 @@ def generate(
351351
f"equal to the number of prompts ({prompt.num_replicates})"
352352
)
353353
return super().generate(
354-
prompt=prompt,
354+
# NB: poet(-1) cannot use null prompt, so we don't change its .generate's type signature
355+
prompt=prompt, # type: ignore
355356
num_samples=num_samples,
356357
temperature=temperature,
357358
topk=topk,

openprotein/fold/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .schemas import FoldJob, FoldMetadata
88
from .models import FoldModel
99
from .esmfold import ESMFoldModel
10+
from .minifold import MiniFoldModel
1011
from .alphafold2 import AlphaFold2Model
1112
from .boltz import (
1213
Boltz1Model,
@@ -17,5 +18,6 @@
1718
BoltzConstraint,
1819
BoltzProperty,
1920
)
21+
from .rosettafold3 import RosettaFold3Model
2022
from .future import FoldResultFuture, FoldComplexResultFuture
2123
from .fold import FoldAPI

openprotein/fold/boltz.py

Lines changed: 38 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Community-based Boltz models for complex structure prediction with ligands/dna/rna."""
22

3-
import re
4-
import string
3+
import warnings
4+
from logging import warning
55
from typing import Any
66

77
from pydantic import BaseModel, Field, TypeAdapter, model_validator
@@ -13,67 +13,10 @@
1313
from openprotein.protein import Protein
1414

1515
from . import api
16+
from .complex import id_generator
1617
from .future import FoldComplexResultFuture
1718
from .models import FoldModel
1819

19-
valid_id_pattern = re.compile(r"^[A-Z]{1,5}$|^\d{1,5}$")
20-
21-
22-
def is_valid_id(id_str: str) -> bool:
23-
"""
24-
Check if the id_str matches the valid pattern for IDs (1-5 uppercase or 1-5 digits).
25-
"""
26-
if not id_str or len(id_str) > 5:
27-
return False
28-
return bool(valid_id_pattern.fullmatch(id_str))
29-
30-
31-
def id_generator(used_ids: list[str] | None = None, max_alpha_len=5, max_numeric=99999):
32-
"""
33-
Yields new chain IDs, skipping any in 'used_ids'.
34-
First A..Z, AA..ZZ, … up to max_alpha_len, then '1','2',… up to max_numeric.
35-
"""
36-
used = set(tuple(used_ids or []))
37-
letters = list(string.ascii_uppercase)
38-
39-
# --- Alphabetic IDs ---
40-
curr_len = 1
41-
curr_indices = [0] * curr_len # start at 'A'
42-
43-
def bump_indices():
44-
# lexicographically increment curr_indices; return False on overflow
45-
for i in reversed(range(len(curr_indices))):
46-
if curr_indices[i] < len(letters) - 1:
47-
curr_indices[i] += 1
48-
for j in range(i + 1, len(curr_indices)):
49-
curr_indices[j] = 0
50-
return True
51-
return False
52-
53-
while curr_len <= max_alpha_len:
54-
candidate = "".join(letters[i] for i in curr_indices)
55-
if candidate not in used:
56-
used.add(candidate)
57-
yield candidate
58-
# bump
59-
if not bump_indices():
60-
curr_len += 1
61-
if curr_len > max_alpha_len:
62-
break
63-
curr_indices = [0] * curr_len
64-
65-
# --- Numeric IDs ---
66-
num = 1
67-
while num <= max_numeric:
68-
candidate = str(num)
69-
num += 1
70-
if candidate not in used:
71-
used.add(candidate)
72-
yield candidate
73-
74-
# exhausted
75-
raise RuntimeError("exhausted all possible IDs")
76-
7720

7821
class BoltzModel(FoldModel):
7922
"""
@@ -97,8 +40,8 @@ def fold(
9740
rnas: list[RNA] | None = None,
9841
ligands: list[Ligand] | None = None,
9942
diffusion_samples: int = 1,
100-
recycling_steps: int = 3,
101-
sampling_steps: int = 200,
43+
num_recycles: int = 3,
44+
num_steps: int = 200,
10245
step_scale: float = 1.638,
10346
use_potentials: bool = False,
10447
constraints: list[dict] | None = None,
@@ -119,9 +62,9 @@ def fold(
11962
List of ligands to include in folded output.
12063
diffusion_samples: int
12164
Number of diffusion samples to use
122-
recycling_steps : int
65+
num_recycles : int
12366
Number of recycling steps to use
124-
sampling_steps : int
67+
num_steps : int
12568
Number of sampling steps to use
12669
step_scale : float
12770
Scaling factor for diffusion steps.
@@ -133,6 +76,17 @@ def fold(
13376
FoldComplexResultFuture
13477
Future for the folding complex result.
13578
"""
79+
# migrate old parameter
80+
if (recycling_steps := kwargs.get("recycling_steps")) is not None:
81+
num_recycles = recycling_steps
82+
warnings.warn(
83+
"`recycling_steps` has been updated to `num_recycles`. The parameter will be auto-corrected for now but raise an exception in the future."
84+
)
85+
if (sampling_steps := kwargs.get("sampling_steps")) is not None:
86+
num_steps = sampling_steps
87+
warnings.warn(
88+
"`sampling_steps` has been updated to `num_steps`. The parameter will be auto-corrected for now but raise an exception in the future."
89+
)
13690
# validate constraints
13791
if constraints is not None:
13892
TypeAdapter(list[BoltzConstraint]).validate_python(constraints)
@@ -247,8 +201,8 @@ def fold(
247201
model_id=self.model_id,
248202
sequences=sequences,
249203
diffusion_samples=diffusion_samples,
250-
recycling_steps=recycling_steps,
251-
sampling_steps=sampling_steps,
204+
num_recycles=num_recycles,
205+
num_steps=num_steps,
252206
step_scale=step_scale,
253207
constraints=constraints,
254208
use_potentials=use_potentials,
@@ -276,8 +230,8 @@ def fold(
276230
rnas: list[RNA] | None = None,
277231
ligands: list[Ligand] | None = None,
278232
diffusion_samples: int = 1,
279-
recycling_steps: int = 3,
280-
sampling_steps: int = 200,
233+
num_recycles: int = 3,
234+
num_steps: int = 200,
281235
step_scale: float = 1.638,
282236
use_potentials: bool = False,
283237
constraints: list[dict] | None = None,
@@ -300,9 +254,9 @@ def fold(
300254
List of ligands to include in folded output.
301255
diffusion_samples: int
302256
Number of diffusion samples to use
303-
recycling_steps : int
257+
num_recycles : int
304258
Number of recycling steps to use
305-
sampling_steps : int
259+
num_steps : int
306260
Number of sampling steps to use
307261
step_scale : float
308262
Scaling factor for diffusion steps.
@@ -360,8 +314,8 @@ def fold(
360314
rnas=rnas,
361315
ligands=ligands,
362316
diffusion_samples=diffusion_samples,
363-
recycling_steps=recycling_steps,
364-
sampling_steps=sampling_steps,
317+
num_recycles=num_recycles,
318+
num_steps=num_steps,
365319
step_scale=step_scale,
366320
use_potentials=use_potentials,
367321
constraints=constraints,
@@ -385,8 +339,8 @@ def fold(
385339
rnas: list[RNA] | None = None,
386340
ligands: list[Ligand] | None = None,
387341
diffusion_samples: int = 1,
388-
recycling_steps: int = 3,
389-
sampling_steps: int = 200,
342+
num_recycles: int = 3,
343+
num_steps: int = 200,
390344
step_scale: float = 1.638,
391345
constraints: list[dict] | None = None,
392346
) -> FoldComplexResultFuture:
@@ -405,9 +359,9 @@ def fold(
405359
List of ligands to include in folded output.
406360
diffusion_samples: int
407361
Number of diffusion samples to use
408-
recycling_steps : int
362+
num_recycles : int
409363
Number of recycling steps to use
410-
sampling_steps : int
364+
num_steps : int
411365
Number of sampling steps to use
412366
step_scale : float
413367
Scaling factor for diffusion steps.
@@ -426,8 +380,8 @@ def fold(
426380
rnas=rnas,
427381
ligands=ligands,
428382
diffusion_samples=diffusion_samples,
429-
recycling_steps=recycling_steps,
430-
sampling_steps=sampling_steps,
383+
num_recycles=num_recycles,
384+
num_steps=num_steps,
431385
step_scale=step_scale,
432386
use_potentials=True,
433387
constraints=constraints,
@@ -448,8 +402,8 @@ def fold(
448402
rnas: list[RNA] | None = None,
449403
ligands: list[Ligand] | None = None,
450404
diffusion_samples: int = 1,
451-
recycling_steps: int = 3,
452-
sampling_steps: int = 200,
405+
num_recycles: int = 3,
406+
num_steps: int = 200,
453407
step_scale: float = 1.638,
454408
use_potentials: bool = False,
455409
constraints: list[dict] | None = None,
@@ -469,9 +423,9 @@ def fold(
469423
List of ligands to include in folded output.
470424
diffusion_samples: int
471425
Number of diffusion samples to use
472-
recycling_steps : int
426+
num_recycles : int
473427
Number of recycling steps to use
474-
sampling_steps : int
428+
num_steps : int
475429
Number of sampling steps to use
476430
step_scale : float
477431
Scaling factor for diffusion steps.
@@ -492,8 +446,8 @@ def fold(
492446
rnas=rnas,
493447
ligands=ligands,
494448
diffusion_samples=diffusion_samples,
495-
recycling_steps=recycling_steps,
496-
sampling_steps=sampling_steps,
449+
num_recycles=num_recycles,
450+
num_steps=num_steps,
497451
step_scale=step_scale,
498452
use_potentials=use_potentials,
499453
constraints=constraints,

openprotein/fold/complex.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import re
2+
import string
3+
4+
valid_id_pattern = re.compile(r"^[A-Z]{1,5}$|^\d{1,5}$")
5+
6+
7+
def is_valid_id(id_str: str) -> bool:
8+
"""
9+
Check if the id_str matches the valid pattern for IDs (1-5 uppercase or 1-5 digits).
10+
"""
11+
if not id_str or len(id_str) > 5:
12+
return False
13+
return bool(valid_id_pattern.fullmatch(id_str))
14+
15+
16+
def id_generator(used_ids: list[str] | None = None, max_alpha_len=5, max_numeric=99999):
17+
"""
18+
Yields new chain IDs, skipping any in 'used_ids'.
19+
First A..Z, AA..ZZ, … up to max_alpha_len, then '1','2',… up to max_numeric.
20+
"""
21+
used = set(tuple(used_ids or []))
22+
letters = list(string.ascii_uppercase)
23+
24+
# --- Alphabetic IDs ---
25+
curr_len = 1
26+
curr_indices = [0] * curr_len # start at 'A'
27+
28+
def bump_indices():
29+
# lexicographically increment curr_indices; return False on overflow
30+
for i in reversed(range(len(curr_indices))):
31+
if curr_indices[i] < len(letters) - 1:
32+
curr_indices[i] += 1
33+
for j in range(i + 1, len(curr_indices)):
34+
curr_indices[j] = 0
35+
return True
36+
return False
37+
38+
while curr_len <= max_alpha_len:
39+
candidate = "".join(letters[i] for i in curr_indices)
40+
if candidate not in used:
41+
used.add(candidate)
42+
yield candidate
43+
# bump
44+
if not bump_indices():
45+
curr_len += 1
46+
if curr_len > max_alpha_len:
47+
break
48+
curr_indices = [0] * curr_len
49+
50+
# --- Numeric IDs ---
51+
num = 1
52+
while num <= max_numeric:
53+
candidate = str(num)
54+
num += 1
55+
if candidate not in used:
56+
used.add(candidate)
57+
yield candidate
58+
59+
# exhausted
60+
raise RuntimeError("exhausted all possible IDs")

0 commit comments

Comments
 (0)