-
Notifications
You must be signed in to change notification settings - Fork 144
Feature/Add index_fill op #998
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @phantom5125, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new index_fill
operation, along with its in-place version index_fill_
, including the core implementation, registration, and corresponding accuracy and performance tests. The implementation of the Triton kernel for index_fill
has a critical correctness issue that needs to be addressed. Additionally, there are opportunities to improve code readability in the operator implementation and to correct the bandwidth calculation in the performance benchmark. The changes are well-structured, but the kernel logic requires a significant revision.
code.writeline( | ||
"src_dim_idx = (tl.load(index + dim_idx, mask=mask, other=0)).to(tl.int64)" | ||
) | ||
code.writeline( | ||
'assert src_dim_idx >= 0 and src_dim_idx < inp_shape_dim, "0 <= index < self.size(dim)"' | ||
) | ||
|
||
# Calculate the final index in the output tensor | ||
code.writeline( | ||
"output_idx = (inp_offset + (src_dim_idx - dim_idx) * inp_stride_dim).to(tl.int64)" | ||
) | ||
code.writeline("output_mask = output_idx < inp_numel") | ||
|
||
# Fill the value at the calculated positions | ||
code.writeline("tl.store(out + output_idx, value, mask=output_mask)") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic within the generated Triton kernel for index_fill
appears to be incorrect and has a critical bug.
- Out-of-bounds memory access: The line
src_dim_idx = (tl.load(index + dim_idx, ...))
usesdim_idx
(the coordinate of an element alongdim
) as an offset into theindex
tensor.dim_idx
can range from0
toinp.shape[dim] - 1
, while theindex
tensor can be much smaller. This will cause reads beyond the bounds of theindex
tensor ifinp.shape[dim] > index.numel()
. - Incorrect logic: The kernel performs a scatter-like operation, not
index_fill
. For an element at(..., dim_idx, ...)
, it writesvalue
to(..., index[dim_idx], ...)
. The correctindex_fill
behavior is to fill entire slices. For eachk
in theindex
tensor, all elements in the slicedim=k
should be set tovalue
. - Inefficiency: The kernel iterates over all elements of the output tensor, which is inefficient for a sparse operation like
index_fill
. A better approach would be to launch threads based on the number of elements to be filled.
This is a critical correctness issue. The passing tests might be due to specific test data or environment behavior masking the bug.
elemetns_filled = ( | ||
index.numel() * (inp.numel() // inp.shape[dim]) if index.numel() > 0 else 0 | ||
) | ||
io_amount = shape_utils.size_in_bytes(inp) + elemetns_filled * inp.element_size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a typo in elemetns_filled
. More importantly, the io_amount
calculation appears to model an in-place operation. Since torch.index_fill
is out-of-place, a more accurate bandwidth calculation should account for reading the input tensor, reading the index tensor, and writing the entire output tensor.
# For an out-of-place op, we read the input, read the index, and write the full output.
io_amount = shape_utils.size_in_bytes(inp) * 2 + shape_utils.size_in_bytes(index)
assert ((0 <= index) * (index < inp.size(dim))).equal( | ||
torch.ones(tuple(index.shape), dtype=torch.bool, device=inp.device) | ||
), "0 <= index < self.size(dim)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert ((0 <= index) * (index < inp.size(dim))).equal( | ||
torch.ones(tuple(index.shape), dtype=torch.bool, device=inp.device) | ||
), "0 <= index < self.size(dim)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR Category
Operator
Type of Change
New Feature
Description
Add an
index_fill
op according to a feature request: #892Issue
Closes #892
Progress
Performance