From 70e844fa927eb1dfdd78eb4e13c5ba50ceeaa1d3 Mon Sep 17 00:00:00 2001 From: taivu1998 <46636857+taivu1998@users.noreply.github.com> Date: Sun, 10 May 2026 04:39:17 -0700 Subject: [PATCH] Fix batched multimer template embeddings --- openfold/model/embedders.py | 8 +-- tests/test_template.py | 97 +++++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 3 deletions(-) diff --git a/openfold/model/embedders.py b/openfold/model/embedders.py index 8915ab190..dcd48edd2 100644 --- a/openfold/model/embedders.py +++ b/openfold/model/embedders.py @@ -911,7 +911,7 @@ def forward(self, for i in range(n_templ): idx = batch["template_aatype"].new_tensor(i) single_template_feats = tensor_tree_map( - lambda t: torch.index_select(t, templ_dim, idx), + lambda t: torch.index_select(t, templ_dim, idx).squeeze(templ_dim), batch, ) @@ -944,7 +944,9 @@ def forward(self, single_template_feats["template_aatype"], ) points = rigid.translation - rigid_vec = rigid[..., None].inverse().apply_to_point(points) + rigid_vec = rigid[..., None].inverse().apply_to_point( + points[..., None, :] + ) unit_vector = rigid_vec.normalized() pair_act = self.template_pair_embedder( @@ -968,7 +970,7 @@ def forward(self, template_embeds.append(single_template_embeds) template_embeds = dict_multimap( - partial(torch.cat, dim=templ_dim), + partial(torch.stack, dim=templ_dim), template_embeds, ) diff --git a/tests/test_template.py b/tests/test_template.py index ae65b7cb2..df2009b30 100644 --- a/tests/test_template.py +++ b/tests/test_template.py @@ -16,10 +16,13 @@ import torch import numpy as np import unittest +from openfold.config import model_config +from openfold.model.embedders import TemplateEmbedderMultimer from openfold.model.template import ( TemplatePointwiseAttention, TemplatePairStack, ) +from openfold.np import residue_constants import tests.compare_utils as compare_utils from tests.config import consts from tests.data_utils import random_template_feats @@ -194,6 +197,100 @@ def template_iteration_fn(x): compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps) +class TestTemplateEmbedderMultimer(unittest.TestCase): + def test_batched_template_embedding_shape(self): + torch.manual_seed(0) + batch_size = 2 + n_templ = 2 + n_res = 8 + + config = model_config("model_1_multimer_v3") + config.model.template.template_pair_stack.no_blocks = 1 + config.model.template.template_pair_stack.blocks_per_ckpt = None + + embedder = TemplateEmbedderMultimer(config.model.template) + embedder.eval() + + aatype = torch.full( + (batch_size, n_templ, n_res), + residue_constants.restype_order["A"], + dtype=torch.long, + ) + atom_positions = torch.zeros( + batch_size, + n_templ, + n_res, + residue_constants.atom_type_num, + 3, + ) + atom_mask = torch.zeros( + batch_size, + n_templ, + n_res, + residue_constants.atom_type_num, + ) + + n_idx = residue_constants.atom_order["N"] + ca_idx = residue_constants.atom_order["CA"] + c_idx = residue_constants.atom_order["C"] + cb_idx = residue_constants.atom_order["CB"] + residue_positions = torch.arange(n_res, dtype=torch.float32) + + atom_positions[..., n_idx, 0] = residue_positions + atom_positions[..., ca_idx, 0] = residue_positions + atom_positions[..., ca_idx, 1] = 1.0 + atom_positions[..., c_idx, 0] = residue_positions + atom_positions[..., c_idx, 2] = 1.0 + atom_positions[..., cb_idx, 0] = residue_positions + atom_positions[..., cb_idx, 1] = 1.0 + atom_positions[..., cb_idx, 2] = 1.0 + atom_mask[..., [n_idx, ca_idx, c_idx, cb_idx]] = 1.0 + + batch = { + "template_aatype": aatype, + "template_all_atom_positions": atom_positions, + "template_all_atom_mask": atom_mask, + } + z = torch.randn(batch_size, n_res, n_res, config.globals.c_z) + padding_mask_2d = torch.ones(batch_size, n_res, n_res) + asym_id = torch.tensor( + [[1] * (n_res // 2) + [2] * (n_res - n_res // 2)] + * batch_size + ) + multichain_mask_2d = asym_id[..., None] == asym_id[..., None, :] + + with torch.no_grad(): + out = embedder( + batch, + z, + padding_mask_2d, + templ_dim=1, + chunk_size=None, + multichain_mask_2d=multichain_mask_2d, + use_lma=False, + inplace_safe=False, + ) + + self.assertEqual( + out["template_pair_embedding"].shape, + (batch_size, n_res, n_res, config.globals.c_z), + ) + self.assertEqual( + out["template_single_embedding"].shape, + (batch_size, n_templ, n_res, config.globals.c_m), + ) + self.assertEqual( + out["template_mask"].shape, + (batch_size, n_templ, n_res), + ) + self.assertTrue( + torch.isfinite(out["template_pair_embedding"]).all().item() + ) + self.assertTrue( + torch.isfinite(out["template_single_embedding"]).all().item() + ) + + class Template(unittest.TestCase): @classmethod def setUpClass(cls):