Skip to content

Commit

Permalink
[Dma] Don't check Dma unit dims during subsumption and fold them duri…
Browse files Browse the repository at this point in the history
…ng controlCodeLowering (nod-ai#1070)

Before this PR, we need to check the number of sizes/strides before dma
loop subsumption to make sure the number of dims are not exceed the
maximum after subsumption. However, this blocks some opportunities for
loop subsumption when there are unit dimensions which are not
canonicalized because the offsets of these dimensions are none 0. For
example the following loop cannot be subsumed because there are already
4 dimensions on L3 source side.

```
scf.for %arg2 = %c0 to %c6 step %c1 {
  %1 = affine.apply affine_map<(d0) -> (d0 + 1)>(%arg2)
  amdaie.npu.dma_cpy_nd %0([] [] [], [1, %1, 0, 0] [1, 1, 32, 32] [8192, 1024, 32, 1])
}
```

This PR relaxes the constraint by only checking the non-unit dimensions,
so the above loop can be subsumed into dma as

```
amdaie.npu.dma_cpy_nd %0([] [] [], [0, 1, 1, 0, 0] [6, 1, 1, 32, 32] [1024, 8192, 1024, 32, 1])
```

And this dma can be further canonicalized.
  • Loading branch information
yzhang93 authored Jan 31, 2025
1 parent 6417b4a commit e2b0e9d
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ struct HalfDmaCpyNdToNpuConverter final
AMDAIE::LogicalObjectFifoFromMemrefOp logicalObjFifo,
AMDAIE::BdIdOp bdIdOp, AMDAIE::ChannelOp channelOp, int64_t bufferLength,
int64_t bufferOffset, int32_t enablePacket, int32_t packetId,
int32_t packetType, ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) const {
int32_t packetType, SmallVector<OpFoldResult> sizes,
SmallVector<OpFoldResult> strides) const {
uint8_t numIntraAddrDim = deviceModel.getDmaProp<uint8_t>(
tileType, AMDAIE::AMDAIEDmaProp::NumAddrDim);
uint8_t numAddrDim =
Expand Down Expand Up @@ -66,6 +66,21 @@ struct HalfDmaCpyNdToNpuConverter final
int32_t bdId = getConstantIndexOrAssert(bdIdOp.getValue());
int32_t outOfOrderId{0};

SmallVector<OpFoldResult> offsets(
strides.size(), getAsIndexOpFoldResult(rewriter.getContext(), 0));
(void)foldUnitDims(rewriter.getContext(), offsets, sizes, strides);

uint8_t memSpace = logicalObjFifo.getMemorySpaceAsUInt();
DmaDimConfig dmaDimConfig(deviceModel, memSpace);
SmallVector<int64_t> maxSizes = dmaDimConfig.getMaxSizes(offsets.size());
SmallVector<OpFoldResult> linearOffsets, linearSizes, linearStrides;
(void)foldLinearDims(
rewriter.getContext(), offsets, sizes, strides, linearOffsets,
linearSizes, linearStrides, [&](size_t idxFromEnd, int64_t size) {
return idxFromEnd < maxSizes.size() &&
size <= maxSizes[maxSizes.size() - idxFromEnd - 1];
});

SmallVector<int32_t, 4> staticSizes;
SmallVector<int32_t, 4> staticStrides;
// Padding is unused for now.
Expand All @@ -75,14 +90,15 @@ struct HalfDmaCpyNdToNpuConverter final
int32_t iterationSize{0};
int32_t iterationStride{0};
int32_t repeatCount{1};
for (auto iter : llvm::enumerate(llvm::zip(sizes, strides))) {
for (auto iter : llvm::enumerate(llvm::zip(linearSizes, linearStrides))) {
int64_t size = getConstantIndexOrAssert(std::get<0>(iter.value()));
int64_t stride = getConstantIndexOrAssert(std::get<1>(iter.value()));

/// Map the outer dimension to the iteration dimension if intra dimensions
/// are all used already or if the first stride == 0 as only the iteration
/// dimension supports stride == 0.
if (iter.index() == 0 && (sizes.size() == numAddrDim || stride == 0)) {
if (iter.index() == 0 &&
(linearSizes.size() == numAddrDim || stride == 0)) {
if (stride == 0) {
repeatCount = size;
} else {
Expand All @@ -96,7 +112,7 @@ struct HalfDmaCpyNdToNpuConverter final
staticStrides.push_back(
std::max(stride * elemWidthInBits / minStrideBitWidth, (int64_t)1));
// Innermost size needs to account for addressing granularity.
if (iter.index() == (sizes.size() - 1)) {
if (iter.index() == (linearSizes.size() - 1)) {
staticSizes.push_back(size * elemWidthInBits / minStrideBitWidth);
} else {
staticSizes.push_back(size);
Expand All @@ -105,6 +121,12 @@ struct HalfDmaCpyNdToNpuConverter final
}
// Make sure sizes/strides have the correct size based on the number from
// intra addressing dimensions.
assert(staticSizes.size() <= numIntraAddrDim &&
"The number of dimensions in DMA sizes should not more than the "
"number of `intra` addressing dimensions");
assert(staticStrides.size() <= numIntraAddrDim &&
"The number of dimensions in DMA strides should not more than the "
"number of `intra` addressing dimensions");
staticSizes.insert(staticSizes.begin(),
numIntraAddrDim - staticSizes.size(), 0);
staticStrides.insert(staticStrides.begin(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,23 @@ struct SubsumeLoopIntoDMA
if (nbIterations == 0) return failure();
if (nbIterations > 1) nbNonUnitIterations++;
}
if (sourceDmaDimConfig.exceedsNbDims(newSourceOffsets.size() +
nbNonUnitIterations)) {

std::optional<SmallVector<int64_t>> staticSourceSizes =
getConstantIntValues(newSourceSizes);
if (!staticSourceSizes) return failure();
size_t nbUnitDimsSource = std::count(staticSourceSizes.value().begin(),
staticSourceSizes.value().end(), 1);
if (sourceDmaDimConfig.exceedsNbDims(
newSourceOffsets.size() - nbUnitDimsSource + nbNonUnitIterations)) {
return failure();
}
if (targetDmaDimConfig.exceedsNbDims(newTargetOffsets.size() +
nbNonUnitIterations)) {
std::optional<SmallVector<int64_t>> staticTargetSizes =
getConstantIntValues(newTargetSizes);
if (!staticTargetSizes) return failure();
size_t nbUnitDimsTarget = std::count(staticTargetSizes.value().begin(),
staticTargetSizes.value().end(), 1);
if (targetDmaDimConfig.exceedsNbDims(
newTargetOffsets.size() - nbUnitDimsTarget + nbNonUnitIterations)) {
return failure();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ bool mergeOffset(MLIRContext *ctx, int64_t offsetToMerge,
if (cOffset.has_value() && cStride.has_value()) {
int64_t offset = cOffset.value();
int64_t stride = cStride.value();
if (offsetToMerge % stride == 0) {
if (stride != 0 && offsetToMerge % stride == 0) {
offset += offsetToMerge / stride;
offsets[i] = getAsIndexOpFoldResult(ctx, offset);
return true;
Expand Down Expand Up @@ -583,13 +583,22 @@ bool DmaDimConfig::isValidAccessPattern(ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides) const {
assert(sizes.size() == strides.size() &&
"`sizes` and `strides` should have the same size");
SmallVector<int64_t> maxSizes = getMaxSizes(sizes.size());
assert(maxSizes.size() >= sizes.size() &&
// No need to check the unit dimensions.
SmallVector<int64_t> nonUnitSizes;
SmallVector<int64_t> nonUnitStrides;
for (size_t i = 0; i < sizes.size(); ++i) {
if (sizes[i] != 1) {
nonUnitSizes.push_back(sizes[i]);
nonUnitStrides.push_back(strides[i]);
}
}
SmallVector<int64_t> maxSizes = getMaxSizes(nonUnitSizes.size());
assert(maxSizes.size() >= nonUnitSizes.size() &&
"Max number of dimensions exceeded");
size_t frontToDrop = maxSizes.size() - sizes.size();
if (anyOutOfRange(sizes, maxSizes, frontToDrop)) return false;
SmallVector<int64_t> maxStrides = getMaxStrides(sizes.size());
if (anyOutOfRange(strides, maxStrides, frontToDrop)) return false;
size_t frontToDrop = maxSizes.size() - nonUnitSizes.size();
if (anyOutOfRange(nonUnitSizes, maxSizes, frontToDrop)) return false;
SmallVector<int64_t> maxStrides = getMaxStrides(nonUnitSizes.size());
if (anyOutOfRange(nonUnitStrides, maxStrides, frontToDrop)) return false;
return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,8 @@ TEST_F(FoldTest, UnitDimsMerge) {
EXPECT_TRUE(checkFoldUnitDims({2, 0, 1, 0}, {1, 32, 1, 8},
{1024, 32, 1024, 1}, {96, 0}, {32, 8},
{32, 1}));
EXPECT_TRUE(checkFoldUnitDims({0, 0, 1, 0}, {2, 32, 1, 8}, {0, 32, 1024, 1},
{0, 32, 0}, {2, 32, 8}, {0, 32, 1}));
EXPECT_TRUE(
checkFoldUnitDims({2, 2, 15}, {1, 1, 10}, {4, 6, 10}, {17}, {10}, {10}));
EXPECT_TRUE(checkFoldUnitDims({3, 1, 15}, {1, 1, 10}, {4, 6, 10}, {1, 15},
Expand All @@ -607,6 +609,8 @@ TEST_F(FoldTest, UnitDimsFoldAndMerge) {
{1}, {1}, {96}));
EXPECT_TRUE(checkFoldUnitDims({1, 0, 1, 0}, {1, 1, 1, 8}, {1024, 32, 1024, 1},
{2048}, {8}, {1}));
EXPECT_TRUE(checkFoldUnitDims({0, 0, 1, 0}, {1, 32, 1, 8}, {0, 32, 1024, 1},
{32, 0}, {32, 8}, {32, 1}));
}

TEST_F(FoldTest, FoldRepetitionCount) {
Expand Down
Loading

0 comments on commit e2b0e9d

Please sign in to comment.