Skip to content

Row-gather problem statements and issues #334

@nhat-nguyen

Description

@nhat-nguyen

I'm sharing some high-level thoughts on the current experimental implementation of row-gather.

1) Row-gather sub-problems

At a high level, the current row-gather implementation reuses the PtrAnalysis
visitor functions to detect when a dimension is not structured. If a visitor
function fails, we assume that the original triton SSA value that causes the
the failure is the tensor representing the offset for that unstructured dim.
(need to verify if this is indeed always true).
Some problems with this approach:

  • relies on existing PtrAnalysis visitor functions, which handle certain op
    differently (e.g: remsi)
  • make the PtrAnalysis code more complicated since it now handles both
    structured and unstructured case
    I think another way to break this problem down is to:
    a) have a separate "dimension" analysis that, given tensor of pointer
    expression, knows how to generate the tensor of offset for each dimension.

b) refactor PtrAnalysis to allow "partially" structured tensors. Right now
if a dimension is unstructured, the analysis fails.

Combining both a) and b), we can separate out the 2 problems and make the
analysis more robust. I have not thought about how both a) and b)
will interact with loop support.

2) Tracking stride

Tracking stride in the gather (unstructured) dimension is complicated
(most of the code is in PtrState::addState). The main problem is given
a tensor of pointer expression, for each dimension, what is the stride?
The current code is examining many cases which may be error-prone. How do
we solve this in a more generalized way?

3) Loop

To make row-gather work in loops, we have to split the loop-iter arg of the
tensor of pointers into separate tensors for each dimension.
Each iteration will then increment the appropriate tensor.
The current row-gather approach may work with some changes (need further
investigation), but since we already have plans to refactor this code,
it may be easier to implement a new approach from scratch to decouple
the row-gather from the original PtrAnalysis implementation.

4) Mask

The current mask analysis processes the full mask, but for row-gather, we
want to only compute mask for the structured dimension only. How do we split
the combined mask value to get the structured and unstructured dimensions?
For example, for a 2D boolean tensor, we want to split it into 2 1D tensors,
one for the gather (unstructured) dimension, the other for the structured
dimension.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions