-
Notifications
You must be signed in to change notification settings - Fork 273
use torch.gather() to simplify the process of broadcasting a token embedding to an atom embedding and gathering the frame atom coordinates #269
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
9732d80
e torch.gather() to simplify the process of broadcasting a token embe…
OccupyMars2030 74185db
Update utils.py
OccupyMars2030 7646cb0
modify how to calculate pad_right
OccupyMars2030 326b7f9
add comments to explain the advanced tensor indexing
OccupyMars2030 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -364,7 +364,7 @@ def basic_checks(x, dim_x): | |
| ] | ||
|
|
||
| pad_left = (n_keys - n_queries) // 2 | ||
| pad_right = int((n_trunks - 1 / 2) * n_queries + n_keys / 2 - n + 1 / 2) | ||
| pad_right = (n_keys - n_queries) // 2 + (n_trunks * n_queries - n) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line has nothing to do with my other modification. I just cannot understand the original calculation method for
|
||
|
|
||
| k_new = [ | ||
| pad_at_dim(k[i], dim=dim_k[i], pad_length=(pad_left, pad_right)) | ||
|
|
@@ -907,6 +907,7 @@ def gather_pair_embedding_in_dense_trunk( | |
| idx_k_expanded = idx_k.unsqueeze(1).expand(-1, N_q, -1) | ||
|
|
||
| # Use advanced indexing to gather the desired elements | ||
| # https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing | ||
| y = x[..., idx_q_expanded, idx_k_expanded, :] | ||
|
|
||
| return y | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -193,39 +193,33 @@ def one_hot( | |
| return dgram | ||
|
|
||
|
|
||
| # this is mostly from openfold.utils.torch_utils import batched_gather | ||
| def batched_gather( | ||
| data: torch.Tensor, inds: torch.Tensor, dim: int = 0, no_batch_dims: int = 0 | ||
| data: torch.Tensor, inds: torch.Tensor | ||
| ) -> torch.Tensor: | ||
| """Gather data according to indices specify by inds | ||
| """Gather data according to indices specified by inds along the dim = len(inds.shape) - 1 | ||
|
|
||
| Args: | ||
| data (torch.Tensor): the input data | ||
| [..., K, ...] | ||
| inds (torch.Tensor): the indices for gathering data | ||
| [..., N] | ||
| dim (int, optional): along which dimension to gather data by inds (the dim of "K" "N"). Defaults to 0. | ||
| no_batch_dims (int, optional): length of dimensions before the "dim" dimension. Defaults to 0. | ||
|
|
||
| Returns: | ||
| torch.Tensor: gathered data | ||
| torch.Tensor: gathered data, have the same number of dimensions as data, | ||
| only the size of dimension len(inds.shape) - 1 is changed to N | ||
| [..., N, ...] | ||
| """ | ||
| assert len(inds.shape) <= len(data.shape), "inds must have less or equal dimensions than data" | ||
| assert inds.shape[:len(inds.shape)-1] == data.shape[:len(inds.shape)-1], "Batch dimensions must match between data and inds" | ||
|
|
||
| if len(inds.shape) == len(data.shape): | ||
| return torch.gather(data, dim=-1, index=inds) | ||
|
|
||
| # for the naive case | ||
| if len(inds.shape) == 1 and no_batch_dims == 0 and dim == 0: | ||
| return data[inds] | ||
|
|
||
| ranges = [] | ||
| for i, s in enumerate(data.shape[:no_batch_dims]): | ||
| r = torch.arange(s) | ||
| r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) | ||
| ranges.append(r) | ||
|
|
||
| remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)] | ||
| remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds | ||
| ranges.extend(remaining_dims) | ||
| return data[ranges] | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we are gathering data along only "one" dimension, there is no need to construct an index tensor for each dim like "ranges". Just use torch.gather() because it is more logically concise |
||
| append_shape = (1,) * (len(data.shape) - len(inds.shape)) | ||
| append_shape_broadcasted = data.shape[len(inds.shape) - len(data.shape):] | ||
| inds_broadcasted = inds.reshape(inds.shape + append_shape) | ||
| inds_broadcasted = inds_broadcasted.expand(inds.shape + append_shape_broadcasted) | ||
| return torch.gather(data, dim=len(inds.shape) - 1, index=inds_broadcasted) | ||
|
|
||
|
|
||
| def broadcast_token_to_atom( | ||
|
|
@@ -252,9 +246,7 @@ def broadcast_token_to_atom( | |
|
|
||
| return batched_gather( | ||
| data=x_token, | ||
| inds=atom_to_token_idx, | ||
| dim=-2, | ||
| no_batch_dims=len(x_token.shape[:-2]), | ||
| inds=atom_to_token_idx | ||
| ) | ||
|
|
||
|
|
||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will trigger the bug below. The error occurs at line 894 in loss.py within the call to gather_frame_atom_by_indices using dim=-1.
return coordinate[..., frame_atom_index, :]
~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^
IndexError: too many indices for tensor of dimension 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your review. I hope I will have time to study the details (maybe before 2026/4/20)