From bfafdd544c6cf95eb57a96b0a7c34d31f00119b1 Mon Sep 17 00:00:00 2001 From: MatveyK Date: Sun, 18 May 2025 12:03:36 +0300 Subject: [PATCH 1/2] trace_loop --- .../kurakin_trace_loop/CMakeLists.txt | 16 ++ .../kurakin_trace_loop/kurakin_trace_loop.cpp | 85 +++++++ .../lit_test_kurakin_trace_loop.mlir | 236 ++++++++++++++++++ 3 files changed, 337 insertions(+) create mode 100644 mlir/compiler-course/kurakin_trace_loop/CMakeLists.txt create mode 100644 mlir/compiler-course/kurakin_trace_loop/kurakin_trace_loop.cpp create mode 100644 mlir/test/compiler-course/kurakin_trace_loop/lit_test_kurakin_trace_loop.mlir diff --git a/mlir/compiler-course/kurakin_trace_loop/CMakeLists.txt b/mlir/compiler-course/kurakin_trace_loop/CMakeLists.txt new file mode 100644 index 0000000000000..8b41cc51be04f --- /dev/null +++ b/mlir/compiler-course/kurakin_trace_loop/CMakeLists.txt @@ -0,0 +1,16 @@ +set(Title "TraceLoopPass") +set(Student "Kurakin_Matvey") +set(Group "FIIT1") +set(TARGET_NAME "${Title}_${Student}_${Group}_MLIR") + +file(GLOB_RECURSE SOURCES *.cpp *.h *.hpp) + +add_llvm_pass_plugin(${TARGET_NAME} + ${SOURCES} + DEPENDS + intrinsics_gen + MLIRBuiltinLocationAttributesIncGen + BUILDTREE_ONLY +) + +set(MLIR_TEST_DEPENDS ${TARGET_NAME} ${MLIR_TEST_DEPENDS} PARENT_SCOPE) diff --git a/mlir/compiler-course/kurakin_trace_loop/kurakin_trace_loop.cpp b/mlir/compiler-course/kurakin_trace_loop/kurakin_trace_loop.cpp new file mode 100644 index 0000000000000..8d50251db8932 --- /dev/null +++ b/mlir/compiler-course/kurakin_trace_loop/kurakin_trace_loop.cpp @@ -0,0 +1,85 @@ +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Tools/Plugins/PassPlugin.h" + +namespace { + +struct TraceLoopPass + : public mlir::PassWrapper> { + void insertTraceLoop(mlir::Block &blockBefore, mlir::Block &blockAfter, + mlir::Location loc, mlir::OpBuilder &opBuilder) { + opBuilder.setInsertionPointToStart(&blockBefore); + opBuilder.create(loc, "trace_loop_iter_begin", + mlir::TypeRange(), mlir::ValueRange()); + for (auto &op : blockAfter) { + if (mlir::isa(op) || + mlir::isa(op)) { + opBuilder.setInsertionPoint(&op); + opBuilder.create( + loc, "trace_loop_iter_end", mlir::TypeRange(), mlir::ValueRange()); + } + } + } + +public: + mlir::StringRef getArgument() const final { + return "TraceLoopPass_Kurakin_Matvey_FIIT1_MLIR"; + } + + mlir::StringRef getDescription() const final { return "Trace Loop Pass"; } + + void runOnOperation() override { + mlir::ModuleOp module = getOperation(); + mlir::MLIRContext *context = module.getContext(); + mlir::OpBuilder opBuilder(context); + + mlir::SymbolTable symbolTable(module); + if (!symbolTable.lookup("trace_loop_iter_begin")) { + auto funcType = mlir::FunctionType::get(context, {}, {}); + auto newFunc = mlir::func::FuncOp::create( + module.getLoc(), "trace_loop_iter_begin", funcType); + newFunc.setVisibility(mlir::SymbolTable::Visibility::Private); + symbolTable.insert(newFunc); + } + if (!symbolTable.lookup("trace_loop_iter_end")) { + auto funcType = mlir::FunctionType::get(context, {}, {}); + auto newFunc = mlir::func::FuncOp::create( + module.getLoc(), "trace_loop_iter_end", funcType); + newFunc.setVisibility(mlir::SymbolTable::Visibility::Private); + symbolTable.insert(newFunc); + } + + module.walk([&](mlir::Operation *op) { + if (auto affineForOp = mlir::dyn_cast(op)) { + insertTraceLoop(*affineForOp.getBody(), *affineForOp.getBody(), + affineForOp->getLoc(), opBuilder); + } else if (auto forOp = mlir::dyn_cast(op)) { + insertTraceLoop(*forOp.getBody(), *forOp.getBody(), forOp->getLoc(), + opBuilder); + } else if (auto whileOp = mlir::dyn_cast(op)) { + insertTraceLoop(whileOp.getBefore().front(), whileOp.getAfter().front(), + whileOp->getLoc(), opBuilder); + } + }); + } +}; +} // namespace + +MLIR_DECLARE_EXPLICIT_TYPE_ID(TraceLoopPass) +MLIR_DEFINE_EXPLICIT_TYPE_ID(TraceLoopPass) + +mlir::PassPluginLibraryInfo getFunctionCallCounterPassPluginInfo() { + return {MLIR_PLUGIN_API_VERSION, "TraceLoopPass", "1.0", + []() { mlir::PassRegistration(); }}; +} + +extern "C" LLVM_ATTRIBUTE_WEAK mlir::PassPluginLibraryInfo +mlirGetPassPluginInfo() { + return getFunctionCallCounterPassPluginInfo(); +} diff --git a/mlir/test/compiler-course/kurakin_trace_loop/lit_test_kurakin_trace_loop.mlir b/mlir/test/compiler-course/kurakin_trace_loop/lit_test_kurakin_trace_loop.mlir new file mode 100644 index 0000000000000..d7e50560797f0 --- /dev/null +++ b/mlir/test/compiler-course/kurakin_trace_loop/lit_test_kurakin_trace_loop.mlir @@ -0,0 +1,236 @@ +// RUN: mlir-opt -load-pass-plugin=%mlir_lib_dir/TraceLoopPass_Kurakin_Matvey_FIIT1_MLIR%shlibext --pass-pipeline="builtin.module(TraceLoopPass_Kurakin_Matvey_FIIT1_MLIR)" %s | FileCheck %s + +module { + +// CHECK: func.func @affine_for() -> i32 { +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 +// CHECK-NEXT: %0 = affine.for %arg0 = 0 to 10 iter_args(%arg1 = %c0_i32) -> (i32) { +// CHECK-NEXT: func.call @trace_loop_iter_begin() : () -> () +// CHECK-NEXT: %1 = arith.index_cast %arg0 : index to i32 +// CHECK-NEXT: %2 = arith.addi %arg1, %1 : i32 +// CHECK-NEXT: func.call @trace_loop_iter_end() : () -> () +// CHECK-NEXT: affine.yield %2 : i32 +// CHECK-NEXT: } +// CHECK-NEXT: return %0 : i32 +// CHECK-NEXT: } + + func.func @affine_for() -> i32 { + %sum_init = arith.constant 0 : i32 + %res = affine.for %i = 0 to 10 iter_args(%arg = %sum_init) -> i32 { + %i_32 = arith.index_cast %i : index to i32 + %sum = arith.addi %arg, %i_32 : i32 + affine.yield %sum : i32 + } + return %res : i32 + } + +// CHECK: func.func @scf_for() -> i32 { +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 +// CHECK-NEXT: %c0 = arith.constant 0 : index +// CHECK-NEXT: %c10 = arith.constant 10 : index +// CHECK-NEXT: %c1 = arith.constant 1 : index +// CHECK-NEXT: %0 = scf.for %arg0 = %c0 to %c10 step %c1 iter_args(%arg1 = %c0_i32) -> (i32) { +// CHECK-NEXT: func.call @trace_loop_iter_begin() : () -> () +// CHECK-NEXT: %1 = arith.index_cast %arg0 : index to i32 +// CHECK-NEXT: %2 = arith.addi %arg1, %1 : i32 +// CHECK-NEXT: func.call @trace_loop_iter_end() : () -> () +// CHECK-NEXT: scf.yield %2 : i32 +// CHECK-NEXT: } +// CHECK-NEXT: return %0 : i32 +// CHECK-NEXT: } + + func.func @scf_for() -> i32 { + %sum_init = arith.constant 0 : i32 + %begin = arith.constant 0 : index + %end = arith.constant 10 : index + %step = arith.constant 1 : index + + %result = scf.for %i = %begin to %end step %step iter_args(%arg = %sum_init) -> i32 { + %i_32 = arith.index_cast %i : index to i32 + %sum = arith.addi %arg, %i_32 : i32 + scf.yield %sum : i32 + } + return %result : i32 + } + +// CHECK: func.func @scf_while() -> i32 { +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 +// CHECK-NEXT: %c0_i32_0 = arith.constant 0 : i32 +// CHECK-NEXT: %c10_i32 = arith.constant 10 : i32 +// CHECK-NEXT: %0:2 = scf.while (%arg0 = %c0_i32, %arg1 = %c0_i32_0) : (i32, i32) -> (i32, i32) { +// CHECK-NEXT: func.call @trace_loop_iter_begin() : () -> () +// CHECK-NEXT: %1 = arith.cmpi slt, %arg1, %c10_i32 : i32 +// CHECK-NEXT: scf.condition(%1) %arg0, %arg1 : i32, i32 +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%arg0: i32, %arg1: i32): +// CHECK-NEXT: %1 = arith.addi %arg0, %arg1 : i32 +// CHECK-NEXT: %c1_i32 = arith.constant 1 : i32 +// CHECK-NEXT: %2 = arith.addi %arg1, %c1_i32 : i32 +// CHECK-NEXT: func.call @trace_loop_iter_end() : () -> () +// CHECK-NEXT: scf.yield %1, %2 : i32, i32 +// CHECK-NEXT: } +// CHECK-NEXT: return %0#0 : i32 +// CHECK-NEXT: } + + func.func @scf_while() -> i32 { + %sum_init = arith.constant 0 : i32 + %i_init = arith.constant 0 : i32 + %sum_limit = arith.constant 10 : i32 + + %result:2 = scf.while (%sum = %sum_init, %i = %i_init) : (i32, i32) -> (i32, i32) { + %cmp = arith.cmpi slt, %i, %sum_limit : i32 + scf.condition(%cmp) %sum, %i : i32, i32 + } do { + ^bb0(%sum_arg: i32, %i_arg: i32): + %sum = arith.addi %sum_arg, %i_arg : i32 + %step = arith.constant 1 : i32 + %new_i = arith.addi %i_arg, %step : i32 + scf.yield %sum, %new_i : i32, i32 + } + return %result#0 : i32 + } + +// CHECK: func.func @affine_for_and_scf_for() -> i32 { +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 +// CHECK-NEXT: %c0 = arith.constant 0 : index +// CHECK-NEXT: %c5 = arith.constant 5 : index +// CHECK-NEXT: %c1 = arith.constant 1 : index +// CHECK-NEXT: %0 = affine.for %arg0 = 0 to 10 iter_args(%arg1 = %c0_i32) -> (i32) { +// CHECK-NEXT: func.call @trace_loop_iter_begin() : () -> () +// CHECK-NEXT: %1 = arith.index_cast %arg0 : index to i32 +// CHECK-NEXT: %2 = scf.for %arg2 = %c0 to %c5 step %c1 iter_args(%arg3 = %arg1) -> (i32) { +// CHECK-NEXT: func.call @trace_loop_iter_begin() : () -> () +// CHECK-NEXT: %3 = arith.index_cast %arg2 : index to i32 +// CHECK-NEXT: %4 = arith.addi %arg3, %3 : i32 +// CHECK-NEXT: func.call @trace_loop_iter_end() : () -> () +// CHECK-NEXT: scf.yield %4 : i32 +// CHECK-NEXT: } +// CHECK-NEXT: func.call @trace_loop_iter_end() : () -> () +// CHECK-NEXT: affine.yield %2 : i32 +// CHECK-NEXT: } +// CHECK-NEXT: return %0 : i32 +// CHECK-NEXT: } + + func.func @affine_for_and_scf_for() -> i32 { + %sum_init = arith.constant 0 : i32 + %begin = arith.constant 0 : index + %end = arith.constant 5 : index + %step = arith.constant 1 : index + + %outer_result = affine.for %i = 0 to 10 iter_args(%outer_sum = %sum_init) -> i32 { + %i_i32 = arith.index_cast %i : index to i32 + + %inner_result = scf.for %j = %begin to %end step %step iter_args(%inner_sum = %outer_sum) -> i32 { + %j_i32 = arith.index_cast %j : index to i32 + %sum = arith.addi %inner_sum, %j_i32 : i32 + scf.yield %sum : i32 + } + affine.yield %inner_result : i32 + } + return %outer_result : i32 + } + +// CHECK: func.func @affine_for_and_scf_while() -> i32 { +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 +// CHECK-NEXT: %c0_i32_0 = arith.constant 0 : i32 +// CHECK-NEXT: %c5_i32 = arith.constant 5 : i32 +// CHECK-NEXT: %0 = affine.for %arg0 = 0 to 10 iter_args(%arg1 = %c0_i32) -> (i32) { +// CHECK-NEXT: func.call @trace_loop_iter_begin() : () -> () +// CHECK-NEXT: %1 = arith.index_cast %arg0 : index to i32 +// CHECK-NEXT: %2:2 = scf.while (%arg2 = %arg1, %arg3 = %c0_i32_0) : (i32, i32) -> (i32, i32) { +// CHECK-NEXT: func.call @trace_loop_iter_begin() : () -> () +// CHECK-NEXT: %3 = arith.cmpi slt, %arg2, %c5_i32 : i32 +// CHECK-NEXT: scf.condition(%3) %arg2, %arg3 : i32, i32 +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%arg2: i32, %arg3: i32): +// CHECK-NEXT: %3 = arith.addi %arg2, %arg3 : i32 +// CHECK-NEXT: %c1_i32 = arith.constant 1 : i32 +// CHECK-NEXT: %4 = arith.addi %arg3, %c1_i32 : i32 +// CHECK-NEXT: func.call @trace_loop_iter_end() : () -> () +// CHECK-NEXT: scf.yield %3, %4 : i32, i32 +// CHECK-NEXT: } +// CHECK-NEXT: func.call @trace_loop_iter_end() : () -> () +// CHECK-NEXT: affine.yield %2#0 : i32 +// CHECK-NEXT: } +// CHECK-NEXT: return %0 : i32 +// CHECK-NEXT: } + + func.func @affine_for_and_scf_while() -> i32 { + %sum_init = arith.constant 0 : i32 + %j_init = arith.constant 0 : i32 + %sum_limit = arith.constant 5 : i32 + + %outer_result = affine.for %i = 0 to 10 iter_args(%outer_sum = %sum_init) -> i32 { + %i_i32 = arith.index_cast %i : index to i32 + %inner_result:2 = scf.while (%inner_sum = %outer_sum, %j = %j_init) : (i32, i32) -> (i32, i32) { + %cmp = arith.cmpi slt, %inner_sum, %sum_limit : i32 + scf.condition(%cmp) %inner_sum, %j : i32, i32 + } do { + ^bb0(%sum_arg: i32, %j_arg: i32): + %sum = arith.addi %sum_arg, %j_arg : i32 + %step = arith.constant 1 : i32 + %new_j = arith.addi %j_arg, %step : i32 + scf.yield %sum, %new_j : i32, i32 + } + affine.yield %inner_result#0 : i32 + } + return %outer_result : i32 + } + +// CHECK: func.func @scf_for_and_scf_while() -> i32 { +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 +// CHECK-NEXT: %c0 = arith.constant 0 : index +// CHECK-NEXT: %c5 = arith.constant 5 : index +// CHECK-NEXT: %c1 = arith.constant 1 : index +// CHECK-NEXT: %c0_i32_0 = arith.constant 0 : i32 +// CHECK-NEXT: %c5_i32 = arith.constant 5 : i32 +// CHECK-NEXT: %0 = scf.for %arg0 = %c0 to %c5 step %c1 iter_args(%arg1 = %c0_i32) -> (i32) { +// CHECK-NEXT: func.call @trace_loop_iter_begin() : () -> () +// CHECK-NEXT: %c0_i32_1 = arith.constant 0 : i32 +// CHECK-NEXT: %1:2 = scf.while (%arg2 = %arg1, %arg3 = %c0_i32_0) : (i32, i32) -> (i32, i32) { +// CHECK-NEXT: func.call @trace_loop_iter_begin() : () -> () +// CHECK-NEXT: %2 = arith.cmpi slt, %arg2, %c5_i32 : i32 +// CHECK-NEXT: scf.condition(%2) %arg2, %arg3 : i32, i32 +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%arg2: i32, %arg3: i32): +// CHECK-NEXT: %2 = arith.addi %arg2, %arg3 : i32 +// CHECK-NEXT: %c1_i32 = arith.constant 1 : i32 +// CHECK-NEXT: %3 = arith.addi %arg3, %c1_i32 : i32 +// CHECK-NEXT: func.call @trace_loop_iter_end() : () -> () +// CHECK-NEXT: scf.yield %2, %3 : i32, i32 +// CHECK-NEXT: } +// CHECK-NEXT: func.call @trace_loop_iter_end() : () -> () +// CHECK-NEXT: scf.yield %1#0 : i32 +// CHECK-NEXT: } +// CHECK-NEXT: return %0 : i32 +// CHECK-NEXT: } + + func.func @scf_for_and_scf_while() -> i32 { + %sum_init = arith.constant 0 : i32 + %begin = arith.constant 0 : index + %end = arith.constant 5 : index + %step = arith.constant 1 : index + %j_init = arith.constant 0 : i32 + %sum_limit = arith.constant 5 : i32 + + %outer_result = scf.for %i = %begin to %end step %step iter_args(%outer_sum = %sum_init) -> i32 { + %counter = arith.constant 0 : i32 + %inner_result:2 = scf.while (%inner_sum = %outer_sum, %j = %j_init) : (i32, i32) -> (i32, i32) { + %cmp = arith.cmpi slt, %inner_sum, %sum_limit : i32 + scf.condition(%cmp) %inner_sum, %j : i32, i32 + } do { + ^bb0(%sum_arg: i32, %j_arg: i32): + %sum = arith.addi %sum_arg, %j_arg : i32 + %step_j = arith.constant 1 : i32 + %new_j = arith.addi %j_arg, %step_j : i32 + scf.yield %sum, %new_j : i32, i32 + } + scf.yield %inner_result#0 : i32 + } + return %outer_result : i32 + } + +// CHECK: func.func private @trace_loop_iter_begin() +// CHECK-NEXT: func.func private @trace_loop_iter_end() + +} From 324563e4a86c3398730b91df625fe52806aaef8c Mon Sep 17 00:00:00 2001 From: MatveyK Date: Sun, 18 May 2025 12:25:33 +0300 Subject: [PATCH 2/2] fix clang --- .../kurakin_trace_loop/kurakin_trace_loop.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/compiler-course/kurakin_trace_loop/kurakin_trace_loop.cpp b/mlir/compiler-course/kurakin_trace_loop/kurakin_trace_loop.cpp index 8d50251db8932..898f673f70a1b 100644 --- a/mlir/compiler-course/kurakin_trace_loop/kurakin_trace_loop.cpp +++ b/mlir/compiler-course/kurakin_trace_loop/kurakin_trace_loop.cpp @@ -13,7 +13,7 @@ struct TraceLoopPass : public mlir::PassWrapper> { void insertTraceLoop(mlir::Block &blockBefore, mlir::Block &blockAfter, - mlir::Location loc, mlir::OpBuilder &opBuilder) { + mlir::Location loc, mlir::OpBuilder &opBuilder) { opBuilder.setInsertionPointToStart(&blockBefore); opBuilder.create(loc, "trace_loop_iter_begin", mlir::TypeRange(), mlir::ValueRange()); @@ -58,13 +58,13 @@ struct TraceLoopPass module.walk([&](mlir::Operation *op) { if (auto affineForOp = mlir::dyn_cast(op)) { insertTraceLoop(*affineForOp.getBody(), *affineForOp.getBody(), - affineForOp->getLoc(), opBuilder); + affineForOp->getLoc(), opBuilder); } else if (auto forOp = mlir::dyn_cast(op)) { insertTraceLoop(*forOp.getBody(), *forOp.getBody(), forOp->getLoc(), - opBuilder); + opBuilder); } else if (auto whileOp = mlir::dyn_cast(op)) { insertTraceLoop(whileOp.getBefore().front(), whileOp.getAfter().front(), - whileOp->getLoc(), opBuilder); + whileOp->getLoc(), opBuilder); } }); }