Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions src/transformers/models/esmfold2/modeling_esmfold2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,11 +1040,23 @@ def forward(

@torch.no_grad()
def infer_protein(self, seq: str, **forward_kwargs) -> dict:
from .protein_utils import prepare_protein_features
from .protein_utils import OUTPUT_TO_PDB_FEATURE_KEYS, prepare_protein_features

features = prepare_protein_features(seq)
features = {k: v.to(self.device) for k, v in features.items()}
return self(**features, **forward_kwargs)
output = self(**features, **forward_kwargs)
for k in OUTPUT_TO_PDB_FEATURE_KEYS:
output[k] = features[k]
return output

def infer_protein_as_pdb(self, seq: str, **forward_kwargs) -> str:
return self.output_to_pdb(self.infer_protein(seq, **forward_kwargs))

@staticmethod
def output_to_pdb(output: dict) -> str:
from .protein_utils import output_to_pdb as _output_to_pdb

return _output_to_pdb(output)


class MSAEncoderBlock(nn.Module):
Expand Down
25 changes: 18 additions & 7 deletions src/transformers/models/esmfold2/modeling_esmfold2_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,13 +556,24 @@ def forward(self, x: Tensor, attention_params: tuple) -> Tensor:
window_size=(self.half_window, self.half_window),
)
else:
# Fallback: standard attention (no SWA)
q_t = q.transpose(1, 2)
k_t = k.transpose(1, 2)
v_t = v.transpose(1, 2)
attn = torch.matmul(q_t, k_t.transpose(-2, -1)) * self.scale
attn = F.softmax(attn, dim=-1)
out = torch.matmul(attn, v_t).transpose(1, 2)
if len(attention_params) > 2:
valid = torch.zeros(B * N, dtype=torch.bool, device=q.device)
valid[attention_params[2]] = True
valid = valid.view(B, N)
else:
valid = torch.ones(B, N, dtype=torch.bool, device=q.device)
rank = torch.cumsum(valid, dim=1) - 1
within = (rank.unsqueeze(2) - rank.unsqueeze(1)).abs() <= self.half_window
allowed = within & valid.unsqueeze(1) & valid.unsqueeze(2)
allowed |= torch.eye(N, dtype=torch.bool, device=q.device)
out = F.scaled_dot_product_attention(
q.transpose(1, 2),
k.transpose(1, 2),
v.transpose(1, 2),
attn_mask=allowed.unsqueeze(1),
scale=self.scale,
).transpose(1, 2)
out = out * valid.unsqueeze(-1).unsqueeze(-1)

out = out.to(input_dtype).reshape(B, N, -1) # type: ignore[union-attr]
out = out * torch.sigmoid(self.gate_proj(x_input))
Expand Down
104 changes: 4 additions & 100 deletions src/transformers/models/esmfold2/modeling_esmfold2_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,19 +612,12 @@ def from_pretrained(

@torch.no_grad()
def infer_protein(self, seq: str, **forward_kwargs) -> dict:
from .protein_utils import prepare_protein_features
from .protein_utils import OUTPUT_TO_PDB_FEATURE_KEYS, prepare_protein_features

features = prepare_protein_features(seq)
features = {k: v.to(self.device) for k, v in features.items()}
output = self(**features, **forward_kwargs)
for k in (
"res_type",
"atom_to_token",
"ref_atom_name_chars",
"atom_attention_mask",
"token_attention_mask",
"residue_index",
):
for k in OUTPUT_TO_PDB_FEATURE_KEYS:
output[k] = features[k]
return output

Expand Down Expand Up @@ -746,98 +739,9 @@ def _output_to_molecular_complex(output: dict, features: dict, chain_infos: list

@staticmethod
def output_to_pdb(output: dict) -> str:
from transformers.models.esm.openfold_utils import OFProtein, to_pdb
from transformers.models.esm.openfold_utils import residue_constants as rc

# 0-32 res_type → 3-letter name (only protein indices 2-22 are populated)
res_type_to_3letter = {
2: "ALA",
3: "ARG",
4: "ASN",
5: "ASP",
6: "CYS",
7: "GLN",
8: "GLU",
9: "GLY",
10: "HIS",
11: "ILE",
12: "LEU",
13: "LYS",
14: "MET",
15: "PHE",
16: "PRO",
17: "SER",
18: "THR",
19: "TRP",
20: "TYR",
21: "VAL",
22: "UNK",
}

coords = output["sample_atom_coords"]
if coords.dim() == 4:
coords = coords[:, 0]
coords = coords.detach().cpu().numpy()[0]

plddt = output["plddt"].detach().cpu().numpy()[0]
atom_to_token = output["atom_to_token"].cpu().numpy()
ref_chars = output["ref_atom_name_chars"].cpu().numpy()
res_type = output["res_type"].cpu().numpy()
token_mask = output["token_attention_mask"].cpu().numpy().astype(bool)
atom_mask_in = output["atom_attention_mask"].cpu().numpy().astype(bool)
residue_index_arr = output["residue_index"].cpu().numpy()

if atom_to_token.ndim == 2:
atom_to_token = atom_to_token[0]
ref_chars = ref_chars[0]
res_type = res_type[0]
token_mask = token_mask[0]
atom_mask_in = atom_mask_in[0]
residue_index_arr = residue_index_arr[0]

valid_tok = np.where(token_mask)[0]
n_res = valid_tok.shape[0]

aatype = np.full(n_res, rc.restype_order_with_x["X"], dtype=np.int64)
for new_i, t in enumerate(valid_tok):
rt = int(res_type[t])
three = res_type_to_3letter.get(rt)
if three is None or three == "UNK":
aatype[new_i] = rc.restype_order_with_x["X"]
else:
one = rc.restype_3to1.get(three, "X")
aatype[new_i] = rc.restype_order_with_x[one]

atom_positions = np.zeros((n_res, 37, 3), dtype=np.float32)
atom_mask = np.zeros((n_res, 37), dtype=np.float32)
b_factors = np.zeros((n_res, 37), dtype=np.float32)
tok_to_new = {int(t): i for i, t in enumerate(valid_tok)}
from .protein_utils import output_to_pdb as _output_to_pdb

for a in range(atom_to_token.shape[0]):
if not atom_mask_in[a]:
continue
tok = int(atom_to_token[a])
if tok not in tok_to_new:
continue
new_i = tok_to_new[tok]
name = "".join(
chr(int(c) + 32) if int(c) != 0 else " " for c in ref_chars[a]
).strip()
idx37 = rc.atom_order.get(name)
if idx37 is None:
continue
atom_positions[new_i, idx37] = coords[a]
atom_mask[new_i, idx37] = 1.0
b_factors[new_i, idx37] = float(plddt[tok])

pred = OFProtein(
aatype=aatype,
atom_positions=atom_positions,
atom_mask=atom_mask,
residue_index=residue_index_arr[valid_tok].astype(np.int32) + 1,
b_factors=b_factors,
)
return to_pdb(pred)
return _output_to_pdb(output)

def _compute_lm_hidden_states(
self,
Expand Down
108 changes: 101 additions & 7 deletions src/transformers/models/esmfold2/protein_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Self-contained protein featurization for ESMFold2 inference.

Lets ``ESMFold2ExperimentalModel.infer_protein_as_pdb`` fold a protein sequence
ESMFold-style without the ``esm`` companion package. The featurization
mirrors ``ESMFold2InputBuilder.prepare_input`` for the protein-only path —
``test_prepare_protein_features.py`` enforces tensor-exact parity.
"""
"""Self-contained protein featurization for ESMFold2 inference."""

from __future__ import annotations

import math

import numpy as np
import torch
from torch import Tensor

Expand Down Expand Up @@ -486,3 +481,102 @@ def prepare_protein_features(sequence: str) -> dict[str, Tensor]:
"deletion_mean": deletion_mean,
}
return {k: v.unsqueeze(0) for k, v in features.items()}


# 0-32 res_type → 3-letter name (only protein indices 2-22 are populated).
_RES_TYPE_TO_3LETTER: dict[int, str] = {
rt: three for three, rt in PROTEIN_RESIDUE_TO_RES_TYPE.items()
}
_RES_TYPE_TO_3LETTER[PROTEIN_UNK_RES_TYPE] = "UNK"

# Featurization keys that ``output_to_pdb`` reads off the forward output.
# ``infer_protein`` re-attaches them because ``forward`` does not echo them
# back; both ESMFold2 model classes share this list.
OUTPUT_TO_PDB_FEATURE_KEYS: tuple[str, ...] = (
"res_type",
"atom_to_token",
"ref_atom_name_chars",
"atom_attention_mask",
"token_attention_mask",
"residue_index",
)


def output_to_pdb(output: dict) -> str:
"""Convert an ESMFold2 protein forward output into a PDB string.

Expects ``output`` to carry the featurization keys re-attached by
``infer_protein`` (``res_type``, ``atom_to_token``,
``ref_atom_name_chars``, ``atom_attention_mask``,
``token_attention_mask``, ``residue_index``) alongside the predicted
``sample_atom_coords`` and ``plddt``. Builds a 37-atom
``OFProtein`` (per-atom pLDDT in the b-factor column) and renders it
with the OpenFold utilities shipped in ``transformers.models.esm``.
"""
from transformers.models.esm.openfold_utils import OFProtein, to_pdb
from transformers.models.esm.openfold_utils import residue_constants as rc

coords = output["sample_atom_coords"]
if coords.dim() == 4:
coords = coords[:, 0]
coords = coords.detach().cpu().numpy()[0]

plddt = output["plddt"].detach().cpu().numpy()[0]
atom_to_token = output["atom_to_token"].cpu().numpy()
ref_chars = output["ref_atom_name_chars"].cpu().numpy()
res_type = output["res_type"].cpu().numpy()
token_mask = output["token_attention_mask"].cpu().numpy().astype(bool)
atom_mask_in = output["atom_attention_mask"].cpu().numpy().astype(bool)
residue_index_arr = output["residue_index"].cpu().numpy()

if atom_to_token.ndim == 2:
atom_to_token = atom_to_token[0]
ref_chars = ref_chars[0]
res_type = res_type[0]
token_mask = token_mask[0]
atom_mask_in = atom_mask_in[0]
residue_index_arr = residue_index_arr[0]

valid_tok = np.where(token_mask)[0]
n_res = valid_tok.shape[0]

aatype = np.full(n_res, rc.restype_order_with_x["X"], dtype=np.int64)
for new_i, t in enumerate(valid_tok):
rt = int(res_type[t])
three = _RES_TYPE_TO_3LETTER.get(rt)
if three is None or three == "UNK":
aatype[new_i] = rc.restype_order_with_x["X"]
else:
one = rc.restype_3to1.get(three, "X")
aatype[new_i] = rc.restype_order_with_x[one]

atom_positions = np.zeros((n_res, 37, 3), dtype=np.float32)
atom_mask = np.zeros((n_res, 37), dtype=np.float32)
b_factors = np.zeros((n_res, 37), dtype=np.float32)
tok_to_new = {int(t): i for i, t in enumerate(valid_tok)}

for a in range(atom_to_token.shape[0]):
if not atom_mask_in[a]:
continue
tok = int(atom_to_token[a])
if tok not in tok_to_new:
continue
new_i = tok_to_new[tok]
name = "".join(
chr(int(c) + 32) if int(c) != 0 else " " for c in ref_chars[a]
).strip()
idx37 = rc.atom_order.get(name)
if idx37 is None:
continue
atom_positions[new_i, idx37] = coords[a]
atom_mask[new_i, idx37] = 1.0
b_factors[new_i, idx37] = float(plddt[tok])

pred = OFProtein(
aatype=aatype,
atom_positions=atom_positions,
atom_mask=atom_mask,
residue_index=residue_index_arr[valid_tok].astype(np.int32) + 1,
b_factors=b_factors,
)
return to_pdb(pred)
Loading