Skip to content

Commit

Permalink
[AssignChannels] Prioritize channel assignment for control packets (#…
Browse files Browse the repository at this point in the history
…1106)

Changes in this PR:
- Previously, channels for data flows were assigned first using the
`AssignChannels` pass, followed by control flow channel assignment with
the `GenerateControlOverla`y pass. This PR reverses the order, ensuring
that control flow channels are assigned before data flow channels.
- Previously, in the runtime `ChannelGenerator`, packet flow channels
were not explicitly marked as assigned (so that they can be reused),
while only circuit flow channels were tracked. However, this could
result in a packet flow channel later being mistakenly reused by another
circuit flow. Now, both circuit and packet flows track their assigned
channels and store them separately. This ensures that circuit flow
channels remain exclusive, while packet flow channels can be reused but
only by other packets.
  • Loading branch information
Yu-Zhewen authored Feb 15, 2025
1 parent 13870fd commit 54b89c2
Show file tree
Hide file tree
Showing 9 changed files with 433 additions and 259 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,62 @@ namespace mlir::iree_compiler::AMDAIE {

namespace {

/// Initializes channel generators for tiles by detecting DMA channels
/// previously assigned by other passes (e.g., for control packets) and
/// registering them to prevent conflicts.
LogicalResult initializeChannelsGenerators(
AMDAIE::WorkgroupOp workgroupOp, const AMDAIEDeviceModel &deviceModel,
DenseMap<Value, ChannelGenerator> &tileToGeneratorMap) {
// Get the number of producer and consumer channels for each tile.
workgroupOp.walk([&](AMDAIE::TileOp tileOp) {
uint32_t col = getConstantIndexOrAssert(tileOp.getCol());
uint32_t row = getConstantIndexOrAssert(tileOp.getRow());
AMDAIETileType tileType = deviceModel.getTileType(col, row);
uint8_t numDmaChannels =
deviceModel.getDmaProp<uint8_t>(tileType, AMDAIEDmaProp::NumChannels);
tileToGeneratorMap[tileOp.getResult()] =
ChannelGenerator(numDmaChannels, numDmaChannels);
});

WalkResult res = workgroupOp.walk([&](AMDAIE::ConnectionOp connectionOp) {
ChannelAssignmentMode mode =
(connectionOp.getConnectionType() == AMDAIE::ConnectionType::Packet)
? ChannelAssignmentMode::RoundRobinPacketFlow
: ChannelAssignmentMode::FirstAvailableCircuitFlow;
// Check source DMA channels previously assigned by other passes,
// and register them in `ChannelGenerator` using `assignProducerDMAChannel`.
for (Value source : connectionOp.getSourceChannels()) {
auto channelOp = dyn_cast<AMDAIE::ChannelOp>(source.getDefiningOp());
if (!channelOp) {
connectionOp.emitOpError() << "expected a `amdaie.channel` op source";
return WalkResult::interrupt();
}
if (channelOp.getPortType() == StrmSwPortType::DMA) {
Value tile = channelOp.getTileOp().getResult();
tileToGeneratorMap[tile].assignProducerDMAChannel(channelOp.getValue(),
mode);
}
}
// Check target DMA channels previously assigned by other passes,
// and register them in `ChannelGenerator` using `assignConsumerDMAChannel`.
for (Value target : connectionOp.getTargetChannels()) {
auto channelOp = dyn_cast<AMDAIE::ChannelOp>(target.getDefiningOp());
if (!channelOp) {
connectionOp.emitOpError() << "expected a `amdaie.channel` op target";
return WalkResult::interrupt();
}
if (channelOp.getPortType() == StrmSwPortType::DMA) {
Value tile = channelOp.getTileOp().getResult();
tileToGeneratorMap[tile].assignConsumerDMAChannel(channelOp.getValue(),
mode);
}
}
return WalkResult::advance();
});
if (res.wasInterrupted()) return failure();
return success();
}

/// Assign channels to `amdaie.connection` ops.
LogicalResult assignChannels(AMDAIE::WorkgroupOp workgroupOp) {
IRRewriter rewriter(workgroupOp->getContext());
Expand All @@ -27,19 +83,13 @@ LogicalResult assignChannels(AMDAIE::WorkgroupOp workgroupOp) {
<< "could not find an AMDAIEDevice attribute";
}
AMDAIEDeviceModel deviceModel = AMDAIE::getDeviceModel(device.value());

// Get the number of producer and consumer channels for each tile.
// Initialize channel generators for tiles.
DenseMap<Value, ChannelGenerator> tileToGeneratorMap;
workgroupOp.walk([&](AMDAIE::TileOp tileOp) {
uint32_t col = getConstantIndexOrAssert(tileOp.getCol());
uint32_t row = getConstantIndexOrAssert(tileOp.getRow());
AMDAIETileType tileType = deviceModel.getTileType(col, row);
uint8_t numDmaChannels =
deviceModel.getDmaProp<uint8_t>(tileType, AMDAIEDmaProp::NumChannels);
tileToGeneratorMap[tileOp.getResult()] =
ChannelGenerator(numDmaChannels, numDmaChannels);
});

if (failed(initializeChannelsGenerators(workgroupOp, deviceModel,
tileToGeneratorMap))) {
return failure();
}
// Get all `amdaie.connection` ops.
SmallVector<AMDAIE::ConnectionOp> connectionOps;
workgroupOp->walk([&](AMDAIE::ConnectionOp connectionOp) {
connectionOps.push_back(connectionOp);
Expand All @@ -59,48 +109,49 @@ LogicalResult assignChannels(AMDAIE::WorkgroupOp workgroupOp) {
return connectionOp.emitOpError()
<< "expected a `LogicalObjFifoOpInterface` target";
}
std::optional<AMDAIE::ConnectionType> connectionType =
connectionOp.getConnectionType();
bool isPacketFlow = connectionType && connectionType.value() ==
AMDAIE::ConnectionType::Packet;

ChannelAssignmentMode mode =
(connectionOp.getConnectionType() == AMDAIE::ConnectionType::Packet)
? ChannelAssignmentMode::RoundRobinPacketFlow
: ChannelAssignmentMode::FirstAvailableCircuitFlow;
rewriter.setInsertionPoint(connectionOp);
SmallVector<Value> sourceChannels;
for (Value tile : sourceLogicalObjFifo.getTiles()) {
assert(tileToGeneratorMap.contains(tile) &&
"no channel generator found for tile");
std::optional<uint8_t> maybeChannel =
tileToGeneratorMap[tile].getProducerDMAChannel();
if (!maybeChannel) {
return connectionOp.emitOpError()
<< "no producer DMA channel available";
SmallVector<Value> sourceChannels = connectionOp.getSourceChannels();
// Assign source (producer) DMA channels if not already assigned.
if (sourceChannels.empty()) {
for (Value tile : sourceLogicalObjFifo.getTiles()) {
assert(tileToGeneratorMap.contains(tile) &&
"no channel generator found for tile");
std::optional<uint8_t> maybeChannel =
tileToGeneratorMap[tile].getAndAssignProducerDMAChannel(mode);
if (!maybeChannel) {
return connectionOp.emitOpError()
<< "no producer DMA channel available";
}
auto channelOp = rewriter.create<AMDAIE::ChannelOp>(
rewriter.getUnknownLoc(), tile, maybeChannel.value(),
StrmSwPortType::DMA, AMDAIE::DMAChannelDir::MM2S);
sourceChannels.push_back(channelOp.getResult());
}
// Only assign the channel if it is for circuit flow.
if (!isPacketFlow)
tileToGeneratorMap[tile].assignProducerDMAChannel(maybeChannel.value());
auto channelOp = rewriter.create<AMDAIE::ChannelOp>(
rewriter.getUnknownLoc(), tile, maybeChannel.value(),
StrmSwPortType::DMA, AMDAIE::DMAChannelDir::MM2S);
sourceChannels.push_back(channelOp.getResult());
}
SmallVector<Value> targetChannels;
for (Value tile : targetLogicalObjFifo.getTiles()) {
assert(tileToGeneratorMap.contains(tile) &&
"no channel generator found for tile");
std::optional<uint8_t> maybeChannel =
tileToGeneratorMap[tile].getConsumerDMAChannel();
if (!maybeChannel) {
return connectionOp.emitOpError()
<< "no consumer DMA channel available";
// Assign target (consumer) DMA channels if not already assigned.
SmallVector<Value> targetChannels = connectionOp.getTargetChannels();
if (targetChannels.empty()) {
for (Value tile : targetLogicalObjFifo.getTiles()) {
assert(tileToGeneratorMap.contains(tile) &&
"no channel generator found for tile");
std::optional<uint8_t> maybeChannel =
tileToGeneratorMap[tile].getAndAssignConsumerDMAChannel(mode);
if (!maybeChannel) {
return connectionOp.emitOpError()
<< "no consumer DMA channel available";
}
auto channelOp = rewriter.create<AMDAIE::ChannelOp>(
rewriter.getUnknownLoc(), tile, maybeChannel.value(),
StrmSwPortType::DMA, AMDAIE::DMAChannelDir::S2MM);
targetChannels.push_back(channelOp.getResult());
}
// Only assign the channel if it is for circuit flow.
if (!isPacketFlow)
tileToGeneratorMap[tile].assignConsumerDMAChannel(maybeChannel.value());
auto channelOp = rewriter.create<AMDAIE::ChannelOp>(
rewriter.getUnknownLoc(), tile, maybeChannel.value(),
StrmSwPortType::DMA, AMDAIE::DMAChannelDir::S2MM);
targetChannels.push_back(channelOp.getResult());
}
// Replace the `amdaie.connection` op with newly assigned `sourceChannels`
// and `targetChannels`.
rewriter.replaceOpWithNewOp<AMDAIE::ConnectionOp>(
connectionOp, connectionOp.getTarget(), targetChannels,
connectionOp.getSource(), sourceChannels,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ namespace mlir::iree_compiler::AMDAIE {

namespace {

/// Initializes the channel generators for the shim tiles, excluding any
/// channels that are already in use by existing circuit-mode connections.
/// Initializes channel generators for shim tiles, ensuring that no shim DMA
/// MM2S channels have been assigned before. This guarantees priority for the
/// control overlay.
LogicalResult initializeChannelsGenerators(
AMDAIE::WorkgroupOp workgroupOp, const AMDAIEDeviceModel &deviceModel,
const DenseSet<TileOp> &shimTileOps,
Expand All @@ -29,40 +30,19 @@ LogicalResult initializeChannelsGenerators(
shimTileToGeneratorMap[shimTileOp.getResult()] =
ChannelGenerator(numShimDmaChannels, numShimDmaChannels);
});
// Exclude those channels that are already used by a circuit-mode connection.
workgroupOp->walk([&](AMDAIE::ConnectionOp connectionOp) {
std::optional<AMDAIE::ConnectionType> connectionType =
connectionOp.getConnectionType();
bool isPacketFlow = connectionType && connectionType.value() ==
AMDAIE::ConnectionType::Packet;
if (isPacketFlow) return WalkResult::advance();
SmallVector<AMDAIE::ChannelOp> sourceChannels;
for (Value source : connectionOp.getSourceChannels()) {
if (auto channelOp =
dyn_cast<AMDAIE::ChannelOp>(source.getDefiningOp())) {
sourceChannels.push_back(channelOp);
}
}
for (AMDAIE::ChannelOp channelOp : sourceChannels) {
AMDAIE::TileOp tileOp = channelOp.getTileOp();
uint8_t channel = channelOp.getValue();
StrmSwPortType portType = channelOp.getPortType();
AMDAIE::DMAChannelDir direction = channelOp.getDirection();
if (shimTileOps.contains(tileOp) && portType == StrmSwPortType::DMA) {
// Assign to exclude.
if (direction == AMDAIE::DMAChannelDir::MM2S) {
shimTileToGeneratorMap[tileOp.getResult()].assignProducerDMAChannel(
channel);
} else if (direction == AMDAIE::DMAChannelDir::S2MM) {
shimTileToGeneratorMap[tileOp.getResult()].assignConsumerDMAChannel(
channel);
} else {
assert(false && "unexpected DMA channel direction");
}
}
// Ensure that shim DMA MM2S channels are not already assigned.
WalkResult res = workgroupOp->walk([&](AMDAIE::ChannelOp channelOp) {
if (shimTileOps.contains(channelOp.getTileOp()) &&
channelOp.getPortType() == StrmSwPortType::DMA &&
channelOp.getDirection() == AMDAIE::DMAChannelDir::MM2S) {
channelOp.emitOpError()
<< "shim DMA MM2S channel must remain unassigned before "
"control overlay generation.";
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (res.wasInterrupted()) return failure();
return success();
}

Expand Down Expand Up @@ -114,11 +94,12 @@ LogicalResult generateControlOverlay(AMDAIE::WorkgroupOp workgroupOp,
WalkResult res = workgroupOp->walk([&](AMDAIE::TileOp tileOp) {
uint32_t col = getConstantIndexOrAssert(tileOp.getCol());
TileOp shimTileOp = columnToShimTile[col];
// Get the available channel, but do not assign it. Allow it to be
// shared across multiple packet-mode connections as needed.
// Get the available DMA channel for the shim tile, and assign it for the
// packet flow.
std::optional<uint8_t> maybeChannel =
shimTileToGeneratorMap[shimTileOp.getResult()]
.getProducerDMAChannel();
.getAndAssignProducerDMAChannel(
ChannelAssignmentMode::RoundRobinPacketFlow);
if (!maybeChannel) {
shimTileOp.emitOpError() << "no producer DMA channel available";
return WalkResult::interrupt();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,8 @@ void addAMDAIEObjectFifoLoweringPasses(
passManager.addPass(createCanonicalizerPass());
passManager.addPass(createAMDAIEDmaCSEPass());

passManager.addPass(createAMDAIEGenerateControlOverlayPass());

passManager.addPass(createAMDAIEAssignChannelsPass());
passManager.addPass(createCSEPass());
passManager.addPass(createCanonicalizerPass());
Expand All @@ -860,8 +862,6 @@ void addAMDAIEObjectFifoLoweringPasses(
passManager.addPass(createAMDAIEObjFifoBufferizationPass());
passManager.addPass(createAMDAIETemporaryAllocBufferizationPass());

passManager.addPass(createAMDAIEGenerateControlOverlayPass());

passManager.addPass(createAMDAIEConnectionToFlowPass());
passManager.addPass(createAMDAIEAssignPacketIdsPass());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,83 @@ module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb}
return
}
}

// -----

// In the input IR:
// - Tile (0,0) has its DMA MM2S channel 0 already assigned to a circuit flow.
// - Tile (0,1) has its DMA S2MM channel 0 assigned to the same circuit flow.
// As a result, channel assignment starts from channel 1 for both tiles.
// CHECK-LABEL: @previously_assigned_circuit
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: amdaie.workgroup
// CHECK: %[[tile_0_0:.+]] = amdaie.tile(%[[C0]], %[[C0]])
// CHECK: %[[tile_0_1:.+]] = amdaie.tile(%[[C0]], %[[C1]])
// CHECK: %[[CHANNEL_0:.+]] = amdaie.channel(%[[tile_0_0]], 1, port_type = DMA, direction = MM2S)
// CHECK: %[[CHANNEL_1:.+]] = amdaie.channel(%[[tile_0_1]], 1, port_type = DMA, direction = S2MM)
// CHECK: amdaie.connection(%{{.+}} {%[[CHANNEL_1]]}, %{{.+}} {%[[CHANNEL_0]]})
#executable_target_amdaie_xclbin_fb = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_device = "npu1_4col", ukernels = "none"}>
module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} {
func.func @previously_assigned_circuit(%arg0: memref<1x1x8x16xi32, 1>, %arg1: memref<8x16xi32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
amdaie.workgroup {
%tile_0_0 = amdaie.tile(%c0, %c0)
%tile_0_1 = amdaie.tile(%c0, %c1)
%0 = amdaie.logicalobjectfifo.from_memref %arg0, {%tile_0_1} : memref<1x1x8x16xi32, 1> -> !amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>
%1 = amdaie.logicalobjectfifo.from_memref %arg1, {%tile_0_0} : memref<8x16xi32> -> !amdaie.logicalobjectfifo<memref<8x16xi32>>
%2 = amdaie.connection(%0, %1) : (!amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>, !amdaie.logicalobjectfifo<memref<8x16xi32>>)
%channel = amdaie.channel(%tile_0_0, 0, port_type = DMA, direction = MM2S)
%channel_0 = amdaie.channel(%tile_0_1, 0, port_type = DMA, direction = S2MM)
%3 = amdaie.logicalobjectfifo.placeholder{%tile_0_0} : !amdaie.logicalobjectfifo<memref<16x8xi32>>
%4 = amdaie.logicalobjectfifo.placeholder{%tile_0_1} : !amdaie.logicalobjectfifo<memref<1x1x16x8xi32, 1>>
%5 = amdaie.connection(%4 {%channel_0}, %3 {%channel}) {connection_type = #amdaie<connection_type Circuit>} : (!amdaie.logicalobjectfifo<memref<1x1x16x8xi32, 1>>, !amdaie.logicalobjectfifo<memref<16x8xi32>>)
amdaie.controlcode {
amdaie.end
}
}
return
}
}

// -----

// In the input IR:
// - Tile (0,0) has its DMA MM2S channel 0 already assigned to a control packet flow.
// - Tile (0,1) has its CTRL S2MM channel 0 assigned to the same flow.
// Therefore, the next available channels are:
// - Tile (0,0): DMA MM2S channel 1
// - Tile (0,1): DMA S2MM channel 0
// CHECK-LABEL: @previously_assigned_packet
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: amdaie.workgroup
// CHECK: %[[tile_0_0:.+]] = amdaie.tile(%[[C0]], %[[C0]])
// CHECK: %[[tile_0_1:.+]] = amdaie.tile(%[[C0]], %[[C1]])
// CHECK: %[[CHANNEL_0:.+]] = amdaie.channel(%[[tile_0_0]], 1, port_type = DMA, direction = MM2S)
// CHECK: %[[CHANNEL_1:.+]] = amdaie.channel(%[[tile_0_1]], 0, port_type = DMA, direction = S2MM)
// CHECK: amdaie.connection(%{{.+}} {%[[CHANNEL_1]]}, %{{.+}} {%[[CHANNEL_0]]})
#executable_target_amdaie_xclbin_fb = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_device = "npu1_4col", ukernels = "none"}>
module attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} {
func.func @previously_assigned_packet(%arg0: memref<1x1x8x16xi32, 1>, %arg1: memref<8x16xi32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
amdaie.workgroup {
%tile_0_0 = amdaie.tile(%c0, %c0)
%tile_0_1 = amdaie.tile(%c0, %c1)
%0 = amdaie.logicalobjectfifo.from_memref %arg0, {%tile_0_1} : memref<1x1x8x16xi32, 1> -> !amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>
%1 = amdaie.logicalobjectfifo.from_memref %arg1, {%tile_0_0} : memref<8x16xi32> -> !amdaie.logicalobjectfifo<memref<8x16xi32>>
%2 = amdaie.connection(%0, %1) : (!amdaie.logicalobjectfifo<memref<1x1x8x16xi32, 1>>, !amdaie.logicalobjectfifo<memref<8x16xi32>>)
%channel = amdaie.channel(%tile_0_0, 0, port_type = DMA, direction = MM2S)
%channel_0 = amdaie.channel(%tile_0_0, 0, port_type = CTRL, direction = S2MM)
%3 = amdaie.logicalobjectfifo.placeholder{%tile_0_0} : !amdaie.logicalobjectfifo<memref<?xi32>>
%4 = amdaie.logicalobjectfifo.placeholder{%tile_0_0} : !amdaie.logicalobjectfifo<memref<?xi32>>
%5 = amdaie.connection(%4 {%channel_0}, %3 {%channel}) {connection_type = #amdaie<connection_type Packet>} : (!amdaie.logicalobjectfifo<memref<?xi32>>, !amdaie.logicalobjectfifo<memref<?xi32>>)
amdaie.controlcode {
amdaie.end
}
}
return
}
}
Loading

0 comments on commit 54b89c2

Please sign in to comment.