Skip to content

Conversation

phantom5125
Copy link

@phantom5125 phantom5125 commented Oct 12, 2025

PR Category

Operator

Type of Change

New Feature

Description

Add an index_fill op according to a feature request: #892

Issue

Closes #892

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

屏幕截图 2025-10-12 200054 屏幕截图 2025-10-12 202934

Copy link
Contributor

Summary of Changes

Hello @phantom5125, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces the index_fill operator, a new feature that allows filling elements of a tensor with a specified value along a given dimension using provided indices. The implementation includes both an out-of-place and an in-place version, leveraging a dynamically generated Triton kernel to ensure high performance. The changes also incorporate dedicated performance benchmarks and extensive accuracy tests to validate the operator's functionality and efficiency.

Highlights

  • New index_fill Operator: Introduces both the out-of-place torch.index_fill and in-place torch.index_fill_ operations.
  • Triton Kernel Implementation: The core logic for index_fill is implemented using a dynamically generated Triton JIT kernel for optimized performance.
  • Performance Benchmarks: Adds a new performance benchmark (test_index_fill_perf) to evaluate the operator's efficiency across different data types.
  • Comprehensive Accuracy Tests: Includes new accuracy tests (test_accuracy_index_fill and test_accuracy_index_fill_multidim) to ensure correctness for various tensor shapes, data types, and dimensions.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new index_fill operation, along with its in-place version index_fill_, including the core implementation, registration, and corresponding accuracy and performance tests. The implementation of the Triton kernel for index_fill has a critical correctness issue that needs to be addressed. Additionally, there are opportunities to improve code readability in the operator implementation and to correct the bandwidth calculation in the performance benchmark. The changes are well-structured, but the kernel logic requires a significant revision.

Comment on lines +80 to +94
code.writeline(
"src_dim_idx = (tl.load(index + dim_idx, mask=mask, other=0)).to(tl.int64)"
)
code.writeline(
'assert src_dim_idx >= 0 and src_dim_idx < inp_shape_dim, "0 <= index < self.size(dim)"'
)

# Calculate the final index in the output tensor
code.writeline(
"output_idx = (inp_offset + (src_dim_idx - dim_idx) * inp_stride_dim).to(tl.int64)"
)
code.writeline("output_mask = output_idx < inp_numel")

# Fill the value at the calculated positions
code.writeline("tl.store(out + output_idx, value, mask=output_mask)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic within the generated Triton kernel for index_fill appears to be incorrect and has a critical bug.

  1. Out-of-bounds memory access: The line src_dim_idx = (tl.load(index + dim_idx, ...)) uses dim_idx (the coordinate of an element along dim) as an offset into the index tensor. dim_idx can range from 0 to inp.shape[dim] - 1, while the index tensor can be much smaller. This will cause reads beyond the bounds of the index tensor if inp.shape[dim] > index.numel().
  2. Incorrect logic: The kernel performs a scatter-like operation, not index_fill. For an element at (..., dim_idx, ...), it writes value to (..., index[dim_idx], ...). The correct index_fill behavior is to fill entire slices. For each k in the index tensor, all elements in the slice dim=k should be set to value.
  3. Inefficiency: The kernel iterates over all elements of the output tensor, which is inefficient for a sparse operation like index_fill. A better approach would be to launch threads based on the number of elements to be filled.

This is a critical correctness issue. The passing tests might be due to specific test data or environment behavior masking the bug.

Comment on lines +513 to +516
elemetns_filled = (
index.numel() * (inp.numel() // inp.shape[dim]) if index.numel() > 0 else 0
)
io_amount = shape_utils.size_in_bytes(inp) + elemetns_filled * inp.element_size()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a typo in elemetns_filled. More importantly, the io_amount calculation appears to model an in-place operation. Since torch.index_fill is out-of-place, a more accurate bandwidth calculation should account for reading the input tensor, reading the index tensor, and writing the entire output tensor.

    # For an out-of-place op, we read the input, read the index, and write the full output.
    io_amount = shape_utils.size_in_bytes(inp) * 2 + shape_utils.size_in_bytes(index)

Comment on lines +215 to +217
assert ((0 <= index) * (index < inp.size(dim))).equal(
torch.ones(tuple(index.shape), dtype=torch.bool, device=inp.device)
), "0 <= index < self.size(dim)"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This assert statement can be simplified for better readability. Using torch.all() is more idiomatic and clearer than creating a tensor of ones and using .equal().

    assert torch.all((0 <= index) & (index < inp.size(dim))), "0 <= index < self.size(dim)"

Comment on lines +242 to +244
assert ((0 <= index) * (index < inp.size(dim))).equal(
torch.ones(tuple(index.shape), dtype=torch.bool, device=inp.device)
), "0 <= index < self.size(dim)"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This assert statement can be simplified for better readability. Using torch.all() is more idiomatic and clearer than creating a tensor of ones and using .equal().

    assert torch.all((0 <= index) & (index < inp.size(dim))), "0 <= index < self.size(dim)"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Code Contribution: 【Hard】【Operator Development】index_fill

1 participant