Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMDAIESinkIntoCore] Generalize sinking for reuse with other regioned ops #1117

Merged
merged 2 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion compiler/plugins/target/AMD-AIE/aie/AMDAIECoreToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "AIEDialect.h"
#include "Passes.h"
#include "iree-amd-aie/Transforms/Utils/AMDAIEUtils.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
Expand Down Expand Up @@ -192,6 +193,22 @@ struct AMDAIECoreToStandardPass
return true;
}

// Ensure that all aie.core ops are isolated from above, i.e. that all
// operands of ops within an aie.core are produced inside the aie.core (or are
// block arguments of the core). The expection is ops in the aie dialect --
// operands produced by for example an aie.buffer may be outside the core.
static void isolateCores(ModuleOp m) {
IRRewriter rewriter(m->getContext());
auto notAieDialect = [](Operation *op) -> bool {
StringRef dialect = op->getDialect()->getNamespace();
if (dialect == AIEDialect::getDialectNamespace()) return false;
return true;
};
m->walk([&](CoreOp coreOp) {
sinkInto(coreOp.getRegion(), rewriter, notAieDialect);
});
}

void runOnOperation() override {
ModuleOp m = getOperation();

Expand Down Expand Up @@ -222,8 +239,11 @@ struct AMDAIECoreToStandardPass
m->setAttr(LLVM::LLVMDialect::getTargetTripleAttrName(),
rewriter.getStringAttr(targetArchStr));

if (failed(lockToStd(rewriter, m, targetArchStr)))
isolateCores(m);

if (failed(lockToStd(rewriter, m, targetArchStr))) {
return signalPassFailure();
}

m.walk([&](BufferOp buffer) { bufferToStd(m, buffer, rewriter); });

Expand Down
2 changes: 2 additions & 0 deletions compiler/plugins/target/AMD-AIE/aie/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ iree_cc_library(
MLIRMemRefDialect
MLIRIR
MLIREmitCDialect
iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::HAL::IR::HALDialect
)

###############################################################################
Expand Down
62 changes: 43 additions & 19 deletions compiler/plugins/target/AMD-AIE/aie/test/lower_buffer.mlir
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
// RUN: iree-opt --amdaie-standard-lowering %s | FileCheck %s
// RUN: iree-opt --amdaie-standard-lowering --split-input-file %s | FileCheck %s

// CHECK: memref.global "public" @a : memref<4xi32>
// CHECK-LABEL: func.func @core_4_3() {
// CHECK-LABEL: @basic_test
// CHECK-DAG: memref.global "public" @a : memref<4xi32>
// CHECK-DAG: memref.global "public" @b : memref<4xi32>
// CHECK: func.func @core_3_4() {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_0:.*]] = memref.get_global @a : memref<4xi32>
// CHECK: %[[VAL_0:.*]] = memref.get_global @b : memref<4xi32>
// CHECK: memref.assume_alignment %[[VAL_0]], 32 : memref<4xi32>
// CHECK: %[[VAL_1:.*]] = memref.load %[[VAL_0]]{{\[}}%[[C0]]] : memref<4xi32>
// CHECK: return
Expand All @@ -18,22 +20,44 @@
// CHECK: return
// CHECK: }

module @codegen1 {
aie.device(xcvc1902) {
%t33 = aie.tile(3, 3)
%a = aie.buffer(%t33) { sym_name = "a" } : memref<4xi32>
%core33 = aie.core(%t33) {
%0 = arith.constant 0 : index
%377 = arith.constant 377 : i32
memref.store %377, %a[%0] : memref<4xi32>
aie.end
module @basic_test {
aie.device(xcvc1902) {
%tile_3_3 = aie.tile(3, 3)
%buffer_3_3 = aie.buffer(%tile_3_3) {sym_name = "a"} : memref<4xi32>
%core_3_3 = aie.core(%tile_3_3) {
%c0 = arith.constant 0 : index
%c377_i32 = arith.constant 377 : i32
memref.store %c377_i32, %buffer_3_3[%c0] : memref<4xi32>
aie.end
}
%tile_3_4 = aie.tile(3, 4)
%buffer_3_4 = aie.buffer(%tile_3_4) {sym_name = "b"} : memref<4xi32>
%core_3_4 = aie.core(%tile_3_4) {
%c0 = arith.constant 0 : index
%0 = memref.load %buffer_3_4[%c0] : memref<4xi32>
aie.end
}
}
%t34 = aie.tile(4, 3)
}

// -----

// CHECK: func.func @core_4_3() {
// CHECK-DAG: %[[C44:.*]] = arith.constant 44 : index
// CHECK-DAG: %[[VAL_0:.*]] = memref.get_global @a : memref<4xi32>
// CHECK: %[[VAL_1:.*]] = memref.load %[[VAL_0]]{{\[}}%[[C44]]] : memref<4xi32>
// CHECK: return
// CHECK: }

%core34 = aie.core(%t34) {
%0 = arith.constant 0 : index
%1 = memref.load %a[%0] : memref<4xi32>
aie.end
// Check that the constant 44 is hoisted into the core/function.
module @isolation_test {
aie.device(xcvc1902) {
%tile_4_3 = aie.tile(4, 3)
%c44 = arith.constant 44 : index
%buffer_4_3 = aie.buffer(%tile_4_3) {sym_name = "a"} : memref<4xi32>
%core_4_3 = aie.core(%tile_4_3) {
%0 = memref.load %buffer_4_3[%c44] : memref<4xi32>
aie.end
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@
#include "iree-amd-aie/IR/AMDAIEDialect.h"
#include "iree-amd-aie/IR/AMDAIEOps.h"
#include "iree-amd-aie/Transforms/Passes.h"
#include "iree-amd-aie/Transforms/Utils/AMDAIEUtils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"

#define DEBUG_TYPE "iree-amdaie-sink-into-core"
Expand All @@ -23,62 +22,6 @@ namespace mlir::iree_compiler::AMDAIE {

namespace {

bool sinkInto(AMDAIE::CoreOp coreOp, PatternRewriter &rewriter) {
// Record if any ops are sunk into the core during this iteration.
bool changed = false;

// Collect all ops in the amdaie.core op
SmallVector<Operation *> opsInCore;
coreOp->walk([&](Operation *op) {
if (op == coreOp) return WalkResult::advance();
opsInCore.push_back(op);
return WalkResult::advance();
});

for (auto opInCore : opsInCore) {
for (Value operand : opInCore->getOperands()) {
if (!operand || !operand.getDefiningOp()) continue;
Operation *dependencyOp = operand.getDefiningOp();

// Skip if the dependency is already in the core.
if (coreOp->isAncestor(dependencyOp)) {
continue;
}

// Ops in the amdaie dialect are probably related to data movement
// and should not be sunk into the core. This might need adjustment
// later.
if (dependencyOp->getDialect()->getNamespace() ==
AMDAIE::AMDAIEDialect::getDialectNamespace()) {
continue;
}

// Create a clone of the dependency op in the core region.
Region &r = coreOp->getRegion(0);
assert(r.getBlocks().size() == 1 && "expected single block region");
rewriter.setInsertionPointToStart(&r.front());
Operation *sunkOp = rewriter.clone(*dependencyOp);

// Replace uses of the dependency op inside the core.
dependencyOp->replaceUsesWithIf(sunkOp, [&](OpOperand &use) {
return coreOp->isAncestor(use.getOwner());
});
changed = true;
}
}
return changed;
}

class SinkingPattern : public OpRewritePattern<AMDAIE::CoreOp> {
public:
using OpRewritePattern<AMDAIE::CoreOp>::OpRewritePattern;

LogicalResult matchAndRewrite(AMDAIE::CoreOp coreOp,
PatternRewriter &rewriter) const override {
return success(sinkInto(coreOp, rewriter));
}
};

class AMDAIESinkIntoCorePass
: public impl::AMDAIESinkIntoCoreBase<AMDAIESinkIntoCorePass> {
public:
Expand All @@ -87,10 +30,23 @@ class AMDAIESinkIntoCorePass
xilinx::AIE::AIEDialect, AMDAIE::AMDAIEDialect>();
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.insert<SinkingPattern>(&getContext());
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
signalPassFailure();
auto shouldSink = [&](Operation *op) -> bool {
// Ops in the amdaie dialect are probably related to data movement
// and should not be sunk into the core. This might need adjustment
// later.
if (op->getDialect()->getNamespace() ==
AMDAIE::AMDAIEDialect::getDialectNamespace()) {
return false;
}
return true;
};
IRRewriter rewriter(getOperation());
SmallVector<AMDAIE::CoreOp> coreOps;

getOperation()->walk(
[&](AMDAIE::CoreOp coreOp) { coreOps.push_back(coreOp); });
for (auto coreOp : coreOps) {
sinkInto(coreOp.getRegion(), rewriter, shouldSink);
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
namespace mlir::iree_compiler::AMDAIE {

std::string getConstantIntValuesString(ArrayRef<OpFoldResult> ofrs) {
auto maybeValues = mlir::getConstantIntValues(ofrs);
std::optional<SmallVector<int64_t>> maybeValues =
mlir::getConstantIntValues(ofrs);
if (maybeValues.has_value())
return getArrayString<int64_t>(maybeValues.value());
return "[not all constant integers]";
Expand All @@ -26,7 +27,7 @@ template <typename T>
std::optional<T> getConfigAttr(IREE::HAL::ExecutableTargetAttr targetAttr,
StringRef name) {
if (!targetAttr) return std::nullopt;
auto config = targetAttr.getConfiguration();
DictionaryAttr config = targetAttr.getConfiguration();
if (!config) return std::nullopt;
std::optional<T> attr = config.getAs<T>(name);
return attr;
Expand All @@ -41,7 +42,8 @@ std::optional<AMDAIEDevice> getConfigAMDAIEDevice(
}

std::optional<AMDAIEDevice> getConfigAMDAIEDevice(Operation *op) {
auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(op);
IREE::HAL::ExecutableTargetAttr targetAttr =
IREE::HAL::ExecutableTargetAttr::lookup(op);
if (!targetAttr) return std::nullopt;
return getConfigAMDAIEDevice(targetAttr);
}
Expand Down Expand Up @@ -116,7 +118,7 @@ static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps,
static BlockArgument checkOptionalExtOps(Value val) {
BlockArgument blockArg;
if (!(blockArg = dyn_cast<BlockArgument>(val))) {
auto defOp = val.getDefiningOp();
Operation *defOp = val.getDefiningOp();
if (!dyn_cast_if_present<arith::ExtFOp>(defOp) &&
!dyn_cast_if_present<arith::ExtSIOp>(defOp) &&
!dyn_cast_if_present<arith::ExtUIOp>(defOp)) {
Expand Down Expand Up @@ -397,4 +399,47 @@ int detail::findLargestFactor(int num, int max, int multiple) {
return factor ? factor : detail::findLargestFactor(num, max);
}

bool sinkInto(Region &region, IRRewriter &rewriter,
std::function<bool(Operation *)> shouldSink) {
Operation *parentOfRegion = region.getParentOp();
assert(parentOfRegion && "Region has no parent operation");
if (region.getBlocks().empty()) return false;
bool regionChanged = false;
for (Block &block : region.getBlocks()) {
// Collect all ops in the block.
SmallVector<Operation *> ops;
SmallVector<Operation *> nextIterationOps;
block.walk([&](Operation *op) { ops.push_back(op); });
while (!ops.empty()) {
for (Operation *op : ops) {
for (Value operand : op->getOperands()) {
if (!operand || !operand.getDefiningOp()) continue;
Operation *dependencyOp = operand.getDefiningOp();
// Skip if the dependency is already in the core.
if (parentOfRegion->isAncestor(dependencyOp)) continue;
if (!shouldSink(dependencyOp)) continue;
rewriter.setInsertionPointToStart(&block);
Operation *sunkOp = rewriter.clone(*dependencyOp);
nextIterationOps.push_back(sunkOp);
// Replace uses of the dependency op inside the block. Specifically,
// if `use` is in `block` then replace its operand with `sunkOp`.
auto isInBlock = [&block](OpOperand &use) {
Operation *op = use.getOwner();
while (op) {
if (op->getBlock() == &block) return true;
op = op->getParentOp();
}
return false;
};
dependencyOp->replaceUsesWithIf(sunkOp, isInBlock);
regionChanged = true;
}
}
std::swap(ops, nextIterationOps);
nextIterationOps.clear();
}
}
return regionChanged;
}

} // namespace mlir::iree_compiler::AMDAIE
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,20 @@ std::string getArrayString(ArrayRef<T> vs) {
/// "[not constant integers]".
std::string getConstantIntValuesString(ArrayRef<OpFoldResult> opFoldResults);

/// Consider all operations in the region, recursively. If the operation
/// has an operand that is not in the region, and the `shouldSink` function
/// returns true for that operand's producer, then replace all uses of the
/// operand inside the region with a clone of the operand in the block.
///
/// If `shouldSink` returns true for all operations, then this function will
/// make the region isolated from above. So this function essentially makes
/// the region isolated from above with respect to the set of operation types
/// defined by `shouldSink`.
///
/// \return true if the region was changed.
bool sinkInto(Region &, IRRewriter &,
std::function<bool(Operation *)> shouldSink);

} // namespace mlir::iree_compiler::AMDAIE

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@
module {
// CHECK-LABEL: func @sink_into_single_core
func.func @sink_into_single_core(%arg0: index) {
// CHECK-NOT: arith.constant 3 : index
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%0 = arith.addi %arg0, %c3 : index
%tile = amdaie.tile(%c0, %c2)
// CHECK: amdaie.core
%1 = amdaie.core(%tile, in : [], out : []) {
// CHECK: arith.constant 3 : index
// CHECK: arith.addi
// CHECK: linalg.fill
// CHECK-NEXT: %[[C3:.*]] = arith.constant 3 : index
// CHECK-NEXT: %[[ADD:.*]] = arith.addi %arg0, %[[C3]] : index
// CHECK: linalg.fill ins(%[[ADD]] : index)
%alloc = memref.alloc() : memref<2x2xindex>
linalg.fill ins(%0 : index) outs(%alloc : memref<2x2xindex>)
amdaie.end
Expand All @@ -25,16 +24,8 @@ module {
// -----

module {
// Constants 0 and 1 are cloned into the cores, but not removed, because
// they are still used outside of the cores. Constants 2 and 3 are used only
// inside the cores, so they are cloned into the cores but then removed from
// the outer function.
// CHECK-LABEL: func @sink_into_pair_of_cores
func.func @sink_into_pair_of_cores(%arg0 : index) {
// CHECK-NOT: arith.constant 3 : index
// CHECK-NOT: arith.constant 2 : index
// CHECK-DAG: arith.constant 1 : index
// CHECK-DAG: arith.constant 0 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
Expand All @@ -43,9 +34,14 @@ module {
%tile_0 = amdaie.tile(%c0, %c1)
// CHECK: amdaie.core
%0 = amdaie.core(%tile, in : [], out : []) {
// CHECK-DAG: arith.constant 3 : index
// CHECK-DAG: arith.constant 2 : index
// CHECK-DAG: arith.constant 1 : index
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[A0:.*]] = arith.addi %arg0, %[[C1]] : index
// CHECK: %[[A1:.*]] = arith.addi %[[C1]], %[[A0]] : index
// CHECK: %[[A2:.*]] = arith.addi %[[A1]], %[[C2]] : index
// CHECK: %[[A3:.*]] = arith.addi %[[A2]], %[[C3]] : index
// CHECK: linalg.fill ins(%[[A3]] : index)
%1 = arith.addi %arg0, %c1 : index
%2 = arith.addi %c1, %1 : index
%3 = arith.addi %2, %c2 : index
Expand Down
Loading