Skip to content

Commit

Permalink
[FoldLinearDims] Support folding with static non-zero offsets (nod-ai…
Browse files Browse the repository at this point in the history
  • Loading branch information
jtuyls authored Jan 31, 2025
1 parent 2b89dc2 commit 6417b4a
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -332,30 +332,52 @@ LogicalResult foldLinearDims(
newStrides.push_back(strides[strides.size() - 1]);
newSizes.push_back(sizes[sizes.size() - 1]);

for (int i = offsets.size() - 2; i >= 0; i--) {
for (int i = static_cast<int>(offsets.size()) - 2; i >= 0; i--) {
// Conditions for folding a dim.
// 1. Offsets[i] == 0.This is required because we are dropping the offset
// of the i dimension and keep newOffets[-1]
// 1. Either, offsets[i] == 0 and then we can fold with any `newOffsets[-1]`
// (even dynamic ones), OR offsets[i] multiplied by the respective stride,
// is a multiple of the previous stride.
// 2. newSizes[-1] x newStrides[-1] == strides[i]. With this we can have
// newSizes[-1] = sizes[i] * newSizes[-1] , and then fold away the i
// dimension
// 3. checkValidSize(sizes[i] * newSizes[-1]). This allows hardware
// constraints to be checked.
size_t vecSize = newOffsets.size();
std::optional<int64_t> maybeNewOffset = getConstantIntValue(offsets[i]);
int64_t newStride = staticStrideVals[i];
int64_t newSize = staticSizeVals[i];
std::optional<int64_t> maybePrevOffset =
getConstantIntValue(newOffsets[vecSize - 1]);
int64_t prevStride = getConstantIndexOrAssert(newStrides[vecSize - 1]);
int64_t prevSize = getConstantIndexOrAssert(newSizes[vecSize - 1]);
int64_t dimExtent = prevStride * prevSize;
// Fail if max constraints are provided, but the newly created
// offsets/sizes/strides start exceeding the number of provide max
// constraints as this will result in undefined behaviour.
bool fitsMaxConstraint = checkValidSize(vecSize - 1, newSize * prevSize);
if (fitsMaxConstraint && isConstantIntValue(offsets[i], 0) &&
dimExtent == newStride) {
foldableLinearDimsFound = true;
newSizes[vecSize - 1] = getAsIndexOpFoldResult(ctx, newSize * prevSize);
continue;
if (fitsMaxConstraint && dimExtent == newStride) {
// There are currently two cases supported for folding a dimension:
// 1. If the offset is 0, we can fold the dimension, no matter what the
// value of `newPrevOffset` is (it can be dynamic).
// 2. If the offset, multiplied by the respective stride, is a multiple of
// the previous stride, we can fold the dimension if we update the new
// offset as well. However, in this case we need to add to new offset and
// this is currently only supported for constant offsets.
if (isConstantIntValue(offsets[i], 0)) {
foldableLinearDimsFound = true;
newSizes[vecSize - 1] = getAsIndexOpFoldResult(ctx, newSize * prevSize);
continue;
} else if (maybeNewOffset.has_value() && maybePrevOffset.has_value()) {
// NOTE: It's guaranteed that
// `(maybeNewOffset.value() * newStride) % prevStride == 0`
// as `newStride == prevStride * prevSize`
foldableLinearDimsFound = true;
newSizes[vecSize - 1] = getAsIndexOpFoldResult(ctx, newSize * prevSize);
int64_t newPrevOffset = maybePrevOffset.value() +
maybeNewOffset.value() * newStride / prevStride;
newOffsets[vecSize - 1] = getAsIndexOpFoldResult(ctx, newPrevOffset);
continue;
}
}
newOffsets.push_back(offsets[i]);
newStrides.push_back(strides[i]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -535,8 +535,6 @@ TEST_F(FoldTest, NoLinearDimsFold) {
checkFoldLinearDims({0}, {8}, {1}, {}, {0}, {8}, {1}, false);
checkFoldLinearDims({0, 0}, {16, 8}, {16, 1}, {}, {0, 0}, {16, 8}, {16, 1},
false);
checkFoldLinearDims({8, 0}, {16, 8}, {8, 1}, {}, {8, 0}, {16, 8}, {8, 1},
false);
}

TEST_F(FoldTest, FoldLinearDims) {
Expand All @@ -546,8 +544,8 @@ TEST_F(FoldTest, FoldLinearDims) {
true);
checkFoldLinearDims({0, 0, 0, 0}, {4, 8, 16, 8}, {1024, 128, 8, 1}, {}, {0},
{4096}, {1}, true);
checkFoldLinearDims({0, 0, 8, 0}, {4, 8, 16, 8}, {1024, 128, 8, 1}, {},
{8, 0}, {512, 8}, {8, 1}, true);
checkFoldLinearDims({5, 3, 8, 1}, {4, 8, 16, 8}, {1024, 128, 8, 1}, {},
{5569}, {4096}, {1}, true);
}

TEST_F(FoldTest, FoldLinearDimsWithMax) {
Expand All @@ -561,9 +559,9 @@ TEST_F(FoldTest, FoldLinearDimsWithMax) {
checkFoldLinearDims({0, 0, 0, 0}, {4, 8, 16, 8}, {1024, 128, 8, 1},
{1024, 1024, 1024, 1024}, {0, 0}, {4, 1024}, {1024, 1},
true);
checkFoldLinearDims({0, 0, 8, 0}, {4, 8, 16, 8}, {1024, 128, 8, 1},
{511, 511, 511, 511}, {0, 8, 0}, {4, 128, 8},
{1024, 8, 1}, true);
checkFoldLinearDims({4, 0, 8, 0}, {4, 8, 16, 8}, {1024, 128, 8, 1},
{511, 511, 511, 511}, {32, 64}, {32, 128}, {128, 1},
true);
}

TEST_F(FoldTest, NoUnitDimsFold) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ func.func @circular_dma_cpy_nd_unit_between_linear(%arg0: !amdaie.logicalobjectf
// -----

// CHECK-LABEL: func.func @circular_dma_cpy_nd_non_zero_offset
// CHECK: amdaie.circular_dma_cpy_nd(%{{.+}}[25, 1] [8, 16] [16, 1], %{{.+}}[5, 1, 1] [4, 2, 8] [16, 8, 1])
// FOLD-SINGLE-DIMS: amdaie.circular_dma_cpy_nd(%{{.+}}[25, 1] [8, 16] [16, 1], %{{.+}}[5, 1, 1] [4, 2, 8] [16, 8, 1])
// CHECK: amdaie.circular_dma_cpy_nd(%{{.+}}[401] [128] [1], %{{.+}}[89] [64] [1])
// FOLD-SINGLE-DIMS: amdaie.circular_dma_cpy_nd(%{{.+}}[401] [128] [1], %{{.+}}[89] [64] [1])
func.func @circular_dma_cpy_nd_non_zero_offset(%arg0: !amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>, %arg1: !amdaie.logicalobjectfifo<memref<8x16xi32, 1>>) {
%0 = amdaie.circular_dma_cpy_nd(%arg0[2, 1, 1, 1] [1, 1, 8, 16] [128, 128, 16, 1], %arg1[1, 1, 1, 1] [1, 4, 2, 8] [64, 16, 8, 1]) : (!amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>, !amdaie.logicalobjectfifo<memref<8x16xi32, 1>>)
"iree.keep"(%0) : (index) -> ()
Expand All @@ -87,6 +87,17 @@ func.func @circular_dma_cpy_nd_non_zero_offset(%arg0: !amdaie.logicalobjectfifo<

// -----

// CHECK-LABEL: func.func @circular_dma_cpy_nd_non_zero_dynamic_offset
// CHECK: amdaie.circular_dma_cpy_nd(%{{.+}}[%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}] [1, 1, 8, 16] [128, 128, 16, 1], %{{.+}}[%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}] [1, 4, 2, 8] [64, 16, 8, 1])
// FOLD-SINGLE-DIMS: amdaie.circular_dma_cpy_nd(%{{.+}}[%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}] [1, 1, 8, 16] [128, 128, 16, 1], %{{.+}}[%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}] [1, 4, 2, 8] [64, 16, 8, 1])
func.func @circular_dma_cpy_nd_non_zero_dynamic_offset(%arg0: !amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>, %arg1: !amdaie.logicalobjectfifo<memref<8x16xi32, 1>>, %arg2: index) {
%0 = amdaie.circular_dma_cpy_nd(%arg0[%arg2, %arg2, %arg2, %arg2] [1, 1, 8, 16] [128, 128, 16, 1], %arg1[%arg2, %arg2, %arg2, %arg2] [1, 4, 2, 8] [64, 16, 8, 1]) : (!amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>, !amdaie.logicalobjectfifo<memref<8x16xi32, 1>>)
"iree.keep"(%0) : (index) -> ()
return
}

// -----

// CHECK-LABEL: func.func @circular_dma_cpy_nd_partial_non_zero_offset
// CHECK: amdaie.circular_dma_cpy_nd(%{{.+}}[1] [128] [1], %{{.+}}[1] [64] [1])
// FOLD-SINGLE-DIMS: amdaie.circular_dma_cpy_nd(%{{.+}}[1] [128] [1], %{{.+}}[1] [64] [1])
Expand Down Expand Up @@ -174,8 +185,8 @@ func.func @dma_cpy_nd_unit_between_linear(%arg0: !amdaie.logicalobjectfifo<memre
// -----

// CHECK-LABEL: func.func @dma_cpy_nd_non_zero_offset
// CHECK: amdaie.dma_cpy_nd(%{{.+}}[25, 1] [8, 16] [16, 1], %{{.+}}[5, 1, 1] [4, 2, 8] [16, 8, 1])
// FOLD-SINGLE-DIMS: amdaie.dma_cpy_nd(%{{.+}}[25, 1] [8, 16] [16, 1], %{{.+}}[5, 1, 1] [4, 2, 8] [16, 8, 1])
// CHECK: amdaie.dma_cpy_nd(%{{.+}}[401] [128] [1], %{{.+}}[89] [64] [1])
// FOLD-SINGLE-DIMS: amdaie.dma_cpy_nd(%{{.+}}[401] [128] [1], %{{.+}}[89] [64] [1])
func.func @dma_cpy_nd_non_zero_offset(%arg0: !amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>, %arg1: !amdaie.logicalobjectfifo<memref<8x16xi32, 1>>) {
%0 = amdaie.dma_cpy_nd(%arg0[1, 2, 1, 1] [1, 1, 8, 16] [128, 128, 16, 1], %arg1[1, 1, 1, 1] [1, 4, 2, 8] [64, 16, 8, 1]) : (!amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>, !amdaie.logicalobjectfifo<memref<8x16xi32, 1>>)
"iree.keep"(%0) : (index) -> ()
Expand All @@ -184,6 +195,17 @@ func.func @dma_cpy_nd_non_zero_offset(%arg0: !amdaie.logicalobjectfifo<memref<1x

// -----

// CHECK-LABEL: func.func @dma_cpy_nd_non_zero_dynamic_offset
// CHECK: amdaie.dma_cpy_nd(%{{.+}}[%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}] [1, 1, 8, 16] [128, 128, 16, 1], %{{.+}}[%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}] [1, 4, 2, 8] [64, 16, 8, 1])
// FOLD-SINGLE-DIMS: amdaie.dma_cpy_nd(%{{.+}}[%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}] [1, 1, 8, 16] [128, 128, 16, 1], %{{.+}}[%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}] [1, 4, 2, 8] [64, 16, 8, 1])
func.func @dma_cpy_nd_non_zero_dynamic_offset(%arg0: !amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>, %arg1: !amdaie.logicalobjectfifo<memref<8x16xi32, 1>>, %arg2: index) {
%0 = amdaie.dma_cpy_nd(%arg0[%arg2, %arg2, %arg2, %arg2] [1, 1, 8, 16] [128, 128, 16, 1], %arg1[%arg2, %arg2, %arg2, %arg2] [1, 4, 2, 8] [64, 16, 8, 1]) : (!amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>, !amdaie.logicalobjectfifo<memref<8x16xi32, 1>>)
"iree.keep"(%0) : (index) -> ()
return
}

// -----

// CHECK-LABEL: func.func @dma_cpy_nd_partial_non_zero_offset
// CHECK: amdaie.dma_cpy_nd(%{{.+}}[1] [128] [1], %{{.+}}[1] [64] [1])
// FOLD-SINGLE-DIMS: amdaie.dma_cpy_nd(%{{.+}}[1] [128] [1], %{{.+}}[1] [64] [1])
Expand Down Expand Up @@ -273,8 +295,8 @@ func.func @npu_dma_cpy_nd_unit_between_linear(%arg0: !amdaie.logicalobjectfifo<m
// -----

// CHECK-LABEL: func.func @npu_dma_cpy_nd_non_zero_offset
// CHECK: amdaie.npu.dma_cpy_nd %{{.+}}([25, 1] [8, 16] [16, 1], [5, 1, 1] [4, 2, 8] [16, 8, 1])
// FOLD-SINGLE-DIMS: amdaie.npu.dma_cpy_nd %{{.+}}([25, 1] [8, 16] [16, 1], [5, 1, 1] [4, 2, 8] [16, 8, 1])
// CHECK: amdaie.npu.dma_cpy_nd %{{.+}}([401] [128] [1], [89] [64] [1])
// FOLD-SINGLE-DIMS: amdaie.npu.dma_cpy_nd %{{.+}}([401] [128] [1], [89] [64] [1])
func.func @npu_dma_cpy_nd_non_zero_offset(%arg0: !amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>, %arg1: !amdaie.logicalobjectfifo<memref<8x16xi32, 1>>) {
%0 = amdaie.circular_dma_cpy_nd(%arg0[] [] [], %arg1[] [] []) : (!amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>, !amdaie.logicalobjectfifo<memref<8x16xi32, 1>>)
amdaie.npu.dma_cpy_nd %0([1, 2, 1, 1] [1, 1, 8, 16] [128, 128, 16, 1], [1, 1, 1, 1] [1, 4, 2, 8] [64, 16, 8, 1])
Expand All @@ -283,6 +305,17 @@ func.func @npu_dma_cpy_nd_non_zero_offset(%arg0: !amdaie.logicalobjectfifo<memre

// -----

// CHECK-LABEL: func.func @npu_dma_cpy_nd_dynamic_non_zero_offset
// CHECK: amdaie.npu.dma_cpy_nd %{{.+}}([%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}] [1, 1, 8, 16] [128, 128, 16, 1], [%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}] [1, 4, 2, 8] [64, 16, 8, 1])
// FOLD-SINGLE-DIMS: amdaie.npu.dma_cpy_nd %{{.+}}([%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}] [1, 1, 8, 16] [128, 128, 16, 1], [%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}] [1, 4, 2, 8] [64, 16, 8, 1])
func.func @npu_dma_cpy_nd_dynamic_non_zero_offset(%arg0: !amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>, %arg1: !amdaie.logicalobjectfifo<memref<8x16xi32, 1>>, %arg2: index) {
%0 = amdaie.circular_dma_cpy_nd(%arg0[] [] [], %arg1[] [] []) : (!amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>, !amdaie.logicalobjectfifo<memref<8x16xi32, 1>>)
amdaie.npu.dma_cpy_nd %0([%arg2, %arg2, %arg2, %arg2] [1, 1, 8, 16] [128, 128, 16, 1], [%arg2, %arg2, %arg2, %arg2] [1, 4, 2, 8] [64, 16, 8, 1])
return
}

// -----

// CHECK-LABEL: func.func @npu_dma_cpy_nd_partial_non_zero_offset
// CHECK: amdaie.npu.dma_cpy_nd %{{.+}}([1] [128] [1], [1] [64] [1])
// FOLD-SINGLE-DIMS: amdaie.npu.dma_cpy_nd %{{.+}}([1] [128] [1], [1] [64] [1])
Expand Down

0 comments on commit 6417b4a

Please sign in to comment.