diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/AIETarget.h b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/AIETarget.h index 35da26360..47a6f1f94 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/AIETarget.h +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/AIETarget.h @@ -56,7 +56,7 @@ struct AMDAIEOptions { bool enableVectorizationPasses{true}; bool enableCoalescingLoops{false}; bool enableCollapsingUnitDims{false}; - bool enableFunctionOutlining{true}; + OutliningStrategy enableFunctionOutlining{OutliningStrategy::Balanced}; bool replaceOutlinedFunctionsWithEmpty{false}; bool insertLoopAroundCoreBlock{false}; bool matmulElementwiseFusion{false}; @@ -197,11 +197,19 @@ struct AMDAIEOptions { "unit dims of a tensor/memref depending on this pass flag. It is " "intended for development purposes only.")); - binder.opt<bool>( + binder.opt<OutliningStrategy>( "iree-amdaie-enable-function-outlining", enableFunctionOutlining, llvm::cl::cat(category), llvm::cl::desc("Flag to enable/disable linalg-function-outlining pass." - "It is intended for development purposes only.")); + "It is intended for development purposes only."), + llvm::cl::values(clEnumValN(OutliningStrategy::None, "none", + "No linalg ops will be outlined."), + clEnumValN(OutliningStrategy::All, "all", + "All linalg ops will be outlined."), + clEnumValN(OutliningStrategy::Balanced, "balanced", + "Will outline some ops, to try to achieve " + "a good balance between " + "performance and program size."))); binder.opt<bool>( "iree-amdaie-replace-outlined-functions-with-empty", diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELinalgFunctionOutlining.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELinalgFunctionOutlining.cpp index 20ba6df47..cd458e413 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELinalgFunctionOutlining.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELinalgFunctionOutlining.cpp @@ -34,12 +34,12 @@ static FailureOr<func::FuncOp> outline(IRRewriter &rewriter, ModuleOp moduleOp, // clang-format off // https://github.com/llvm/llvm-project/blob/6b0785390d02193d81d8db7fb12279ffa4651afe/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td#L475 // clang-format on - auto type = dyn_cast<MemRefType>(operand.getType()); - assert(type && "we've already checked that all operands are memrefs"); - MemRefLayoutAttrInterface layout = type.getLayout(); - assert(layout && - "MemRefType layout attribute interface should always be present"); - if (!layout.isIdentity()) return failure(); + if (auto type = dyn_cast<MemRefType>(operand.getType())) { + MemRefLayoutAttrInterface layout = type.getLayout(); + assert(layout && + "MemRefType layout attribute interface should always be present"); + if (!layout.isIdentity()) return failure(); + } } auto funcType = FunctionType::get( rewriter.getContext(), computeOp->getOperandTypes(), /*outputTypes=*/{}); @@ -72,8 +72,9 @@ static FailureOr<func::FuncOp> outline(IRRewriter &rewriter, ModuleOp moduleOp, return func; } -/// Utility to check if the linalg op is one we know should be outlined. -static bool mustOutline(linalg::LinalgOp linalgOp) { +/// Utility to check whether the linalg should be outlined in case the balanced +/// strategy is enabled. +static bool mustOutlineBalanced(linalg::LinalgOp linalgOp) { if (isa<linalg::CopyOp, linalg::FillOp>(linalgOp)) return false; if (isElementwise(linalgOp)) return false; // TODO(newling) not all remaining ops should be outlined, not even all @@ -158,16 +159,22 @@ void AMDAIELinalgFunctionOutliningPass::runOnOperation() { MLIRContext *context = &getContext(); IRRewriter rewriter(context); + if (outliningStrategy == OutliningStrategy::None) { + if (emptyFunctions) { + moduleOp.emitWarning() + << "The option to empty outlined functions is enabled while the " + "outlining strategy specifies to not outline any functions, so no " + "transformation will happen. This combination might not result in " + "the intended behaviour."; + } + return; + } + SmallVector<Operation *> toBeErased; moduleOp.walk([&](linalg::LinalgOp computeOp) { - if (!mustOutline(computeOp)) return WalkResult::skip(); - - // Assert that we're in reference semantics, ie that all operands of - // computeOp have MemRefType: - if (!llvm::all_of(computeOp->getOperandTypes(), - [](Type t) { return isa<MemRefType>(t); })) { - computeOp->emitError("expected all operands to be of MemRefType"); - return WalkResult::interrupt(); + if (outliningStrategy == OutliningStrategy::Balanced && + !mustOutlineBalanced(computeOp)) { + return WalkResult::skip(); } FailureOr<func::FuncOp> maybeFunc = diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.h b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.h index 3d5ed1d9d..b64066c20 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.h +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.h @@ -41,6 +41,16 @@ enum class BufferizeOperand { /// Enum for hardware mapping attributes. enum class HardwareMapping { Core, Block, None }; +enum class OutliningStrategy { + // No outlining. + None, + // Try outlining all ops. + All, + // A balanced strategy trying to achieve good performance with the program + // memory limits. + Balanced, +}; + LogicalResult initAIELaunchConfig(FunctionOpInterface funcOp, TilePassPipeline useTilePipeline, LowerToAIEPassPipeline useLowerToAIEPipeline, diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp index 610e63d2d..dbcf615d0 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp @@ -710,7 +710,7 @@ void buildAMDAIETransformPassPipeline( LowerToAIEPassPipeline useLowerToAIEPipeline, bool matmulElementwiseFusion, bool enableVectorizationPasses, const std::string &pathToUkernels, bool enablePacketFlow, bool enableCoalescingLoops, - bool enableCollapsingUnitDims, bool enableFunctionOutlining, + bool enableCollapsingUnitDims, OutliningStrategy enableFunctionOutlining, bool replaceOutlinedFunctionsWithEmpty, bool insertLoopAroundCoreBlock) { OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>(); { @@ -767,8 +767,9 @@ void addAMDAIEObjectFifoLoweringPasses( OpPassManager &passManager, bool enablePacketFlow, TilePassPipeline useTilePipeline, bool enableVectorizationPasses, bool enableCoalescingLoops, bool enableCollapsingUnitDims, - bool enableFunctionOutlining, bool replaceOutlinedFunctionsWithEmpty, - bool insertLoopAroundCoreBlock, uint32_t numCols) { + OutliningStrategy enableFunctionOutlining, + bool replaceOutlinedFunctionsWithEmpty, bool insertLoopAroundCoreBlock, + uint32_t numCols) { passManager.addPass(createEraseHALDescriptorTypeFromMemRefPass()); passManager.addPass(memref::createFoldMemRefAliasOpsPass()); @@ -794,18 +795,11 @@ void addAMDAIEObjectFifoLoweringPasses( passManager.addPass(createAMDAIENormalizeLoopBoundsPass()); passManager.addPass(createAMDAIEInsertCoresPass()); - if (enableFunctionOutlining) { - // Create function outlining options object, etc. - AMDAIELinalgFunctionOutliningOptions options; - if (replaceOutlinedFunctionsWithEmpty) { - options.emptyFunctions = true; - } - passManager.addPass(createAMDAIELinalgFunctionOutliningPass(options)); - } else { - assert(!replaceOutlinedFunctionsWithEmpty && - "`replaceOutlinedFunctionsWithEmpty` is only valid when " - "`enableFunctionOutlining` is true."); - } + // Create function outlining options object, etc. + AMDAIELinalgFunctionOutliningOptions options; + options.outliningStrategy = enableFunctionOutlining; + options.emptyFunctions = replaceOutlinedFunctionsWithEmpty; + passManager.addPass(createAMDAIELinalgFunctionOutliningPass(options)); { // Vectorization passes diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.h b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.h index 93029415e..b5b23826f 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.h +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.h @@ -19,8 +19,9 @@ void addAMDAIEObjectFifoLoweringPasses( OpPassManager &passManager, bool enablePacketFlow, TilePassPipeline useTilePipeline, bool enableVectorizationPasses, bool enableCoalescingLoops, bool enableCollapsingUnitDims, - bool enableFunctionOutlining, bool replaceOutlinedFunctionsWithEmpty, - bool insertLoopAroundCoreBlock, uint32_t numCols); + OutliningStrategy enableFunctionOutlining, + bool replaceOutlinedFunctionsWithEmpty, bool insertLoopAroundCoreBlock, + uint32_t numCols); /// Add passes to lower from MLIR-AIR through AIE. This is /// currently the default passes used for lowering after IREEs tiling. @@ -42,7 +43,7 @@ void buildAMDAIETransformPassPipeline( LowerToAIEPassPipeline useLowerToAIEPipeline, bool matmulElementwiseFusion, bool enableVectorizationPasses, const std::string &pathToUkernels, bool enablePacketFlow, bool enableCoalescingLoops, - bool enableCollapsingUnitDims, bool enableFunctionOutlining, + bool enableCollapsingUnitDims, OutliningStrategy enableFunctionOutlining, bool replaceOutlinedFunctionsWithEmpty, bool insertLoopAroundCoreBlock); /// Populates passes needed to lower the IR via a Pack-Peel based approach. diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.td b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.td index 49c40d76d..5cd20473a 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.td +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.td @@ -308,8 +308,21 @@ def AMDAIELinalgFunctionOutlining : "Replace all outlined functions with a function that does nothing, " "i.e. it just returns. Useful for measuring the performance of data " "movement to/from the device -- by doing zero compute, all time is spent " - "moving data to/from the AIE cores."> - ]; + "moving data to/from the AIE cores.">, + Option<"outliningStrategy", "outlining-strategy", + "mlir::iree_compiler::AMDAIE::OutliningStrategy", + /*default=*/"mlir::iree_compiler::AMDAIE::OutliningStrategy::Balanced", + "The strategy to be used for outlining. The default is balanced to " + "achieve a good tradeoff in performance and program size.", + [{::llvm::cl::values( + clEnumValN(mlir::iree_compiler::AMDAIE::OutliningStrategy::None, "none", + "No outlining."), + clEnumValN(mlir::iree_compiler::AMDAIE::OutliningStrategy::All, "all", + "Use the pack-peel based lowering strategy with 4 tiling levels for matmul-like ops."), + clEnumValN(mlir::iree_compiler::AMDAIE::OutliningStrategy::Balanced, "balanced", + "A strategy that tries to achieve a balanced tradeoff between performance and program size.") + )}]>, + ]; } def AMDAIEFoldDmaWaits : diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/disable_linalg_function_outlining.mlir b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/disable_linalg_function_outlining.mlir index d95869a75..a85b88df9 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/disable_linalg_function_outlining.mlir +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/disable_linalg_function_outlining.mlir @@ -2,23 +2,23 @@ // pipeline. We check 3 paths: // // 1) Explicitly disabling linalg function outlining with -// --iree-amdaie-enable-function-outlining=0 +// --iree-amdaie-enable-function-outlining=none // -// 2) Explicitly enabling linalg function outlining with -// --iree-amdaie-enable-function-outlining=1 +// 2) Explicitly enabling linalg function outlining for all linalg ops with +// --iree-amdaie-enable-function-outlining=all // -// 3) Not specifying the flag at all, which should use the default value (1). +// 3) Not specifying the flag at all, which should use the default value (balanced). // 1) Explicitly disabled: // RUN: iree-compile --iree-hal-target-backends=amd-aie \ -// RUN: --compile-to=executable-targets --iree-amdaie-enable-function-outlining=0 %s | FileCheck %s -check-prefix=CHECK-DISABLED +// RUN: --compile-to=executable-targets --iree-amdaie-enable-function-outlining=none %s | FileCheck %s -check-prefix=CHECK-DISABLED // 2) Explicitly enabled: // RUN: iree-compile --iree-hal-target-backends=amd-aie \ -// RUN: --compile-to=executable-targets --iree-amdaie-enable-function-outlining=1 %s | FileCheck %s -check-prefix=CHECK-ENABLED +// RUN: --compile-to=executable-targets --iree-amdaie-enable-function-outlining=all %s | FileCheck %s -check-prefix=CHECK-ENABLED -// 3) Default value: +// 3) Default value (balanced): // RUN: iree-compile --iree-hal-target-backends=amd-aie \ // RUN: --compile-to=executable-targets %s | FileCheck %s -check-prefix=CHECK-DEFAULT @@ -33,6 +33,6 @@ func.func @matmul(%lhs: tensor<64x64xbf16>, return %res : tensor<64x64xf32> } -// CHECK-DISABLED-NOT: func.call -// CHECK-ENABLED: func.call -// CHECK-DEFAULT: func.call +// CHECK-DISABLED-NOT: func.call +// CHECK-ENABLED-COUNT-2: func.call +// CHECK-DEFAULT-COUNT-1: func.call diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/linalg_function_outlining.mlir b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/linalg_function_outlining.mlir index 99cbb2b69..b7b29b4b2 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/linalg_function_outlining.mlir +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/linalg_function_outlining.mlir @@ -1,4 +1,6 @@ // RUN: iree-opt --split-input-file --iree-amdaie-linalg-function-outlining --verify-diagnostics --split-input-file %s | FileCheck %s +// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-amdaie-linalg-function-outlining{outlining-strategy=all})" --verify-diagnostics --split-input-file %s | FileCheck %s -check-prefix=CHECK-ALL +// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(iree-amdaie-linalg-function-outlining{outlining-strategy=none})" --verify-diagnostics --split-input-file %s | FileCheck %s -check-prefix=CHECK-NONE // Test demonstrating multiple matmuls using different SSAs. @@ -27,6 +29,15 @@ // CHECK: } // CHECK: return // CHECK: } + +// CHECK-ALL: func.func private @[[MATMUL_FUNC:.*]]({{.*}}memref<4x8xbf16>, {{.*}}memref<8x4xbf16>, {{.*}}memref<4x4xf32>) +// CHECK-ALL: func.func @repeated_identical_matmul +// CHECK-ALL-COUNT-2: func.call @[[MATMUL_FUNC]] +// CHECK-ALL-NOT: func.call @[[MATMUL_FUNC]] + +// CHECK-NONE: @repeated_identical_matmul +// CHECK-NONE-NOT: func.call + func.func @repeated_identical_matmul(%A: memref<4x8xbf16>, %B: memref<8x4xbf16>, %C: memref<4x4xf32>) { %c2 = arith.constant 2 : index %c1 = arith.constant 1 : index @@ -90,6 +101,18 @@ func.func @repeated_identical_matmul(%A: memref<4x8xbf16>, %B: memref<8x4xbf16>, // CHECK-NEXT: func.call @[[MATMUL_K6]](%[[A1]], %[[B1]], %[[C]]) // CHECK-NEXT: amdaie.end // CHECK: return + +// CHECK-ALL-DAG: func.func private @[[MATMUL_K6:.+]]({{.*}}memref<4x6xbf16>, {{.*}}memref<6x4xbf16>, {{.*}}memref<4x4xf32>) +// CHECK-ALL-DAG: func.func private @[[MATMUL_K4:.+]]({{.*}}memref<4x4xbf16>, {{.*}}memref<4x4xbf16>, {{.*}}memref<4x4xf32>) +// CHECK-ALL: func.func @distinct_matmul_shapes +// CHECK-ALL-COUNT-3: func.call @[[MATMUL_K4]] +// CHECK-ALL: func.call @[[MATMUL_K6]] +// CHECK-ALL-NOT: func.call @[[MATMUL_K4]] +// CHECK-ALL-NOT: func.call @[[MATMUL_K6]] + +// CHECK-NONE: @distinct_matmul_shapes +// CHECK-NONE-NOT: func.call + func.func @distinct_matmul_shapes(%A0: memref<4x4xbf16>, %B0: memref<4x4xbf16>, %A1: memref<4x6xbf16>, %B1: memref<6x4xbf16>, %C: memref<4x4xf32>) { @@ -112,16 +135,30 @@ func.func @distinct_matmul_shapes(%A0: memref<4x4xbf16>, %B0: memref<4x4xbf16>, // ----- // CHECK-LABEL: @linalg_fill_copy +// CHECK: linalg.fill +// CHECK-NOT: func.call @fill_elementwise_0_outlined +// CHECK: linalg.copy +// CHECK-NOT: func.call @copy_elementwise_1_outlined + +// CHECK-ALL-DAG: func.func private @[[FILL_FUNC:.*]]({{.*}}: f32, {{.*}}: memref<4xf32>) +// CHECK-ALL-DAG: linalg.fill +// CHECK-ALL-DAG: func.func private @[[COPY_FUNC:.*]]({{.*}}: memref<4xf32>, {{.*}}: memref<4xf32>) +// CHECK-ALL-DAG: linalg.copy +// CHECK-ALL: @linalg_fill_copy(%[[ARG0:.*]]: memref<4xf32>, %[[ARG1:.*]]: memref<4xf32>) +// CHECK-ALL: %[[C0:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-ALL: func.call @[[FILL_FUNC]](%[[C0]], %[[ARG0]]) +// CHECK-ALL: func.call @[[COPY_FUNC]](%[[ARG0]], %[[ARG1]]) + +// CHECK-NONE: @linalg_fill_copy +// CHECK-NONE-NOT: func.call + func.func @linalg_fill_copy(%A: memref<4xf32>, %B: memref<4xf32>) { %c2 = arith.constant 2 : index %c1 = arith.constant 1 : index %cst = arith.constant 0.0 : f32 %tile = amdaie.tile(%c1, %c2) %0 = amdaie.core(%tile, in : [], out : []) { - // CHECK: linalg.fill - // CHECK-NOT: func.call @fill_elementwise_0_outlined - // CHECK: linalg.copy - // CHECK-NOT: func.call @copy_elementwise_1_outlined + linalg.fill ins(%cst : f32) outs(%A : memref<4xf32>) linalg.copy ins(%A : memref<4xf32>) outs(%B : memref<4xf32>) amdaie.end @@ -146,6 +183,15 @@ func.func @linalg_fill_copy(%A: memref<4xf32>, %B: memref<4xf32>) { // CHECK: func.call @generic_0_outlined // CHECK-SAME: (memref<4xbf16>, memref<bf16>) -> () // CHECK: return + +// CHECK-ALL-DAG: func.func private @[[GENERIC:.+]]({{.*}}memref<4xbf16>, {{.*}}memref<bf16>) +// CHECK-ALL: func.func @reduction +// CHECK-ALL: func.call @[[GENERIC]] +// CHECK-ALL-NOT: func.call @[[GENERIC]] + +// CHECK-NONE: @reduction +// CHECK-NONE-NOT: func.call + func.func @reduction(%A: memref<4xbf16>, %B: memref<bf16>) { %c2 = arith.constant 2 : index %tile = amdaie.tile(%c2, %c2) @@ -170,6 +216,13 @@ func.func @reduction(%A: memref<4xbf16>, %B: memref<bf16>) { // CHECK-COUNT-1: func.func // CHECK-NOT: func.func + +// CHECK-ALL-LABEL: func.func @unoutlineable_layout_with_offset +// CHECK-ALL-NOT: func.func + +// CHECK-NONE: @unoutlineable_layout_with_offset +// CHECK-NONE-NOT: func.call + func.func @unoutlineable_layout_with_offset(%A: memref<4x8xbf16, strided<[8,1], offset:?>>, %B: memref<bf16>) { %c2 = arith.constant 2 : index %tile = amdaie.tile(%c2, %c2) @@ -195,6 +248,13 @@ func.func @unoutlineable_layout_with_offset(%A: memref<4x8xbf16, strided<[8,1], // CHECK-COUNT-1: func.func // CHECK-NOT: func.func + +// CHECK-ALL-LABEL: func.func @unoutlineable_strided_layout +// CHECK-ALL-NOT: func.func + +// CHECK-NONE: @unoutlineable_strided_layout +// CHECK-NONE-NOT: func.call + func.func @unoutlineable_strided_layout(%A: memref<4x8xbf16, strided<[9,1]>>, %B: memref<bf16>) { %c2 = arith.constant 2 : index %tile = amdaie.tile(%c2, %c2)