diff --git a/src/transformers/models/esmfold2/modeling_esmfold2.py b/src/transformers/models/esmfold2/modeling_esmfold2.py index 8be05fdebb..79024feb02 100644 --- a/src/transformers/models/esmfold2/modeling_esmfold2.py +++ b/src/transformers/models/esmfold2/modeling_esmfold2.py @@ -1040,11 +1040,23 @@ def forward( @torch.no_grad() def infer_protein(self, seq: str, **forward_kwargs) -> dict: - from .protein_utils import prepare_protein_features + from .protein_utils import OUTPUT_TO_PDB_FEATURE_KEYS, prepare_protein_features features = prepare_protein_features(seq) features = {k: v.to(self.device) for k, v in features.items()} - return self(**features, **forward_kwargs) + output = self(**features, **forward_kwargs) + for k in OUTPUT_TO_PDB_FEATURE_KEYS: + output[k] = features[k] + return output + + def infer_protein_as_pdb(self, seq: str, **forward_kwargs) -> str: + return self.output_to_pdb(self.infer_protein(seq, **forward_kwargs)) + + @staticmethod + def output_to_pdb(output: dict) -> str: + from .protein_utils import output_to_pdb as _output_to_pdb + + return _output_to_pdb(output) class MSAEncoderBlock(nn.Module): diff --git a/src/transformers/models/esmfold2/modeling_esmfold2_common.py b/src/transformers/models/esmfold2/modeling_esmfold2_common.py index bcccfbfa4a..a7789647db 100644 --- a/src/transformers/models/esmfold2/modeling_esmfold2_common.py +++ b/src/transformers/models/esmfold2/modeling_esmfold2_common.py @@ -556,13 +556,24 @@ def forward(self, x: Tensor, attention_params: tuple) -> Tensor: window_size=(self.half_window, self.half_window), ) else: - # Fallback: standard attention (no SWA) - q_t = q.transpose(1, 2) - k_t = k.transpose(1, 2) - v_t = v.transpose(1, 2) - attn = torch.matmul(q_t, k_t.transpose(-2, -1)) * self.scale - attn = F.softmax(attn, dim=-1) - out = torch.matmul(attn, v_t).transpose(1, 2) + if len(attention_params) > 2: + valid = torch.zeros(B * N, dtype=torch.bool, device=q.device) + valid[attention_params[2]] = True + valid = valid.view(B, N) + else: + valid = torch.ones(B, N, dtype=torch.bool, device=q.device) + rank = torch.cumsum(valid, dim=1) - 1 + within = (rank.unsqueeze(2) - rank.unsqueeze(1)).abs() <= self.half_window + allowed = within & valid.unsqueeze(1) & valid.unsqueeze(2) + allowed |= torch.eye(N, dtype=torch.bool, device=q.device) + out = F.scaled_dot_product_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + attn_mask=allowed.unsqueeze(1), + scale=self.scale, + ).transpose(1, 2) + out = out * valid.unsqueeze(-1).unsqueeze(-1) out = out.to(input_dtype).reshape(B, N, -1) # type: ignore[union-attr] out = out * torch.sigmoid(self.gate_proj(x_input)) diff --git a/src/transformers/models/esmfold2/modeling_esmfold2_experimental.py b/src/transformers/models/esmfold2/modeling_esmfold2_experimental.py index 78f471b61b..255229afbf 100644 --- a/src/transformers/models/esmfold2/modeling_esmfold2_experimental.py +++ b/src/transformers/models/esmfold2/modeling_esmfold2_experimental.py @@ -612,19 +612,12 @@ def from_pretrained( @torch.no_grad() def infer_protein(self, seq: str, **forward_kwargs) -> dict: - from .protein_utils import prepare_protein_features + from .protein_utils import OUTPUT_TO_PDB_FEATURE_KEYS, prepare_protein_features features = prepare_protein_features(seq) features = {k: v.to(self.device) for k, v in features.items()} output = self(**features, **forward_kwargs) - for k in ( - "res_type", - "atom_to_token", - "ref_atom_name_chars", - "atom_attention_mask", - "token_attention_mask", - "residue_index", - ): + for k in OUTPUT_TO_PDB_FEATURE_KEYS: output[k] = features[k] return output @@ -746,98 +739,9 @@ def _output_to_molecular_complex(output: dict, features: dict, chain_infos: list @staticmethod def output_to_pdb(output: dict) -> str: - from transformers.models.esm.openfold_utils import OFProtein, to_pdb - from transformers.models.esm.openfold_utils import residue_constants as rc - - # 0-32 res_type → 3-letter name (only protein indices 2-22 are populated) - res_type_to_3letter = { - 2: "ALA", - 3: "ARG", - 4: "ASN", - 5: "ASP", - 6: "CYS", - 7: "GLN", - 8: "GLU", - 9: "GLY", - 10: "HIS", - 11: "ILE", - 12: "LEU", - 13: "LYS", - 14: "MET", - 15: "PHE", - 16: "PRO", - 17: "SER", - 18: "THR", - 19: "TRP", - 20: "TYR", - 21: "VAL", - 22: "UNK", - } - - coords = output["sample_atom_coords"] - if coords.dim() == 4: - coords = coords[:, 0] - coords = coords.detach().cpu().numpy()[0] - - plddt = output["plddt"].detach().cpu().numpy()[0] - atom_to_token = output["atom_to_token"].cpu().numpy() - ref_chars = output["ref_atom_name_chars"].cpu().numpy() - res_type = output["res_type"].cpu().numpy() - token_mask = output["token_attention_mask"].cpu().numpy().astype(bool) - atom_mask_in = output["atom_attention_mask"].cpu().numpy().astype(bool) - residue_index_arr = output["residue_index"].cpu().numpy() - - if atom_to_token.ndim == 2: - atom_to_token = atom_to_token[0] - ref_chars = ref_chars[0] - res_type = res_type[0] - token_mask = token_mask[0] - atom_mask_in = atom_mask_in[0] - residue_index_arr = residue_index_arr[0] - - valid_tok = np.where(token_mask)[0] - n_res = valid_tok.shape[0] - - aatype = np.full(n_res, rc.restype_order_with_x["X"], dtype=np.int64) - for new_i, t in enumerate(valid_tok): - rt = int(res_type[t]) - three = res_type_to_3letter.get(rt) - if three is None or three == "UNK": - aatype[new_i] = rc.restype_order_with_x["X"] - else: - one = rc.restype_3to1.get(three, "X") - aatype[new_i] = rc.restype_order_with_x[one] - - atom_positions = np.zeros((n_res, 37, 3), dtype=np.float32) - atom_mask = np.zeros((n_res, 37), dtype=np.float32) - b_factors = np.zeros((n_res, 37), dtype=np.float32) - tok_to_new = {int(t): i for i, t in enumerate(valid_tok)} + from .protein_utils import output_to_pdb as _output_to_pdb - for a in range(atom_to_token.shape[0]): - if not atom_mask_in[a]: - continue - tok = int(atom_to_token[a]) - if tok not in tok_to_new: - continue - new_i = tok_to_new[tok] - name = "".join( - chr(int(c) + 32) if int(c) != 0 else " " for c in ref_chars[a] - ).strip() - idx37 = rc.atom_order.get(name) - if idx37 is None: - continue - atom_positions[new_i, idx37] = coords[a] - atom_mask[new_i, idx37] = 1.0 - b_factors[new_i, idx37] = float(plddt[tok]) - - pred = OFProtein( - aatype=aatype, - atom_positions=atom_positions, - atom_mask=atom_mask, - residue_index=residue_index_arr[valid_tok].astype(np.int32) + 1, - b_factors=b_factors, - ) - return to_pdb(pred) + return _output_to_pdb(output) def _compute_lm_hidden_states( self, diff --git a/src/transformers/models/esmfold2/protein_utils.py b/src/transformers/models/esmfold2/protein_utils.py index 75b785d9e9..a94c05c9c5 100644 --- a/src/transformers/models/esmfold2/protein_utils.py +++ b/src/transformers/models/esmfold2/protein_utils.py @@ -12,18 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Self-contained protein featurization for ESMFold2 inference. - -Lets ``ESMFold2ExperimentalModel.infer_protein_as_pdb`` fold a protein sequence -ESMFold-style without the ``esm`` companion package. The featurization -mirrors ``ESMFold2InputBuilder.prepare_input`` for the protein-only path — -``test_prepare_protein_features.py`` enforces tensor-exact parity. -""" +"""Self-contained protein featurization for ESMFold2 inference.""" from __future__ import annotations import math +import numpy as np import torch from torch import Tensor @@ -486,3 +481,102 @@ def prepare_protein_features(sequence: str) -> dict[str, Tensor]: "deletion_mean": deletion_mean, } return {k: v.unsqueeze(0) for k, v in features.items()} + + +# 0-32 res_type → 3-letter name (only protein indices 2-22 are populated). +_RES_TYPE_TO_3LETTER: dict[int, str] = { + rt: three for three, rt in PROTEIN_RESIDUE_TO_RES_TYPE.items() +} +_RES_TYPE_TO_3LETTER[PROTEIN_UNK_RES_TYPE] = "UNK" + +# Featurization keys that ``output_to_pdb`` reads off the forward output. +# ``infer_protein`` re-attaches them because ``forward`` does not echo them +# back; both ESMFold2 model classes share this list. +OUTPUT_TO_PDB_FEATURE_KEYS: tuple[str, ...] = ( + "res_type", + "atom_to_token", + "ref_atom_name_chars", + "atom_attention_mask", + "token_attention_mask", + "residue_index", +) + + +def output_to_pdb(output: dict) -> str: + """Convert an ESMFold2 protein forward output into a PDB string. + + Expects ``output`` to carry the featurization keys re-attached by + ``infer_protein`` (``res_type``, ``atom_to_token``, + ``ref_atom_name_chars``, ``atom_attention_mask``, + ``token_attention_mask``, ``residue_index``) alongside the predicted + ``sample_atom_coords`` and ``plddt``. Builds a 37-atom + ``OFProtein`` (per-atom pLDDT in the b-factor column) and renders it + with the OpenFold utilities shipped in ``transformers.models.esm``. + """ + from transformers.models.esm.openfold_utils import OFProtein, to_pdb + from transformers.models.esm.openfold_utils import residue_constants as rc + + coords = output["sample_atom_coords"] + if coords.dim() == 4: + coords = coords[:, 0] + coords = coords.detach().cpu().numpy()[0] + + plddt = output["plddt"].detach().cpu().numpy()[0] + atom_to_token = output["atom_to_token"].cpu().numpy() + ref_chars = output["ref_atom_name_chars"].cpu().numpy() + res_type = output["res_type"].cpu().numpy() + token_mask = output["token_attention_mask"].cpu().numpy().astype(bool) + atom_mask_in = output["atom_attention_mask"].cpu().numpy().astype(bool) + residue_index_arr = output["residue_index"].cpu().numpy() + + if atom_to_token.ndim == 2: + atom_to_token = atom_to_token[0] + ref_chars = ref_chars[0] + res_type = res_type[0] + token_mask = token_mask[0] + atom_mask_in = atom_mask_in[0] + residue_index_arr = residue_index_arr[0] + + valid_tok = np.where(token_mask)[0] + n_res = valid_tok.shape[0] + + aatype = np.full(n_res, rc.restype_order_with_x["X"], dtype=np.int64) + for new_i, t in enumerate(valid_tok): + rt = int(res_type[t]) + three = _RES_TYPE_TO_3LETTER.get(rt) + if three is None or three == "UNK": + aatype[new_i] = rc.restype_order_with_x["X"] + else: + one = rc.restype_3to1.get(three, "X") + aatype[new_i] = rc.restype_order_with_x[one] + + atom_positions = np.zeros((n_res, 37, 3), dtype=np.float32) + atom_mask = np.zeros((n_res, 37), dtype=np.float32) + b_factors = np.zeros((n_res, 37), dtype=np.float32) + tok_to_new = {int(t): i for i, t in enumerate(valid_tok)} + + for a in range(atom_to_token.shape[0]): + if not atom_mask_in[a]: + continue + tok = int(atom_to_token[a]) + if tok not in tok_to_new: + continue + new_i = tok_to_new[tok] + name = "".join( + chr(int(c) + 32) if int(c) != 0 else " " for c in ref_chars[a] + ).strip() + idx37 = rc.atom_order.get(name) + if idx37 is None: + continue + atom_positions[new_i, idx37] = coords[a] + atom_mask[new_i, idx37] = 1.0 + b_factors[new_i, idx37] = float(plddt[tok]) + + pred = OFProtein( + aatype=aatype, + atom_positions=atom_positions, + atom_mask=atom_mask, + residue_index=residue_index_arr[valid_tok].astype(np.int32) + 1, + b_factors=b_factors, + ) + return to_pdb(pred)