Skip to content

Commit

Permalink
[LowerToAIE] Add support for DMA chains (#1000)
Browse files Browse the repository at this point in the history
Adds support for lowering `amdaie.npu.circular_dma_cpy_nd` to a chain of
`aie.dma_bd` operations. This is needed becaused low-level DMA BD
configurations currently don't support a zero stride.

For example:

```
amdaie.npu.circular_dma_cpy_nd %some_connection([] [] [], [0, 0, 0] [2, 32, 32] [0, 64, 1])
```
contains an outer dimension with zero stride, which describes a
repetition, and if this is not the common repetition count of all
connections operating on a logical objectFifo, this is lowered to a
chain of `dma_bd` operations to implement the zero stride:

```
^bb3:  // 2 preds: ^bb2, ^bb4
  aie.use_lock(%lock_0_1_0, AcquireGreaterEqual, 1)
  aie.dma_bd(%buffer_0_1 : memref<4096xi32, 1 : i32>) {dimensions = #aie<bd_dim_layout_array[<size = 32, stride = 64>, <size = 32, stride = 1>]>, len = 1024 : i32}
  aie.next_bd ^bb4
^bb4:  // pred: ^bb3
  aie.dma_bd(%buffer_0_1 : memref<4096xi32, 1 : i32>) {dimensions = #aie<bd_dim_layout_array[<size = 32, stride = 64>, <size = 32, stride = 1>]>, len = 1024 : i32}
  aie.use_lock(%lock_0_1, Release, 1)
  aie.next_bd ^bb3
```

Note how the first block contains a lock acquire operation, but no lock
release operation as it needs to hand off to the second block before
releasing a lock to create the DMA chain.
  • Loading branch information
jtuyls authored Dec 24, 2024
1 parent 1266a9f commit 6fae427
Show file tree
Hide file tree
Showing 3 changed files with 320 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,79 @@ using namespace xilinx;

namespace mlir::iree_compiler::AMDAIE {

/// Utility to update the current common repetition count based on the new size
/// and stride access pattern. If this new access pattern has a smaller
/// repetition count, the common repetition count will be decreased.
static FailureOr<size_t> getUpdatedRepetitionCount(
ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides,
size_t curRepetitionCount) {
if (!strides.empty() && !sizes.empty() && isConstantIntValue(strides[0], 0)) {
int i = 0;
size_t repetitionCount{1};
while (i < strides.size() && isConstantIntValue(strides[i], 0)) {
std::optional<int64_t> maybeRepetitionCount =
getConstantIntValue(sizes[i]);
if (!maybeRepetitionCount) return failure();
assert(maybeRepetitionCount.value() >= 0 &&
"sizes should always be larger or equal to zero");
repetitionCount *= maybeRepetitionCount.value();
i += 1;
}
if (curRepetitionCount == 1) return repetitionCount;
size_t newRepetitionCount = std::min(repetitionCount, curRepetitionCount);
if (repetitionCount % newRepetitionCount != 0) return failure();
if (curRepetitionCount % newRepetitionCount != 0) return failure();
return newRepetitionCount;
}
return curRepetitionCount;
}

/// Utility to retrieve the common repetition count from all producers and
/// consumers of a logical objectFifo.
static FailureOr<size_t> getRepetitionCount(LogicalObjFifoOpInterface op) {
size_t repetitionCount = 1;
for (Operation *userOp : op->getUsers()) {
if (auto connectionOp = dyn_cast<AMDAIE::ConnectionOp>(userOp);
connectionOp.getTarget() &&
dyn_cast_if_present<LogicalObjFifoOpInterface>(
connectionOp.getTarget().getDefiningOp()) == op) {
// Handle producer connection operations.
FailureOr<AMDAIE::NpuCircularDmaCpyNdOp> maybeNpuDmaUserOp =
connectionOp.getNpuCircularDmaCpyNdUser();
if (failed(maybeNpuDmaUserOp)) {
return connectionOp.emitOpError()
<< "does not have a circular DMA op user";
}
FailureOr<size_t> maybeNewRepetitionCount = getUpdatedRepetitionCount(
maybeNpuDmaUserOp->getTargetMixedSizes(),
maybeNpuDmaUserOp->getTargetMixedStrides(), repetitionCount);
if (failed(maybeNewRepetitionCount)) {
return maybeNpuDmaUserOp->emitOpError() << "no repetition count found";
}
repetitionCount = maybeNewRepetitionCount.value();
} else if (auto connectionOp = dyn_cast<AMDAIE::ConnectionOp>(userOp);
connectionOp.getSource() &&
dyn_cast_if_present<LogicalObjFifoOpInterface>(
connectionOp.getSource().getDefiningOp()) == op) {
// Handle consumer connection operations.
FailureOr<AMDAIE::NpuCircularDmaCpyNdOp> maybeNpuDmaUserOp =
connectionOp.getNpuCircularDmaCpyNdUser();
if (failed(maybeNpuDmaUserOp)) {
return connectionOp.emitOpError()
<< "does not have a circular DMA op user";
}
FailureOr<size_t> maybeNewRepetitionCount = getUpdatedRepetitionCount(
maybeNpuDmaUserOp->getSourceMixedSizes(),
maybeNpuDmaUserOp->getSourceMixedStrides(), repetitionCount);
if (failed(maybeNewRepetitionCount)) {
return maybeNpuDmaUserOp->emitOpError() << "no repetition count found";
}
repetitionCount = maybeNewRepetitionCount.value();
}
}
return repetitionCount;
}

//===----------------------------------------------------------------------===//
// AIEDeviceBuilder utilities
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -95,8 +168,6 @@ LogicalResult AIEDeviceBuilder::createDMABlocks(
std::optional<uint8_t> pktId) {
OpBuilder::InsertionGuard g(rewriter);

auto [dims, len] = convertSizeStrideToBDDimLayoutArrayAttr(sizes, strides);

Block &endBlock = memOp->getRegion(0).getBlocks().back();
assert(!endBlock.getOps<AIE::EndOp>().empty() &&
"expected last block to have aie.end");
Expand All @@ -111,43 +182,89 @@ LogicalResult AIEDeviceBuilder::createDMABlocks(
&endBlock);
if (lastDmaBlock) lastDmaBlock->getTerminator()->setSuccessor(dmaBlock, 1);

auto createBdBlockOps = [&](AIE::BufferOp buff, Block *succ,
AIE::BDDimLayoutArrayAttr dimsAttr,
int64_t transferLength) {
auto createDMAOps = [&](Block *succ, AIE::BufferOp buff,
AIE::BDDimLayoutArrayAttr dims, bool shouldAcqLock,
bool shouldRelLock, int64_t transferLength,
int64_t offset) {
AIE::LockOp acqLock = locks.first, relLock = locks.second;
rewriter.create<AIE::UseLockOp>(rewriter.getUnknownLoc(), acqLock,
AIE::LockAction::AcquireGreaterEqual,
acqNum);
// Insert a packet op for MM2S DMAs if part of a packet flow. Only do this
// for MM2S DMA ports as only those can insert packet headers.
if (channelDir == AIE::DMAChannelDir::MM2S && pktId) {
if (shouldAcqLock) {
rewriter.create<AIE::UseLockOp>(rewriter.getUnknownLoc(), acqLock,
AIE::LockAction::AcquireGreaterEqual,
acqNum);
}
// Insert a packet op for MM2S DMAs if part of a packet flow. Only do
// this for MM2S DMA ports as only those can insert packet headers.
if (channelDir == AIE::DMAChannelDir::MM2S && pktId.has_value()) {
rewriter.create<AIE::DMABDPACKETOp>(rewriter.getUnknownLoc(),
/*pkt_type*/ 0,
/*pkt_id*/ pktId.value());
}
if (!dimsAttr.getValue().empty()) {
if (!dims.getValue().empty()) {
rewriter.create<AIE::DMABDOp>(rewriter.getUnknownLoc(), buff, offset,
transferLength, dimsAttr);
transferLength, dims);
} else {
rewriter.create<AIE::DMABDOp>(rewriter.getUnknownLoc(), buff, offset,
transferLength);
}
rewriter.create<AIE::UseLockOp>(rewriter.getUnknownLoc(), relLock,
AIE::LockAction::Release, relNum);
if (shouldRelLock) {
rewriter.create<AIE::UseLockOp>(rewriter.getUnknownLoc(), relLock,
AIE::LockAction::Release, relNum);
}
rewriter.create<AIE::NextBDOp>(rewriter.getUnknownLoc(), succ);
};

// Create Bd blocks.
// Find the last index with a zero stride. All dimensions before and including
// this one will be converted into separate DMA ops, while the dimensions
// after this will be included in the access pattern within a DMA op. This is
// needed becaused low-level DMA BD configurations currently don't support
// zero stride and/or because more dimensions are needed than available.
int64_t lastZeroStrideIndex{-1};
for (size_t i = 0; i < strides.size(); i++)
if (strides[i] == 0) lastZeroStrideIndex = i;

// Convert all dimensions after the last index with zero stride to a
// `BDDimLayoutArrayAttr` as these are the inner/intra DMA dimensions.
auto [dims, transferLength] = convertSizeStrideToBDDimLayoutArrayAttr(
ArrayRef<int64_t>(sizes).drop_front(lastZeroStrideIndex + 1),
ArrayRef<int64_t>(strides).drop_front(lastZeroStrideIndex + 1));

SmallVector<size_t> indexRange(lastZeroStrideIndex + 1);
std::iota(indexRange.begin(), indexRange.end(), 0);
// Compute the total number of iterations of all dimensions up till
// `lastZeroStrideIndex`.
int64_t numIters = std::accumulate(
sizes.begin(), sizes.begin() + indexRange.size(), 1, std::multiplies<>());
// Compute the divisors to be used to get the indices for every dimension from
// the total number of iterations (as if all dimensions are coalesced).
SmallVector<int64_t> cartesianDivisors(indexRange.size(), 1);
for (int64_t i = indexRange.size() - 2; i >= 0; i--)
cartesianDivisors[i] = cartesianDivisors[i + 1] * sizes[i + 1];

// Create blocks with DMA ops.
Block *succ = nullptr, *curr = bdBlock;
for (size_t blockIndex = 0; blockIndex < bufferOps.size(); ++blockIndex) {
if (blockIndex == bufferOps.size() - 1) {
succ = bdBlock;
} else {
succ = rewriter.createBlock(&endBlock);
// Iterate through the cartesian product of all dimension up to the last
// dimension with zero strides to create a DMA chain of `dma_bd` ops.
for (int64_t index = 0; index < numIters; index++) {
SmallVector<int64_t> indices = llvm::map_to_vector(
indexRange,
[&](size_t i) { return (index / cartesianDivisors[i]) % sizes[i]; });
bool isFirst = llvm::all_of(indices, [](int64_t v) { return v == 0; });
bool isLast = llvm::all_of(
indexRange, [&](size_t i) { return indices[i] == (sizes[i] - 1); });
if (blockIndex == bufferOps.size() - 1 && isLast) {
succ = bdBlock;
} else {
succ = rewriter.createBlock(&endBlock);
}
rewriter.setInsertionPointToStart(curr);
int64_t addOffset = 0;
for (size_t i = 0; i < indexRange.size(); i++)
addOffset += (indices[i] * strides[i]);
createDMAOps(succ, bufferOps[blockIndex], dims, isFirst, isLast,
transferLength, offset + addOffset);
curr = succ;
}
rewriter.setInsertionPointToStart(curr);
createBdBlockOps(bufferOps[blockIndex], succ, dims, len);
curr = succ;
}
return success();
}
Expand Down Expand Up @@ -218,8 +335,10 @@ void AIEDeviceBuilder::eraseOp(Operation *op) {
LogicalResult AIEDeviceBuilder::foldDimsAndReturnAsStatic(
SmallVector<OpFoldResult> sizes, SmallVector<OpFoldResult> strides,
SmallVector<int64_t> &newSizes, SmallVector<int64_t> &newStrides,
uint8_t memSpace, function_ref<InFlightDiagnostic()> emitError) {
if (failed(foldRepetitionCount(rewriter.getContext(), sizes, strides))) {
size_t repetitionCount, uint8_t memSpace,
function_ref<InFlightDiagnostic()> emitError) {
if (failed(foldRepetitionCount(rewriter.getContext(), sizes, strides,
repetitionCount))) {
return emitError() << "could not fold repetition counts";
}
SmallVector<OpFoldResult> offsets(
Expand Down Expand Up @@ -505,6 +624,12 @@ LogicalResult AIEDeviceBuilder::connectionToAIE(
<< "expected source to be an "
"`amdaie.logicalobjectfifo.from_buffers` op";
}
FailureOr<size_t> repetitionCount = getRepetitionCount(
cast<LogicalObjFifoOpInterface>(sourceObjFifo.getOperation()));
if (failed(repetitionCount)) {
return sourceObjFifo->emitOpError()
<< "could not retrieve the repetition count";
}
std::optional<size_t> maybeOffset =
maybeNpuDmaUserOp->getSourceStaticBaseOffset();
if (!maybeOffset) {
Expand Down Expand Up @@ -558,21 +683,22 @@ LogicalResult AIEDeviceBuilder::connectionToAIE(
}
std::pair<AIE::LockOp, AIE::LockOp> lockPair =
std::make_pair(consumerLocks[0], producerLocks[0]);
rewriter.moveOpBefore(memOp, deviceBlock,
deviceBlock->without_terminator().end());
SmallVector<int64_t> canonicalizedSizes, canonicalizedStrides;
if (failed(foldDimsAndReturnAsStatic(
maybeNpuDmaUserOp->getSourceMixedSizes(),
maybeNpuDmaUserOp->getSourceMixedStrides(), canonicalizedSizes,
canonicalizedStrides, maybeSourceMemSpace.value(),
canonicalizedStrides, repetitionCount.value(),
maybeSourceMemSpace.value(),
[&]() { return maybeNpuDmaUserOp->emitOpError(); }))) {
return failure();
};
rewriter.moveOpBefore(memOp, deviceBlock,
deviceBlock->without_terminator().end());
if (failed(createDMABlocks(
memOp, AIE::DMAChannelDir::MM2S, channel.getValue(),
canonicalizedSizes, canonicalizedStrides, acqNum, acqNum,
maybeOffset.value(), buffers, lockPair, packetId))) {
return failure();
return sourceObjFifo.emitOpError() << "could not create DMA operations";
}
}
}
Expand Down Expand Up @@ -603,6 +729,12 @@ LogicalResult AIEDeviceBuilder::connectionToAIE(
<< "expected target to be an "
"`amdaie.logicalobjectfifo.from_buffers` op";
}
FailureOr<size_t> repetitionCount = getRepetitionCount(
cast<LogicalObjFifoOpInterface>(targetObjFifo.getOperation()));
if (failed(repetitionCount)) {
return targetObjFifo->emitOpError()
<< "could not retrieve the repetition count";
}
std::optional<size_t> maybeOffset =
maybeNpuDmaUserOp->getTargetStaticBaseOffset();
if (!maybeOffset) {
Expand Down Expand Up @@ -660,7 +792,8 @@ LogicalResult AIEDeviceBuilder::connectionToAIE(
if (failed(foldDimsAndReturnAsStatic(
maybeNpuDmaUserOp->getTargetMixedSizes(),
maybeNpuDmaUserOp->getTargetMixedStrides(), canonicalizedSizes,
canonicalizedStrides, maybeTargetMemSpace.value(),
canonicalizedStrides, repetitionCount.value(),
maybeTargetMemSpace.value(),
[&]() { return maybeNpuDmaUserOp->emitOpError(); }))) {
return failure();
};
Expand All @@ -670,7 +803,7 @@ LogicalResult AIEDeviceBuilder::connectionToAIE(
memOp, AIE::DMAChannelDir::S2MM, channel.getValue(),
canonicalizedSizes, canonicalizedStrides, acqNum, acqNum,
maybeOffset.value(), buffers, lockPair, packetId))) {
return failure();
return targetObjFifo.emitOpError() << "could not create DMA operations";
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ class AIEDeviceBuilder {
LogicalResult foldDimsAndReturnAsStatic(
SmallVector<OpFoldResult> sizes, SmallVector<OpFoldResult> strides,
SmallVector<int64_t> &newSizes, SmallVector<int64_t> &newStrides,
uint8_t memSpace, function_ref<InFlightDiagnostic()> emitError);
size_t repetitionCount, uint8_t memSpace,
function_ref<InFlightDiagnostic()> emitError);

/// Utility to remap the provided operation's operands.
void remapOperands(Operation *op);
Expand Down
Loading

0 comments on commit 6fae427

Please sign in to comment.