Skip to content

Commit

Permalink
[CtrlPkt] Convert control_packet to half_dma_cpy_nd operations (#…
Browse files Browse the repository at this point in the history
…1064)

From control_packet operations, we extract two types of information:
- **reconfiguration sequence**: includes the header and data necessary
for reconfiguration.
- **DMA operations:** covers memcpy and synchronization to send the
content into the device through the shim.

The DMA operations can be further lowered to transactions using existing
controlcode passes.

References: Xilinx/mlir-aie#1709;
Xilinx/mlir-aie#1753

---------

Co-authored-by: Jorn Tuyls <[email protected]>
  • Loading branch information
Yu-Zhewen and jtuyls authored Jan 29, 2025
1 parent 9438dd2 commit cc25a25
Show file tree
Hide file tree
Showing 10 changed files with 561 additions and 7 deletions.
3 changes: 2 additions & 1 deletion compiler/plugins/target/AMD-AIE/iree-amd-aie/IR/AMDAIEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ def AMDAIE_WorkgroupOp : AMDAIE_Op<"workgroup",

let regions = (region SizedRegion<1>:$region);
let arguments = (
ins OptionalAttr<Builtin_DenseResourceElementsAttr>:$npu_instructions
ins OptionalAttr<Builtin_DenseResourceElementsAttr>:$npu_instructions,
OptionalAttr<Builtin_DenseResourceElementsAttr>:$ctrlpkt_sequence
);

let assemblyFormat = [{ regions attr-dict }];
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
// Copyright 2025 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-amd-aie/IR/AMDAIEDialect.h"
#include "iree-amd-aie/IR/AMDAIEOps.h"
#include "iree-amd-aie/Transforms/Passes.h"
#include "iree-amd-aie/Transforms/Utils/AMDAIEUtils.h"
#include "iree-amd-aie/aie_runtime/iree_aie_runtime.h"
#include "mlir/IR/AsmState.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-amdaie-control-packet-to-half-dma-cpy-nd"

namespace mlir::iree_compiler::AMDAIE {

namespace {

struct ControlPacketDmaBuilder {
AMDAIE::AMDAIEDeviceModel deviceModel;
ControlPacketDmaBuilder(AMDAIE::AMDAIEDeviceModel deviceModel)
: deviceModel(std::move(deviceModel)) {}

std::vector<uint32_t> ctrlPktSequence;

llvm::MutableArrayRef<uint32_t> reserveAndGetTail(size_t tailSize) {
size_t oldSize = ctrlPktSequence.size();
size_t newSize = oldSize + tailSize;
ctrlPktSequence.resize(newSize, 0);
return llvm::MutableArrayRef<uint32_t>(ctrlPktSequence.data() + oldSize,
tailSize);
}

void dumpSequenceAsHex() const {
llvm::outs() << "Control Packet Sequence: \n";
// Write hex as 0xXXXXXXXX
for (uint32_t word : ctrlPktSequence)
llvm::outs() << utohexstr(word, 8) << "\n";
}

LogicalResult convert(IRRewriter &rewriter, AMDAIE::WorkgroupOp workgroupOp) {
ctrlPktSequence.clear();

// Get all the `ConnectionOp` whose target is a `CTRL` port.
DenseMap<TileLoc, AMDAIE::ConnectionOp> tileLocToCtrlConnect;
DenseMap<TileLoc, AMDAIE::TileOp> tileLocToTileOp;
auto res = workgroupOp->walk([&](AMDAIE::ConnectionOp connectionOp) {
if (connectionOp.getTargetChannels().size() != 1) {
connectionOp.emitOpError() << "expected a single target channel";
return WalkResult::interrupt();
}

auto targetChannelOp = dyn_cast<AMDAIE::ChannelOp>(
connectionOp.getTargetChannels()[0].getDefiningOp());
if (targetChannelOp.getPortType() == StrmSwPortType::CTRL) {
TileOp tileOp = targetChannelOp.getTileOp();
TileLoc tileLoc = {
static_cast<int>(getConstantIndexOrAssert(tileOp.getCol())),
static_cast<int>(getConstantIndexOrAssert(tileOp.getRow()))};
tileLocToCtrlConnect[tileLoc] = connectionOp;
tileLocToTileOp[tileLoc] = tileOp;
}
return WalkResult::advance();
});
if (res.wasInterrupted()) return failure();

std::vector<AMDAIE::NpuControlPacketOp> ctrlPktOps;
// Convert `NpuControlPacketOp` to `NpuHalfDmaCpyNdOp` + `NpuDmaWaitOp`.
res = workgroupOp->walk([&](AMDAIE::NpuControlPacketOp ctrlPktOp) {
ctrlPktOps.push_back(ctrlPktOp);
// Get `ConnectionOp` for the `CTRL` port.
uint32_t address = ctrlPktOp.getAddress();
uint32_t addrOffset = deviceModel.getOffsetFromAddress(address);
int32_t col = deviceModel.getColumnFromAddress(address);
int32_t row = deviceModel.getRowFromAddress(address);
if (!tileLocToCtrlConnect.count({col, row})) {
ctrlPktOp.emitOpError()
<< "tries to write to tile (col=" << col << ", row=" << row
<< "), but it's `CTRL` port is not routed.";
return WalkResult::interrupt();
}
AMDAIE::ConnectionOp connectionOp = tileLocToCtrlConnect[{col, row}];

// Get `sourceChannelOp`.
if (connectionOp.getSourceChannels().size() != 1) {
connectionOp.emitOpError() << "expected a single source channel";
return WalkResult::interrupt();
}
auto sourceChannelOp = dyn_cast<AMDAIE::ChannelOp>(
connectionOp.getSourceChannels()[0].getDefiningOp());

// Get `offsets`, `sizes`, and `strides`.
uint32_t dataLength = ctrlPktOp.getLength();
int64_t headerAndDataLength = dataLength + 1;
SmallVector<int64_t> offsets{0, 0, 0,
static_cast<long>(ctrlPktSequence.size())};
SmallVector<int64_t> sizes{1, 1, 1, headerAndDataLength};
SmallVector<int64_t> strides{0, 0, 0, 1};

// Store the control packet header.
llvm::MutableArrayRef<uint32_t> words =
reserveAndGetTail(headerAndDataLength);
FailureOr<uint32_t> header = deviceModel.getCtrlPktHeader(
addrOffset, dataLength, static_cast<uint32_t>(ctrlPktOp.getOpcode()),
ctrlPktOp.getStreamId());
if (failed(header)) {
ctrlPktOp.emitOpError() << "failed to get control packet header.";
return WalkResult::interrupt();
}

words[0] = *header;
// Store the control packet data.
std::optional<ArrayRef<int32_t>> maybeData =
ctrlPktOp.getDataFromArrayOrResource();
if (maybeData.has_value()) {
for (uint32_t i = 0; i < dataLength; ++i) {
int32_t data = maybeData.value()[i];
words[i + 1] = reinterpret_cast<uint32_t &>(data);
}
}

rewriter.setInsertionPoint(ctrlPktOp);
// Create token.
SmallVector<Type> resultTypes = {
rewriter.getType<AMDAIE::AsyncTokenType>()};
TypeRange sourceResultTypes = TypeRange{resultTypes};

// Get `bdId`, use `0` for now.
// TODO (zhewen): let `AMDAIEAssignNpuDmaBdIdsPass` decide?
auto constant = rewriter.create<arith::ConstantOp>(
rewriter.getUnknownLoc(), rewriter.getIndexAttr(0));
auto bdIdOp = rewriter.create<AMDAIE::BdIdOp>(rewriter.getUnknownLoc(),
sourceChannelOp.getTileOp(),
constant.getResult());

// Create `NpuHalfDmaCpyNdOp` and `NpuDmaWaitOp`.
auto dmaOp = rewriter.create<AMDAIE::NpuHalfDmaCpyNdOp>(
rewriter.getUnknownLoc(), sourceResultTypes, connectionOp,
connectionOp.getSource(), offsets, sizes, strides, bdIdOp,
sourceChannelOp);
rewriter.create<AMDAIE::NpuDmaWaitOp>(rewriter.getUnknownLoc(),
dmaOp.getResult(0));

return WalkResult::advance();
});
if (res.wasInterrupted()) return failure();

// Erase all the `NpuControlPacketOp`.
for (AMDAIE::NpuControlPacketOp ctrlPktOp : ctrlPktOps)
rewriter.eraseOp(ctrlPktOp);

// Store the control packet sequence in the `WorkgroupOp`.
workgroupOp.setCtrlpktSequenceAttr(DenseUI32ResourceElementsAttr::get(
RankedTensorType::get(
ctrlPktSequence.size(),
IntegerType::get(rewriter.getContext(), 32, IntegerType::Unsigned)),
"ctrlpkt_sequence",
HeapAsmResourceBlob::allocateAndCopyInferAlign(
ArrayRef<uint32_t>(ctrlPktSequence))));
return success();
}
};

class AMDAIEControlPacketToHalfDmaCpyNdPass
: public impl::AMDAIEControlPacketToHalfDmaCpyNdBase<
AMDAIEControlPacketToHalfDmaCpyNdPass> {
public:
AMDAIEControlPacketToHalfDmaCpyNdPass(
const AMDAIEControlPacketToHalfDmaCpyNdOptions &options)
: AMDAIEControlPacketToHalfDmaCpyNdBase(options) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AMDAIEDialect>();
}

void runOnOperation() override;
};

void AMDAIEControlPacketToHalfDmaCpyNdPass::runOnOperation() {
Operation *parentOp = getOperation();
IRRewriter rewriter(parentOp->getContext());

// Get `AMDAIEDeviceModel`.
auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(parentOp);
std::optional<AMDAIEDevice> maybeDevice = getConfigAMDAIEDevice(targetAttr);
if (!maybeDevice) {
parentOp->emitOpError() << "has no AMDAIEDevice in the target "
"attribute configuration.";
return signalPassFailure();
}
AMDAIE::AMDAIEDeviceModel deviceModel =
AMDAIE::getDeviceModel(maybeDevice.value());
ControlPacketDmaBuilder ctrlPktDmaBuilder(std::move(deviceModel));

SmallVector<AMDAIE::WorkgroupOp> workgroupOps;

WalkResult res = parentOp->walk([&](AMDAIE::WorkgroupOp workgroupOp) {
if (failed(ctrlPktDmaBuilder.convert(rewriter, workgroupOp)))
return WalkResult::interrupt();

if (dumpSequence) ctrlPktDmaBuilder.dumpSequenceAsHex();

return WalkResult::advance();
});
if (res.wasInterrupted()) return signalPassFailure();
}

} // namespace

std::unique_ptr<Pass> createAMDAIEControlPacketToHalfDmaCpyNdPass(
AMDAIEControlPacketToHalfDmaCpyNdOptions options) {
return std::make_unique<AMDAIEControlPacketToHalfDmaCpyNdPass>(options);
}

} // namespace mlir::iree_compiler::AMDAIE
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ iree_cc_library(
"AMDAIEControlCodeLowering.cpp"
"AMDAIEControlCodeLoopUnroll.cpp"
"AMDAIEControlCodeToTransaction.cpp"
"AMDAIEControlPacketToHalfDmaCpyNd.cpp"
"AMDAIEConvertCoreForallToFor.cpp"
"AMDAIEConvertDeviceToControlPackets.cpp"
"AMDAIECreateAIEWorkgroup.cpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ namespace mlir::iree_compiler::AMDAIE {
#define GEN_PASS_DEF_AMDAIECONTROLCODELOOPUNROLL
#define GEN_PASS_DEF_AMDAIECONTROLCODELOWERING
#define GEN_PASS_DEF_AMDAIECONTROLCODETOTRANSACTION
#define GEN_PASS_DEF_AMDAIECONTROLPACKETTOHALFDMACPYND
#define GEN_PASS_DEF_AMDAIECONVERTCOREFORALLTOFOR
#define GEN_PASS_DEF_AMDAIECONVERTDEVICETOCONTROLPACKETS
#define GEN_PASS_DEF_AMDAIECREATEAIEWORKGROUP
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,16 @@ std::unique_ptr<Pass> createAMDAIEControlCodeLoweringPass();
std::unique_ptr<Pass> createAMDAIEControlCodeToTransactionPass(
AMDAIEControlCodeToTransactionOptions options = {});

/// Pass to convert `amdaie.npu.control_packet` to
/// `amdaie.npu.half_dma_cpy_nd` operations.
std::unique_ptr<Pass> createAMDAIEControlPacketToHalfDmaCpyNdPass(
AMDAIEControlPacketToHalfDmaCpyNdOptions options = {});

/// Pass to convert `scf.forall` to `scf.for` within `aie.core`.
std::unique_ptr<Pass> createAMDAIEConvertCoreForallToForPass();

/// Pass to convert `aie.device`to a sequence of `aie.npu.control_packet` ops.
/// Pass to convert `aie.device`to a sequence of `amdaie.npu.control_packet`
/// ops.
std::unique_ptr<Pass> createAMDAIEConvertDeviceToControlPacketsPass(
AMDAIEConvertDeviceToControlPacketsOptions options = {});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,24 @@ def AMDAIEControlCodeToTransaction :
];
}

def AMDAIEControlPacketToHalfDmaCpyNd :
Pass<"iree-amdaie-control-packet-to-half-dma-cpy-nd", ""> {
let summary = "Convert control packets to half DMA copy operations.";
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEControlPacketToHalfDmaCpyNdPass()";
let options = [
Option<"dumpSequence", "dump-sequence", "bool", /*default=*/"false",
"Dump the generated control packet sequence, including the header and data. (Used for tests)">
];
}

def AMDAIEConvertCoreForallToFor :
Pass<"iree-amdaie-convert-core-forall-to-for", ""> {
let summary = "Converts `scf.forall` to `scf.for` within `aie.core`.";
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEConvertCoreForallToForPass()";
}

def AMDAIEConvertDeviceToControlPackets: Pass<"iree-amdaie-convert-device-to-control-packets"> {
let summary = "Convert `aie.device` to `amd.npu.control_packet` operations";
let summary = "Convert `aie.device` to `amdaie.npu.control_packet` operations";
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEConvertDeviceToControlPacketsPass()";
let options = [
Option<"pathToElfs", "path-to-elfs", "std::string", /*default=*/"", "Path to ELF files.">,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ iree_lit_test_suite(
"controlcode_loop_unrolling.mlir"
"controlcode_lowering.mlir"
"controlcode_to_transaction.mlir"
"control_packet_to_half_dma_cpy_nd.mlir"
"convert_core_forall_to_for.mlir"
"convert_device_to_control_packets.mlir"
"create_aie_workgroup.mlir"
Expand Down
Loading

0 comments on commit cc25a25

Please sign in to comment.