-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
…1064) From control_packet operations, we extract two types of information: - **reconfiguration sequence**: includes the header and data necessary for reconfiguration. - **DMA operations:** covers memcpy and synchronization to send the content into the device through the shim. The DMA operations can be further lowered to transactions using existing controlcode passes. References: Xilinx/mlir-aie#1709; Xilinx/mlir-aie#1753 --------- Co-authored-by: Jorn Tuyls <[email protected]>
- Loading branch information
Showing
10 changed files
with
561 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
216 changes: 216 additions & 0 deletions
216
...iler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEControlPacketToHalfDmaCpyNd.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
// Copyright 2025 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "iree-amd-aie/IR/AMDAIEDialect.h" | ||
#include "iree-amd-aie/IR/AMDAIEOps.h" | ||
#include "iree-amd-aie/Transforms/Passes.h" | ||
#include "iree-amd-aie/Transforms/Utils/AMDAIEUtils.h" | ||
#include "iree-amd-aie/aie_runtime/iree_aie_runtime.h" | ||
#include "mlir/IR/AsmState.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
#define DEBUG_TYPE "iree-amdaie-control-packet-to-half-dma-cpy-nd" | ||
|
||
namespace mlir::iree_compiler::AMDAIE { | ||
|
||
namespace { | ||
|
||
struct ControlPacketDmaBuilder { | ||
AMDAIE::AMDAIEDeviceModel deviceModel; | ||
ControlPacketDmaBuilder(AMDAIE::AMDAIEDeviceModel deviceModel) | ||
: deviceModel(std::move(deviceModel)) {} | ||
|
||
std::vector<uint32_t> ctrlPktSequence; | ||
|
||
llvm::MutableArrayRef<uint32_t> reserveAndGetTail(size_t tailSize) { | ||
size_t oldSize = ctrlPktSequence.size(); | ||
size_t newSize = oldSize + tailSize; | ||
ctrlPktSequence.resize(newSize, 0); | ||
return llvm::MutableArrayRef<uint32_t>(ctrlPktSequence.data() + oldSize, | ||
tailSize); | ||
} | ||
|
||
void dumpSequenceAsHex() const { | ||
llvm::outs() << "Control Packet Sequence: \n"; | ||
// Write hex as 0xXXXXXXXX | ||
for (uint32_t word : ctrlPktSequence) | ||
llvm::outs() << utohexstr(word, 8) << "\n"; | ||
} | ||
|
||
LogicalResult convert(IRRewriter &rewriter, AMDAIE::WorkgroupOp workgroupOp) { | ||
ctrlPktSequence.clear(); | ||
|
||
// Get all the `ConnectionOp` whose target is a `CTRL` port. | ||
DenseMap<TileLoc, AMDAIE::ConnectionOp> tileLocToCtrlConnect; | ||
DenseMap<TileLoc, AMDAIE::TileOp> tileLocToTileOp; | ||
auto res = workgroupOp->walk([&](AMDAIE::ConnectionOp connectionOp) { | ||
if (connectionOp.getTargetChannels().size() != 1) { | ||
connectionOp.emitOpError() << "expected a single target channel"; | ||
return WalkResult::interrupt(); | ||
} | ||
|
||
auto targetChannelOp = dyn_cast<AMDAIE::ChannelOp>( | ||
connectionOp.getTargetChannels()[0].getDefiningOp()); | ||
if (targetChannelOp.getPortType() == StrmSwPortType::CTRL) { | ||
TileOp tileOp = targetChannelOp.getTileOp(); | ||
TileLoc tileLoc = { | ||
static_cast<int>(getConstantIndexOrAssert(tileOp.getCol())), | ||
static_cast<int>(getConstantIndexOrAssert(tileOp.getRow()))}; | ||
tileLocToCtrlConnect[tileLoc] = connectionOp; | ||
tileLocToTileOp[tileLoc] = tileOp; | ||
} | ||
return WalkResult::advance(); | ||
}); | ||
if (res.wasInterrupted()) return failure(); | ||
|
||
std::vector<AMDAIE::NpuControlPacketOp> ctrlPktOps; | ||
// Convert `NpuControlPacketOp` to `NpuHalfDmaCpyNdOp` + `NpuDmaWaitOp`. | ||
res = workgroupOp->walk([&](AMDAIE::NpuControlPacketOp ctrlPktOp) { | ||
ctrlPktOps.push_back(ctrlPktOp); | ||
// Get `ConnectionOp` for the `CTRL` port. | ||
uint32_t address = ctrlPktOp.getAddress(); | ||
uint32_t addrOffset = deviceModel.getOffsetFromAddress(address); | ||
int32_t col = deviceModel.getColumnFromAddress(address); | ||
int32_t row = deviceModel.getRowFromAddress(address); | ||
if (!tileLocToCtrlConnect.count({col, row})) { | ||
ctrlPktOp.emitOpError() | ||
<< "tries to write to tile (col=" << col << ", row=" << row | ||
<< "), but it's `CTRL` port is not routed."; | ||
return WalkResult::interrupt(); | ||
} | ||
AMDAIE::ConnectionOp connectionOp = tileLocToCtrlConnect[{col, row}]; | ||
|
||
// Get `sourceChannelOp`. | ||
if (connectionOp.getSourceChannels().size() != 1) { | ||
connectionOp.emitOpError() << "expected a single source channel"; | ||
return WalkResult::interrupt(); | ||
} | ||
auto sourceChannelOp = dyn_cast<AMDAIE::ChannelOp>( | ||
connectionOp.getSourceChannels()[0].getDefiningOp()); | ||
|
||
// Get `offsets`, `sizes`, and `strides`. | ||
uint32_t dataLength = ctrlPktOp.getLength(); | ||
int64_t headerAndDataLength = dataLength + 1; | ||
SmallVector<int64_t> offsets{0, 0, 0, | ||
static_cast<long>(ctrlPktSequence.size())}; | ||
SmallVector<int64_t> sizes{1, 1, 1, headerAndDataLength}; | ||
SmallVector<int64_t> strides{0, 0, 0, 1}; | ||
|
||
// Store the control packet header. | ||
llvm::MutableArrayRef<uint32_t> words = | ||
reserveAndGetTail(headerAndDataLength); | ||
FailureOr<uint32_t> header = deviceModel.getCtrlPktHeader( | ||
addrOffset, dataLength, static_cast<uint32_t>(ctrlPktOp.getOpcode()), | ||
ctrlPktOp.getStreamId()); | ||
if (failed(header)) { | ||
ctrlPktOp.emitOpError() << "failed to get control packet header."; | ||
return WalkResult::interrupt(); | ||
} | ||
|
||
words[0] = *header; | ||
// Store the control packet data. | ||
std::optional<ArrayRef<int32_t>> maybeData = | ||
ctrlPktOp.getDataFromArrayOrResource(); | ||
if (maybeData.has_value()) { | ||
for (uint32_t i = 0; i < dataLength; ++i) { | ||
int32_t data = maybeData.value()[i]; | ||
words[i + 1] = reinterpret_cast<uint32_t &>(data); | ||
} | ||
} | ||
|
||
rewriter.setInsertionPoint(ctrlPktOp); | ||
// Create token. | ||
SmallVector<Type> resultTypes = { | ||
rewriter.getType<AMDAIE::AsyncTokenType>()}; | ||
TypeRange sourceResultTypes = TypeRange{resultTypes}; | ||
|
||
// Get `bdId`, use `0` for now. | ||
// TODO (zhewen): let `AMDAIEAssignNpuDmaBdIdsPass` decide? | ||
auto constant = rewriter.create<arith::ConstantOp>( | ||
rewriter.getUnknownLoc(), rewriter.getIndexAttr(0)); | ||
auto bdIdOp = rewriter.create<AMDAIE::BdIdOp>(rewriter.getUnknownLoc(), | ||
sourceChannelOp.getTileOp(), | ||
constant.getResult()); | ||
|
||
// Create `NpuHalfDmaCpyNdOp` and `NpuDmaWaitOp`. | ||
auto dmaOp = rewriter.create<AMDAIE::NpuHalfDmaCpyNdOp>( | ||
rewriter.getUnknownLoc(), sourceResultTypes, connectionOp, | ||
connectionOp.getSource(), offsets, sizes, strides, bdIdOp, | ||
sourceChannelOp); | ||
rewriter.create<AMDAIE::NpuDmaWaitOp>(rewriter.getUnknownLoc(), | ||
dmaOp.getResult(0)); | ||
|
||
return WalkResult::advance(); | ||
}); | ||
if (res.wasInterrupted()) return failure(); | ||
|
||
// Erase all the `NpuControlPacketOp`. | ||
for (AMDAIE::NpuControlPacketOp ctrlPktOp : ctrlPktOps) | ||
rewriter.eraseOp(ctrlPktOp); | ||
|
||
// Store the control packet sequence in the `WorkgroupOp`. | ||
workgroupOp.setCtrlpktSequenceAttr(DenseUI32ResourceElementsAttr::get( | ||
RankedTensorType::get( | ||
ctrlPktSequence.size(), | ||
IntegerType::get(rewriter.getContext(), 32, IntegerType::Unsigned)), | ||
"ctrlpkt_sequence", | ||
HeapAsmResourceBlob::allocateAndCopyInferAlign( | ||
ArrayRef<uint32_t>(ctrlPktSequence)))); | ||
return success(); | ||
} | ||
}; | ||
|
||
class AMDAIEControlPacketToHalfDmaCpyNdPass | ||
: public impl::AMDAIEControlPacketToHalfDmaCpyNdBase< | ||
AMDAIEControlPacketToHalfDmaCpyNdPass> { | ||
public: | ||
AMDAIEControlPacketToHalfDmaCpyNdPass( | ||
const AMDAIEControlPacketToHalfDmaCpyNdOptions &options) | ||
: AMDAIEControlPacketToHalfDmaCpyNdBase(options) {} | ||
void getDependentDialects(DialectRegistry ®istry) const override { | ||
registry.insert<AMDAIEDialect>(); | ||
} | ||
|
||
void runOnOperation() override; | ||
}; | ||
|
||
void AMDAIEControlPacketToHalfDmaCpyNdPass::runOnOperation() { | ||
Operation *parentOp = getOperation(); | ||
IRRewriter rewriter(parentOp->getContext()); | ||
|
||
// Get `AMDAIEDeviceModel`. | ||
auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(parentOp); | ||
std::optional<AMDAIEDevice> maybeDevice = getConfigAMDAIEDevice(targetAttr); | ||
if (!maybeDevice) { | ||
parentOp->emitOpError() << "has no AMDAIEDevice in the target " | ||
"attribute configuration."; | ||
return signalPassFailure(); | ||
} | ||
AMDAIE::AMDAIEDeviceModel deviceModel = | ||
AMDAIE::getDeviceModel(maybeDevice.value()); | ||
ControlPacketDmaBuilder ctrlPktDmaBuilder(std::move(deviceModel)); | ||
|
||
SmallVector<AMDAIE::WorkgroupOp> workgroupOps; | ||
|
||
WalkResult res = parentOp->walk([&](AMDAIE::WorkgroupOp workgroupOp) { | ||
if (failed(ctrlPktDmaBuilder.convert(rewriter, workgroupOp))) | ||
return WalkResult::interrupt(); | ||
|
||
if (dumpSequence) ctrlPktDmaBuilder.dumpSequenceAsHex(); | ||
|
||
return WalkResult::advance(); | ||
}); | ||
if (res.wasInterrupted()) return signalPassFailure(); | ||
} | ||
|
||
} // namespace | ||
|
||
std::unique_ptr<Pass> createAMDAIEControlPacketToHalfDmaCpyNdPass( | ||
AMDAIEControlPacketToHalfDmaCpyNdOptions options) { | ||
return std::make_unique<AMDAIEControlPacketToHalfDmaCpyNdPass>(options); | ||
} | ||
|
||
} // namespace mlir::iree_compiler::AMDAIE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.