Skip to content

Commit

Permalink
[LinalgFunctionOutlining] add none, all and balanced outlining strate…
Browse files Browse the repository at this point in the history
…gies
  • Loading branch information
jtuyls committed Jan 29, 2025
1 parent 11b51d7 commit 471205e
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 53 deletions.
14 changes: 11 additions & 3 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/AIETarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ struct AMDAIEOptions {
bool enableVectorizationPasses{true};
bool enableCoalescingLoops{false};
bool enableCollapsingUnitDims{false};
bool enableFunctionOutlining{true};
OutliningStrategy enableFunctionOutlining{OutliningStrategy::Balanced};
bool replaceOutlinedFunctionsWithEmpty{false};
bool insertLoopAroundCoreBlock{false};
bool matmulElementwiseFusion{false};
Expand Down Expand Up @@ -197,11 +197,19 @@ struct AMDAIEOptions {
"unit dims of a tensor/memref depending on this pass flag. It is "
"intended for development purposes only."));

binder.opt<bool>(
binder.opt<OutliningStrategy>(
"iree-amdaie-enable-function-outlining", enableFunctionOutlining,
llvm::cl::cat(category),
llvm::cl::desc("Flag to enable/disable linalg-function-outlining pass."
"It is intended for development purposes only."));
"It is intended for development purposes only."),
llvm::cl::values(clEnumValN(OutliningStrategy::None, "none",
"No linalg ops will be outlined."),
clEnumValN(OutliningStrategy::All, "all",
"All linalg ops will be outlined."),
clEnumValN(OutliningStrategy::Balanced, "balanced",
"Will outline some ops, to try to achieve "
"a good balance between "
"performance and program size.")));

binder.opt<bool>(
"iree-amdaie-replace-outlined-functions-with-empty",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ static FailureOr<func::FuncOp> outline(IRRewriter &rewriter, ModuleOp moduleOp,
// clang-format off
// https://github.com/llvm/llvm-project/blob/6b0785390d02193d81d8db7fb12279ffa4651afe/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td#L475
// clang-format on
auto type = dyn_cast<MemRefType>(operand.getType());
assert(type && "we've already checked that all operands are memrefs");
MemRefLayoutAttrInterface layout = type.getLayout();
assert(layout &&
"MemRefType layout attribute interface should always be present");
if (!layout.isIdentity()) return failure();
if (auto type = dyn_cast<MemRefType>(operand.getType())) {
MemRefLayoutAttrInterface layout = type.getLayout();
assert(layout &&
"MemRefType layout attribute interface should always be present");
if (!layout.isIdentity()) return failure();
}
}
auto funcType = FunctionType::get(
rewriter.getContext(), computeOp->getOperandTypes(), /*outputTypes=*/{});
Expand Down Expand Up @@ -72,8 +72,9 @@ static FailureOr<func::FuncOp> outline(IRRewriter &rewriter, ModuleOp moduleOp,
return func;
}

/// Utility to check if the linalg op is one we know should be outlined.
static bool mustOutline(linalg::LinalgOp linalgOp) {
/// Utility to check whether the linalg should be outlined in case the balanced
/// strategy is enabled.
static bool mustOutlineBalanced(linalg::LinalgOp linalgOp) {
if (isa<linalg::CopyOp, linalg::FillOp>(linalgOp)) return false;
if (isElementwise(linalgOp)) return false;
// TODO(newling) not all remaining ops should be outlined, not even all
Expand Down Expand Up @@ -158,16 +159,22 @@ void AMDAIELinalgFunctionOutliningPass::runOnOperation() {
MLIRContext *context = &getContext();
IRRewriter rewriter(context);

if (outliningStrategy == OutliningStrategy::None) {
if (emptyFunctions) {
moduleOp.emitWarning()
<< "The option to empty outlined functions is enabled while the "
"outlining strategy specifies to not outline any functions, so no "
"transformation will happen. This combination might not result in "
"the intended behaviour.";
}
return;
}

SmallVector<Operation *> toBeErased;
moduleOp.walk([&](linalg::LinalgOp computeOp) {
if (!mustOutline(computeOp)) return WalkResult::skip();

// Assert that we're in reference semantics, ie that all operands of
// computeOp have MemRefType:
if (!llvm::all_of(computeOp->getOperandTypes(),
[](Type t) { return isa<MemRefType>(t); })) {
computeOp->emitError("expected all operands to be of MemRefType");
return WalkResult::interrupt();
if (outliningStrategy == OutliningStrategy::Balanced &&
!mustOutlineBalanced(computeOp)) {
return WalkResult::skip();
}

FailureOr<func::FuncOp> maybeFunc =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ enum class BufferizeOperand {
/// Enum for hardware mapping attributes.
enum class HardwareMapping { Core, Block, None };

enum class OutliningStrategy {
// No outlining.
None,
// Try outlining all ops.
All,
// A balanced strategy trying to achieve good performance with the program
// memory limits.
Balanced,
};

LogicalResult initAIELaunchConfig(FunctionOpInterface funcOp,
TilePassPipeline useTilePipeline,
LowerToAIEPassPipeline useLowerToAIEPipeline,
Expand Down
24 changes: 9 additions & 15 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ void buildAMDAIETransformPassPipeline(
LowerToAIEPassPipeline useLowerToAIEPipeline, bool matmulElementwiseFusion,
bool enableVectorizationPasses, const std::string &pathToUkernels,
bool enablePacketFlow, bool enableCoalescingLoops,
bool enableCollapsingUnitDims, bool enableFunctionOutlining,
bool enableCollapsingUnitDims, OutliningStrategy enableFunctionOutlining,
bool replaceOutlinedFunctionsWithEmpty, bool insertLoopAroundCoreBlock) {
OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
{
Expand Down Expand Up @@ -767,8 +767,9 @@ void addAMDAIEObjectFifoLoweringPasses(
OpPassManager &passManager, bool enablePacketFlow,
TilePassPipeline useTilePipeline, bool enableVectorizationPasses,
bool enableCoalescingLoops, bool enableCollapsingUnitDims,
bool enableFunctionOutlining, bool replaceOutlinedFunctionsWithEmpty,
bool insertLoopAroundCoreBlock, uint32_t numCols) {
OutliningStrategy enableFunctionOutlining,
bool replaceOutlinedFunctionsWithEmpty, bool insertLoopAroundCoreBlock,
uint32_t numCols) {
passManager.addPass(createEraseHALDescriptorTypeFromMemRefPass());
passManager.addPass(memref::createFoldMemRefAliasOpsPass());

Expand All @@ -794,18 +795,11 @@ void addAMDAIEObjectFifoLoweringPasses(
passManager.addPass(createAMDAIENormalizeLoopBoundsPass());
passManager.addPass(createAMDAIEInsertCoresPass());

if (enableFunctionOutlining) {
// Create function outlining options object, etc.
AMDAIELinalgFunctionOutliningOptions options;
if (replaceOutlinedFunctionsWithEmpty) {
options.emptyFunctions = true;
}
passManager.addPass(createAMDAIELinalgFunctionOutliningPass(options));
} else {
assert(!replaceOutlinedFunctionsWithEmpty &&
"`replaceOutlinedFunctionsWithEmpty` is only valid when "
"`enableFunctionOutlining` is true.");
}
// Create function outlining options object, etc.
AMDAIELinalgFunctionOutliningOptions options;
options.outliningStrategy = enableFunctionOutlining;
options.emptyFunctions = replaceOutlinedFunctionsWithEmpty;
passManager.addPass(createAMDAIELinalgFunctionOutliningPass(options));

{
// Vectorization passes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ void addAMDAIEObjectFifoLoweringPasses(
OpPassManager &passManager, bool enablePacketFlow,
TilePassPipeline useTilePipeline, bool enableVectorizationPasses,
bool enableCoalescingLoops, bool enableCollapsingUnitDims,
bool enableFunctionOutlining, bool replaceOutlinedFunctionsWithEmpty,
bool insertLoopAroundCoreBlock, uint32_t numCols);
OutliningStrategy enableFunctionOutlining,
bool replaceOutlinedFunctionsWithEmpty, bool insertLoopAroundCoreBlock,
uint32_t numCols);

/// Add passes to lower from MLIR-AIR through AIE. This is
/// currently the default passes used for lowering after IREEs tiling.
Expand All @@ -42,7 +43,7 @@ void buildAMDAIETransformPassPipeline(
LowerToAIEPassPipeline useLowerToAIEPipeline, bool matmulElementwiseFusion,
bool enableVectorizationPasses, const std::string &pathToUkernels,
bool enablePacketFlow, bool enableCoalescingLoops,
bool enableCollapsingUnitDims, bool enableFunctionOutlining,
bool enableCollapsingUnitDims, OutliningStrategy enableFunctionOutlining,
bool replaceOutlinedFunctionsWithEmpty, bool insertLoopAroundCoreBlock);

/// Populates passes needed to lower the IR via a Pack-Peel based approach.
Expand Down
17 changes: 15 additions & 2 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,21 @@ def AMDAIELinalgFunctionOutlining :
"Replace all outlined functions with a function that does nothing, "
"i.e. it just returns. Useful for measuring the performance of data "
"movement to/from the device -- by doing zero compute, all time is spent "
"moving data to/from the AIE cores.">
];
"moving data to/from the AIE cores.">,
Option<"outliningStrategy", "outlining-strategy",
"mlir::iree_compiler::AMDAIE::OutliningStrategy",
/*default=*/"mlir::iree_compiler::AMDAIE::OutliningStrategy::Balanced",
"The strategy to be used for outlining. The default is balanced to "
"achieve a good tradeoff in performance and program size.",
[{::llvm::cl::values(
clEnumValN(mlir::iree_compiler::AMDAIE::OutliningStrategy::None, "none",
"No outlining."),
clEnumValN(mlir::iree_compiler::AMDAIE::OutliningStrategy::All, "all",
"Use the pack-peel based lowering strategy with 4 tiling levels for matmul-like ops."),
clEnumValN(mlir::iree_compiler::AMDAIE::OutliningStrategy::Balanced, "balanced",
"A strategy that tries to achieve a balanced tradeoff between performance and program size.")
)}]>,
];
}

def AMDAIEFoldDmaWaits :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@
// pipeline. We check 3 paths:
//
// 1) Explicitly disabling linalg function outlining with
// --iree-amdaie-enable-function-outlining=0
// --iree-amdaie-enable-function-outlining=none
//
// 2) Explicitly enabling linalg function outlining with
// --iree-amdaie-enable-function-outlining=1
// 2) Explicitly enabling linalg function outlining for all linalg ops with
// --iree-amdaie-enable-function-outlining=all
//
// 3) Not specifying the flag at all, which should use the default value (1).
// 3) Not specifying the flag at all, which should use the default value (balanced).


// 1) Explicitly disabled:
// RUN: iree-compile --iree-hal-target-backends=amd-aie \
// RUN: --compile-to=executable-targets --iree-amdaie-enable-function-outlining=0 %s | FileCheck %s -check-prefix=CHECK-DISABLED
// RUN: --compile-to=executable-targets --iree-amdaie-enable-function-outlining=none %s | FileCheck %s -check-prefix=CHECK-DISABLED

// 2) Explicitly enabled:
// RUN: iree-compile --iree-hal-target-backends=amd-aie \
// RUN: --compile-to=executable-targets --iree-amdaie-enable-function-outlining=1 %s | FileCheck %s -check-prefix=CHECK-ENABLED
// RUN: --compile-to=executable-targets --iree-amdaie-enable-function-outlining=all %s | FileCheck %s -check-prefix=CHECK-ENABLED

// 3) Default value:
// 3) Default value (balanced):
// RUN: iree-compile --iree-hal-target-backends=amd-aie \
// RUN: --compile-to=executable-targets %s | FileCheck %s -check-prefix=CHECK-DEFAULT

Expand All @@ -33,6 +33,6 @@ func.func @matmul(%lhs: tensor<64x64xbf16>,
return %res : tensor<64x64xf32>
}

// CHECK-DISABLED-NOT: func.call
// CHECK-ENABLED: func.call
// CHECK-DEFAULT: func.call
// CHECK-DISABLED-NOT: func.call
// CHECK-ENABLED-COUNT-2: func.call
// CHECK-DEFAULT-COUNT-1: func.call
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// RUN: iree-opt --split-input-file --iree-amdaie-linalg-function-outlining --verify-diagnostics --split-input-file %s | FileCheck %s
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-amdaie-linalg-function-outlining{outlining-strategy=all})" --verify-diagnostics --split-input-file %s | FileCheck %s -check-prefix=CHECK-ALL
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-amdaie-linalg-function-outlining{outlining-strategy=none})" --verify-diagnostics --split-input-file %s | FileCheck %s -check-prefix=CHECK-NONE

// Test demonstrating multiple matmuls using different SSAs.

Expand Down Expand Up @@ -27,6 +29,15 @@
// CHECK: }
// CHECK: return
// CHECK: }

// CHECK-ALL: func.func private @[[MATMUL_FUNC:.*]]({{.*}}memref<4x8xbf16>, {{.*}}memref<8x4xbf16>, {{.*}}memref<4x4xf32>)
// CHECK-ALL: func.func @repeated_identical_matmul
// CHECK-ALL-COUNT-2: func.call @[[MATMUL_FUNC]]
// CHECK-ALL-NOT: func.call @[[MATMUL_FUNC]]

// CHECK-NONE: @repeated_identical_matmul
// CHECK-NONE-NOT: func.call

func.func @repeated_identical_matmul(%A: memref<4x8xbf16>, %B: memref<8x4xbf16>, %C: memref<4x4xf32>) {
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
Expand Down Expand Up @@ -90,6 +101,18 @@ func.func @repeated_identical_matmul(%A: memref<4x8xbf16>, %B: memref<8x4xbf16>,
// CHECK-NEXT: func.call @[[MATMUL_K6]](%[[A1]], %[[B1]], %[[C]])
// CHECK-NEXT: amdaie.end
// CHECK: return

// CHECK-ALL-DAG: func.func private @[[MATMUL_K6:.+]]({{.*}}memref<4x6xbf16>, {{.*}}memref<6x4xbf16>, {{.*}}memref<4x4xf32>)
// CHECK-ALL-DAG: func.func private @[[MATMUL_K4:.+]]({{.*}}memref<4x4xbf16>, {{.*}}memref<4x4xbf16>, {{.*}}memref<4x4xf32>)
// CHECK-ALL: func.func @distinct_matmul_shapes
// CHECK-ALL-COUNT-3: func.call @[[MATMUL_K4]]
// CHECK-ALL: func.call @[[MATMUL_K6]]
// CHECK-ALL-NOT: func.call @[[MATMUL_K4]]
// CHECK-ALL-NOT: func.call @[[MATMUL_K6]]

// CHECK-NONE: @distinct_matmul_shapes
// CHECK-NONE-NOT: func.call

func.func @distinct_matmul_shapes(%A0: memref<4x4xbf16>, %B0: memref<4x4xbf16>,
%A1: memref<4x6xbf16>, %B1: memref<6x4xbf16>,
%C: memref<4x4xf32>) {
Expand All @@ -112,16 +135,30 @@ func.func @distinct_matmul_shapes(%A0: memref<4x4xbf16>, %B0: memref<4x4xbf16>,
// -----

// CHECK-LABEL: @linalg_fill_copy
// CHECK: linalg.fill
// CHECK-NOT: func.call @fill_elementwise_0_outlined
// CHECK: linalg.copy
// CHECK-NOT: func.call @copy_elementwise_1_outlined

// CHECK-ALL-DAG: func.func private @[[FILL_FUNC:.*]]({{.*}}: f32, {{.*}}: memref<4xf32>)
// CHECK-ALL-DAG: linalg.fill
// CHECK-ALL-DAG: func.func private @[[COPY_FUNC:.*]]({{.*}}: memref<4xf32>, {{.*}}: memref<4xf32>)
// CHECK-ALL-DAG: linalg.copy
// CHECK-ALL: @linalg_fill_copy(%[[ARG0:.*]]: memref<4xf32>, %[[ARG1:.*]]: memref<4xf32>)
// CHECK-ALL: %[[C0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-ALL: func.call @[[FILL_FUNC]](%[[C0]], %[[ARG0]])
// CHECK-ALL: func.call @[[COPY_FUNC]](%[[ARG0]], %[[ARG1]])

// CHECK-NONE: @linalg_fill_copy
// CHECK-NONE-NOT: func.call

func.func @linalg_fill_copy(%A: memref<4xf32>, %B: memref<4xf32>) {
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.0 : f32
%tile = amdaie.tile(%c1, %c2)
%0 = amdaie.core(%tile, in : [], out : []) {
// CHECK: linalg.fill
// CHECK-NOT: func.call @fill_elementwise_0_outlined
// CHECK: linalg.copy
// CHECK-NOT: func.call @copy_elementwise_1_outlined

linalg.fill ins(%cst : f32) outs(%A : memref<4xf32>)
linalg.copy ins(%A : memref<4xf32>) outs(%B : memref<4xf32>)
amdaie.end
Expand All @@ -146,6 +183,15 @@ func.func @linalg_fill_copy(%A: memref<4xf32>, %B: memref<4xf32>) {
// CHECK: func.call @generic_0_outlined
// CHECK-SAME: (memref<4xbf16>, memref<bf16>) -> ()
// CHECK: return

// CHECK-ALL-DAG: func.func private @[[GENERIC:.+]]({{.*}}memref<4xbf16>, {{.*}}memref<bf16>)
// CHECK-ALL: func.func @reduction
// CHECK-ALL: func.call @[[GENERIC]]
// CHECK-ALL-NOT: func.call @[[GENERIC]]

// CHECK-NONE: @reduction
// CHECK-NONE-NOT: func.call

func.func @reduction(%A: memref<4xbf16>, %B: memref<bf16>) {
%c2 = arith.constant 2 : index
%tile = amdaie.tile(%c2, %c2)
Expand All @@ -170,6 +216,13 @@ func.func @reduction(%A: memref<4xbf16>, %B: memref<bf16>) {

// CHECK-COUNT-1: func.func
// CHECK-NOT: func.func

// CHECK-ALL-LABEL: func.func @unoutlineable_layout_with_offset
// CHECK-ALL-NOT: func.func

// CHECK-NONE: @unoutlineable_layout_with_offset
// CHECK-NONE-NOT: func.call

func.func @unoutlineable_layout_with_offset(%A: memref<4x8xbf16, strided<[8,1], offset:?>>, %B: memref<bf16>) {
%c2 = arith.constant 2 : index
%tile = amdaie.tile(%c2, %c2)
Expand All @@ -195,6 +248,13 @@ func.func @unoutlineable_layout_with_offset(%A: memref<4x8xbf16, strided<[8,1],

// CHECK-COUNT-1: func.func
// CHECK-NOT: func.func

// CHECK-ALL-LABEL: func.func @unoutlineable_strided_layout
// CHECK-ALL-NOT: func.func

// CHECK-NONE: @unoutlineable_strided_layout
// CHECK-NONE-NOT: func.call

func.func @unoutlineable_strided_layout(%A: memref<4x8xbf16, strided<[9,1]>>, %B: memref<bf16>) {
%c2 = arith.constant 2 : index
%tile = amdaie.tile(%c2, %c2)
Expand Down

0 comments on commit 471205e

Please sign in to comment.