diff --git a/protenix/model/modules/frames.py b/protenix/model/modules/frames.py index e390b8b..d71282b 100644 --- a/protenix/model/modules/frames.py +++ b/protenix/model/modules/frames.py @@ -62,47 +62,25 @@ def gather_frame_atom_by_indices( Args: coordinate (torch.Tensor): the input coordinate - [..., N_atom, 3] + [..., N_atom, 3[three coordinates]] frame_atom_index (torch.Tensor): indices of three atoms in each frame - [..., N_frame, 3] or [N_frame, 3] + [..., N_frame, 3[three atoms per frame]] or [N_frame, 3[three atoms per frame]] dim (int): along which dimension to select the frame atoms Returns: torch.Tensor: the constructed frames - [..., N_frame, 3[three atom], 3[three coordinate]] + [..., N_frame, 3[three atoms per frame], 3[three coordinates]] """ if len(frame_atom_index.shape) == 2: - # the navie case - x1 = torch.index_select( - coordinate, dim=dim, index=frame_atom_index[:, 0] - ) # [..., N_frame, 3] - x2 = torch.index_select( - coordinate, dim=dim, index=frame_atom_index[:, 1] - ) # [..., N_frame, 3] - x3 = torch.index_select( - coordinate, dim=dim, index=frame_atom_index[:, 2] - ) # [..., N_frame, 3] - return torch.stack([x1, x2, x3], dim=dim) + # the naive case + return coordinate[..., frame_atom_index, :] else: assert ( frame_atom_index.shape[:dim] == coordinate.shape[:dim] - ), "batch size dims should match" + ), f"the size of each batch dim should match, got {frame_atom_index.shape[:dim]} and {coordinate.shape[:dim]}" - x1 = batched_gather( - data=coordinate, - inds=frame_atom_index[..., 0], - dim=dim, - no_batch_dims=len(coordinate.shape[:dim]), - ) # [..., N_frame, 3] - x2 = batched_gather( - data=coordinate, - inds=frame_atom_index[..., 1], - dim=dim, - no_batch_dims=len(coordinate.shape[:dim]), - ) # [..., N_frame, 3] - x3 = batched_gather( + reshaped_frame_atom_index = frame_atom_index.reshape(*frame_atom_index.shape[:-2], -1) # [..., N_frame*3] + batched_frame_atom_coordinates = batched_gather( data=coordinate, - inds=frame_atom_index[..., 2], - dim=dim, - no_batch_dims=len(coordinate.shape[:dim]), - ) # [..., N_frame, 3] - return torch.stack([x1, x2, x3], dim=dim) + inds=reshaped_frame_atom_index + ) # [..., N_frame*3, 3[three coordinates]] + return batched_frame_atom_coordinates.reshape(*batched_frame_atom_coordinates.shape[:-2], frame_atom_index.shape[-2], frame_atom_index.shape[-1], coordinate.shape[-1]) # [..., N_frame, 3, 3] diff --git a/protenix/model/modules/primitives.py b/protenix/model/modules/primitives.py index 1191280..4bc4344 100644 --- a/protenix/model/modules/primitives.py +++ b/protenix/model/modules/primitives.py @@ -364,7 +364,7 @@ def basic_checks(x, dim_x): ] pad_left = (n_keys - n_queries) // 2 - pad_right = int((n_trunks - 1 / 2) * n_queries + n_keys / 2 - n + 1 / 2) + pad_right = (n_keys - n_queries) // 2 + (n_trunks * n_queries - n) k_new = [ pad_at_dim(k[i], dim=dim_k[i], pad_length=(pad_left, pad_right)) @@ -907,6 +907,7 @@ def gather_pair_embedding_in_dense_trunk( idx_k_expanded = idx_k.unsqueeze(1).expand(-1, N_q, -1) # Use advanced indexing to gather the desired elements + # https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing y = x[..., idx_q_expanded, idx_k_expanded, :] return y diff --git a/protenix/model/utils.py b/protenix/model/utils.py index 9a923d8..13140d6 100644 --- a/protenix/model/utils.py +++ b/protenix/model/utils.py @@ -193,39 +193,33 @@ def one_hot( return dgram -# this is mostly from openfold.utils.torch_utils import batched_gather def batched_gather( - data: torch.Tensor, inds: torch.Tensor, dim: int = 0, no_batch_dims: int = 0 + data: torch.Tensor, inds: torch.Tensor ) -> torch.Tensor: - """Gather data according to indices specify by inds + """Gather data according to indices specified by inds along the dim = len(inds.shape) - 1 Args: data (torch.Tensor): the input data [..., K, ...] inds (torch.Tensor): the indices for gathering data [..., N] - dim (int, optional): along which dimension to gather data by inds (the dim of "K" "N"). Defaults to 0. - no_batch_dims (int, optional): length of dimensions before the "dim" dimension. Defaults to 0. Returns: - torch.Tensor: gathered data + torch.Tensor: gathered data, have the same number of dimensions as data, + only the size of dimension len(inds.shape) - 1 is changed to N [..., N, ...] """ + assert len(inds.shape) <= len(data.shape), "inds must have less or equal dimensions than data" + assert inds.shape[:len(inds.shape)-1] == data.shape[:len(inds.shape)-1], "Batch dimensions must match between data and inds" + + if len(inds.shape) == len(data.shape): + return torch.gather(data, dim=-1, index=inds) - # for the naive case - if len(inds.shape) == 1 and no_batch_dims == 0 and dim == 0: - return data[inds] - - ranges = [] - for i, s in enumerate(data.shape[:no_batch_dims]): - r = torch.arange(s) - r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) - ranges.append(r) - - remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)] - remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds - ranges.extend(remaining_dims) - return data[ranges] + append_shape = (1,) * (len(data.shape) - len(inds.shape)) + append_shape_broadcasted = data.shape[len(inds.shape) - len(data.shape):] + inds_broadcasted = inds.reshape(inds.shape + append_shape) + inds_broadcasted = inds_broadcasted.expand(inds.shape + append_shape_broadcasted) + return torch.gather(data, dim=len(inds.shape) - 1, index=inds_broadcasted) def broadcast_token_to_atom( @@ -252,9 +246,7 @@ def broadcast_token_to_atom( return batched_gather( data=x_token, - inds=atom_to_token_idx, - dim=-2, - no_batch_dims=len(x_token.shape[:-2]), + inds=atom_to_token_idx )