Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dma] Don't check Dma unit dims during subsumption and fold them during controlCodeLowering #1070

Merged
merged 5 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ struct HalfDmaCpyNdToNpuConverter final
AMDAIE::LogicalObjectFifoFromMemrefOp logicalObjFifo,
AMDAIE::BdIdOp bdIdOp, AMDAIE::ChannelOp channelOp, int64_t bufferLength,
int64_t bufferOffset, int32_t enablePacket, int32_t packetId,
int32_t packetType, ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) const {
int32_t packetType, SmallVector<OpFoldResult> sizes,
SmallVector<OpFoldResult> strides) const {
uint8_t numIntraAddrDim = deviceModel.getDmaProp<uint8_t>(
tileType, AMDAIE::AMDAIEDmaProp::NumAddrDim);
uint8_t numAddrDim =
Expand Down Expand Up @@ -66,6 +66,21 @@ struct HalfDmaCpyNdToNpuConverter final
int32_t bdId = getConstantIndexOrAssert(bdIdOp.getValue());
int32_t outOfOrderId{0};

SmallVector<OpFoldResult> offsets(
strides.size(), getAsIndexOpFoldResult(rewriter.getContext(), 0));
(void)foldUnitDims(rewriter.getContext(), offsets, sizes, strides);

uint8_t memSpace = logicalObjFifo.getMemorySpaceAsUInt();
DmaDimConfig dmaDimConfig(deviceModel, memSpace);
SmallVector<int64_t> maxSizes = dmaDimConfig.getMaxSizes(offsets.size());
SmallVector<OpFoldResult> linearOffsets, linearSizes, linearStrides;
(void)foldLinearDims(
rewriter.getContext(), offsets, sizes, strides, linearOffsets,
linearSizes, linearStrides, [&](size_t idxFromEnd, int64_t size) {
return idxFromEnd < maxSizes.size() &&
size <= maxSizes[maxSizes.size() - idxFromEnd - 1];
});

SmallVector<int32_t, 4> staticSizes;
SmallVector<int32_t, 4> staticStrides;
// Padding is unused for now.
Expand All @@ -75,14 +90,15 @@ struct HalfDmaCpyNdToNpuConverter final
int32_t iterationSize{0};
int32_t iterationStride{0};
int32_t repeatCount{1};
for (auto iter : llvm::enumerate(llvm::zip(sizes, strides))) {
for (auto iter : llvm::enumerate(llvm::zip(linearSizes, linearStrides))) {
int64_t size = getConstantIndexOrAssert(std::get<0>(iter.value()));
int64_t stride = getConstantIndexOrAssert(std::get<1>(iter.value()));

/// Map the outer dimension to the iteration dimension if intra dimensions
/// are all used already or if the first stride == 0 as only the iteration
/// dimension supports stride == 0.
if (iter.index() == 0 && (sizes.size() == numAddrDim || stride == 0)) {
if (iter.index() == 0 &&
(linearSizes.size() == numAddrDim || stride == 0)) {
if (stride == 0) {
repeatCount = size;
} else {
Expand All @@ -96,7 +112,7 @@ struct HalfDmaCpyNdToNpuConverter final
staticStrides.push_back(
std::max(stride * elemWidthInBits / minStrideBitWidth, (int64_t)1));
// Innermost size needs to account for addressing granularity.
if (iter.index() == (sizes.size() - 1)) {
if (iter.index() == (linearSizes.size() - 1)) {
staticSizes.push_back(size * elemWidthInBits / minStrideBitWidth);
} else {
staticSizes.push_back(size);
Expand All @@ -105,6 +121,12 @@ struct HalfDmaCpyNdToNpuConverter final
}
// Make sure sizes/strides have the correct size based on the number from
// intra addressing dimensions.
assert(staticSizes.size() <= numIntraAddrDim &&
"The number of dimensions in DMA sizes should not more than the "
"number of `intra` addressing dimensions");
assert(staticStrides.size() <= numIntraAddrDim &&
"The number of dimensions in DMA strides should not more than the "
"number of `intra` addressing dimensions");
staticSizes.insert(staticSizes.begin(),
numIntraAddrDim - staticSizes.size(), 0);
staticStrides.insert(staticStrides.begin(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,23 @@ struct SubsumeLoopIntoDMA
if (nbIterations == 0) return failure();
if (nbIterations > 1) nbNonUnitIterations++;
}
if (sourceDmaDimConfig.exceedsNbDims(newSourceOffsets.size() +
nbNonUnitIterations)) {

std::optional<SmallVector<int64_t>> staticSourceSizes =
getConstantIntValues(newSourceSizes);
if (!staticSourceSizes) return failure();
size_t nbUnitDimsSource = std::count(staticSourceSizes.value().begin(),
staticSourceSizes.value().end(), 1);
if (sourceDmaDimConfig.exceedsNbDims(
newSourceOffsets.size() - nbUnitDimsSource + nbNonUnitIterations)) {
return failure();
}
if (targetDmaDimConfig.exceedsNbDims(newTargetOffsets.size() +
nbNonUnitIterations)) {
std::optional<SmallVector<int64_t>> staticTargetSizes =
getConstantIntValues(newTargetSizes);
if (!staticTargetSizes) return failure();
size_t nbUnitDimsTarget = std::count(staticTargetSizes.value().begin(),
staticTargetSizes.value().end(), 1);
if (targetDmaDimConfig.exceedsNbDims(
newTargetOffsets.size() - nbUnitDimsTarget + nbNonUnitIterations)) {
return failure();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ bool mergeOffset(MLIRContext *ctx, int64_t offsetToMerge,
if (cOffset.has_value() && cStride.has_value()) {
int64_t offset = cOffset.value();
int64_t stride = cStride.value();
if (offsetToMerge % stride == 0) {
if (stride != 0 && offsetToMerge % stride == 0) {
jtuyls marked this conversation as resolved.
Show resolved Hide resolved
offset += offsetToMerge / stride;
offsets[i] = getAsIndexOpFoldResult(ctx, offset);
return true;
Expand Down Expand Up @@ -583,13 +583,22 @@ bool DmaDimConfig::isValidAccessPattern(ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides) const {
assert(sizes.size() == strides.size() &&
"`sizes` and `strides` should have the same size");
SmallVector<int64_t> maxSizes = getMaxSizes(sizes.size());
assert(maxSizes.size() >= sizes.size() &&
// No need to check the unit dimensions.
SmallVector<int64_t> nonUnitSizes;
SmallVector<int64_t> nonUnitStrides;
for (size_t i = 0; i < sizes.size(); ++i) {
if (sizes[i] != 1) {
nonUnitSizes.push_back(sizes[i]);
nonUnitStrides.push_back(strides[i]);
}
}
SmallVector<int64_t> maxSizes = getMaxSizes(nonUnitSizes.size());
assert(maxSizes.size() >= nonUnitSizes.size() &&
"Max number of dimensions exceeded");
size_t frontToDrop = maxSizes.size() - sizes.size();
if (anyOutOfRange(sizes, maxSizes, frontToDrop)) return false;
SmallVector<int64_t> maxStrides = getMaxStrides(sizes.size());
if (anyOutOfRange(strides, maxStrides, frontToDrop)) return false;
size_t frontToDrop = maxSizes.size() - nonUnitSizes.size();
if (anyOutOfRange(nonUnitSizes, maxSizes, frontToDrop)) return false;
SmallVector<int64_t> maxStrides = getMaxStrides(nonUnitSizes.size());
if (anyOutOfRange(nonUnitStrides, maxStrides, frontToDrop)) return false;
return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,8 @@ TEST_F(FoldTest, UnitDimsMerge) {
EXPECT_TRUE(checkFoldUnitDims({2, 0, 1, 0}, {1, 32, 1, 8},
{1024, 32, 1024, 1}, {96, 0}, {32, 8},
{32, 1}));
EXPECT_TRUE(checkFoldUnitDims({0, 0, 1, 0}, {2, 32, 1, 8}, {0, 32, 1024, 1},
{0, 32, 0}, {2, 32, 8}, {0, 32, 1}));
EXPECT_TRUE(
checkFoldUnitDims({2, 2, 15}, {1, 1, 10}, {4, 6, 10}, {17}, {10}, {10}));
EXPECT_TRUE(checkFoldUnitDims({3, 1, 15}, {1, 1, 10}, {4, 6, 10}, {1, 15},
Expand All @@ -607,6 +609,8 @@ TEST_F(FoldTest, UnitDimsFoldAndMerge) {
{1}, {1}, {96}));
EXPECT_TRUE(checkFoldUnitDims({1, 0, 1, 0}, {1, 1, 1, 8}, {1024, 32, 1024, 1},
{2048}, {8}, {1}));
EXPECT_TRUE(checkFoldUnitDims({0, 0, 1, 0}, {1, 32, 1, 8}, {0, 32, 1024, 1},
{32, 0}, {32, 8}, {32, 1}));
}

TEST_F(FoldTest, FoldRepetitionCount) {
Expand Down
Loading
Loading