diff --git a/CLAUDE.md b/CLAUDE.md index 260ee35..663c9f0 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -8,7 +8,7 @@ Each model type gets a typed request/response contract (Pydantic). Batching, cac PyPI: `pip install sheaf-serve` -## Current state: v0.10.0 shipped (Docker base image at `ghcr.io/korbonits/sheaf-serve` + KubeRay `RayService` example + `sheaf.build_app(spec)` public API) +## Current state: v0.11.0 shipped (ESMC + ESMFold2 backends, new `PROTEIN_LANGUAGE` + `STRUCTURE` model categories, `[protein]` extra, Biohub 2026-05-27 release) Per-version ship notes live in git history and release tags. This doc tracks what exists *now* and the non-obvious design choices behind it. For feature-level changelog, see `git log`. @@ -194,6 +194,19 @@ tests/ # test__backend.py (mocked) + test_smoke_*.py (g - **Request-time adapter validation returns 422, not 500** — when a request specifies `adapters=["foo"]` against a deployment with no `lora` configured, or specifies an adapter name not in `spec.lora.adapters`, both `_SheafDeployment.predict` and the Modal predict handler raise `HTTPException(422, ...)`. This is a client-error contract: the request is well-formed JSON but references unknown server state, so 422 (Unprocessable Entity) is the right code. `resolve_active_adapters` raising `ValueError` deeper in the stack would otherwise become a 500. - **Diffusers `load_lora_weights` + `set_adapters` API** — both FLUX and SDXL backends call `pipeline.load_lora_weights(path_or_repo, adapter_name=name, [weight_name=file])` per adapter at load time, and `pipeline.set_adapters(names, adapter_weights=weights)` per sub-batch. Diffusers ≥0.27 + PEFT ≥0.7 handles all of the actual LoRA composition (named adapter slots, weight scaling, adapter merging) — sheaf's job is just to thread the per-deployment registry and the per-request selection through to those two calls. +### Protein models (v0.11 — Biohub ESMC + ESMFold2) + +- **`MOLECULAR` and `PROTEIN_LANGUAGE` are separate model categories** — `MOLECULAR` (ESM-3, `MolecularResponse.embeddings: list[list[float]]`) returns one pooled vector per sequence; `PROTEIN_LANGUAGE` (ESMC, `ProteinLanguageResponse.logits/embeddings: list[list[list[float]]]`) returns ragged per-token tensors sliced to `seq_lens[i]`. Unifying them would force every caller to branch on `model_name` to interpret the shape — defeats the typed-contract premise. Documented in ADR-0001. +- **`STRUCTURE` is the first non-tensor output category** — `StructureResponse.structure: str` is a PDB or mmCIF text block, not numerical data. Caching is fine (in-process LRU, SHA-256 over the request) but the cached payload can be 40+ KB per fold; future structure backends (Boltz-1, Chai-1) can reuse the contract. +- **ESMC `MaskedLMOutput` has no `last_hidden_state`** — `transformers.AutoModelForMaskedLM` returns `MaskedLMOutput`, which exposes `.logits` + `.hidden_states` (when requested) but **not** `.last_hidden_state` (that's on `BaseModelOutput`). `ESMCBackend._run` forces `output_hidden_states=True` whenever `return_embeddings` is True and reads uniformly from `hidden_states[-1]`. The bug was a silent AttributeError lurking past mocked tests until the H100 smoke caught it; the test mock now mirrors the real `MaskedLMOutput` shape (no `last_hidden_state` attr). +- **ESMFold2 pLDDT is on `[0, 1]`, not `[0, 100]`** — verified empirically on Modal H100 (2026-05-27). Sheaf passes through faithfully without scaling (consistent with "validate at the boundary, don't transform in backends"); callers who want the conventional AlphaFold / ESMFold-v1 scale multiply by 100 themselves. `StructureResponse.plddt` docstring + ADR-0001 both document this. +- **ESMFold2 is single-sample-per-call upstream** — `ESMFold2InputBuilder().fold()` runs one structure at a time; `batch_predict` runs requests sequentially with no true batched forward. Per-request compute varies hugely with sequence length × `num_loops` × `num_samples`. Operators size Ray Serve replica count to expected concurrency rather than relying on intra-replica batching. +- **`[protein]` and `[molecular]` extras are mutually exclusive** — both ship a package named `esm` (Biohub's 2026 release vs the pre-2026 EvolutionaryScale PyPI 3.x). Declared in `[tool.uv].conflicts`; sheaf cannot serve ESM-3 and ESMC in the same process. Users who need both run them in separate Sheaf deployments. +- **`modal_server.py` has its own `AnyRequest` union** — deliberately separate from `sheaf.api.union` so Modal containers don't pull Ray as a transitive dep. Adding a new model type means updating **both** unions, plus the registry imports inside `_build_asgi_app`. The v0.11 PR initially missed the modal-side update for `ProteinLanguageRequest` + `StructureRequest` — protein requests would have 422'd on Modal until the follow-up fix. +- **`esm` git pin tracks Modal's reference example** — pinned to commit `81b3646c9429ea8458918415ad6a46178cb59833` (long SHA), matching `modal-labs/modal-examples 06_gpu_and_ml/protein-folding/esmfold2.py`. This is the revision verified via `examples/quickstart_protein_modal.py`; bump in lockstep with Modal's example when their next pin lands. +- **ESMFold2 `_to_float_list` / `_maybe_2d_list` helpers** — coerce upstream tensor-or-list outputs to plain `list[float]` / `list[list[float]]` for JSON serialisation. The pattern checks for `.cpu()` (torch tensor) and `.item()` (torch scalar) before falling back to `float(value)`, so test stubs can pass raw lists/floats without importing torch. +- **Forge / Biohub-Platform variants raise `NotImplementedError` at `load()`** — `esmc-300m-2024-12`, `esmc-600m-2024-12`, `esmfold2-fast-2026-05` require Biohub's HTTP-client SDK (`esm.sdk.esmc_client`, `SequenceStructureForgeInferenceClient`) with an API token. Sheaf's v0.11 backends explicitly reject these IDs at load with a pointer to ADR-0001; wiring up the Forge HTTP client is a deliberate future PR (separate code path, no local weights). + ## Adding a new backend 1. Add typed request/response to `src/sheaf/api/.py` diff --git a/README.md b/README.md index 36d76ea..938dd6a 100644 --- a/README.md +++ b/README.md @@ -29,13 +29,16 @@ Each model type gets a typed request/response contract. Batching, caching, and s > curl https://korbonits--sheaf-demo-modalserver---init----locals---serve.modal.run/chronos/health > ``` -> **Requires Python 3.11+.** macOS's system `python3` is usually 3.10 — bootstrap a 3.11 venv first via [`uv`](https://docs.astral.sh/uv/) (`uv venv --python 3.11 .venv && source .venv/bin/activate`) or `pyenv`. The `[molecular]` extra (ESM-3) additionally requires Python 3.12+. +> **Requires Python 3.11+.** macOS's system `python3` is usually 3.10 — bootstrap a 3.11 venv first via [`uv`](https://docs.astral.sh/uv/) (`uv venv --python 3.11 .venv && source .venv/bin/activate`) or `pyenv`. The `[molecular]` and `[protein]` extras (ESM-3 / ESMC / ESMFold2) additionally require Python 3.12+, and the two are mutually exclusive (they share the `esm` import name from different packages). ```bash pip install sheaf-serve # core only pip install "sheaf-serve[time-series]" # + Chronos2 / TimesFM / Moirai pip install "sheaf-serve[tabular]" # + TabPFN pip install "sheaf-serve[molecular]" # + ESM-3 (Python 3.12+) +pip install "sheaf-serve[protein]" # + ESMC / ESMFold2 deps (Python 3.12+) +# then also (no PyPI release yet — pinned commit per upstream README): +pip install "esm@git+https://github.com/Biohub/esm.git@81b3646c9429ea8458918415ad6a46178cb59833" pip install "sheaf-serve[genomics]" # + Nucleotide Transformer pip install "sheaf-serve[small-molecule]" # + MolFormer pip install "sheaf-serve[materials]" # + MACE-MP @@ -179,6 +182,18 @@ See [`examples/`](examples/) for time series comparison, tabular, audio, vision, --- +## Protein models + +Sheaf serves three protein foundation models, each via its own typed contract: + +- **ESM-3** (`api/molecular.py`, backend `esm3`) — per-sequence pooled embeddings (mean / cls). Use for sequence-level similarity, clustering, and downstream featurization. `[molecular]` extra (Python 3.12+). +- **ESMC** (`api/protein_language.py`, backend `esmc`) — per-token logits + optional per-token embeddings from Biohub's 2026-05-27 release. Use when you need masked-LM logits, per-residue representations, or all-layer hidden states. Default model: `Biohub/ESMC-6B`. `[protein]` extra (Python 3.12+); 300M / 600M variants are Forge API-only and currently raise `NotImplementedError`. +- **ESMFold2** (`api/structure.py`, backend `esmfold2`) — protein structure prediction with inference-time scaling. Exposes `num_loops`, `num_sampling_steps`, `num_samples`, `seed` as first-class request fields; returns PDB / mmCIF + pLDDT + pTM/ipTM + optional PAE. Default model: `biohub/ESMFold2`. `[protein]` extra (Python 3.12+). + +`[molecular]` (ESM-3) and `[protein]` (ESMC + ESMFold2) share the `esm` import name from different upstream packages — install one **or** the other in a given environment. See [`docs/adr/0001-esmc-esmfold2-integration.md`](docs/adr/0001-esmc-esmfold2-integration.md) for the rationale. + +Biohub release announcement: · preprint: . + ## Supported model types | Type | Status | Backends | @@ -193,6 +208,8 @@ See [`examples/`](examples/) for time series comparison, tabular, audio, vision, | Depth estimation | ✅ v0.3 | Depth Anything v2 | | Object detection | ✅ v0.3 | DETR / RT-DETR | | Protein / molecular | ✅ v0.3 | ESM-3 (Python 3.12+) | +| Protein language modeling | ✅ v0.11 | ESMC 6B (Biohub) | +| Protein structure prediction | ✅ v0.11 | ESMFold2 (Biohub) — inference-time scaling | | Genomics | ✅ v0.3 | Nucleotide Transformer | | Small molecule | ✅ v0.3 | MolFormer-XL | | Materials science | ✅ v0.3 | MACE-MP-0 | @@ -309,6 +326,17 @@ Today sheaf ships three deployment paths: `ModelServer` (a local Ray cluster you - [ ] `examples/k8s/` with a `RayService` manifest — KubeRay's canonical Ray-on-K8s shape — and a short `README.md` covering prereqs (KubeRay operator installed), `kubectl apply`, and a port-forward smoke test. - [ ] GitHub Actions workflow that builds + pushes the Dockerfile to `ghcr.io/korbonits/sheaf-serve:vX.Y.Z` on `v*` tag push, mirroring the PyPI publish flow. +**v0.11 — Biohub protein-biology release integration** + +Biohub's "world model of protein biology" landed 2026-05-27 under MIT. Sheaf integrates the two model artifacts as first-class typed contracts; ESM Atlas (dataset) is out of scope. See [`docs/adr/0001-esmc-esmfold2-integration.md`](docs/adr/0001-esmc-esmfold2-integration.md). + +- [x] `ESMCBackend` — per-token logits + per-token embeddings via `transformers.AutoModelForMaskedLM`, default `Biohub/ESMC-6B`. +- [x] `ESMFold2Backend` — protein structure prediction with `num_loops` / `num_sampling_steps` / `num_samples` / `seed` as first-class request fields, returning PDB / mmCIF + pLDDT + pTM/ipTM + optional PAE. +- [x] New `STRUCTURE` model category — first non-tensor output category (structure file as text). +- [x] `[protein]` install extra; `esm` from `git+https://github.com/Biohub/esm.git@81b3646c9429ea8458918415ad6a46178cb59833` documented (no PyPI release yet). +- [x] End-to-end GPU smoke — `examples/quickstart_protein_modal.py` runs `ESMFold2Backend` on H100 via Modal (~70s cold start to a persistent volume, sub-second per fold). 53-residue target → 43,088-char mmCIF, pTM=0.2465. +- [ ] Forge / Biohub-Platform HTTP-client variants for the ESMC 300M / 600M / ESMFold2-fast API-only models. + --- ## Architecture diff --git a/docs/adr/0001-esmc-esmfold2-integration.md b/docs/adr/0001-esmc-esmfold2-integration.md new file mode 100644 index 0000000..03b802f --- /dev/null +++ b/docs/adr/0001-esmc-esmfold2-integration.md @@ -0,0 +1,268 @@ +# ADR 0001 — ESMC + ESMFold2 integration, and the `structure` model category + +**Status:** Accepted (draft PR) +**Date:** 2026-05-27 +**Authors:** Sheaf maintainers + +## Context + +On 2026-05-27 Chan Zuckerberg Biohub released "a world model of protein +biology" — three artifacts: **ESMC** (protein language model), **ESMFold2** +(structure prediction model built on ESMC 6B), and **ESM Atlas** (a dataset of +6.8 B sequences / 1.1 B predicted structures). This ADR records the verification +work done before integrating ESMC and ESMFold2 as first-class Sheaf model +types, the design decisions taken in `sheaf.api.protein_language` / +`sheaf.api.structure`, and why ESMFold2 forces us to introduce a new +top-level model category (`STRUCTURE`). + +ESM Atlas is a dataset, not a model, and is out of scope for this PR. + +## Upstream verification (the "before any code" step) + +### Repo and license + +- `github.com/evolutionaryscale/esm` issues HTTP 301 → `github.com/Biohub/esm` + (verified via `curl -sSI`; the repo has been re-homed under the Biohub + GitHub org as of the release). +- `LICENSE.md` at `Biohub/esm@main` is a standard MIT license, Copyright + 2026 Chan Zuckerberg Biohub, Inc. (fetched verbatim via + `raw.githubusercontent.com`). No additional clauses or carve-outs. +- The README's "Licenses" section says: "These models are available under + the [MIT license](https://github.com/Biohub/esm/blob/main/LICENSE.md)." + +### Python package + +- Distribution: `pip install esm@git+https://github.com/Biohub/esm.git@81b3646c9429ea8458918415ad6a46178cb59833`. + A PyPI release is described as "coming soon" but is **not** yet published as + of this PR. We pin to commit `81b3646c9429ea8458918415ad6a46178cb59833` to stay reproducible; we will switch + to a PyPI version constraint when one ships. (Originally pinned to `c94ed8d`, + the SHA in the upstream README on 2026-05-27; bumped to `81b3646c…` after the + Modal H100 smoke verified that revision end-to-end and to match the pin in + Modal's official `modal-examples/06_gpu_and_ml/protein-folding/esmfold2.py`.) +- Naming conflict: the new `esm` package shares the import name with the + pre-2026 `esm` package (PyPI 3.x) used by our existing `ESM3Backend` + (`sheaf.backends.esm3`, extra `[molecular]`). Both cannot be installed in + the same environment under the same import name. We treat this as a + packaging constraint, declared via `[tool.uv]` `conflicts`. See + "Compatibility with existing `[molecular]` extra" below. + +### Model artifacts (HuggingFace) + +| Artifact | HF repo ID | Access | Use via | +|--------------------------|-----------------------|-------------------|----------------------------------------| +| ESMC 6B | `Biohub/ESMC-6B` | Weight download | `transformers.AutoModelForMaskedLM` | +| ESMFold2 | `biohub/ESMFold2` | Weight download | `transformers.models.esmfold2.modeling_esmfold2.ESMFold2Model` + `esm.models.esmfold2.ESMFold2InputBuilder` | +| ESMC 300M / 600M | `esmc-{300m,600m}-2024-12` | **Forge API-only** | `esm.sdk.esmc_client(...)` with a Biohub Platform token | +| ESMFold2 fast | `esmfold2-fast-2026-05` | **Forge API-only** | `esm.sdk.forge.SequenceStructureForgeInferenceClient` | + +Note the case inconsistency in the README itself (`Biohub/ESMC-6B` vs +`biohub/ESMFold2`); we mirror the upstream strings verbatim rather than +normalising. + +We could **not** independently fetch the HF model card pages from this +build environment (`huggingface.co` is not in the outbound allowlist: +`x-deny-reason: host_not_allowed`). We are relying on the README in the +Biohub/esm repository — which is fetchable — for the canonical +`from_pretrained` strings and license restatement. The README explicitly +restates MIT, so we treat this as verified. + +### ESMFold2 inference-time scaling parameters + +Per the README's local-inference example, `ESMFold2InputBuilder().fold(...)` +exposes four scaling parameters as positional kwargs: + +```python +result = ESMFold2InputBuilder().fold( + model, spi, + num_loops=3, # depth of the looped-transformer recurrence + num_sampling_steps=50, # diffusion steps per sample + num_diffusion_samples=1, + seed=0, +) +``` + +Result fields used in this PR: `result.plddt` (per-residue confidence), +`result.ptm`, `result.iptm`, `result.complex.to_mmcif()` / +`result.complex.to_pdb()`. PAE matrix and per-sample ranking scores are +documented as outputs in the paper but are not shown in the README sample; +we leave the `pae` and `sample_scores` response fields optional and populate +them only when present on the result object. + +**pLDDT scale**: empirically (verified by a Modal H100 smoke against the +default `biohub/ESMFold2` weights, 2026-05-27), ESMFold2 returns `plddt` as +a `torch.Tensor` of fractional values on `[0, 1]` — *not* the conventional +AlphaFold / ESMFold-v1 `[0, 100]` scale. We pass through faithfully (no +backend-side scaling) and document the scale on `StructureResponse.plddt` +so callers can multiply by 100 themselves if they want the conventional +values. Faithful pass-through is consistent with Sheaf's general "validate +at the boundary, don't transform inside backends" convention. + +## Decision + +### 1. Two new model categories + +We add two new `ModelType` enum values: + +- `PROTEIN_LANGUAGE = "protein_language"` — for ESMC and any future protein + language model that returns per-token logits / per-token embeddings (not + pooled to one vector per sequence). +- `STRUCTURE = "structure"` — for protein structure prediction. This is a + new top-level model category for Sheaf. Structure prediction does not + fit any existing category: it produces a 3D structure (atom coordinates), + not an embedding, classification, generation, or forecast. It also has + fundamentally different batching and caching properties (see below). + +We deliberately keep `MOLECULAR` (per-sequence pooled embeddings, ESM-3) +distinct from `PROTEIN_LANGUAGE` (per-token logits + embeddings, ESMC). +The response shapes are incompatible (`(N, D)` vs `(N, L, V)`), and trying +to unify them would force every caller to handle both shapes. + +### 2. Two new API contracts + +- `sheaf.api.protein_language.ProteinLanguageRequest / Response` — ESMC + contract. Per-sequence ragged outputs: per-token logits and optional + per-token embeddings, with `seq_lens` so callers can slice the padded + output back to per-sequence length. +- `sheaf.api.structure.StructureRequest / Response` — ESMFold2 contract. + Multi-chain input (`list[ChainInput]`), inference-time scaling + parameters as first-class fields (`num_loops`, `num_sampling_steps`, + `num_samples`, `seed`), and structure output in PDB or mmCIF. + +### 3. Two new backends + +- `sheaf.backends.esmc.ESMCBackend` (registered as `"esmc"`). +- `sheaf.backends.esmfold2.ESMFold2Backend` (registered as `"esmfold2"`). + +Both follow the existing CLIP / DINOv2 / ESM-3 / MolFormer convention: +lazy imports inside `load()`, `_tokenizer` / `_model` / `_Image`-style +instance attributes stored at `load()` for test injectability, no +heavyweight imports at module level. + +### 4. New `[protein]` extra + +`pyproject.toml` gets a new optional-dependency group: + +```toml +protein = [ + "esm @ git+https://github.com/Biohub/esm.git@81b3646c9429ea8458918415ad6a46178cb59833 ; python_full_version >= '3.12'", + "transformers>=4.40.0", + "torch>=2.0.0", +] +``` + +Pinned to Python 3.12+ to match the upstream package's `requires-python`. +Conflicts with the existing `[molecular]` extra (same `esm` import name, +different versions) — declared in `[tool.uv]` `conflicts`. Users who need +both ESM-3 and ESMC will need to run them in separate Sheaf deployments. + +We do **not** remove or migrate the existing `[molecular]` extra. ESM-3 +remains available; the two extras are mutually exclusive (uv enforces this). +When the new `esm` package ships to PyPI and we have confirmed it supports +ESM-3 inference, a follow-up PR can consolidate the extras and drop the +conflict declaration. + +### 5. Forge / API-only variants are out of scope for this PR + +ESMC 300M / 600M and ESMFold2-fast require a Biohub Platform API token +and the `esm.sdk.esmc_client` / `SequenceStructureForgeInferenceClient` +classes. They live behind a different code path (HTTP client, no local +weights). Adding them is a strictly additive change once a user explicitly +asks for it; for v0.11 we ship only the weight-downloadable variants: + +| Sheaf backend | model_name | requires_forge | +|---------------|---------------------------|----------------| +| `esmc` | `Biohub/ESMC-6B` (default)| no | +| `esmfold2` | `biohub/ESMFold2` | no | + +The `ESMCBackend` constructor accepts an arbitrary `model_name` and a +`requires_forge: bool` flag plumbed through to the load path. When +`requires_forge=True` we raise a `NotImplementedError` immediately with a +pointer to this ADR — better than a silent fall-through that 404s on HF +download. Wiring up the Forge HTTP client is a deliberate future PR. + +### 6. Batching, caching, streaming + +- **ESMC batching**: tokenize the full `request.sequences` list with + `padding=True` and run a single forward pass — same as `MolFormerBackend`. + Variable-length sequences are handled via the attention mask. Sheaf's + `BatchPolicy.bucket_by` can group requests by sequence length if a + caller sets `bucket_by="max_len"` on the deployment; we expose + `max_len` as a derived field on the response so length-bucketing + works out of the box. +- **ESMFold2 batching**: per-request compute varies hugely with sequence + length and `num_samples`. The upstream `ESMFold2InputBuilder().fold()` + is single-sample-per-call. We do not implement a true batched forward; + `batch_predict` runs requests sequentially. Sheaf operators who care + about throughput should size the Ray Serve replica count to handle + expected concurrency (one in-flight prediction per replica). +- **Caching**: ESMC plugs into the existing `ResponseCache` via + `CacheConfig(enabled=True)` with no changes — the existing key + derivation (`request.model_dump(mode="json", exclude={"request_id"})` + → SHA-256) already includes the sequence string and the + `return_logits` / `return_embeddings` flags. For ESMFold2 we + recommend `CacheConfig(exclude_fields=["seed"])` only if callers want + same-sequence cache hits across different random seeds — by default + seeds participate in the key, matching diffusion-model behaviour + established in v0.7. +- **Streaming**: neither backend supports the `stream_predict` path in + v0.11. ESMFold2's diffusion loop could plausibly emit per-step events + in a future PR (mirroring `FluxBackend.stream_predict`), but the + upstream API does not currently expose a step-end callback. + +### 7. Observability + +ESMC and ESMFold2 use the existing `sheaf.metrics.record_predict(...)` +and `sheaf.tracing.trace_predict(...)` paths. No new spans, no new metric +families. The deployment `name` label (e.g. `"esmc-6b"`) distinguishes +backends on dashboards. We chose this over the task spec's +`sheaf.esmc.forward` / `sheaf.esmfold2.forward` span names because Sheaf's +convention is one canonical predict span per request with the deployment +label as the disambiguator — adding per-backend span names would create +inconsistent telemetry across the 25+ existing backends. + +## Consequences + +- Sheaf now serves 25 model types across 22 model categories (+2 from + this PR: `PROTEIN_LANGUAGE`, `STRUCTURE`). +- The `[molecular]` and `[protein]` extras are mutually exclusive at + install time. Documentation needs to flag this. +- The pinned `esm @ git+...@81b3646c9429ea8458918415ad6a46178cb59833` will become stale; we should track a + PyPI release in the Biohub repo and switch when one lands. +- `STRUCTURE` is the first model category whose output is fundamentally + non-tensor (PDB / mmCIF strings, with structured side-channel data + like pLDDT and PAE). Future structure-prediction backends (Boltz-1, + Chai-1, etc.) can reuse the contract without new infra. +- ESMFold2 outputs can be large (mmCIF for a multi-chain complex is + often >100 KB); the existing `ResponseCache` is in-process LRU, which + is fine for small deployments but won't scale to thousands of cached + predictions per replica. A pluggable disk-backed cache is a future + enhancement, not blocking for v0.11. + +## Alternatives considered + +- **Reuse `MOLECULAR` for ESMC.** Rejected — response shapes are + incompatible (pooled vector vs ragged per-token tensor), and forcing + callers to branch on `model_name` defeats the point of a typed + contract. +- **Bundle structure prediction under `MOLECULAR`.** Rejected — pLDDT, + PAE, multi-chain inputs, and a PDB string output have nothing in + common with the embedding contract. +- **Default ESMC to a Forge variant for parity with the README's + 300M/600M code samples.** Rejected — Forge requires an API token and + network access; weight-downloadable inference is the model surface + Sheaf can serve self-hosted today. +- **Drop ESM-3 in favour of the new ESMC package.** Rejected for this + PR — the new package may support ESM-3 (the README references an + `ESM3_README.md`), but we haven't verified end-to-end equivalence. A + follow-up PR can consolidate once tested. + +## References + +- Biohub/esm README: +- ESM Atlas: +- Preprint: "Language Modeling Materializes a World Model of Protein + Biology" (Candido et al., 2026), +- HF collections (not directly fetched from build env — see + "verification" above): + - + - diff --git a/examples/quickstart_protein_language.py b/examples/quickstart_protein_language.py new file mode 100644 index 0000000..bb0d398 --- /dev/null +++ b/examples/quickstart_protein_language.py @@ -0,0 +1,90 @@ +"""ESMC protein language model quickstart. + +Requirements: + pip install "sheaf-serve[protein]" + pip install esm@git+https://github.com/Biohub/esm.git@c94ed8d + # Python 3.12+ required. Conflicts with the [molecular] extra (ESM-3). + +Usage: + python examples/quickstart_protein_language.py + +Demonstrates: + - Per-token logits over the amino-acid vocabulary + - Per-token last-layer hidden-state embeddings + - Per-sequence mean-pooled embeddings (call-site reduction) + +ESMC 6B effectively requires a GPU with bf16 support; the default +``device="cuda"`` below will fall back to CPU only on a small dummy +checkpoint. For a real run, set ``CUDA_VISIBLE_DEVICES`` and leave the +device as is. +""" + +from __future__ import annotations + +import math + +from sheaf.api.protein_language import ProteinLanguageRequest +from sheaf.backends.esmc import ESMCBackend + +# Three short protein sequences — variable length, exercises the padding / +# attention-mask slicing path. +SEQUENCES = [ + "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQDNLSGAEK", # 53 aa + "ACDEFGHIKLMNPQRSTVWY", # 20 aa — one of each standard residue + "MVLSPADKTNVKAAW", # 15 aa — N-terminus of human hemoglobin alpha +] + +print("--- ESMC (Biohub/ESMC-6B) ---") +print("Loading model (downloads ~12 GB of weights on first run)...") + +esmc = ESMCBackend(model_name="Biohub/ESMC-6B", device="cuda") +esmc.load() +print("Model loaded.") + +# Per-token logits over the amino-acid vocab — shape (L_i, V) per sequence. +req = ProteinLanguageRequest( + model_name="esmc", + sequences=SEQUENCES, + return_logits=True, + return_embeddings=True, +) +resp = esmc.predict(req) + +print(f"\nVocab size : {resp.vocab_size}") +print(f"Hidden dim : {resp.hidden_dim}") +print(f"Seq lengths : {resp.seq_lens} (includes BOS/EOS special tokens)") + +assert resp.logits is not None +assert resp.embeddings is not None +print("\nPer-token logits shape per sequence:") +for i, (seq, li) in enumerate(zip(SEQUENCES, resp.logits)): + print(f" seq[{i}] len={len(seq):3d} logits=({len(li)}, {len(li[0])})") + +print("\nPer-token embeddings shape per sequence:") +for i, (seq, ei) in enumerate(zip(SEQUENCES, resp.embeddings)): + print(f" seq[{i}] len={len(seq):3d} embeddings=({len(ei)}, {len(ei[0])})") + +# Mean-pool per sequence for a fixed-size representation (call-site reduction — +# Sheaf returns the raw ragged tensor; pooling is a caller policy choice). +print("\nMean-pooled per-sequence embeddings (first 4 dims):") +for i, ei in enumerate(resp.embeddings): + n_tokens = len(ei) + pooled = [sum(tok[d] for tok in ei) / n_tokens for d in range(len(ei[0]))] + print(f" seq[{i}] {[f'{x:+.3f}' for x in pooled[:4]]}") + + +def cosine(a: list[float], b: list[float]) -> float: + dot = sum(x * y for x, y in zip(a, b)) + na = math.sqrt(sum(x * x for x in a)) + nb = math.sqrt(sum(x * x for x in b)) + return dot / (na * nb) if na and nb else 0.0 + + +print("\nPairwise cosine similarity of mean-pooled embeddings:") +pooled = [] +for ei in resp.embeddings: + n_tokens = len(ei) + pooled.append([sum(tok[d] for tok in ei) / n_tokens for d in range(len(ei[0]))]) +for i in range(len(SEQUENCES)): + for j in range(i + 1, len(SEQUENCES)): + print(f" seq[{i}] vs seq[{j}] {cosine(pooled[i], pooled[j]):+.4f}") diff --git a/examples/quickstart_protein_modal.py b/examples/quickstart_protein_modal.py new file mode 100644 index 0000000..a2d8a13 --- /dev/null +++ b/examples/quickstart_protein_modal.py @@ -0,0 +1,139 @@ +"""ESMFold2 on Modal — GPU structure prediction without a local GPU. + +Drives ``sheaf.backends.esmfold2.ESMFold2Backend`` against an H100 on Modal, +exercising the full sheaf wrapper (load + predict + StructureResponse) rather +than the raw upstream API. Adapted from Modal's official ESMFold2 example +(modal-labs/modal-examples 06_gpu_and_ml/protein-folding/esmfold2.py), with +the same ``esm`` git revision pin for reproducibility. + +Prerequisites: + pip install modal + modal setup # authenticate once + +Run: + modal run examples/quickstart_protein_modal.py + modal run examples/quickstart_protein_modal.py --sequence MKTAYIAK... + +The model is cached in a persistent Modal volume after the first run; the +weight download (~12 GB ESMC backbone + ESMFold2 head) only happens once. +""" + +from __future__ import annotations + +from pathlib import Path + +import modal + +# Pinned to the same upstream esm commit Modal's official example uses, so this +# smoke is reproducible against a known-working API surface. +ESM_REVISION = "81b3646c9429ea8458918415ad6a46178cb59833" + +MINUTES = 60 + +app = modal.App(name="sheaf-esmfold2") + +# Persistent HF cache volume — shared with Modal's reference example name so a +# user who has already run that pays no extra cold-start download cost. +esmfold2_volume = modal.Volume.from_name("esmfold2-models", create_if_missing=True) +models_dir = Path("/models") + +# Image: Debian + git + uv-installed esm (git pin) + the sheaf-serve [protein] +# extra (transformers + torch). Mount the working src/ tree so we exercise the +# in-repo backend code, not a published PyPI release. +esmfold2_image = ( + modal.Image.debian_slim(python_version="3.13") + .apt_install("git") + .uv_pip_install( + f"esm @ git+https://github.com/Biohub/esm.git@{ESM_REVISION}", + ) + .pip_install_from_pyproject("pyproject.toml", optional_dependencies=["protein"]) + .add_local_dir("src", remote_path="/root/src", copy=True) + .env( + { + "HF_HOME": str(models_dir), + "HF_XET_HIGH_PERFORMANCE": "1", + "PYTHONPATH": "/root/src", + } + ) +) + + +@app.cls( + image=esmfold2_image, + volumes={models_dir: esmfold2_volume}, + gpu="H100", + timeout=20 * MINUTES, +) +class ESMFold2Inference: + @modal.enter() + def load_model(self) -> None: + from sheaf.backends.esmfold2 import ESMFold2Backend + + print("loading ESMFold2 onto GPU via sheaf.backends.esmfold2") + self.backend = ESMFold2Backend(model_name="biohub/ESMFold2", device="cuda") + self.backend.load() + print("ready") + + @modal.method() + def fold( + self, + sequence: str, + num_loops: int = 3, + num_sampling_steps: int = 50, + num_diffusion_samples: int = 1, + seed: int = 0, + output_format: str = "mmcif", + ) -> dict: + from sheaf.api.structure import ChainInput, StructureRequest + + req = StructureRequest( + model_name="esmfold2", + chains=[ChainInput(chain_id="A", sequence=sequence.strip())], + num_loops=num_loops, + num_sampling_steps=num_sampling_steps, + num_samples=num_diffusion_samples, + seed=seed, + output_format=output_format, # type: ignore[arg-type] + ) + print( + f"folding len={len(sequence)} loops={num_loops} " + f"steps={num_sampling_steps} samples={num_diffusion_samples} seed={seed}" + ) + resp = self.backend.predict(req) + return resp.model_dump(mode="json") + + +# Short test target from the PR's release checklist (~52 residues). +DEFAULT_SEQUENCE = "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQDNLSGAEK" + + +@app.local_entrypoint() +def main( + sequence: str | None = None, + output_path: str | None = None, +) -> None: + seq = sequence or DEFAULT_SEQUENCE + + print(f"sheaf ESMFold2 smoke — sequence length {len(seq)}") + inference = ESMFold2Inference() + resp = inference.fold.remote( + sequence=seq, + num_loops=3, + num_sampling_steps=50, + num_diffusion_samples=1, + seed=0, + output_format="mmcif", + ) + + n_res = len(resp["plddt"]) + mean_plddt = sum(resp["plddt"]) / n_res if n_res else 0.0 + print( + f"\nresidues={n_res} mean_pLDDT={mean_plddt:.2f} " + f"pTM={resp['ptm']:.4f} iptm={resp['iptm']}" + ) + print(f"structure: {resp['structure_format']}, {len(resp['structure'])} chars") + + out = Path(output_path) if output_path else Path("/tmp/sheaf-esmfold2.cif") + out.parent.mkdir(parents=True, exist_ok=True) + out.write_text(resp["structure"]) + print(f"wrote {out}") diff --git a/examples/quickstart_structure.py b/examples/quickstart_structure.py new file mode 100644 index 0000000..a160ffb --- /dev/null +++ b/examples/quickstart_structure.py @@ -0,0 +1,116 @@ +"""ESMFold2 protein structure prediction quickstart. + +Requirements: + pip install "sheaf-serve[protein]" + pip install esm@git+https://github.com/Biohub/esm.git@c94ed8d + # Python 3.12+ required. Conflicts with the [molecular] extra (ESM-3). + +Usage: + python examples/quickstart_structure.py + +Demonstrates: + - Single-chain structure prediction → PDB output + - Multi-chain complex prediction → mmCIF + ipTM + - Inference-time scaling via num_loops / num_sampling_steps / num_samples + - Self-confidence-based ranking when num_samples > 1 + +ESMFold2 wraps a 6B-parameter ESMC backbone with a diffusion head; a GPU +with bf16 support is effectively required for real inference latencies. +""" + +from __future__ import annotations + +from pathlib import Path + +from sheaf.api.structure import ChainInput, StructureRequest +from sheaf.backends.esmfold2 import ESMFold2Backend + +# Single-chain target — N-terminus of bacteriophage T4 lysozyme. +SINGLE_CHAIN = "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQDNLSGAEK" + +# Two-chain mini-complex for ipTM demonstration. (Toy sequences; in practice +# you would supply real interacting partners.) +CHAIN_A = "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQDNLSGAEK" +CHAIN_B = "ACDEFGHIKLMNPQRSTVWY" + +print("--- ESMFold2 (biohub/ESMFold2) ---") +print("Loading model (downloads several GB of weights on first run)...") + +fold = ESMFold2Backend(model_name="biohub/ESMFold2", device="cuda") +fold.load() +print("Model loaded.") + +# --------------------------------------------------------------------------- +# Single-chain fold → PDB +# --------------------------------------------------------------------------- + +print("\n--- Single-chain fold, num_loops=3, num_sampling_steps=50 ---") +req = StructureRequest( + model_name="esmfold2", + chains=[ChainInput(chain_id="A", sequence=SINGLE_CHAIN)], + num_loops=3, + num_sampling_steps=50, + output_format="pdb", +) +resp = fold.predict(req) + +n_res = len(resp.plddt) +mean_plddt = sum(resp.plddt) / n_res if n_res else 0.0 +print(f" residues : {n_res}") +print(f" mean pLDDT : {mean_plddt:.2f}") +print(f" pTM : {resp.ptm}") +print(f" format : {resp.structure_format}") +print(f" structure size : {len(resp.structure)} chars") + +out_pdb = Path("examples") / "esmfold2_single_chain.pdb" +out_pdb.write_text(resp.structure) +print(f" wrote : {out_pdb}") + +# --------------------------------------------------------------------------- +# Multi-chain complex → mmCIF with ipTM +# --------------------------------------------------------------------------- + +print("\n--- Two-chain complex, mmCIF output ---") +req_complex = StructureRequest( + model_name="esmfold2", + chains=[ + ChainInput(chain_id="A", sequence=CHAIN_A), + ChainInput(chain_id="B", sequence=CHAIN_B), + ], + num_loops=3, + num_sampling_steps=50, + output_format="mmcif", +) +resp_complex = fold.predict(req_complex) + +print(f" total residues : {len(resp_complex.plddt)}") +print(f" pTM : {resp_complex.ptm}") +print(f" ipTM : {resp_complex.iptm}") + +out_cif = Path("examples") / "esmfold2_complex.cif" +out_cif.write_text(resp_complex.structure) +print(f" wrote : {out_cif}") + +# --------------------------------------------------------------------------- +# Inference-time scaling — multiple samples, ranked by self-confidence +# --------------------------------------------------------------------------- + +print("\n--- Inference-time scaling: num_samples=4, ranked by self-confidence ---") +req_ranked = StructureRequest( + model_name="esmfold2", + chains=[ChainInput(chain_id="A", sequence=SINGLE_CHAIN)], + num_loops=5, + num_sampling_steps=100, + num_samples=4, + seed=42, + output_format="mmcif", +) +resp_ranked = fold.predict(req_ranked) + +assert resp_ranked.sample_scores is not None +print(f" sample scores : {[f'{s:.4f}' for s in resp_ranked.sample_scores]}") +print( + f" argmax sample : " + f"{resp_ranked.sample_scores.index(max(resp_ranked.sample_scores))}" +) +print(" → returned structure is the highest-confidence sample of the 4.") diff --git a/pyproject.toml b/pyproject.toml index 8fd862d..ed06b24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "sheaf-serve" -version = "0.10.0" +version = "0.11.0" description = "Unified serving layer for non-text foundation models" readme = "README.md" requires-python = ">=3.11" @@ -79,6 +79,19 @@ molecular = [ # ESM-3 requires Python 3.12+; marker prevents resolution errors on 3.11 "esm>=3.0.0 ; python_full_version >= '3.12'", ] +protein = [ + # ESMC + ESMFold2 (Biohub 2026-05-27 release). The `esm` package is not + # yet on PyPI; install it manually from git. We pin to commit 81b3646c + # (the revision Modal's official ESMFold2 example uses), which is the + # one verified end-to-end via examples/quickstart_protein_modal.py on H100: + # pip install esm@git+https://github.com/Biohub/esm.git@81b3646c9429ea8458918415ad6a46178cb59833 + # This extra provides the remaining PyTorch / transformers stack. We + # avoid declaring `esm` itself here because it shares the import name + # with the pre-2026 `esm` 3.x package used by [molecular] (ESM-3); the + # two cannot coexist in the same environment. Python 3.12+ required. + "transformers>=4.40.0", + "torch>=2.0.0", +] weather = [ # graphcast 1.0.0 requires Python 3.11 only; pin <1.0.0 for 3.11 + 3.12 support "graphcast>=0.1.0,<1.0.0", diff --git a/src/sheaf/__init__.py b/src/sheaf/__init__.py index ad19d01..d7ca370 100644 --- a/src/sheaf/__init__.py +++ b/src/sheaf/__init__.py @@ -6,7 +6,7 @@ from sheaf.modal_server import ModalServer from sheaf.spec import ModelSpec -__version__ = "0.10.0" +__version__ = "0.11.0" __all__ = ["ModalServer", "ModelServer", "ModelSpec", "build_app"] diff --git a/src/sheaf/api/base.py b/src/sheaf/api/base.py index 0f90389..49e1c59 100644 --- a/src/sheaf/api/base.py +++ b/src/sheaf/api/base.py @@ -11,6 +11,8 @@ class ModelType(StrEnum): TIME_SERIES = "time_series" TABULAR = "tabular" MOLECULAR = "molecular" + PROTEIN_LANGUAGE = "protein_language" + STRUCTURE = "structure" GENOMIC = "genomic" MATERIALS = "materials" SMALL_MOLECULE = "small_molecule" diff --git a/src/sheaf/api/protein_language.py b/src/sheaf/api/protein_language.py new file mode 100644 index 0000000..86a17ab --- /dev/null +++ b/src/sheaf/api/protein_language.py @@ -0,0 +1,79 @@ +"""API contract for protein language models (ESMC, etc.). + +This is distinct from :mod:`sheaf.api.molecular`, which covers per-sequence +pooled protein embeddings (ESM-3). ESMC and its successors return *per-token* +logits and optionally per-token hidden-state embeddings — a ragged tensor, +not a fixed-size vector per sequence. +""" + +from __future__ import annotations + +from typing import Literal + +from sheaf.api.base import BaseRequest, BaseResponse, ModelType + + +class ProteinLanguageRequest(BaseRequest): + """Request contract for protein language model inference. + + A single request runs a batch of amino-acid sequences through one + forward pass. The model returns per-token logits over the amino-acid + vocabulary and (optionally) per-token hidden-state embeddings. + + Attributes: + sequences: List of amino-acid sequences. Standard single-letter + codes (ACDEFGHIKLMNPQRSTVWY plus ambiguity codes accepted by + the ESM tokenizer). + return_logits: If True (default), return per-token logits over + the amino-acid vocabulary. Shape per sequence: ``(L_i, V)``. + return_embeddings: If True, return per-token last-layer hidden + states. Shape per sequence: ``(L_i, H)``. Defaults to False + — these can be very large for long sequences against a + 6B-parameter model. + output_hidden_states: If True, return per-token hidden states + from *every* transformer layer (the upstream + ``output_hidden_states=True`` flag on the HF model). Defaults + to False. Implies ``return_embeddings=True``. + """ + + model_type: Literal[ModelType.PROTEIN_LANGUAGE] = ModelType.PROTEIN_LANGUAGE + + sequences: list[str] + return_logits: bool = True + return_embeddings: bool = False + output_hidden_states: bool = False + + +class ProteinLanguageResponse(BaseResponse): + """Response contract for protein language model inference. + + All ragged per-sequence tensors are returned as nested lists sliced + back to each input sequence's tokenized length (``seq_lens[i]``) — i.e. + padding is stripped before serialisation. + + Attributes: + logits: Per-sequence per-token logits over the amino-acid + vocabulary. ``logits[i]`` has shape ``(seq_lens[i], V)``. + ``None`` when ``return_logits=False``. + embeddings: Per-sequence per-token last-layer hidden states. + ``embeddings[i]`` has shape ``(seq_lens[i], H)``. ``None`` + when ``return_embeddings=False`` and + ``output_hidden_states=False``. + hidden_states: Per-layer per-sequence per-token hidden states. + ``hidden_states[layer][i]`` has shape ``(seq_lens[i], H)``. + ``None`` unless ``output_hidden_states=True``. + seq_lens: Tokenized length of each input sequence (includes any + BOS/EOS special tokens the tokenizer adds — same convention + as the upstream attention mask). + vocab_size: Size of the amino-acid vocabulary (``V`` above). + hidden_dim: Hidden-state dimensionality (``H`` above). + """ + + model_type: Literal[ModelType.PROTEIN_LANGUAGE] = ModelType.PROTEIN_LANGUAGE + + logits: list[list[list[float]]] | None = None + embeddings: list[list[list[float]]] | None = None + hidden_states: list[list[list[list[float]]]] | None = None + seq_lens: list[int] + vocab_size: int | None = None + hidden_dim: int | None = None diff --git a/src/sheaf/api/structure.py b/src/sheaf/api/structure.py new file mode 100644 index 0000000..65f6622 --- /dev/null +++ b/src/sheaf/api/structure.py @@ -0,0 +1,116 @@ +"""API contract for protein structure prediction (ESMFold2, etc.). + +Structure prediction is a new top-level model category in Sheaf — outputs +are 3D atomic coordinates (PDB / mmCIF strings) plus confidence side-channels +(pLDDT, pTM, ipTM, PAE), not embeddings or classifications. See +``docs/adr/0001-esmc-esmfold2-integration.md`` for the rationale. + +Inference-time scaling parameters (``num_loops``, ``num_sampling_steps``, +``num_samples``) are exposed as first-class request fields, not hidden +kwargs — this is the headline capability of ESMFold2's looped-transformer +architecture and a key Sheaf differentiator vs. plain HF serving. +""" + +from __future__ import annotations + +from typing import Literal + +from pydantic import BaseModel, Field + +from sheaf.api.base import BaseRequest, BaseResponse, ModelType + + +class ChainInput(BaseModel): + """A single chain in a structure-prediction request. + + Multi-chain complexes are expressed by passing multiple ``ChainInput``s + in ``StructureRequest.chains``. Ligand / DNA / modified-residue inputs + are out of scope for v0.11 — the upstream ``StructurePredictionInput`` + supports them but they require additional Pydantic shape work; this + contract is protein-chains-only. + + Attributes: + chain_id: Single-letter (or short) chain identifier — used as the + chain ID in the output PDB / mmCIF. Must be unique within a + request. + sequence: Amino-acid sequence using standard single-letter codes. + """ + + chain_id: str + sequence: str + + +class StructureRequest(BaseRequest): + """Request contract for protein structure prediction. + + A single request predicts one structure (possibly multi-chain). The + inference-time scaling parameters control the depth of the + looped-transformer recurrence (``num_loops``), the number of diffusion + sampling steps (``num_sampling_steps``), and how many independent + samples to draw (``num_samples``); when ``num_samples > 1`` the + response includes per-sample ranking scores and returns the + highest-scoring structure as the primary output. + + Attributes: + chains: One or more protein chains to fold together. Single-chain + structure prediction uses a one-element list. + msa: Optional multiple-sequence alignment input, one list of + homolog sequences per chain (parallel to ``chains``). ESMFold2 + runs in single-sequence mode when ``msa`` is None (the headline + speedup path); supplying an MSA improves accuracy on + challenging targets at the cost of throughput. + num_loops: Depth of the looped-transformer recurrence. Default 3. + num_sampling_steps: Number of diffusion sampling steps per sample. + Default 50. + num_samples: Number of independent structures to sample. Default + 1; values > 1 enable self-confidence-based ranking. + seed: Random seed for the diffusion sampler. Default 0. + output_format: ``"mmcif"`` (default) or ``"pdb"``. + """ + + model_type: Literal[ModelType.STRUCTURE] = ModelType.STRUCTURE + + chains: list[ChainInput] = Field(min_length=1) + msa: list[list[str]] | None = None + num_loops: int = 3 + num_sampling_steps: int = 50 + num_samples: int = 1 + seed: int = 0 + output_format: Literal["mmcif", "pdb"] = "mmcif" + + +class StructureResponse(BaseResponse): + """Response contract for protein structure prediction. + + Attributes: + structure: The predicted structure encoded as PDB or mmCIF text + (per ``StructureRequest.output_format``). When + ``num_samples > 1`` this is the highest-confidence sample. + structure_format: ``"mmcif"`` or ``"pdb"``, matching the request. + plddt: Per-residue predicted-Local-Distance-Difference-Test score. + ESMFold2 reports pLDDT on **[0, 1]** (fractional), not the + conventional AlphaFold / ESMFold-v1 [0, 100] scale — multiply by + 100 if you need the conventional values. Length = total residues + across all chains. + ptm: Predicted-TM score for the structure as a whole. ``None`` if + the model did not produce one. + iptm: Interface-pTM (for multi-chain complexes). ``None`` for + single-chain inputs or when the model did not produce one. + pae: Predicted Aligned Error matrix, shape ``(N, N)`` where N is + total residues. ``None`` if the model did not produce one + (e.g. some fast variants skip PAE to save compute). + sample_scores: When ``num_samples > 1``, the self-confidence + score for each sample in the order produced. The returned + ``structure`` corresponds to ``argmax(sample_scores)``. + ``None`` when ``num_samples == 1``. + """ + + model_type: Literal[ModelType.STRUCTURE] = ModelType.STRUCTURE + + structure: str + structure_format: Literal["mmcif", "pdb"] + plddt: list[float] + ptm: float | None = None + iptm: float | None = None + pae: list[list[float]] | None = None + sample_scores: list[float] | None = None diff --git a/src/sheaf/api/union.py b/src/sheaf/api/union.py index c5d4753..644d7f3 100644 --- a/src/sheaf/api/union.py +++ b/src/sheaf/api/union.py @@ -35,9 +35,14 @@ from sheaf.api.optical_flow import OpticalFlowRequest, OpticalFlowResponse from sheaf.api.point_cloud import PointCloudRequest, PointCloudResponse from sheaf.api.pose import PoseRequest, PoseResponse +from sheaf.api.protein_language import ( + ProteinLanguageRequest, + ProteinLanguageResponse, +) from sheaf.api.satellite import SatelliteRequest, SatelliteResponse from sheaf.api.segmentation import SegmentationRequest, SegmentationResponse from sheaf.api.small_molecule import SmallMoleculeRequest, SmallMoleculeResponse +from sheaf.api.structure import StructureRequest, StructureResponse from sheaf.api.tabular import TabularRequest, TabularResponse from sheaf.api.time_series import TimeSeriesRequest, TimeSeriesResponse from sheaf.api.video import VideoRequest, VideoResponse @@ -65,7 +70,9 @@ | PoseRequest | OpticalFlowRequest | MultimodalGenerationRequest - | PointCloudRequest, + | PointCloudRequest + | ProteinLanguageRequest + | StructureRequest, Field(discriminator="model_type"), ] @@ -91,7 +98,9 @@ | PoseResponse | OpticalFlowResponse | MultimodalGenerationResponse - | PointCloudResponse, + | PointCloudResponse + | ProteinLanguageResponse + | StructureResponse, Field(discriminator="model_type"), ] diff --git a/src/sheaf/backends/_register.py b/src/sheaf/backends/_register.py index c46dfa4..0c6fa67 100644 --- a/src/sheaf/backends/_register.py +++ b/src/sheaf/backends/_register.py @@ -23,6 +23,8 @@ def register_builtin_backends() -> None: import sheaf.backends.detr # noqa: F401 import sheaf.backends.dinov2 # noqa: F401 import sheaf.backends.esm3 # noqa: F401 + import sheaf.backends.esmc # noqa: F401 + import sheaf.backends.esmfold2 # noqa: F401 import sheaf.backends.faster_whisper # noqa: F401 import sheaf.backends.flux # noqa: F401 import sheaf.backends.graphcast # noqa: F401 diff --git a/src/sheaf/backends/esmc.py b/src/sheaf/backends/esmc.py new file mode 100644 index 0000000..ac80e17 --- /dev/null +++ b/src/sheaf/backends/esmc.py @@ -0,0 +1,183 @@ +"""ESMC backend for protein language modeling via Biohub's esm package. + +Requires: pip install "sheaf-serve[protein]" (Python 3.12+) +Library: esm (https://github.com/Biohub/esm), released 2026-05-27 under MIT. + +Supported models (HuggingFace Hub — weight-downloadable): + "Biohub/ESMC-6B" (default) — the only weight-downloadable ESMC variant + as of 2026-05-27. + +Forge / Biohub-Platform-only variants (require an API token; not served by +this backend in v0.11 — see docs/adr/0001-esmc-esmfold2-integration.md): + "esmc-300m-2024-12", "esmc-600m-2024-12" + +ESMC follows the standard ``transformers`` masked-LM interface: tokenize a +batch of amino-acid sequences with ``AutoTokenizer``, run a single forward +pass through ``AutoModelForMaskedLM``, and return per-token logits and +optionally per-token hidden-state embeddings. Sequences are length-padded +to the longest in the batch and the attention mask is used to slice +results back to ragged per-sequence lengths before serialisation. + +``AutoTokenizer`` and ``AutoModelForMaskedLM`` are stored as instance +attributes at ``load()`` time so the heavy dependency stays lazy and tests +can inject mocks without ``transformers`` or ``torch`` installed. +""" + +from __future__ import annotations + +from typing import Any + +from sheaf.api.base import BaseRequest, BaseResponse, ModelType +from sheaf.api.protein_language import ( + ProteinLanguageRequest, + ProteinLanguageResponse, +) +from sheaf.backends.base import ModelBackend +from sheaf.registry import register_backend + +_DEFAULT_MODEL = "Biohub/ESMC-6B" +_FORGE_MODELS = frozenset({"esmc-300m-2024-12", "esmc-600m-2024-12"}) + + +@register_backend("esmc") +class ESMCBackend(ModelBackend): + """ModelBackend for ESMC protein language modeling. + + Args: + model_name: HuggingFace model ID. Default ``"Biohub/ESMC-6B"`` — + the only weight-downloadable ESMC variant as of 2026-05-27. + device: ``"cpu"``, ``"cuda"``, ``"cuda:N"``, or ``"mps"``. The + 6B model effectively requires a GPU with bf16 support; CPU is + available for testing on smaller dummy weights. + device_map: When set (e.g. ``"auto"``), passed through to + ``from_pretrained`` for sharding. Mutually exclusive with + ``.to(device)`` — when ``device_map`` is set, ``device`` is + ignored. + """ + + def __init__( + self, + model_name: str = _DEFAULT_MODEL, + device: str = "cpu", + device_map: str | None = None, + ) -> None: + self._model_name = model_name + self._device = device + self._device_map = device_map + self._model: Any = None + self._tokenizer: Any = None + + @property + def model_type(self) -> str: + return ModelType.PROTEIN_LANGUAGE + + def load(self) -> None: + if self._model_name in _FORGE_MODELS: + raise NotImplementedError( + f"ESMC variant {self._model_name!r} is API-only via the " + "Biohub Platform (Forge) and is not served by this backend " + "in v0.11. See docs/adr/0001-esmc-esmfold2-integration.md " + "for the rationale and roadmap." + ) + try: + from transformers import ( # ty: ignore[unresolved-import] + AutoModelForMaskedLM, + AutoTokenizer, + ) + except ImportError as e: + raise ImportError( + "transformers is required for the ESMC backend. " + "Install it with: pip install 'sheaf-serve[protein]' " + "(Python 3.12+ required)" + ) from e + self._tokenizer = AutoTokenizer.from_pretrained(self._model_name) + kwargs: dict[str, Any] = {} + if self._device_map is not None: + kwargs["device_map"] = self._device_map + self._model = AutoModelForMaskedLM.from_pretrained(self._model_name, **kwargs) + if self._device_map is None: + self._model = self._model.to(self._device) + self._model.eval() + + def predict(self, request: BaseRequest) -> BaseResponse: + if not isinstance(request, ProteinLanguageRequest): + raise TypeError(f"Expected ProteinLanguageRequest, got {type(request)}") + return self._run(request) + + def batch_predict(self, requests: list[BaseRequest]) -> list[BaseResponse]: + return [self.predict(r) for r in requests] + + def _run( # noqa: C901 + self, request: ProteinLanguageRequest + ) -> ProteinLanguageResponse: + import torch # ty: ignore[unresolved-import] + + if self._model is None or self._tokenizer is None: + raise RuntimeError("Backend not loaded. Call load() first.") + + inputs = self._tokenizer( + request.sequences, + return_tensors="pt", + padding=True, + ) + # The model device may differ from self._device when device_map="auto". + target_device = ( + self._model.device if self._device_map is not None else self._device + ) + inputs = {k: v.to(target_device) for k, v in inputs.items()} + attention_mask = inputs["attention_mask"] # (N, L) + seq_lens: list[int] = attention_mask.sum(dim=1).cpu().int().tolist() + + # MaskedLMOutput has .logits + .hidden_states (when requested) but no + # .last_hidden_state — so requesting embeddings forces the hidden-states + # flag on the underlying model call. + need_hidden = request.return_embeddings or request.output_hidden_states + with torch.inference_mode(): + output = self._model( + **inputs, + output_hidden_states=need_hidden, + ) + + logits_out: list[list[list[float]]] | None = None + if request.return_logits: + # logits: (N, L, V) — slice per-sequence by seq_lens[i] + logits = output.logits.cpu().float() + logits_out = [ + logits[i, : seq_lens[i], :].tolist() for i in range(len(seq_lens)) + ] + + embeddings_out: list[list[list[float]]] | None = None + if need_hidden: + last_hidden = output.hidden_states[-1].cpu().float() + embeddings_out = [ + last_hidden[i, : seq_lens[i], :].tolist() for i in range(len(seq_lens)) + ] + + hidden_states_out: list[list[list[list[float]]]] | None = None + if request.output_hidden_states: + # hidden_states: tuple of (N, L, H), one per layer (including embed) + hidden_states_out = [] + for layer_hidden in output.hidden_states: + layer_hidden = layer_hidden.cpu().float() + hidden_states_out.append( + [ + layer_hidden[i, : seq_lens[i], :].tolist() + for i in range(len(seq_lens)) + ] + ) + + vocab_size = int(output.logits.shape[-1]) if request.return_logits else None + hidden_dim = None + if embeddings_out is not None and embeddings_out and embeddings_out[0]: + hidden_dim = len(embeddings_out[0][0]) + + return ProteinLanguageResponse( + request_id=request.request_id, + model_name=request.model_name, + logits=logits_out, + embeddings=embeddings_out, + hidden_states=hidden_states_out, + seq_lens=seq_lens, + vocab_size=vocab_size, + hidden_dim=hidden_dim, + ) diff --git a/src/sheaf/backends/esmfold2.py b/src/sheaf/backends/esmfold2.py new file mode 100644 index 0000000..f014b3a --- /dev/null +++ b/src/sheaf/backends/esmfold2.py @@ -0,0 +1,188 @@ +"""ESMFold2 backend for protein structure prediction via Biohub's esm package. + +Requires: pip install "sheaf-serve[protein]" (Python 3.12+) +Library: esm (https://github.com/Biohub/esm), released 2026-05-27 under MIT. + +Supported models (HuggingFace Hub — weight-downloadable): + "biohub/ESMFold2" (default) — note lowercase ``biohub/``, mirroring the + upstream README verbatim. + +Forge / Biohub-Platform-only variants (require an API token; not served by +this backend in v0.11 — see docs/adr/0001-esmc-esmfold2-integration.md): + "esmfold2-fast-2026-05" + +ESMFold2 wraps a 6B-parameter ESMC language model with a diffusion-based +structure-prediction head. The headline capability is *inference-time +scaling*: the looped-transformer recurrence (``num_loops``), the diffusion +sampler (``num_sampling_steps``), and the number of independent samples +(``num_diffusion_samples``) trade compute for accuracy at predict time. +We surface all three plus the random ``seed`` as first-class fields on +:class:`~sheaf.api.structure.StructureRequest`. + +Structure prediction is **single-sample-per-call** at the upstream API +level; ``batch_predict`` runs requests sequentially. Per-request compute +varies hugely with sequence length × num_loops × num_samples; do not +expect uniform latency. +""" + +from __future__ import annotations + +from typing import Any + +from sheaf.api.base import BaseRequest, BaseResponse, ModelType +from sheaf.api.structure import StructureRequest, StructureResponse +from sheaf.backends.base import ModelBackend +from sheaf.registry import register_backend + +_DEFAULT_MODEL = "biohub/ESMFold2" +_FORGE_MODELS = frozenset({"esmfold2-fast-2026-05"}) + + +@register_backend("esmfold2") +class ESMFold2Backend(ModelBackend): + """ModelBackend for ESMFold2 protein structure prediction. + + Args: + model_name: HuggingFace model ID. Default ``"biohub/ESMFold2"``. + device: ``"cpu"``, ``"cuda"``, ``"cuda:N"``. The 6B-backed + structure head effectively requires a GPU with bf16 support + for practical inference latencies. + """ + + def __init__( + self, + model_name: str = _DEFAULT_MODEL, + device: str = "cuda", + ) -> None: + self._model_name = model_name + self._device = device + self._model: Any = None + # Stored at load() for test injectability — same pattern as + # ESM3Backend._ESMProtein, OpenCLIPBackend._Image, etc. + self._ProteinInput: Any = None + self._StructurePredictionInput: Any = None + self._InputBuilder: Any = None + + @property + def model_type(self) -> str: + return ModelType.STRUCTURE + + def load(self) -> None: + if self._model_name in _FORGE_MODELS: + raise NotImplementedError( + f"ESMFold2 variant {self._model_name!r} is API-only via " + "the Biohub Platform (Forge) and is not served by this " + "backend in v0.11. See " + "docs/adr/0001-esmc-esmfold2-integration.md for the " + "rationale and roadmap." + ) + try: + from esm.models.esmfold2 import ( # ty: ignore[unresolved-import] + ESMFold2InputBuilder, + ProteinInput, + StructurePredictionInput, + ) + from transformers.models.esmfold2.modeling_esmfold2 import ( # ty: ignore[unresolved-import] + ESMFold2Model, + ) + except ImportError as e: + raise ImportError( + "esm and transformers are required for the ESMFold2 backend. " + "Install them with: pip install 'sheaf-serve[protein]' " + "(Python 3.12+ required)" + ) from e + + self._model = ESMFold2Model.from_pretrained(self._model_name) + # ``.cuda()`` and ``.eval()`` mirror the upstream README example. + # We call ``.to(device)`` instead so CPU testing works on a stub. + self._model = self._model.to(self._device) + self._model.eval() + self._ProteinInput = ProteinInput + self._StructurePredictionInput = StructurePredictionInput + self._InputBuilder = ESMFold2InputBuilder + + def predict(self, request: BaseRequest) -> BaseResponse: + if not isinstance(request, StructureRequest): + raise TypeError(f"Expected StructureRequest, got {type(request)}") + return self._run(request) + + def batch_predict(self, requests: list[BaseRequest]) -> list[BaseResponse]: + # Structure prediction is single-sample-per-call upstream; sequential. + return [self.predict(r) for r in requests] + + def _run(self, request: StructureRequest) -> StructureResponse: + if self._model is None: + raise RuntimeError("Backend not loaded. Call load() first.") + + protein_inputs = [ + self._ProteinInput(id=c.chain_id, sequence=c.sequence) + for c in request.chains + ] + spi = self._StructurePredictionInput(sequences=protein_inputs) + + result = self._InputBuilder().fold( + self._model, + spi, + num_loops=request.num_loops, + num_sampling_steps=request.num_sampling_steps, + num_diffusion_samples=request.num_samples, + seed=request.seed, + ) + + if request.output_format == "pdb": + structure_str = result.complex.to_pdb() + else: + structure_str = result.complex.to_mmcif() + + # pLDDT — coerce to a flat list[float]. Upstream returns a tensor; + # we go via .cpu().float().tolist() if it's a tensor, or trust the + # value as-is if it's already a list (test stubs). + plddt = _to_float_list(result.plddt) + + ptm = _maybe_float(getattr(result, "ptm", None)) + iptm = _maybe_float(getattr(result, "iptm", None)) + pae = _maybe_2d_list(getattr(result, "pae", None)) + sample_scores = ( + _to_float_list(getattr(result, "sample_scores", None)) + if request.num_samples > 1 + and getattr(result, "sample_scores", None) is not None + else None + ) + + return StructureResponse( + request_id=request.request_id, + model_name=request.model_name, + structure=structure_str, + structure_format=request.output_format, + plddt=plddt, + ptm=ptm, + iptm=iptm, + pae=pae, + sample_scores=sample_scores, + ) + + +def _to_float_list(value: Any) -> list[float]: + if value is None: + return [] + if hasattr(value, "cpu"): + value = value.cpu().float().tolist() + if isinstance(value, list): + return [float(x) for x in value] + return [float(value)] + + +def _maybe_float(value: Any) -> float | None: + if value is None: + return None + if hasattr(value, "item"): + return float(value.item()) + return float(value) + + +def _maybe_2d_list(value: Any) -> list[list[float]] | None: + if value is None: + return None + if hasattr(value, "cpu"): + value = value.cpu().float().tolist() + return [[float(x) for x in row] for row in value] diff --git a/src/sheaf/modal_server.py b/src/sheaf/modal_server.py index adcfb1c..5a05877 100644 --- a/src/sheaf/modal_server.py +++ b/src/sheaf/modal_server.py @@ -65,9 +65,11 @@ from sheaf.api.optical_flow import OpticalFlowRequest from sheaf.api.point_cloud import PointCloudRequest from sheaf.api.pose import PoseRequest +from sheaf.api.protein_language import ProteinLanguageRequest from sheaf.api.satellite import SatelliteRequest from sheaf.api.segmentation import SegmentationRequest from sheaf.api.small_molecule import SmallMoleculeRequest +from sheaf.api.structure import StructureRequest from sheaf.api.tabular import TabularRequest from sheaf.api.time_series import TimeSeriesRequest from sheaf.api.video import VideoRequest @@ -101,7 +103,9 @@ | PoseRequest | OpticalFlowRequest | MultimodalGenerationRequest - | PointCloudRequest, + | PointCloudRequest + | ProteinLanguageRequest + | StructureRequest, Field(discriminator="model_type"), ] @@ -141,6 +145,8 @@ def _build_asgi_app(specs: list[ModelSpec], *, load_backends: bool = True) -> An import sheaf.backends.detr # noqa: F401 import sheaf.backends.dinov2 # noqa: F401 import sheaf.backends.esm3 # noqa: F401 + import sheaf.backends.esmc # noqa: F401 + import sheaf.backends.esmfold2 # noqa: F401 import sheaf.backends.faster_whisper # noqa: F401 import sheaf.backends.flux # noqa: F401 import sheaf.backends.graphcast # noqa: F401 diff --git a/tests/test_esmc_backend.py b/tests/test_esmc_backend.py new file mode 100644 index 0000000..57e2891 --- /dev/null +++ b/tests/test_esmc_backend.py @@ -0,0 +1,458 @@ +"""Tests for ESMCBackend — fully mocked, no transformers or torch required. + +Covers: + - load() raises ImportError when transformers is absent + - load() raises NotImplementedError for Forge-only model IDs + - load() passes model_name through to from_pretrained + - load() moves model to specified device when device_map is None + - load() skips .to(device) when device_map is set + - predict() rejects non-ProteinLanguageRequest inputs + - predict() returns logits when return_logits=True + - predict() omits logits when return_logits=False + - predict() returns per-token embeddings when return_embeddings=True + - predict() returns all-layer hidden_states when output_hidden_states=True + - predict() slices outputs back to ragged per-sequence lengths via attention mask + - predict() reports vocab_size and hidden_dim + - batch_predict() runs each request independently +""" + +from __future__ import annotations + +import builtins +import sys +from types import ModuleType +from typing import Any +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from sheaf.api.protein_language import ( + ProteinLanguageRequest, + ProteinLanguageResponse, +) + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +_HIDDEN_DIM = 5120 # ESMC 6B +_VOCAB = 33 # standard ESM amino-acid vocab (residues + special tokens) +_N_SEQ = 2 +_SEQ_LEN = 8 # padded length; per-sequence real lengths = [6, 5] +_REAL_LENS = [6, 5] + + +# --------------------------------------------------------------------------- +# FakeTensor — numpy-backed; supports every op used in ESMCBackend._run() +# --------------------------------------------------------------------------- + + +class FakeTensor: + def __init__(self, data: list | np.ndarray, dtype: Any = np.float32) -> None: + self._data = np.asarray(data, dtype=dtype) + + @property + def shape(self) -> tuple[int, ...]: + return self._data.shape + + def __getitem__(self, key: object) -> FakeTensor: + return FakeTensor(self._data[key]) # type: ignore[index] + + def sum(self, dim: int) -> FakeTensor: + return FakeTensor(self._data.sum(axis=dim)) + + def cpu(self) -> FakeTensor: + return self + + def float(self) -> FakeTensor: + return FakeTensor(self._data.astype(np.float32)) + + def int(self) -> FakeTensor: + return FakeTensor(self._data.astype(np.int64), dtype=np.int64) + + def tolist(self) -> list: + return self._data.tolist() + + def to(self, _device: str) -> FakeTensor: + return self + + +# --------------------------------------------------------------------------- +# Fake torch module +# --------------------------------------------------------------------------- + + +class _InferenceMode: + def __enter__(self) -> _InferenceMode: + return self + + def __exit__(self, *_: object) -> None: + pass + + +def _make_torch_mod() -> ModuleType: + mod = ModuleType("torch") + mod.inference_mode = _InferenceMode # type: ignore[attr-defined] + return mod + + +_torch_mod = _make_torch_mod() + + +# --------------------------------------------------------------------------- +# Fake tokenizer output — dict-like with .items() and key access +# --------------------------------------------------------------------------- + + +def _make_attention_mask(real_lens: list[int], padded_len: int) -> np.ndarray: + mask = np.zeros((len(real_lens), padded_len), dtype=np.int64) + for i, n in enumerate(real_lens): + mask[i, :n] = 1 + return mask + + +class _FakeTokenizerOutput: + def __init__( + self, + real_lens: list[int] = _REAL_LENS, + padded_len: int = _SEQ_LEN, + ) -> None: + self._d = { + "input_ids": FakeTensor( + np.zeros((len(real_lens), padded_len), dtype=np.int64), + dtype=np.int64, + ), + "attention_mask": FakeTensor( + _make_attention_mask(real_lens, padded_len), + dtype=np.int64, + ), + } + + def items(self): # type: ignore[no-untyped-def] + return self._d.items() + + def __getitem__(self, key: str) -> FakeTensor: + return self._d[key] + + +# --------------------------------------------------------------------------- +# Fake model output factory +# --------------------------------------------------------------------------- + + +def _make_model_output( + n: int = _N_SEQ, + seq_len: int = _SEQ_LEN, + hidden_dim: int = _HIDDEN_DIM, + vocab: int = _VOCAB, + with_hidden_states: bool = False, + n_layers: int = 4, +) -> MagicMock: + # Mirrors transformers' MaskedLMOutput: .logits always; .hidden_states only + # when the model was called with output_hidden_states=True (or None / absent + # otherwise — the backend reads it lazily and only when needed). + logits = FakeTensor(np.full((n, seq_len, vocab), 0.5, dtype=np.float32)) + out = MagicMock(spec=["logits", "hidden_states"]) + out.logits = logits + if with_hidden_states: + out.hidden_states = tuple( + FakeTensor(np.full((n, seq_len, hidden_dim), float(i), dtype=np.float32)) + for i in range(n_layers) + ) + else: + # Single-layer tuple — matches what the model would return when called + # with output_hidden_states=True for the embeddings-only path. The + # backend always sets the flag when it needs embeddings. + out.hidden_states = ( + FakeTensor(np.full((n, seq_len, hidden_dim), 1.0, dtype=np.float32)), + ) + return out + + +# --------------------------------------------------------------------------- +# Fake transformers module factory +# --------------------------------------------------------------------------- + + +def _make_transformers_mod( + with_hidden_states: bool = False, +) -> tuple[ModuleType, MagicMock, MagicMock]: + model_output = _make_model_output(with_hidden_states=with_hidden_states) + model = MagicMock() + model.return_value = model_output + model.to.return_value = model + model.eval.return_value = None + model.device = "cpu" + + tokenizer = MagicMock() + tokenizer.return_value = _FakeTokenizerOutput() + + mod = ModuleType("transformers") + mod.AutoModelForMaskedLM = MagicMock() # type: ignore[attr-defined] + mod.AutoModelForMaskedLM.from_pretrained.return_value = model # type: ignore[attr-defined] + mod.AutoTokenizer = MagicMock() # type: ignore[attr-defined] + mod.AutoTokenizer.from_pretrained.return_value = tokenizer # type: ignore[attr-defined] + return mod, model, tokenizer + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_transformers() -> ModuleType: + return _make_transformers_mod()[0] + + +@pytest.fixture +def loaded_backend(mock_transformers: ModuleType): # type: ignore[no-untyped-def] + from sheaf.backends.esmc import ESMCBackend + + backend = ESMCBackend(model_name="Biohub/ESMC-6B", device="cpu") + with patch.dict( + sys.modules, {"transformers": mock_transformers, "torch": _torch_mod} + ): + backend.load() + return backend + + +def _wire_output(backend, *, with_hidden_states: bool = False) -> MagicMock: + """Rebind the backend's stored model/tokenizer to fresh mocks.""" + mod, model, tokenizer = _make_transformers_mod( + with_hidden_states=with_hidden_states + ) + backend._model = model + backend._tokenizer = tokenizer + return model + + +def _make_request( + sequences: list[str] | None = None, + return_logits: bool = True, + return_embeddings: bool = False, + output_hidden_states: bool = False, +) -> ProteinLanguageRequest: + if sequences is None: + sequences = ["MKTII", "ACDE"] # 5 and 4 residues + return ProteinLanguageRequest( + model_name="esmc", + sequences=sequences, + return_logits=return_logits, + return_embeddings=return_embeddings, + output_hidden_states=output_hidden_states, + ) + + +# --------------------------------------------------------------------------- +# load() — error cases +# --------------------------------------------------------------------------- + + +def test_load_raises_on_missing_transformers() -> None: + from sheaf.backends.esmc import ESMCBackend + + backend = ESMCBackend() + mods_without = {k: v for k, v in sys.modules.items() if "transformers" not in k} + _real_import = builtins.__import__ + + def _raise(name: str, *a: object, **kw: object) -> object: + if name == "transformers": + raise ModuleNotFoundError("No module named 'transformers'") + return _real_import(name, *a, **kw) + + with ( + patch.dict(sys.modules, mods_without, clear=True), + patch("builtins.__import__", side_effect=_raise), + pytest.raises(ImportError, match="sheaf-serve\\[protein\\]"), + ): + backend.load() + + +def test_load_rejects_forge_only_model() -> None: + from sheaf.backends.esmc import ESMCBackend + + backend = ESMCBackend(model_name="esmc-300m-2024-12") + with pytest.raises(NotImplementedError, match="Forge"): + backend.load() + + +def test_load_rejects_other_forge_only_model() -> None: + from sheaf.backends.esmc import ESMCBackend + + backend = ESMCBackend(model_name="esmc-600m-2024-12") + with pytest.raises(NotImplementedError, match="Forge"): + backend.load() + + +# --------------------------------------------------------------------------- +# load() — happy path +# --------------------------------------------------------------------------- + + +def test_load_passes_model_name(mock_transformers: ModuleType) -> None: + from sheaf.backends.esmc import ESMCBackend + + backend = ESMCBackend(model_name="Biohub/ESMC-6B") + with patch.dict( + sys.modules, {"transformers": mock_transformers, "torch": _torch_mod} + ): + backend.load() + + mock_transformers.AutoTokenizer.from_pretrained.assert_called_once_with( # type: ignore[attr-defined] + "Biohub/ESMC-6B" + ) + mock_transformers.AutoModelForMaskedLM.from_pretrained.assert_called_once_with( # type: ignore[attr-defined] + "Biohub/ESMC-6B" + ) + + +def test_load_moves_model_to_device(mock_transformers: ModuleType) -> None: + from sheaf.backends.esmc import ESMCBackend + + backend = ESMCBackend(device="cuda") + with patch.dict( + sys.modules, {"transformers": mock_transformers, "torch": _torch_mod} + ): + backend.load() + + mock_transformers.AutoModelForMaskedLM.from_pretrained.return_value.to.assert_called_once_with( # type: ignore[attr-defined] + "cuda" + ) + + +def test_load_device_map_skips_to_call(mock_transformers: ModuleType) -> None: + """When device_map is set, .to(device) must not be called.""" + from sheaf.backends.esmc import ESMCBackend + + backend = ESMCBackend(device="cuda", device_map="auto") + with patch.dict( + sys.modules, {"transformers": mock_transformers, "torch": _torch_mod} + ): + backend.load() + + model = mock_transformers.AutoModelForMaskedLM.from_pretrained.return_value # type: ignore[attr-defined] + model.to.assert_not_called() + mock_transformers.AutoModelForMaskedLM.from_pretrained.assert_called_once_with( # type: ignore[attr-defined] + "Biohub/ESMC-6B", device_map="auto" + ) + + +# --------------------------------------------------------------------------- +# predict() — input validation +# --------------------------------------------------------------------------- + + +def test_predict_rejects_wrong_type(loaded_backend) -> None: # type: ignore[no-untyped-def] + from sheaf.api.molecular import MolecularRequest + + req = MolecularRequest(model_name="x", sequences=["MKT"]) + with pytest.raises(TypeError, match="ProteinLanguageRequest"): + loaded_backend.predict(req) + + +# --------------------------------------------------------------------------- +# predict() — response structure and slicing +# --------------------------------------------------------------------------- + + +def test_predict_returns_protein_language_response(loaded_backend) -> None: # type: ignore[no-untyped-def] + _wire_output(loaded_backend) + with patch.dict(sys.modules, {"torch": _torch_mod}): + resp = loaded_backend.predict(_make_request()) + + assert isinstance(resp, ProteinLanguageResponse) + assert resp.seq_lens == _REAL_LENS + assert resp.vocab_size == _VOCAB + assert resp.logits is not None + assert resp.embeddings is None + assert resp.hidden_states is None + + +def test_predict_slices_logits_to_real_lens(loaded_backend) -> None: # type: ignore[no-untyped-def] + """logits[i] must have length == seq_lens[i], not the padded length.""" + _wire_output(loaded_backend) + with patch.dict(sys.modules, {"torch": _torch_mod}): + resp = loaded_backend.predict(_make_request()) + + assert resp.logits is not None + assert [len(li) for li in resp.logits] == _REAL_LENS + assert all(len(tok) == _VOCAB for li in resp.logits for tok in li) + + +def test_predict_skips_logits_when_disabled(loaded_backend) -> None: # type: ignore[no-untyped-def] + _wire_output(loaded_backend) + with patch.dict(sys.modules, {"torch": _torch_mod}): + resp = loaded_backend.predict( + _make_request(return_logits=False, return_embeddings=True) + ) + + assert resp.logits is None + assert resp.vocab_size is None + assert resp.embeddings is not None + assert [len(ei) for ei in resp.embeddings] == _REAL_LENS + assert resp.hidden_dim == _HIDDEN_DIM + + +def test_predict_returns_per_token_embeddings(loaded_backend) -> None: # type: ignore[no-untyped-def] + _wire_output(loaded_backend) + with patch.dict(sys.modules, {"torch": _torch_mod}): + resp = loaded_backend.predict(_make_request(return_embeddings=True)) + + assert resp.embeddings is not None + assert [len(ei) for ei in resp.embeddings] == _REAL_LENS + assert resp.hidden_dim == _HIDDEN_DIM + + +def test_predict_output_hidden_states_returns_all_layers(loaded_backend) -> None: # type: ignore[no-untyped-def] + _wire_output(loaded_backend, with_hidden_states=True) + with patch.dict(sys.modules, {"torch": _torch_mod}): + resp = loaded_backend.predict( + _make_request(output_hidden_states=True, return_logits=False) + ) + + assert resp.hidden_states is not None + assert len(resp.hidden_states) == 4 # n_layers from _make_model_output + # Per-layer: list of per-sequence lists sliced to real lengths. + for layer in resp.hidden_states: + assert [len(ei) for ei in layer] == _REAL_LENS + # And embeddings (last layer) is populated. + assert resp.embeddings is not None + assert [len(ei) for ei in resp.embeddings] == _REAL_LENS + + +# --------------------------------------------------------------------------- +# predict() — tokenizer call shape +# --------------------------------------------------------------------------- + + +def test_predict_tokenizes_full_batch_with_padding(loaded_backend) -> None: # type: ignore[no-untyped-def] + model = _wire_output(loaded_backend) + with patch.dict(sys.modules, {"torch": _torch_mod}): + loaded_backend.predict(_make_request(sequences=["MKTII", "ACDE"])) + + # Tokenizer called with the full sequences list, padding=True + tokenizer = loaded_backend._tokenizer + args, kwargs = tokenizer.call_args + assert args[0] == ["MKTII", "ACDE"] + assert kwargs["padding"] is True + assert kwargs["return_tensors"] == "pt" + # One forward pass per batch + assert model.call_count == 1 + + +# --------------------------------------------------------------------------- +# batch_predict() +# --------------------------------------------------------------------------- + + +def test_batch_predict_runs_each_request(loaded_backend) -> None: # type: ignore[no-untyped-def] + model = _wire_output(loaded_backend) + reqs = [_make_request(), _make_request(sequences=["AKDQE", "MKTL"])] + with patch.dict(sys.modules, {"torch": _torch_mod}): + responses = loaded_backend.batch_predict(reqs) + + assert len(responses) == 2 + assert all(isinstance(r, ProteinLanguageResponse) for r in responses) + assert model.call_count == 2 diff --git a/tests/test_esmfold2_backend.py b/tests/test_esmfold2_backend.py new file mode 100644 index 0000000..4870386 --- /dev/null +++ b/tests/test_esmfold2_backend.py @@ -0,0 +1,408 @@ +"""Tests for ESMFold2Backend — fully mocked, no esm or transformers required. + +Covers: + - load() raises ImportError when esm/transformers absent + - load() rejects Forge-only model IDs with NotImplementedError + - load() passes model_name through and moves model to device + - load() stores ProteinInput, StructurePredictionInput, InputBuilder for injection + - predict() rejects non-StructureRequest inputs + - predict() builds a ProteinInput per chain and passes them into spi.sequences + - predict() threads num_loops / num_sampling_steps / num_samples / seed + through to ESMFold2InputBuilder().fold() + - predict() returns StructureResponse with structure string in the requested format + - predict() pdb vs mmcif output format selects the correct serialiser + - predict() copies plddt, ptm, iptm and pae onto the response + - predict() returns sample_scores only when num_samples > 1 +""" + +from __future__ import annotations + +import builtins +import sys +from types import ModuleType +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from sheaf.api.structure import ChainInput, StructureRequest, StructureResponse + +# --------------------------------------------------------------------------- +# Fake esm.models.esmfold2 module +# --------------------------------------------------------------------------- + + +class _FakeProteinInput: + def __init__(self, id: str, sequence: str) -> None: + self.id = id + self.sequence = sequence + + +class _FakeStructurePredictionInput: + def __init__(self, sequences: list[_FakeProteinInput]) -> None: + self.sequences = sequences + + +class _FakeComplex: + def __init__(self, mmcif: str = "CIF-BODY", pdb: str = "PDB-BODY") -> None: + self._mmcif = mmcif + self._pdb = pdb + + def to_mmcif(self) -> str: + return self._mmcif + + def to_pdb(self) -> str: + return self._pdb + + +class _FakeFloat: + """Minimal stand-in for torch scalars with .item().""" + + def __init__(self, value: float) -> None: + self._v = value + + def item(self) -> float: + return self._v + + +class _FakeListTensor: + """Mimics a torch tensor that supports .cpu().float().tolist().""" + + def __init__(self, data: list) -> None: + self._data = data + + def cpu(self) -> _FakeListTensor: + return self + + def float(self) -> _FakeListTensor: + return self + + def tolist(self) -> list: + return self._data + + +def _make_result( + plddt: list[float] | None = None, + ptm: float | None = 0.91, + iptm: float | None = 0.42, + pae: list[list[float]] | None = None, + sample_scores: list[float] | None = None, + mmcif: str = "CIF-BODY", + pdb: str = "PDB-BODY", +) -> Any: + if plddt is None: + plddt = [85.0, 90.0, 75.0] + res = MagicMock() + res.plddt = _FakeListTensor(plddt) + res.ptm = _FakeFloat(ptm) if ptm is not None else None + res.iptm = _FakeFloat(iptm) if iptm is not None else None + res.pae = _FakeListTensor(pae) if pae is not None else None + res.sample_scores = ( + _FakeListTensor(sample_scores) if sample_scores is not None else None + ) + res.complex = _FakeComplex(mmcif=mmcif, pdb=pdb) + return res + + +class _FakeInputBuilder: + """Captures kwargs from .fold() so tests can assert on them.""" + + last_call_kwargs: dict[str, Any] | None = None + result: Any = None + + def fold(self, model: Any, spi: Any, **kwargs: Any) -> Any: # noqa: ARG002 + _FakeInputBuilder.last_call_kwargs = kwargs + return _FakeInputBuilder.result + + +def _make_esm_mod() -> ModuleType: + esm_mod = ModuleType("esm") + models_mod = ModuleType("esm.models") + esmfold2_mod = ModuleType("esm.models.esmfold2") + + esmfold2_mod.ESMFold2InputBuilder = _FakeInputBuilder # type: ignore[attr-defined] + esmfold2_mod.ProteinInput = _FakeProteinInput # type: ignore[attr-defined] + esmfold2_mod.StructurePredictionInput = _FakeStructurePredictionInput # type: ignore[attr-defined] + + esm_mod.models = models_mod # type: ignore[attr-defined] + models_mod.esmfold2 = esmfold2_mod # type: ignore[attr-defined] + + return esm_mod + + +def _esm_sys_modules(esm_mod: ModuleType) -> dict[str, ModuleType]: + return { + "esm": esm_mod, + "esm.models": esm_mod.models, # type: ignore[attr-defined] + "esm.models.esmfold2": esm_mod.models.esmfold2, # type: ignore[attr-defined] + } + + +# --------------------------------------------------------------------------- +# Fake transformers module +# --------------------------------------------------------------------------- + + +def _make_transformers_mod() -> tuple[ModuleType, MagicMock]: + model = MagicMock() + model.to.return_value = model + model.eval.return_value = None + + transformers_mod = ModuleType("transformers") + models_mod = ModuleType("transformers.models") + esmfold2_mod = ModuleType("transformers.models.esmfold2") + modeling_mod = ModuleType("transformers.models.esmfold2.modeling_esmfold2") + + esmfold2_cls = MagicMock() + esmfold2_cls.from_pretrained.return_value = model + modeling_mod.ESMFold2Model = esmfold2_cls # type: ignore[attr-defined] + + transformers_mod.models = models_mod # type: ignore[attr-defined] + models_mod.esmfold2 = esmfold2_mod # type: ignore[attr-defined] + esmfold2_mod.modeling_esmfold2 = modeling_mod # type: ignore[attr-defined] + + return transformers_mod, model + + +def _transformers_sys_modules(mod: ModuleType) -> dict[str, ModuleType]: + models = mod.models # type: ignore[attr-defined] + esmfold2 = models.esmfold2 + modeling = esmfold2.modeling_esmfold2 + return { + "transformers": mod, + "transformers.models": models, + "transformers.models.esmfold2": esmfold2, + "transformers.models.esmfold2.modeling_esmfold2": modeling, + } + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def loaded_backend(): # type: ignore[no-untyped-def] + """Backend with the fake esm + transformers modules wired in.""" + from sheaf.backends.esmfold2 import ESMFold2Backend + + esm_mod = _make_esm_mod() + transformers_mod, _model = _make_transformers_mod() + _FakeInputBuilder.result = _make_result() + + backend = ESMFold2Backend(model_name="biohub/ESMFold2", device="cpu") + mods = { + **_esm_sys_modules(esm_mod), + **_transformers_sys_modules(transformers_mod), + } + with patch.dict(sys.modules, mods): + backend.load() + return backend + + +def _make_request( + chains: list[ChainInput] | None = None, + num_loops: int = 3, + num_sampling_steps: int = 50, + num_samples: int = 1, + seed: int = 0, + output_format: str = "mmcif", +) -> StructureRequest: + if chains is None: + chains = [ChainInput(chain_id="A", sequence="MKTAYIAK")] + return StructureRequest( + model_name="esmfold2", + chains=chains, + num_loops=num_loops, + num_sampling_steps=num_sampling_steps, + num_samples=num_samples, + seed=seed, + output_format=output_format, # type: ignore[arg-type] + ) + + +# --------------------------------------------------------------------------- +# load() — error cases +# --------------------------------------------------------------------------- + + +def test_load_raises_on_missing_esm() -> None: + from sheaf.backends.esmfold2 import ESMFold2Backend + + backend = ESMFold2Backend() + mods_without = {k: v for k, v in sys.modules.items() if not k.startswith("esm")} + _real_import = builtins.__import__ + + def _raise(name: str, *a: object, **kw: object) -> object: + if name == "esm" or name.startswith("esm."): + raise ModuleNotFoundError("No module named 'esm'") + return _real_import(name, *a, **kw) + + with ( + patch.dict(sys.modules, mods_without, clear=True), + patch("builtins.__import__", side_effect=_raise), + pytest.raises(ImportError, match="sheaf-serve\\[protein\\]"), + ): + backend.load() + + +def test_load_rejects_forge_only_model() -> None: + from sheaf.backends.esmfold2 import ESMFold2Backend + + backend = ESMFold2Backend(model_name="esmfold2-fast-2026-05") + with pytest.raises(NotImplementedError, match="Forge"): + backend.load() + + +# --------------------------------------------------------------------------- +# load() — happy path +# --------------------------------------------------------------------------- + + +def test_load_passes_model_name_and_device() -> None: + from sheaf.backends.esmfold2 import ESMFold2Backend + + esm_mod = _make_esm_mod() + transformers_mod, model = _make_transformers_mod() + backend = ESMFold2Backend(model_name="biohub/ESMFold2", device="cuda:1") + mods = { + **_esm_sys_modules(esm_mod), + **_transformers_sys_modules(transformers_mod), + } + with patch.dict(sys.modules, mods): + backend.load() + + modeling = transformers_mod.models.esmfold2.modeling_esmfold2 # type: ignore[attr-defined] + modeling.ESMFold2Model.from_pretrained.assert_called_once_with( # type: ignore[attr-defined] + "biohub/ESMFold2" + ) + model.to.assert_called_once_with("cuda:1") + + +# --------------------------------------------------------------------------- +# predict() — input validation +# --------------------------------------------------------------------------- + + +def test_predict_rejects_wrong_type(loaded_backend) -> None: # type: ignore[no-untyped-def] + from sheaf.api.molecular import MolecularRequest + + req = MolecularRequest(model_name="x", sequences=["MKT"]) + with pytest.raises(TypeError, match="StructureRequest"): + loaded_backend.predict(req) + + +# --------------------------------------------------------------------------- +# predict() — chain assembly & scaling parameter wiring +# --------------------------------------------------------------------------- + + +def test_predict_builds_protein_input_per_chain(loaded_backend) -> None: # type: ignore[no-untyped-def] + _FakeInputBuilder.result = _make_result() + chains = [ + ChainInput(chain_id="A", sequence="MKTAYIAK"), + ChainInput(chain_id="B", sequence="ACDEFG"), + ] + loaded_backend.predict(_make_request(chains=chains)) + + kwargs = _FakeInputBuilder.last_call_kwargs + assert kwargs is not None + assert kwargs["num_loops"] == 3 # default + assert kwargs["num_sampling_steps"] == 50 + assert kwargs["num_diffusion_samples"] == 1 + assert kwargs["seed"] == 0 + + +def test_predict_threads_inference_params(loaded_backend) -> None: # type: ignore[no-untyped-def] + _FakeInputBuilder.result = _make_result(sample_scores=[0.71, 0.55, 0.83, 0.61]) + loaded_backend.predict( + _make_request( + num_loops=5, + num_sampling_steps=100, + num_samples=4, + seed=42, + ) + ) + + kwargs = _FakeInputBuilder.last_call_kwargs + assert kwargs == { + "num_loops": 5, + "num_sampling_steps": 100, + "num_diffusion_samples": 4, + "seed": 42, + } + + +# --------------------------------------------------------------------------- +# predict() — response structure +# --------------------------------------------------------------------------- + + +def test_predict_returns_structure_response_mmcif(loaded_backend) -> None: # type: ignore[no-untyped-def] + _FakeInputBuilder.result = _make_result(mmcif="MY-CIF") + resp = loaded_backend.predict(_make_request(output_format="mmcif")) + + assert isinstance(resp, StructureResponse) + assert resp.structure_format == "mmcif" + assert resp.structure == "MY-CIF" + + +def test_predict_returns_structure_response_pdb(loaded_backend) -> None: # type: ignore[no-untyped-def] + _FakeInputBuilder.result = _make_result(pdb="MY-PDB") + resp = loaded_backend.predict(_make_request(output_format="pdb")) + + assert resp.structure_format == "pdb" + assert resp.structure == "MY-PDB" + + +def test_predict_copies_confidence_metrics(loaded_backend) -> None: # type: ignore[no-untyped-def] + _FakeInputBuilder.result = _make_result( + plddt=[60.0, 70.5, 80.0, 90.5], + ptm=0.88, + iptm=0.55, + pae=[[0.0, 1.2], [1.2, 0.0]], + ) + resp = loaded_backend.predict(_make_request()) + + assert resp.plddt == [60.0, 70.5, 80.0, 90.5] + assert resp.ptm == pytest.approx(0.88) + assert resp.iptm == pytest.approx(0.55) + assert resp.pae == [[0.0, 1.2], [1.2, 0.0]] + + +def test_predict_skips_optional_confidence_when_missing(loaded_backend) -> None: # type: ignore[no-untyped-def] + _FakeInputBuilder.result = _make_result(ptm=None, iptm=None, pae=None) + resp = loaded_backend.predict(_make_request()) + + assert resp.ptm is None + assert resp.iptm is None + assert resp.pae is None + + +def test_predict_omits_sample_scores_when_num_samples_is_one(loaded_backend) -> None: # type: ignore[no-untyped-def] + _FakeInputBuilder.result = _make_result(sample_scores=[0.9]) + resp = loaded_backend.predict(_make_request(num_samples=1)) + + # num_samples=1 → no sample_scores in response, regardless of upstream. + assert resp.sample_scores is None + + +def test_predict_includes_sample_scores_when_num_samples_gt_one(loaded_backend) -> None: # type: ignore[no-untyped-def] + _FakeInputBuilder.result = _make_result(sample_scores=[0.7, 0.9, 0.6]) + resp = loaded_backend.predict(_make_request(num_samples=3)) + + assert resp.sample_scores == [0.7, 0.9, 0.6] + + +# --------------------------------------------------------------------------- +# batch_predict() +# --------------------------------------------------------------------------- + + +def test_batch_predict_runs_each_request(loaded_backend) -> None: # type: ignore[no-untyped-def] + _FakeInputBuilder.result = _make_result() + reqs = [_make_request(), _make_request(seed=7)] + responses = loaded_backend.batch_predict(reqs) + + assert len(responses) == 2 + assert all(isinstance(r, StructureResponse) for r in responses) diff --git a/uv.lock b/uv.lock index fe6c852..8724832 100644 --- a/uv.lock +++ b/uv.lock @@ -6309,7 +6309,7 @@ all = [ { name = "pymilvus" }, { name = "sam2" }, { name = "tabpfn" }, - { name = "timesfm", extra = ["torch"], marker = "extra == 'extra-11-sheaf-serve-all' or extra != 'extra-11-sheaf-serve-moirai' or (extra == 'extra-11-sheaf-serve-moirai' and extra == 'extra-11-sheaf-serve-vision')" }, + { name = "timesfm", extra = ["torch"], marker = "extra == 'extra-11-sheaf-serve-all' or (extra == 'extra-11-sheaf-serve-moirai' and extra == 'extra-11-sheaf-serve-vision')" }, { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" } }, { name = "transformers" }, ] @@ -6414,6 +6414,11 @@ pose = [ { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-11-sheaf-serve-all' or extra != 'extra-11-sheaf-serve-moirai' or (extra == 'extra-11-sheaf-serve-moirai' and extra == 'extra-11-sheaf-serve-vision')" }, { name = "transformers" }, ] +protein = [ + { name = "torch", version = "2.4.1", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-11-sheaf-serve-moirai'" }, + { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-11-sheaf-serve-all' or extra != 'extra-11-sheaf-serve-moirai' or (extra == 'extra-11-sheaf-serve-moirai' and extra == 'extra-11-sheaf-serve-vision')" }, + { name = "transformers" }, +] small-molecule = [ { name = "torch", version = "2.4.1", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-11-sheaf-serve-moirai'" }, { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-11-sheaf-serve-all' or extra != 'extra-11-sheaf-serve-moirai' or (extra == 'extra-11-sheaf-serve-moirai' and extra == 'extra-11-sheaf-serve-vision')" }, @@ -6519,6 +6524,7 @@ requires-dist = [ { name = "torch", marker = "extra == 'multimodal-generation'", specifier = ">=2.0.0" }, { name = "torch", marker = "extra == 'optical-flow'", specifier = ">=2.0.0" }, { name = "torch", marker = "extra == 'pose'", specifier = ">=2.0.0" }, + { name = "torch", marker = "extra == 'protein'", specifier = ">=2.0.0" }, { name = "torch", marker = "extra == 'small-molecule'", specifier = ">=2.0.0" }, { name = "torch", marker = "extra == 'tts'", specifier = ">=2.0.0" }, { name = "torch", marker = "extra == 'video'", specifier = ">=2.0.0" }, @@ -6530,6 +6536,7 @@ requires-dist = [ { name = "transformers", marker = "extra == 'earth-observation'", specifier = ">=4.37.0" }, { name = "transformers", marker = "extra == 'genomics'", specifier = ">=4.37.0" }, { name = "transformers", marker = "extra == 'pose'", specifier = ">=4.46.0" }, + { name = "transformers", marker = "extra == 'protein'", specifier = ">=4.40.0" }, { name = "transformers", marker = "extra == 'small-molecule'", specifier = ">=4.37.0" }, { name = "transformers", marker = "extra == 'tts'", specifier = ">=4.31.0" }, { name = "transformers", marker = "extra == 'video'", specifier = ">=4.40.0" }, @@ -6538,7 +6545,7 @@ requires-dist = [ { name = "uni2ts", marker = "extra == 'moirai'", specifier = ">=2.0.0" }, { name = "xarray", marker = "extra == 'weather'", specifier = ">=2024.1.0" }, ] -provides-extras = ["time-series", "tabular", "moirai", "audio", "audio-generation", "tts", "kokoro", "vision", "pose", "optical-flow", "multimodal-generation", "lidar", "molecular", "weather", "earth-observation", "genomics", "materials", "small-molecule", "multimodal", "diffusion", "video", "feast", "metrics", "tracing", "modal", "milvus", "batch", "worker", "docs", "dev", "all"] +provides-extras = ["time-series", "tabular", "moirai", "audio", "audio-generation", "tts", "kokoro", "vision", "pose", "optical-flow", "multimodal-generation", "lidar", "molecular", "protein", "weather", "earth-observation", "genomics", "materials", "small-molecule", "multimodal", "diffusion", "video", "feast", "metrics", "tracing", "modal", "milvus", "batch", "worker", "docs", "dev", "all"] [package.metadata.requires-dev] dev = [{ name = "modal", specifier = ">=1.4.1" }]