diff --git a/mlir/compiler-course/chizhov_m_loop_pass/CMakeLists.txt b/mlir/compiler-course/chizhov_m_loop_pass/CMakeLists.txt new file mode 100644 index 0000000000000..86b5f7d3fa118 --- /dev/null +++ b/mlir/compiler-course/chizhov_m_loop_pass/CMakeLists.txt @@ -0,0 +1,16 @@ +set(Title "LoopPassBeginEnd") +set(Student "Chizhov_Maxim") +set(Group "FIIT3") +set(TARGET_NAME "${Title}_${Student}_${Group}_MLIR") + +file(GLOB_RECURSE SOURCES *.cpp *.h *.hpp) + +add_llvm_pass_plugin(${TARGET_NAME} + ${SOURCES} + DEPENDS + intrinsics_gen + MLIRBuiltinLocationAttributesIncGen + BUILDTREE_ONLY +) + +set(MLIR_TEST_DEPENDS ${TARGET_NAME} ${MLIR_TEST_DEPENDS} PARENT_SCOPE) diff --git a/mlir/compiler-course/chizhov_m_loop_pass/chizhov_m_loop_pass.cpp b/mlir/compiler-course/chizhov_m_loop_pass/chizhov_m_loop_pass.cpp new file mode 100644 index 0000000000000..41655a8e47a32 --- /dev/null +++ b/mlir/compiler-course/chizhov_m_loop_pass/chizhov_m_loop_pass.cpp @@ -0,0 +1,76 @@ +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Tools/Plugins/PassPlugin.h" + +using namespace mlir; + +namespace { +class TraceLoopIterPass + : public PassWrapper> { +public: + StringRef getArgument() const final { + return "LoopPassBeginEnd_Chizhov_Maxim_FIIT3_MLIR"; + } + + StringRef getDescription() const final { + return "Inserts `@trace_loop_iter_begin` and `@trace_loop_iter_end` calls " + "into each loop iteration"; + } + + void insertTraceCalls(Block &body, OpBuilder &builder, Location loc) { + builder.setInsertionPointToStart(&body); + builder.create(loc, + builder.getStringAttr("trace_loop_iter_begin"), + TypeRange{}, ValueRange{}); + + builder.setInsertionPoint(body.getTerminator()); + builder.create(loc, + builder.getStringAttr("trace_loop_iter_end"), + TypeRange{}, ValueRange{}); + } + + void processAffineFor(affine::AffineForOp op, OpBuilder &builder) { + insertTraceCalls(*op.getBody(), builder, op.getLoc()); + } + + void processSCFFor(scf::ForOp op, OpBuilder &builder) { + insertTraceCalls(*op.getBody(), builder, op.getLoc()); + } + + void processSCFWhile(scf::WhileOp op, OpBuilder &builder) { + Block &afterBlock = op.getAfter().front(); + insertTraceCalls(afterBlock, builder, op.getLoc()); + } + + void runOnOperation() override { + ModuleOp moduleOp = getOperation(); + OpBuilder builder(moduleOp.getContext()); + + moduleOp.walk([&](Operation *op) { + if (auto affineFor = dyn_cast(op)) { + processAffineFor(affineFor, builder); + } else if (auto scfFor = dyn_cast(op)) { + processSCFFor(scfFor, builder); + } else if (auto scfWhile = dyn_cast(op)) { + processSCFWhile(scfWhile, builder); + } + }); + } +}; +} // namespace + +MLIR_DECLARE_EXPLICIT_TYPE_ID(TraceLoopIterPass) +MLIR_DEFINE_EXPLICIT_TYPE_ID(TraceLoopIterPass) + +mlir::PassPluginLibraryInfo getFunctionCallCounterPassPluginInfo() { + return {MLIR_PLUGIN_API_VERSION, "LoopPassBeginEnd_Chizhov_Maxim_FIIT3_MLIR", + "1.0", []() { mlir::PassRegistration(); }}; +} + +extern "C" LLVM_ATTRIBUTE_WEAK mlir::PassPluginLibraryInfo +mlirGetPassPluginInfo() { + return getFunctionCallCounterPassPluginInfo(); +} diff --git a/mlir/test/compiler-course/chizhov_m_loop_pass/test.mlir b/mlir/test/compiler-course/chizhov_m_loop_pass/test.mlir new file mode 100644 index 0000000000000..9bb70f7e9839f --- /dev/null +++ b/mlir/test/compiler-course/chizhov_m_loop_pass/test.mlir @@ -0,0 +1,77 @@ +// RUN: mlir-opt -load-pass-plugin=%mlir_lib_dir/LoopPassBeginEnd_Chizhov_Maxim_FIIT3_MLIR%shlibext \ +// RUN: --pass-pipeline="builtin.module(LoopPassBeginEnd_Chizhov_Maxim_FIIT3_MLIR)" %s | FileCheck %s + +module { + func.func private @trace_loop_iter_begin() -> () + func.func private @trace_loop_iter_end() -> () + + // CHECK-LABEL: func.func @affine_loop + func.func @affine_loop() { + %c0 = arith.constant 0 : index + %c10 = arith.constant 10 : index + + affine.for %i = %c0 to %c10 { + // CHECK: call @trace_loop_iter_begin + "test.op"() : () -> () + // CHECK: call @trace_loop_iter_end + } + return + } +} + +module { + func.func private @trace_loop_iter_begin() -> () + func.func private @trace_loop_iter_end() -> () + + // CHECK-LABEL: func.func @scf_loop + func.func @scf_loop() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + + scf.for %i = %c0 to %c10 step %c1 { + // CHECK: call @trace_loop_iter_begin + "test.op"() : () -> () + // CHECK: call @trace_loop_iter_end + } + + return + } +} +// CHECK-LABEL: func.func @scf_while +// CHECK-NEXT: %c0 = arith.constant 0 : index +// CHECK-NEXT: %c1 = arith.constant 1 : index +// CHECK-NEXT: %c10 = arith.constant 10 : index +// CHECK-NEXT: %0 = scf.while (%arg0 = %c0) : (index) -> index { +// CHECK-NEXT: %1 = arith.cmpi slt, %arg0, %c10 : index +// CHECK-NEXT: scf.condition(%1) %arg0 : index +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%arg0: index): +// CHECK-NEXT: call @trace_loop_iter_begin() : () -> () +// CHECK-NEXT: "test.op"() : () -> () +// CHECK-NEXT: %1 = arith.addi %arg0, %c1 : index +// CHECK-NEXT: call @trace_loop_iter_end() : () -> () +// CHECK-NEXT: scf.yield %1 : index + +module { + func.func private @trace_loop_iter_begin() -> () + func.func private @trace_loop_iter_end() -> () + + func.func @scf_while() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + + %init = scf.while (%i = %c0) : (index) -> (index) { + %cond = arith.cmpi "slt", %i, %c10 : index + scf.condition(%cond) %i : index + } do { + ^bb0(%i_in: index): + "test.op"() : () -> () + %inc = arith.addi %i_in, %c1 : index + scf.yield %inc : index + } + + return + } +} \ No newline at end of file