Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 11 additions & 33 deletions protenix/model/modules/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,47 +62,25 @@ def gather_frame_atom_by_indices(

Args:
coordinate (torch.Tensor): the input coordinate
[..., N_atom, 3]
[..., N_atom, 3[three coordinates]]
frame_atom_index (torch.Tensor): indices of three atoms in each frame
[..., N_frame, 3] or [N_frame, 3]
[..., N_frame, 3[three atoms per frame]] or [N_frame, 3[three atoms per frame]]
dim (int): along which dimension to select the frame atoms
Returns:
torch.Tensor: the constructed frames
[..., N_frame, 3[three atom], 3[three coordinate]]
[..., N_frame, 3[three atoms per frame], 3[three coordinates]]
"""
if len(frame_atom_index.shape) == 2:
# the navie case
x1 = torch.index_select(
coordinate, dim=dim, index=frame_atom_index[:, 0]
) # [..., N_frame, 3]
x2 = torch.index_select(
coordinate, dim=dim, index=frame_atom_index[:, 1]
) # [..., N_frame, 3]
x3 = torch.index_select(
coordinate, dim=dim, index=frame_atom_index[:, 2]
) # [..., N_frame, 3]
return torch.stack([x1, x2, x3], dim=dim)
# the naive case
return coordinate[..., frame_atom_index, :]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will trigger the bug below. The error occurs at line 894 in loss.py within the call to gather_frame_atom_by_indices using dim=-1.
return coordinate[..., frame_atom_index, :]
~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^
IndexError: too many indices for tensor of dimension 1

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your review. I hope I will have time to study the details (maybe before 2026/4/20)

else:
assert (
frame_atom_index.shape[:dim] == coordinate.shape[:dim]
), "batch size dims should match"
), f"the size of each batch dim should match, got {frame_atom_index.shape[:dim]} and {coordinate.shape[:dim]}"

x1 = batched_gather(
data=coordinate,
inds=frame_atom_index[..., 0],
dim=dim,
no_batch_dims=len(coordinate.shape[:dim]),
) # [..., N_frame, 3]
x2 = batched_gather(
data=coordinate,
inds=frame_atom_index[..., 1],
dim=dim,
no_batch_dims=len(coordinate.shape[:dim]),
) # [..., N_frame, 3]
x3 = batched_gather(
reshaped_frame_atom_index = frame_atom_index.reshape(*frame_atom_index.shape[:-2], -1) # [..., N_frame*3]
batched_frame_atom_coordinates = batched_gather(
data=coordinate,
inds=frame_atom_index[..., 2],
dim=dim,
no_batch_dims=len(coordinate.shape[:dim]),
) # [..., N_frame, 3]
return torch.stack([x1, x2, x3], dim=dim)
inds=reshaped_frame_atom_index
) # [..., N_frame*3, 3[three coordinates]]
return batched_frame_atom_coordinates.reshape(*batched_frame_atom_coordinates.shape[:-2], frame_atom_index.shape[-2], frame_atom_index.shape[-1], coordinate.shape[-1]) # [..., N_frame, 3, 3]
3 changes: 2 additions & 1 deletion protenix/model/modules/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def basic_checks(x, dim_x):
]

pad_left = (n_keys - n_queries) // 2
pad_right = int((n_trunks - 1 / 2) * n_queries + n_keys / 2 - n + 1 / 2)
pad_right = (n_keys - n_queries) // 2 + (n_trunks * n_queries - n)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line has nothing to do with my other modification. I just cannot understand the original calculation method for pad_right.

Note: both n_keys and n_queries are even integer numbers. You have used assert statement to confirm it.

https://github.com/OccupyMars2025/Protenix/wiki/explain-how-to-calculate-pad_right-for-the-key-tensor


k_new = [
pad_at_dim(k[i], dim=dim_k[i], pad_length=(pad_left, pad_right))
Expand Down Expand Up @@ -907,6 +907,7 @@ def gather_pair_embedding_in_dense_trunk(
idx_k_expanded = idx_k.unsqueeze(1).expand(-1, N_q, -1)

# Use advanced indexing to gather the desired elements
# https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
y = x[..., idx_q_expanded, idx_k_expanded, :]

return y
Expand Down
38 changes: 15 additions & 23 deletions protenix/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,39 +193,33 @@ def one_hot(
return dgram


# this is mostly from openfold.utils.torch_utils import batched_gather
def batched_gather(
data: torch.Tensor, inds: torch.Tensor, dim: int = 0, no_batch_dims: int = 0
data: torch.Tensor, inds: torch.Tensor
) -> torch.Tensor:
"""Gather data according to indices specify by inds
"""Gather data according to indices specified by inds along the dim = len(inds.shape) - 1

Args:
data (torch.Tensor): the input data
[..., K, ...]
inds (torch.Tensor): the indices for gathering data
[..., N]
dim (int, optional): along which dimension to gather data by inds (the dim of "K" "N"). Defaults to 0.
no_batch_dims (int, optional): length of dimensions before the "dim" dimension. Defaults to 0.

Returns:
torch.Tensor: gathered data
torch.Tensor: gathered data, have the same number of dimensions as data,
only the size of dimension len(inds.shape) - 1 is changed to N
[..., N, ...]
"""
assert len(inds.shape) <= len(data.shape), "inds must have less or equal dimensions than data"
assert inds.shape[:len(inds.shape)-1] == data.shape[:len(inds.shape)-1], "Batch dimensions must match between data and inds"

if len(inds.shape) == len(data.shape):
return torch.gather(data, dim=-1, index=inds)

# for the naive case
if len(inds.shape) == 1 and no_batch_dims == 0 and dim == 0:
return data[inds]

ranges = []
for i, s in enumerate(data.shape[:no_batch_dims]):
r = torch.arange(s)
r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
ranges.append(r)

remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)]
remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
ranges.extend(remaining_dims)
return data[ranges]
Copy link
Copy Markdown
Contributor Author

@OccupyMars2030 OccupyMars2030 Mar 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are gathering data along only "one" dimension, there is no need to construct an index tensor for each dim like "ranges". Just use torch.gather() because it is more logically concise

append_shape = (1,) * (len(data.shape) - len(inds.shape))
append_shape_broadcasted = data.shape[len(inds.shape) - len(data.shape):]
inds_broadcasted = inds.reshape(inds.shape + append_shape)
inds_broadcasted = inds_broadcasted.expand(inds.shape + append_shape_broadcasted)
return torch.gather(data, dim=len(inds.shape) - 1, index=inds_broadcasted)


def broadcast_token_to_atom(
Expand All @@ -252,9 +246,7 @@ def broadcast_token_to_atom(

return batched_gather(
data=x_token,
inds=atom_to_token_idx,
dim=-2,
no_batch_dims=len(x_token.shape[:-2]),
inds=atom_to_token_idx
)


Expand Down