diff --git a/CMakeLists.txt b/CMakeLists.txt index 7ae68d38..bc827b37 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,7 +23,7 @@ add_subdirectory(test) add_subdirectory(tools/triton-shared-opt) if (TRITON_SHARED_BUILD_CPU_BACKEND) - add_triton_plugin(TritonShared ${CMAKE_CURRENT_SOURCE_DIR}/triton_shared.cc LINK_LIBS TritonSharedAnalysis TritonToLinalg TritonTilingExtIR) + add_triton_plugin(TritonShared ${CMAKE_CURRENT_SOURCE_DIR}/triton_shared.cc LINK_LIBS TritonSharedAnalysis TritonTilingExtIR) target_link_libraries(TritonShared PRIVATE Python3::Module pybind11::headers) endif() diff --git a/README.md b/README.md index 8d405534..9802ee7d 100644 --- a/README.md +++ b/README.md @@ -95,7 +95,7 @@ As part of the conversion process, there are three important analyses: ### Conversion strategy -We introduce the `TritonToLinalg` pass that converts the `triton` dialect to the `linalg` dialect on *tensors*. This means the resulting IR is fully compatible with `linalg` tiling and fusion transformation passes. As mentioned in the `Pointer analysis`'s description, we do however have to deal with memref instructions at the load and store boundaries and have to convert them to tensors using `bufferization.to_tensor`. Here's a simple example of what the IR looks like: +We introduce the `TritonToLinalgExperimental` pass that converts the `triton` dialect to the `linalg` dialect on *tensors*. This means the resulting IR is fully compatible with `linalg` tiling and fusion transformation passes. As mentioned in the `Pointer analysis`'s description, we do however have to deal with memref instructions at the load and store boundaries and have to convert them to tensors using `bufferization.to_tensor`. Here's a simple example of what the IR looks like: ```mlir tt.func @kernel(%afloat : !tt.ptr, %res : !tt.ptr) { diff --git a/include/triton-shared/Conversion/CMakeLists.txt b/include/triton-shared/Conversion/CMakeLists.txt index a4a03949..f8e180a7 100644 --- a/include/triton-shared/Conversion/CMakeLists.txt +++ b/include/triton-shared/Conversion/CMakeLists.txt @@ -1,4 +1,3 @@ -add_subdirectory(TritonToLinalg) add_subdirectory(TritonToLinalgExperimental) add_subdirectory(TritonToStructured) add_subdirectory(TritonArithToLinalg) diff --git a/include/triton-shared/Conversion/TritonToLinalg/CMakeLists.txt b/include/triton-shared/Conversion/TritonToLinalg/CMakeLists.txt deleted file mode 100644 index 74ccdd39..00000000 --- a/include/triton-shared/Conversion/TritonToLinalg/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -#===------------------------------------------------------------------------===# -# -# Copyright (c) Triton Project Contributors. -# -#===------------------------------------------------------------------------===# - -set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToLinalg) -add_public_tablegen_target(TritonToLinalgConversionPassIncGen) diff --git a/include/triton-shared/Conversion/TritonToLinalg/Passes.h b/include/triton-shared/Conversion/TritonToLinalg/Passes.h deleted file mode 100644 index 404af080..00000000 --- a/include/triton-shared/Conversion/TritonToLinalg/Passes.h +++ /dev/null @@ -1,22 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. -// -//===----------------------------------------------------------------------===// - -#ifndef TRITON_TO_LINALG_CONVERSION_PASSES_H -#define TRITON_TO_LINALG_CONVERSION_PASSES_H - -#include "triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h" - -namespace mlir { -namespace triton { - -#define GEN_PASS_REGISTRATION -#include "triton-shared/Conversion/TritonToLinalg/Passes.h.inc" - -} // namespace triton -} // namespace mlir - -#endif diff --git a/include/triton-shared/Conversion/TritonToLinalg/Passes.td b/include/triton-shared/Conversion/TritonToLinalg/Passes.td deleted file mode 100644 index 627077e3..00000000 --- a/include/triton-shared/Conversion/TritonToLinalg/Passes.td +++ /dev/null @@ -1,18 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. -// -//===----------------------------------------------------------------------===// - -#ifndef TRITON_TO_LINALG_CONVERSION_PASSES -#define TRITON_TO_LINALG_CONVERSION_PASSES - -include "mlir/Pass/PassBase.td" - -def TritonToLinalg : Pass<"triton-to-linalg", "mlir::ModuleOp"> { - let summary = "Convert Triton to Linalg dialect"; - let constructor = "triton::createTritonToLinalgPass()"; -} - -#endif diff --git a/include/triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h b/include/triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h deleted file mode 100644 index 4c58e992..00000000 --- a/include/triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h +++ /dev/null @@ -1,33 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. -// -//===----------------------------------------------------------------------===// - -#ifndef TRITON_CONVERSION_TRITONTOLINALG_TRITONTOLINALG_H -#define TRITON_CONVERSION_TRITONTOLINALG_TRITONTOLINALG_H - -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "triton/Dialect/Triton/IR/Dialect.h" - -namespace mlir { -namespace triton { - -std::unique_ptr> createTritonToLinalgPass(); - -void populateTritonToLinalgCanonicalizationPatterns( - RewritePatternSet &patterns); - -void populateTritonToLinalgConversionPatterns(TypeConverter &typeConverter, - RewritePatternSet &patterns, - unsigned int launchGridRank); - -} // namespace triton -} // namespace mlir - -#endif // TRITON_CONVERSION_TRITONTOLINALG_TRITONTOLINALG_H diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 358b4f92..2a591e97 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,4 +1,3 @@ -add_subdirectory(TritonToLinalg) add_subdirectory(TritonToLinalgExperimental) add_subdirectory(TritonToStructured) add_subdirectory(TritonToUnstructured) diff --git a/lib/Conversion/TritonToLinalg/CMakeLists.txt b/lib/Conversion/TritonToLinalg/CMakeLists.txt deleted file mode 100644 index acc3c4fb..00000000 --- a/lib/Conversion/TritonToLinalg/CMakeLists.txt +++ /dev/null @@ -1,27 +0,0 @@ -#===------------------------------------------------------------------------===# -# -# Copyright (c) Triton Project Contributors. -# -#===------------------------------------------------------------------------===# - -add_triton_library(TritonToLinalg - TritonToLinalg.cpp - TritonToLinalgPass.cpp - - DEPENDS - TritonToLinalgConversionPassIncGen - - LINK_LIBS PUBLIC - TritonTilingExtIR - MLIRArithDialect - MLIRDialectUtils - MLIRIR - MLIRMathDialect - MLIRPass - MLIRTensorDialect - MLIRTransforms - MLIRSupport - TritonIR - TritonTransforms - TritonSharedAnalysis -) diff --git a/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp b/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp deleted file mode 100644 index 1c8ed9cf..00000000 --- a/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp +++ /dev/null @@ -1,95 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. -// -//===----------------------------------------------------------------------===// - -#include "llvm/ADT/SmallVectorExtras.h" -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/MathExtras.h" - -#include "triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h" - -#define DEBUG_TYPE "triton-to-linalg" -#include "triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp" - -using namespace mlir; -using namespace triton; - -#define GEN_PASS_CLASSES -#include "triton-shared/Conversion/TritonToLinalg/Passes.h.inc" - -void mlir::triton::populateTritonToLinalgCanonicalizationPatterns( - RewritePatternSet &patterns) { - patterns.add, MinMaxConverter>( - patterns.getContext()); -} - -void mlir::triton::populateTritonToLinalgConversionPatterns( - TypeConverter &typeConverter, RewritePatternSet &patterns, - unsigned int launchGridRank) { - populateFunctionOpInterfaceTypeConversionPattern( - patterns, typeConverter); - populateFunctionOpInterfaceTypeConversionPattern( - patterns, typeConverter); - - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add( - patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - - populateExternElementwiseOpToMLIROps(patterns); - - // Reduce converters - // Triton's reduce op is idential to linalg.reduce op, so we can clone - // `tt.reduce` body to `linalg.reduce`. Unfortunately, we still need to - // perform pattern matching to know what reduce ops we are dealing with - // so that we know how to initialize the initial reduce values correctly. - // - // We can do this in a generic way without pattern matching by always using - // the first elements along the reduction axis and perform the reduction on - // the remaining elements. However, this results in creatings sub-tensors that - // aren't always multiple of 2s, which are sub-optimal for certain hardwares. - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); - - // Note: the ordering here matters! - // MetaOpConverter has PatternBenefit == 10 which should take precedence over - // these linalg patterns, but to be safe, add these patterns last so that they - // will be tried last. Incorrect ordering or having MetaOpConverter has lower - // PatternBenefit will result in element-wise meta ops being converted to - // linalg.generic ops. - linalg::populateElementwiseToLinalgConversionPatterns(patterns); -} diff --git a/lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp b/lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp deleted file mode 100644 index 25b7db85..00000000 --- a/lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp +++ /dev/null @@ -1,229 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. -// -//===----------------------------------------------------------------------===// - -#include "triton-shared/Analysis/UseAnalysis.h" -#include "triton-shared/Conversion/TritonToLinalg/TritonToLinalg.h" -#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h" - -#include "triton/Dialect/Triton/IR/Dialect.h" - -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/Passes.h" - -#include "llvm/Support/Debug.h" - -#define DEBUG_TYPE "triton-to-linalg" - -using namespace mlir; -using namespace triton; - -#define GEN_PASS_CLASSES -#include "triton-shared/Conversion/TritonToLinalg/Passes.h.inc" - -namespace { - -class TritonTypeConverter : public TypeConverter { -public: - TritonTypeConverter() { - // The order of type conversion is important: later ones are tried earlier. - addConversion([](Type type) { return type; }); - addConversion([](triton::PointerType ptrType) { - return UnrankedMemRefType::get(ptrType.getPointeeType(), 0); - }); - addConversion([](TensorType tensorType) -> Type { - auto elemType = tensorType.getElementType(); - if (auto ptrType = dyn_cast(elemType)) { - elemType = ptrType.getPointeeType(); - } - return MemRefType::get(tensorType.getShape(), elemType); - }); - } -}; - -class TritonToLinalgPass : public TritonToLinalgBase { - - static auto constexpr LAUNCH_GRID_RANK = getMaxEnumValForProgramIDDim() + 1; - static unsigned int constexpr TRITON_PROGRAM_INFO_ARG_COUNT = - LAUNCH_GRID_RANK * 2; - - // Add additional I32 arguments to represent: - // - num_programs, 3 in total, one for each axis of the launch grid - // - program_id, 3 in total, one for each axis of the launch grid - static void addProgramInfo(triton::FuncOp func) { - OpBuilder b(func); - - auto origFuncType = func.getFunctionType(); - auto origInputTypes = origFuncType.getInputs(); - SmallVector newInputTypes(origInputTypes); - newInputTypes.append(TRITON_PROGRAM_INFO_ARG_COUNT, b.getI32Type()); - - auto newFuncType = - b.getFunctionType(newInputTypes, origFuncType.getResults()); - - func.setFunctionType(newFuncType); - - // Add empty attributes for each new argument if needed - if (func.getAllArgAttrs()) { - SmallVector newArgAttrs; - func.getAllArgAttrs(newArgAttrs); - newArgAttrs.append(TRITON_PROGRAM_INFO_ARG_COUNT, DictionaryAttr()); - func.setAllArgAttrs(newArgAttrs); - } - - // Add the corresponding arguments to function body - for (unsigned int i = 0; i < TRITON_PROGRAM_INFO_ARG_COUNT; i++) { - func.getBody().front().addArgument(b.getI32Type(), func.getLoc()); - } - } - -public: - void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); - } - - void runOnOperation() override { - auto moduleOp = getOperation(); - - { - RewritePatternSet patterns(&getContext()); - populateTritonToLinalgCanonicalizationPatterns(patterns); - if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) { - signalPassFailure(); - } - } - - moduleOp.walk([this](triton::FuncOp op) { - if (failed(runUseAnalysis(op))) { - signalPassFailure(); - } - }); - - RewritePatternSet patterns(&getContext()); - ConversionTarget target(getContext()); - TritonTypeConverter tritonTypeConverter; - - target.addLegalDialect< - func::FuncDialect, arith::ArithDialect, math::MathDialect, - linalg::LinalgDialect, affine::AffineDialect, scf::SCFDialect, - cf::ControlFlowDialect, tensor::TensorDialect, - bufferization::BufferizationDialect, memref::MemRefDialect, - ttx::TritonTilingExtDialect>(); - - target.addLegalOp(); - - // Update function signature to use memrefs - target.addDynamicallyLegalOp([&](triton::FuncOp op) { - return tritonTypeConverter.isSignatureLegal(op.getFunctionType()); - }); - - // Lower dense constant to linalg.fill - target.addDynamicallyLegalOp([](arith::ConstantOp op) { - if (!isa(op.getResult().getType())) { - return true; - } - - if (auto denseAttr = dyn_cast(op.getValue())) { - if (denseAttr.isSplat() && - isa(denseAttr.getElementType())) { - return false; - } - } - return true; - }); - - target.addDynamicallyLegalOp([](Operation *op) { - return llvm::all_of(op->getOperandTypes(), [](Type t) { - if (isa(t)) { - return false; - } - if (auto shapedType = dyn_cast(t)) { - return shapedType.getElementType().isIntOrFloat(); - } - assert(t.isIntOrIndexOrFloat()); - return true; - }); - }); - - target.addDynamicallyLegalDialect( - [](Operation *op) { - if (op->hasAttr("MetaUse")) { - return false; - } - - if (isa(op)) { - return true; - } - - bool operateOnTensors = - llvm::all_of(op->getOperandTypes(), [](Type type) { - return isa(type); - }); - - return !operateOnTensors; - }); - - triton::populateTritonToLinalgConversionPatterns( - tritonTypeConverter, patterns, LAUNCH_GRID_RANK); - - for (auto func : getOperation().getOps()) - addProgramInfo(func); - - if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) - signalPassFailure(); - - // Convert tt.func and tt.return into func's counterparts - moduleOp.walk([&](triton::FuncOp func) { - OpBuilder builder(func); - - auto name = func.getName(); - auto type = func.getFunctionType(); - - SmallVector argAttrs, resAttrs; - func.getAllArgAttrs(argAttrs); - func.getAllResultAttrs(resAttrs); - - auto funcFunc = builder.create(func.getLoc(), name, type); - funcFunc.setAllArgAttrs(argAttrs); - funcFunc.setAllResultAttrs(resAttrs); - - auto &funcFuncBody = funcFunc.getBody(); - auto &funcBody = func.getBody(); - - IRMapping map; - funcBody.cloneInto(&funcFuncBody, map); - - for (Block &block : funcFuncBody.getBlocks()) { - auto term = block.getTerminator(); - builder.setInsertionPoint(term); - builder.create(func.getLoc(), term->getOperands()); - term->erase(); - } - func.erase(); - }); - - // Erase dead code and fold constants created during lowering - PassManager pm(&getContext(), moduleOp.getOperationName()); - pm.addPass(createCanonicalizerPass()); - if (failed(runPipeline(pm, getOperation()))) { - signalPassFailure(); - } - } -}; -} // namespace - -std::unique_ptr> triton::createTritonToLinalgPass() { - return std::make_unique(); -} diff --git a/test/Conversion/TritonToLinalg/addptr_2d_example.mlir b/test/Conversion/TritonToLinalg/addptr_2d_example.mlir deleted file mode 100644 index f0f7d1c7..00000000 --- a/test/Conversion/TritonToLinalg/addptr_2d_example.mlir +++ /dev/null @@ -1,69 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr, - %arg2 : !tt.ptr, - %arg3 : i32 - ) - { - %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - // offset = 0, size = 4, stride = 1 - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - // offset = [0,0], size = [4,1], stride = [1,0] - %2 = tt.broadcast %1 : tensor<4x1xi32> -> tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [1,0] - %arg3splat = tt.splat %arg3 : i32 -> tensor<4x256xi32> - %offset3 = arith.addi %2, %arg3splat : tensor<4x256xi32> - // offset = [%arg3,0], size = [4,256], stride = [1,0] - %3 = tt.make_range {end = 256 : i32, start = 0 : i32}: tensor<256xi32> - // offset = 0, size = 256, stride = 1 - %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - // offset = [0,0], size = [1,256], stride = [0,1] - %5 = tt.broadcast %4 : tensor<1x256xi32> -> tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [0,1] - %6 = arith.constant 5 : i32 - %splat6 = tt.splat %6 : i32 -> tensor<4x256xi32> - %scale5 = arith.muli %5, %splat6 : tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [0,5] - %7 = arith.addi %offset3, %scale5: tensor<4x256xi32> - // offset = [%arg3, 0], size = [4, 256], stride = [1, 5] - %8 = tt.splat %arg0 : !tt.ptr -> tensor<4x256x!tt.ptr> - %9 = tt.addptr %8, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source: %arg0, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] - %10 = tt.load %9 : tensor<4x256x!tt.ptr> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<4x256x!tt.ptr> - %12 = tt.addptr %11, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source: %arg1, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] - %13 = tt.load %12 : tensor<4x256x!tt.ptr> - %14 = arith.addf %10, %13 : tensor<4x256xbf16> - %15 = tt.splat %arg2 : !tt.ptr -> tensor<4x256x!tt.ptr> - %16 = tt.addptr %15, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source: %arg2, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] - tt.store %16, %14 : tensor<4x256x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: memref<*xbf16>, %[[VAL_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32) { -// CHECK: %[[VAL_7:.*]] = arith.constant 5 : index -// CHECK: %[[VAL_8:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_9:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_8]]], sizes: [4, 256], strides: [1, %[[VAL_7]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: %[[VAL_10:.*]] = memref.alloc() : memref<4x256xbf16> -// CHECK: memref.copy %[[VAL_9]], %[[VAL_10]] : memref<4x256xbf16, strided<[1, ?], offset: ?>> to memref<4x256xbf16> -// CHECK: %[[VAL_11:.*]] = bufferization.to_tensor %[[VAL_10]] restrict writable : memref<4x256xbf16> -// CHECK: %[[VAL_12:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_12]]], sizes: [4, 256], strides: [1, %[[VAL_7]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<4x256xbf16> -// CHECK: memref.copy %[[VAL_13]], %[[VAL_14]] : memref<4x256xbf16, strided<[1, ?], offset: ?>> to memref<4x256xbf16> -// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_14]] restrict writable : memref<4x256xbf16> -// CHECK: %[[VAL_16:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_11]], %[[VAL_15]] : tensor<4x256xbf16>, tensor<4x256xbf16>) outs(%[[VAL_11]] : tensor<4x256xbf16>) { -// CHECK: ^bb0(%[[VAL_17:.*]]: bf16, %[[VAL_18:.*]]: bf16, %[[VAL_19:.*]]: bf16): -// CHECK: %[[VAL_20:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : bf16 -// CHECK: linalg.yield %[[VAL_20]] : bf16 -// CHECK: } -> tensor<4x256xbf16> -// CHECK: %[[VAL_21:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_22:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: {{\[}}%[[VAL_21]]], sizes: [4, 256], strides: [1, %[[VAL_7]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: bufferization.materialize_in_destination %[[VAL_23:.*]] in writable %[[VAL_22]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_add_value.mlir b/test/Conversion/TritonToLinalg/addptr_add_value.mlir deleted file mode 100644 index 0ed60796..00000000 --- a/test/Conversion/TritonToLinalg/addptr_add_value.mlir +++ /dev/null @@ -1,68 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr, - %arg2 : i32, - %arg3 : i32 - ) - { - %0 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> - // offset = 0, size = 4, stride = 1 - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - // offset = [0,0], size = [4,1], stride = [1,0] - %2 = tt.broadcast %1 : tensor<4x1xi32> -> tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [1,0] - %arg2splat = tt.splat %arg2 : i32 -> tensor<4x256xi32> - %offset2 = arith.addi %2, %arg2splat : tensor<4x256xi32> - // offset = [%arg2,0], size = [4,256], stride = [1,0] - %arg3splat = tt.splat %arg3 : i32 -> tensor<4x256xi32> - %offset3 = arith.addi %offset2, %arg3splat : tensor<4x256xi32> - // offset = [%arg2+%arg3,0], size = [4,256], stride = [1,0] - %c10 = arith.constant 10 : i32 - %c10splat = tt.splat %c10 : i32 -> tensor<4x256xi32> - %offset4 = arith.addi %offset3, %c10splat : tensor<4x256xi32> - // offset = [%arg2+%arg3+10,0], size = [4,256], stride = [1,0] - %3 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32> - // offset = 0, size = 256, stride = 1 - %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - // offset = [0,0], size = [1,256], stride = [0,1] - %5 = tt.broadcast %4 : tensor<1x256xi32> -> tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [0,1] - %c6 = arith.constant 6 : i32 - %splat6 = tt.splat %c6 : i32 -> tensor<4x256xi32> - %scale5 = arith.muli %5, %splat6 : tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [0,6] - %7 = arith.addi %offset4, %scale5: tensor<4x256xi32> - // offset = [%arg2+%arg3+10, 0], size = [4, 256], stride = [1, 6] - %8 = tt.splat %arg0 : !tt.ptr -> tensor<4x256x!tt.ptr> - %9 = tt.addptr %8, %7 : tensor<4x256x!tt.ptr>,tensor<4x256xi32> - // source = %arg0, offset = [%arg2+%arg3+10, 0], size = [4, 256], stride = [1, 6] - %10 = tt.splat %arg1 : !tt.ptr -> tensor<4x256x!tt.ptr> - %11 = tt.addptr %10, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source = %arg1, offset = [%arg2+%arg3+10, 0], size = [4, 256], stride = [1, 6] - %12 = tt.load %9 : tensor<4x256x!tt.ptr> - tt.store %11, %12 : tensor<4x256x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32) { -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 6 : index -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 10 : index -// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_2]] : i32 to index -// CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_9]], %[[VAL_10]] : index -// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_8]] : index -// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_12]]], sizes: [4, 256], strides: [1, %[[VAL_7]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: %[[VAL_14:.*]] = arith.index_cast %[[VAL_2]] : i32 to index -// CHECK: %[[VAL_15:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_15]] : index -// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_8]] : index -// CHECK: %[[VAL_18:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_17]]], sizes: [4, 256], strides: [1, %[[VAL_7]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: %[[VAL_19:.*]] = memref.alloc() : memref<4x256xbf16> -// CHECK: memref.copy %[[VAL_13]], %[[VAL_19]] : memref<4x256xbf16, strided<[1, ?], offset: ?>> to memref<4x256xbf16> -// CHECK: %[[VAL_20:.*]] = bufferization.to_tensor %[[VAL_19]] restrict writable : memref<4x256xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_20]] in writable %[[VAL_18]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_dim1.mlir b/test/Conversion/TritonToLinalg/addptr_dim1.mlir deleted file mode 100644 index 0e314fa4..00000000 --- a/test/Conversion/TritonToLinalg/addptr_dim1.mlir +++ /dev/null @@ -1,113 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -// XFAIL: * -// This test crashes because tt.broadcast's folder tries to cast -// the src operand to a RankedTensorType value, but the TritonToLinalg -// pass has already replaced the src with a value of a different type. -// We're going to retire the monolith triton-to-linalg pass which prevents -// this problem. xfailing the test for now. - -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : i32 - ) - { - %0 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32> - %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - - %splat_arg0 = tt.splat %arg0 : !tt.ptr -> tensor<1x256x!tt.ptr> - %2 = tt.addptr %splat_arg0, %1 : tensor<1x256x!tt.ptr>, tensor<1x256xi32> - - // 1x256 pointer should have meaningful stride in outer dimension - %3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<1x256x!tt.ptr> - - %4 = tt.splat %arg1 : i32 -> tensor<1x256xi32> - // 1x256 pointer should have meaningful stride in outer dimension - %5 = tt.addptr %2, %4 : tensor<1x256x!tt.ptr>, tensor<1x256xi32> - tt.store %5, %3 : tensor<1x256x!tt.ptr>, tensor<1x256x!tt.ptr> - - %10 = arith.constant 0.0 : bf16 - %11 = tt.splat %10 : bf16 -> tensor<4x256xbf16> - - %c0 = arith.constant 0 : index - %c12 = arith.constant 12 : index - %c3 = arith.constant 3 : index - %i_c3 = arith.constant 3 : i32 - %c256 = arith.constant 256 : i32 - %sum_out, %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%sum_iter = %11, %ptr = %2) -> (tensor<4x256xbf16>, tensor<1x256x!tt.ptr>) { - %bptr = tt.broadcast %ptr : tensor<1x256x!tt.ptr> -> tensor<4x256x!tt.ptr> - - %20 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> - %i_i32 = arith.index_cast %i : index to i32 - %21 = arith.muli %c256, %i_i32 : i32 - %22 = tt.splat %21 : i32 -> tensor<4xi32> - %23 = arith.muli %20, %22 : tensor<4xi32> - %24 = tt.expand_dims %23 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - %25 = tt.broadcast %24 : tensor<4x1xi32> -> tensor<4x256xi32> - - // %bptr should have zero stride and %30 should have correct stride - %30 = tt.addptr %bptr, %25 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - %31 = tt.load %30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256x!tt.ptr> - %32 = arith.addf %sum_iter, %31 : tensor<4x256xbf16> - - %40 = tt.splat %c256 : i32 -> tensor<1x256xi32> - %41 = tt.addptr %ptr, %40 : tensor<1x256x!tt.ptr>, tensor<1x256xi32> - - scf.yield %32, %41 : tensor<4x256xbf16>, tensor<1x256x!tt.ptr> - } - - %31 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> - %splat_c256 = tt.splat %c256 : i32 -> tensor<4xi32> - %32 = arith.muli %31, %splat_c256 : tensor<4xi32> - %33 = tt.expand_dims %32 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - %34 = tt.broadcast %33 : tensor<4x1xi32> -> tensor<4x256xi32> - %35 = tt.broadcast %2 : tensor<1x256x!tt.ptr> -> tensor<4x256x!tt.ptr> - %36 = tt.addptr %35, %34 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - tt.store %36, %sum_out : tensor<4x256x!tt.ptr>, tensor<4x256x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-LABEL: func.func @kernel -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: i32, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { -// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index -// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index -// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index -// CHECK-DAG: [[CST_256_1_:%.+]] = arith.constant 256 : i32 -// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : bf16 -// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<4x256xbf16> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : bf16) outs([[VAR_0_]] : tensor<4x256xbf16>) -> tensor<4x256xbf16> -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [1, 256], strides: [256, 1] : memref<*xbf16> to memref<1x256xbf16, strided<[256, 1]>> -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<1x256xbf16> -// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<1x256xbf16, strided<[256, 1]>> to memref<1x256xbf16> -// CHECK-DAG: [[VAR_2_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<1x256xbf16> -// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_1_]] : i32 to index -// CHECK: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_3_]]{{.}}, sizes: [1, 256], strides: [256, 1] : memref<*xbf16> to memref<1x256xbf16, strided<[256, 1], offset: ?>> -// CHECK: bufferization.materialize_in_destination [[VAR_2_]] in writable [[VAR_reinterpret_cast_0_]] -// CHECK-DAG: [[VAR_4_:%.+]]:3 = scf.for [[VAR_arg5_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg6_:%.+]] = [[VAR_1_]], [[VAR_arg7_:%.+]] = [[CST_0_]], [[VAR_arg8_:%.+]] = [[CST_0_]]) -> (tensor<4x256xbf16>, index, index) { -// CHECK-DAG: [[VAR_5_:%.+]] = arith.index_cast [[VAR_arg5_]] : index to i32 -// CHECK: [[VAR_6_:%.+]] = arith.muli [[VAR_5_]], [[CST_256_1_]] : i32 -// CHECK-DAG: [[VAR_7_:%.+]] = arith.index_cast [[VAR_6_]] : i32 to index -// CHECK-DAG: [[VAR_8_:%.+]] = arith.addi [[VAR_arg7_]], [[VAR_arg8_]] : index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_2_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_8_]]{{.}}, sizes: [4, 256], strides: {{.}}[[VAR_7_]], [[CST_1_]]{{.}} : memref<*xbf16> to memref<4x256xbf16, strided<[?, ?], offset: ?>> -// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<4x256xbf16> -// CHECK: memref.copy [[VAR_reinterpret_cast_2_]], [[RES_1_]] : memref<4x256xbf16, strided<[?, ?], offset: ?>> to memref<4x256xbf16> -// CHECK: [[VAR_9_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<4x256xbf16> -// CHECK: [[VAR_10_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_arg6_]], [[VAR_9_]] : tensor<4x256xbf16>, tensor<4x256xbf16>) outs([[VAR_arg6_]] : tensor<4x256xbf16>) { -// CHECK: ^bb0([[in1:%.+]]: bf16, [[in2:%.+]]: bf16, [[out:%.+]]: bf16): -// CHECK: [[VAR_13_:%.+]] = arith.addf [[in1]], [[in2]] : bf16 -// CHECK: linalg.yield [[VAR_13_]] : bf16 -// CHECK: } -> tensor<4x256xbf16> -// CHECK: [[VAR_11_:%.+]] = arith.addi [[VAR_arg7_]], [[CST_256_]] : index -// CHECK: [[VAR_12_:%.+]] = arith.addi [[VAR_11_]], [[VAR_arg8_]] : index -// CHECK: scf.yield [[VAR_10_]], [[VAR_12_]], [[CST_0_]] : tensor<4x256xbf16>, index, index -// CHECK: } -// CHECK: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [4, 256], strides: {{.}}[[CST_256_]], 1] : memref<*xbf16> to memref<4x256xbf16, strided<[?, 1]>> -// CHECK: bufferization.materialize_in_destination [[VAR_4_]]#0 in writable [[VAR_reinterpret_cast_1_]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_for_accumulation.mlir b/test/Conversion/TritonToLinalg/addptr_for_accumulation.mlir deleted file mode 100644 index 89cb4590..00000000 --- a/test/Conversion/TritonToLinalg/addptr_for_accumulation.mlir +++ /dev/null @@ -1,92 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr, - %arg2 : !tt.ptr, - %arg3 : i32, - %arg4 : i32 - ) - { - %0 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> - // offset = 0, size = 4, stride = 1 - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - // offset = [0,0], size = [4,1], stride = [1,0] - %2 = tt.broadcast %1 : tensor<4x1xi32> -> tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [1,0] - %arg3splat = tt.splat %arg3 : i32 -> tensor<4x256xi32> - %offset3 = arith.addi %2, %arg3splat : tensor<4x256xi32> - // offset = [%arg3,0], size = [4,256], stride = [1,0] - %3 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32> - // offset = 0, size = 256, stride = 1 - %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - // offset = [0,0], size = [1,256], stride = [0,1] - %5 = tt.broadcast %4 : tensor<1x256xi32> -> tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [0,1] - %c5 = arith.constant 5 : i32 - %splat6 = tt.splat %c5 : i32 -> tensor<4x256xi32> - // scalar = 5 - %scale5 = arith.muli %5, %splat6 : tensor<4x256xi32> // Why we never called the conversion function for the inputs here? - // offset = [0,0], size = [4,256], stride = [0,5] - %7 = arith.addi %offset3, %scale5: tensor<4x256xi32> // Why we never called the conversion function for the inputs here? - // offset = [%arg3, 0], size = [4, 256], stride = [1, 5] - %8 = tt.splat %arg0 : !tt.ptr -> tensor<4x256x!tt.ptr> // Why is the input unknown - %9 = tt.addptr %8, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source: %arg0, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] - %19 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256x!tt.ptr> // this will be replaced with a memref.copy - %11 = tt.splat %arg1 : !tt.ptr -> tensor<4x256x!tt.ptr> - %12 = tt.addptr %11, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source: %arg1, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] - %c0 = arith.constant 0 : index - %c12 = arith.constant 12 : index - %c3 = arith.constant 3 : index - %i_c3 = arith.constant 3 : i32 - %sum_out, %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%sum_iter = %19, %ptr_iter = %12) -> (tensor<4x256xbf16>, tensor<4x256x!tt.ptr>) { - %20 = tt.load %ptr_iter {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256x!tt.ptr> - %sum = arith.addf %sum_iter, %20 : tensor<4x256xbf16> - // pointer updates - %17 = tt.splat %i_c3 : i32 -> tensor<4x256xi32> - // offset: [3, 0], size = [4, 256], stride [0, 0] - %ptr = tt.addptr %ptr_iter, %17 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source: %arg1, offset = [%arg3+%i, 0], size = [4, 256], stride = [1, 5] - scf.yield %sum, %ptr : tensor<4x256xbf16>, tensor<4x256x!tt.ptr> - } - %15 = tt.splat %arg2 : !tt.ptr -> tensor<4x256x!tt.ptr> - %16 = tt.addptr %15, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source: %arg2, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] - tt.store %16, %sum_out : tensor<4x256x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: memref<*xbf16>, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_1O:.*]]: i32) { -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 5 : index -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 12 : index -// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_14:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_13]]], sizes: [4, 256], strides: [1, %[[VAL_8]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: %[[VAL_15:.*]] = memref.alloc() : memref<4x256xbf16> -// CHECK: memref.copy %[[VAL_14]], %[[VAL_15]] : memref<4x256xbf16, strided<[1, ?], offset: ?>> to memref<4x256xbf16> -// CHECK: %[[VAL_16:.*]] = bufferization.to_tensor %[[VAL_15]] restrict writable : memref<4x256xbf16> -// CHECK: %[[VAL_17:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_18:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_17]]], sizes: [4, 256], strides: {{\[}}%[[VAL_9]], %[[VAL_8]]] : memref<*xbf16> to memref<4x256xbf16, strided<[?, ?], offset: ?>> -// CHECK: %[[VAL_19:.*]]:3 = scf.for %[[VAL_20:.*]] = %[[VAL_12]] to %[[VAL_11]] step %[[VAL_10]] iter_args(%[[VAL_21:.*]] = %[[VAL_16]], %[[VAL_22:.*]] = %[[VAL_18]], %[[VAL_23:.*]] = %[[VAL_17]]) -> (tensor<4x256xbf16>, memref<4x256xbf16, strided<[?, ?], offset: ?>>, index) { -// CHECK: %[[VAL_25:.*]] = memref.alloc() : memref<4x256xbf16> -// CHECK: memref.copy %[[VAL_22]], %[[VAL_25]] : memref<4x256xbf16, strided<[?, ?], offset: ?>> to memref<4x256xbf16> -// CHECK: %[[VAL_26:.*]] = bufferization.to_tensor %[[VAL_25]] restrict writable : memref<4x256xbf16> -// CHECK: %[[VAL_27:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_21]], %[[VAL_26]] : tensor<4x256xbf16>, tensor<4x256xbf16>) outs(%[[VAL_21]] : tensor<4x256xbf16>) { -// CHECK: ^bb0(%[[VAL_28:.*]]: bf16, %[[VAL_29:.*]]: bf16, %[[VAL_30:.*]]: bf16): -// CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_28]], %[[VAL_29]] : bf16 -// CHECK: linalg.yield %[[VAL_31]] : bf16 -// CHECK: } -> tensor<4x256xbf16> -// CHECK: %[[VAL_33:.*]] = arith.addi %[[VAL_23]], %[[VAL_10]] : index -// CHECK: %[[VAL_34:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_33]]], sizes: [4, 256], strides: {{\[}}%[[VAL_9]], %[[VAL_8]]] : memref<*xbf16> to memref<4x256xbf16, strided<[?, ?], offset: ?>> -// CHECK: scf.yield %[[VAL_35:.*]], %[[VAL_34]], %[[VAL_33]] : tensor<4x256xbf16>, memref<4x256xbf16, strided<[?, ?], offset: ?>>, index -// CHECK: } -// CHECK: %[[VAL_36:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_37:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: {{\[}}%[[VAL_36]]], sizes: [4, 256], strides: [1, %[[VAL_8]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: bufferization.materialize_in_destination %[[VAL_38:.*]]#0 in writable %[[VAL_37]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_for_expand_ptr.mlir b/test/Conversion/TritonToLinalg/addptr_for_expand_ptr.mlir deleted file mode 100644 index 67d82948..00000000 --- a/test/Conversion/TritonToLinalg/addptr_for_expand_ptr.mlir +++ /dev/null @@ -1,73 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr - ) - { - %c0 = arith.constant 0 : index - %c12 = arith.constant 12 : index - %c3 = arith.constant 3 : index - %i_c3 = arith.constant 3 : i32 - %0 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr> - %1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32> - // source: null, sizes: 256, offsets: 1024, strides: 1 - - %2 = tt.addptr %0, %1 : tensor<256x!tt.ptr>, tensor<256xi32> - // source: arg0, sizes: 256, offsets: 1024, strides: 1 - - // gep operand is another gep' output, which is passed into the loop as varible, used after update - %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %2) -> (tensor<256x!tt.ptr>) { - %6 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> - %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<256xi32> -> tensor<256x1xi32> - - %8 = tt.broadcast %7 : tensor<256x1xi32> -> tensor<256x256xi32> - // sizes: [256, 256], offsets: [0, 0], strides: [1, 0] - - %9 = tt.make_range {end = 512 : i32, start = 256 : i32} : tensor<256xi32> - %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - - %11 = tt.broadcast %10 : tensor<1x256xi32> -> tensor<256x256xi32> - // sizes: [256, 256], offsets: [0, 256], strides: [0, 1] - - %12 = arith.addi %8, %11 : tensor<256x256xi32> - // sizes: [256, 256], offsets: [0, 256], strides: [1, 1] - - %13 = tt.expand_dims %ptr {axis = 1 : i32} : tensor<256x!tt.ptr> -> tensor<256x1x!tt.ptr> - %14 = tt.broadcast %13 : tensor<256x1x!tt.ptr> -> tensor<256x256x!tt.ptr> - - %15 = tt.addptr %14, %12 : tensor<256x256x!tt.ptr>, tensor<256x256xi32> - // source: arg0, sizes: [256, 256], offsets: [1024 + i, 256], strides: [2, 1] - - // perform load - %16 = tt.load %15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x256x!tt.ptr> - tt.store %15, %16 : tensor<256x256x!tt.ptr> - // pointer updates - %17 = tt.splat %i_c3 : i32 -> tensor<256xi32> - // sizes: 256, offsets: 3, strides: 0 - %ptr_iter = tt.addptr %ptr, %17 : tensor<256x!tt.ptr>, tensor<256xi32> - // source: arg0, sizes: 256, offsets: 1024 + i, strides: 4 - scf.yield %ptr_iter : tensor<256x!tt.ptr> - } - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) { -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 256 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1024 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 12 : index -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 3 : index -// CHECK: %[[VAL_10:.*]] = scf.for %[[VAL_11:.*]] = %[[VAL_7]] to %[[VAL_8]] step %[[VAL_9]] iter_args(%[[VAL_12:.*]] = %[[VAL_6]]) -> (index) { -// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_12]], %[[VAL_5]] : index -// CHECK: %[[VAL_14:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_13]]], sizes: [256, 256], strides: {{\[}}%[[VAL_4]], 1] : memref<*xbf16> to memref<256x256xbf16, strided<[?, 1], offset: ?>> -// CHECK: %[[VAL_15:.*]] = memref.alloc() : memref<256x256xbf16> -// CHECK: memref.copy %[[VAL_14]], %[[VAL_15]] : memref<256x256xbf16, strided<[?, 1], offset: ?>> to memref<256x256xbf16> -// CHECK: %[[VAL_16:.*]] = bufferization.to_tensor %[[VAL_15]] restrict writable : memref<256x256xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_16]] in writable %[[VAL_14]] -// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_12]], %[[VAL_9]] : index -// CHECK: scf.yield %[[VAL_17]] : index -// CHECK: } -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_for_more_init_args.mlir b/test/Conversion/TritonToLinalg/addptr_for_more_init_args.mlir deleted file mode 100644 index 4d77760e..00000000 --- a/test/Conversion/TritonToLinalg/addptr_for_more_init_args.mlir +++ /dev/null @@ -1,71 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr - ) - { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %c12 = arith.constant 12 : index - %0 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr> - %1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32> - // source: null, sizes: 256, offsets: 1024, strides: 1 - %2 = tt.addptr %0, %1 : tensor<256x!tt.ptr>, tensor<256xi32> - // source: arg0, sizes: 256, offsets: 1024, strides: 1 - %3 = tt.splat %arg1 : !tt.ptr -> tensor<256x!tt.ptr> - %4 = tt.addptr %3, %1 : tensor<256x!tt.ptr>, tensor<256xi32> - // source: arg1, sizes: 256, offsets: 1024, strides: 1 - %_arg2, %_ptr_ld, %_arg3, %_ptr_st, %_arg4 = scf.for %i = %c0 to %c12 step %c3 iter_args(%arg2 = %c1, %ptr_ld = %2, %arg3 = %c2, %ptr_st = %4, %arg4 = %c3) -> (index, tensor<256x!tt.ptr>, index, tensor<256x!tt.ptr>, index) { - // perform load - %5 = tt.load %ptr_ld {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x!tt.ptr> - tt.store %ptr_st, %5 : tensor<256x!tt.ptr> - // pointer updates - %cast3 = arith.index_cast %c3 : index to i32 - %6 = tt.splat %cast3 : i32 -> tensor<256xi32> - %ptr_ld_iter = tt.addptr %ptr_ld, %6 : tensor<256x!tt.ptr>, tensor<256xi32> - // source: arg0, sizes: 256, offsets: 1024 + i*3, strides: 1 - %arg2_iter = arith.addi %arg2, %c3 : index - %arg3_iter = arith.addi %arg3, %c3 : index - %arg4_iter = arith.addi %arg4, %c3 : index - %7 = arith.addi %arg2_iter, %arg3_iter : index - %8 = arith.addi %7, %arg4_iter : index - %cast8 = arith.index_cast %8 : index to i32 - %9 = tt.splat %cast8 : i32 -> tensor<256xi32> - %ptr_st_iter = tt.addptr %ptr_st, %9 : tensor<256x!tt.ptr>, tensor<256xi32> - // source: arg1, sizes: 256, offsets: 1024 + loop-carry variable*i, strides: 1 - scf.yield %arg2_iter, %ptr_ld_iter, %arg3_iter, %ptr_st_iter, %arg4_iter : index, tensor<256x!tt.ptr>, index, tensor<256x!tt.ptr>, index - } - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1024 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 12 : index -// CHECK: %[[VAL_12:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_6]]], sizes: [256], strides: {{\[}}%[[VAL_8]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> -// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_6]]], sizes: [256], strides: {{\[}}%[[VAL_8]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> -// CHECK: %[[VAL_14:.*]]:7 = scf.for %[[VAL_15:.*]] = %[[VAL_7]] to %[[VAL_11]] step %[[VAL_10]] iter_args(%[[VAL_16:.*]] = %[[VAL_8]], %[[VAL_17:.*]] = %[[VAL_12]], %[[VAL_18:.*]] = %[[VAL_9]], %[[VAL_19:.*]] = %[[VAL_13]], %[[VAL_20:.*]] = %[[VAL_10]], %[[VAL_21:.*]] = %[[VAL_6]], %[[VAL_22:.*]] = %[[VAL_6]]) -> (index, memref<256xbf16, strided<[?], offset: ?>>, index, memref<256xbf16, strided<[?], offset: ?>>, index, index, index) { -// CHECK: %[[VAL_23:.*]] = memref.alloc() : memref<256xbf16> -// CHECK: memref.copy %[[VAL_17]], %[[VAL_23]] : memref<256xbf16, strided<[?], offset: ?>> to memref<256xbf16> -// CHECK: %[[VAL_24:.*]] = bufferization.to_tensor %[[VAL_23]] restrict writable : memref<256xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_24]] in writable %[[VAL_19]] -// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_21]], %[[VAL_10]] : index -// CHECK: %[[VAL_26:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_25]]], sizes: [256], strides: {{\[}}%[[VAL_8]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> -// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_16]], %[[VAL_10]] : index -// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_18]], %[[VAL_10]] : index -// CHECK: %[[VAL_29:.*]] = arith.addi %[[VAL_20]], %[[VAL_10]] : index -// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_27]], %[[VAL_28]] : index -// CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_30]], %[[VAL_29]] : index -// CHECK: %[[VAL_32:.*]] = arith.addi %[[VAL_22]], %[[VAL_31]] : index -// CHECK: %[[VAL_33:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_32]]], sizes: [256], strides: {{\[}}%[[VAL_8]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> -// CHECK: scf.yield %[[VAL_27]], %[[VAL_26]], %[[VAL_28]], %[[VAL_33]], %[[VAL_29]], %[[VAL_25]], %[[VAL_32]] : index, memref<256xbf16, strided<[?], offset: ?>>, index, memref<256xbf16, strided<[?], offset: ?>>, index, index, index -// CHECK: } -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_for_used_after_update.mlir b/test/Conversion/TritonToLinalg/addptr_for_used_after_update.mlir deleted file mode 100644 index 60b0b7fc..00000000 --- a/test/Conversion/TritonToLinalg/addptr_for_used_after_update.mlir +++ /dev/null @@ -1,98 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr - ) - { - %c0 = arith.constant 0 : index - %c12 = arith.constant 12 : index - %c3 = arith.constant 3 : index - %i_c3 = arith.constant 3 : i32 - %0 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr> - %1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32> - // source: null, sizes: 256, offsets: 1024, strides: 1 - %2 = tt.addptr %0, %1 : tensor<256x!tt.ptr>, tensor<256xi32> - // source: arg0, sizes: 256, offsets: 1024, strides: 1 - // gep operand is another gep' output, which is passed into the loop as varible, used after update - %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %2) -> (tensor<256x!tt.ptr>) { - // pointer updates - %4 = tt.splat %i_c3 : i32 -> tensor<256xi32> - // sizes: 256, offsets: 3, strides: 0 - %ptr_iter = tt.addptr %ptr, %4 : tensor<256x!tt.ptr>, tensor<256xi32> - // source: arg0, sizes: 256, offsets: 1024 + i, strides: 1 - // perform load - %3 = tt.load %ptr_iter {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x!tt.ptr> - tt.store %ptr_iter, %3 : tensor<256x!tt.ptr> - scf.yield %ptr_iter : tensor<256x!tt.ptr> - } - // Expected output - // %offset_dim0 = arith.constant 1024 <- insert instructions to initialize init arg(new) - // for iter_args (%offset_dim0_iter = %offset_dim0) { <- replace varibles passed in as init arg (new) - // %4 = %offset_dim0_iter + %c3 <- replace gep of splat with add (already done) - // %subview = memref.subview %arg0, [%4][256][4] : memref<> -> memref<> <- generate subview on getelementptr (already done) - // ... - // scf.yield %4 <- replace yielding an gep output with the corresponding dim variable (new) - // } - // TODO: examples below are not supported since scf.for does not support returning a tensor type - // Example 3, gep operand is a vector of i32, which is passed into the loop as variable, pointer updated using step, used after update - //%_ptr3 = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %1) -> (tensor<256xi32>) { - // // offset update - // %3 = tt.splat %c3 : i32 -> tensor<256xi32> - // %ptr_iter = arith.addi %3, %ptr : tensor<256xi32> - // // generate pointer - // %gep_ptr = tt.addptr %0, %ptr_iter : tensor<256x!tt.ptr> - // // perform load - // %4 = tt.load %gep_ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x!tt.ptr> - // tt.store %gep_ptr, %4 : tensor<256x!tt.ptr> - // scf.yield %ptr_iter : tensor<256xi32> - //} - // Expected output - // %offset_dim0 = arith.constant 1024 <- insert instructions to initialize init arg(new) - // for iter_args (%offset_dim0_iter = %offset_dim0) { <- replace varibles passed in as init arg (new) - // %4 = %offset_dim0_iter + %c3 <- replace gep of splat with add (already done) - // %subview = memref.subview %arg0, [%offset_dim0_iter][256][4] : memref<> -> memref<> <- generate subview on load (new) - // ... - // scf.yield %4 <- replace yielding an gep output with the corresponding dim variable (new) - // } - //// Example 4, gep operand is a vector of i32, which is passed into the loop as variable, pointer updated using step, used before update - //%_ptr4 = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %1) -> (tensor<256xi32>) { - // // generate pointer - // %gep_ptr = tt.addptr %0, %ptr : tensor<256x!tt.ptr> - // - // // perform load - // %4 = tt.load %gep_ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x!tt.ptr> - // tt.store %gep_ptr, %4 : tensor<256x!tt.ptr> - // // offset update - // %3 = tt.splat %c3 : i32 -> tensor<256xi32> - // %ptr_iter = arith.addi %3, %ptr : tensor<256xi32> - // scf.yield %ptr_iter : tensor<256xi32> - //} - // Expected output - // %offset_dim0 = arith.constant 1024 <- insert instructions to initialize init arg(new) - // for iter_args (%offset_dim0_iter = %offset_dim0) { <- replace varibles passed in as init arg (new) - // %subview = memref.subview %arg0, [%offset_dim0_iter][256][4] : memref<> -> memref<> <- generate subview on load (new) - // ... - // %4 = %offset_dim0_iter + %c3 <- replace gep of splat with add (already done) - // scf.yield %4 <- replace yielding an gep output with the corresponding dim variable (new) - // } - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) { -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1024 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 12 : index -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 3 : index -// CHECK: %[[VAL_9:.*]] = scf.for %[[VAL_10:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[VAL_11:.*]] = %[[VAL_5]]) -> (index) { -// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_8]] : index -// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_12]]], sizes: [256], strides: {{\[}}%[[VAL_4]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> -// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<256xbf16> -// CHECK: memref.copy %[[VAL_13]], %[[VAL_14]] : memref<256xbf16, strided<[?], offset: ?>> to memref<256xbf16> -// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_14]] restrict writable : memref<256xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_15]] in writable %[[VAL_13]] -// CHECK: scf.yield %[[VAL_12]] : index -// CHECK: } -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_for_used_before_update.mlir b/test/Conversion/TritonToLinalg/addptr_for_used_before_update.mlir deleted file mode 100644 index 7855730a..00000000 --- a/test/Conversion/TritonToLinalg/addptr_for_used_before_update.mlir +++ /dev/null @@ -1,55 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr - ) - { - %c0 = arith.constant 0 : index - %c12 = arith.constant 12 : index - %c3 = arith.constant 3 : index - %i_c3 = arith.constant 3 : i32 - %0 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr> - %1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32> - // source: null, sizes: 256, offsets: 1024, strides: 1 - %2 = tt.addptr %0, %1 : tensor<256x!tt.ptr>, tensor<256xi32> - // source: arg0, sizes: 256, offsets: 1024, strides: 1 - // Example 2, gep operand is another gep's output, which is passed into the loop as varible, used before update - %_ptr2 = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %2) -> (tensor<256x!tt.ptr>) { - // perform load - %3 = tt.load %ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x!tt.ptr> - tt.store %ptr, %3 : tensor<256x!tt.ptr> - // pointer updates - %4 = tt.splat %i_c3 : i32 -> tensor<256xi32> - %ptr_iter = tt.addptr %ptr, %4 : tensor<256x!tt.ptr>, tensor<256xi32> - scf.yield %ptr_iter : tensor<256x!tt.ptr> - } - // Expected output - // %offset_dim0 = arith.constant 1024 <- insert instructions to initialize init arg(new) - // for iter_args (%offset_dim0_iter = %offset_dim0) { <- replace varibles passed in as init arg (new) - // %subview = memref.subview %arg0, [%offset_dim0_iter][256][4] : memref<> -> memref<> <- generate subview on load (new) - // ... - // %4 = %offset_dim0_iter + %c3 <- replace gep of splat with add (already done) - // scf.yield %4 <- replace yielding an gep output with the corresponding dim variable (new) - // } - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) { -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1024 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 12 : index -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 3 : index -// CHECK: %[[VAL_9:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_5]]], sizes: [256], strides: {{\[}}%[[VAL_4]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> -// CHECK: %[[VAL_10:.*]]:2 = scf.for %[[VAL_11:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[VAL_12:.*]] = %[[VAL_9]], %[[VAL_13:.*]] = %[[VAL_5]]) -> (memref<256xbf16, strided<[?], offset: ?>>, index) { -// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<256xbf16> -// CHECK: memref.copy %[[VAL_12]], %[[VAL_14]] : memref<256xbf16, strided<[?], offset: ?>> to memref<256xbf16> -// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_14]] restrict writable : memref<256xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_15]] in writable %[[VAL_12]] -// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_13]], %[[VAL_8]] : index -// CHECK: %[[VAL_17:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_16]]], sizes: [256], strides: {{\[}}%[[VAL_4]]] : memref<*xbf16> to memref<256xbf16, strided<[?], offset: ?>> -// CHECK: scf.yield %[[VAL_17]], %[[VAL_16]] : memref<256xbf16, strided<[?], offset: ?>>, index -// CHECK: } -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_loopback.mlir b/test/Conversion/TritonToLinalg/addptr_loopback.mlir deleted file mode 100644 index ee5cb2cc..00000000 --- a/test/Conversion/TritonToLinalg/addptr_loopback.mlir +++ /dev/null @@ -1,53 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr, - %arg2 : i32 - ) - { - %0 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> - // offset = 0, size = 4, stride = 1 - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - // offset = [0,0], size = [4,1], stride = [1,0] - %2 = tt.broadcast %1 : tensor<4x1xi32> -> tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [1,0] - %arg2splat = tt.splat %arg2 : i32 -> tensor<4x256xi32> - %offset2 = arith.addi %2, %arg2splat : tensor<4x256xi32> - // offset = [%arg2,0], size = [4,256], stride = [1,0] - %3 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32> - // offset = 0, size = 256, stride = 1 - %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - // offset = [0,0], size = [1,256], stride = [0,1] - %5 = tt.broadcast %4 : tensor<1x256xi32> -> tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [0,1] - %c6 = arith.constant 6 : i32 - %splat6 = tt.splat %c6 : i32 -> tensor<4x256xi32> - %scale5 = arith.muli %5, %splat6 : tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [0,6] - %7 = arith.addi %offset2, %scale5: tensor<4x256xi32> - // offset = [%arg2, 0], size = [4, 256], stride = [1, 6] - %8 = tt.splat %arg0 : !tt.ptr -> tensor<4x256x!tt.ptr> - %9 = tt.addptr %8, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source: arg0, offset = [%arg2, 0], size = [4, 256], stride = [1, 6] - %10 = tt.splat %arg1 : !tt.ptr -> tensor<4x256x!tt.ptr> - %11 = tt.addptr %10, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source: arg1, offset = [%arg2, 0], size = [4, 256], stride = [1, 6] - %12 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256x!tt.ptr> - tt.store %11, %12 : tensor<4x256x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[ARG_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32) { -// CHECK: %[[VAL_6:.*]] = arith.constant 6 : index -// CHECK: %[[VAL_7:.*]] = arith.index_cast %[[VAL_2]] : i32 to index -// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_7]]], sizes: [4, 256], strides: [1, %[[VAL_6]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_2]] : i32 to index -// CHECK: %[[VAL_10:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_9]]], sizes: [4, 256], strides: [1, %[[VAL_6]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: %[[VAL_11:.*]] = memref.alloc() : memref<4x256xbf16> -// CHECK: memref.copy %[[VAL_8]], %[[VAL_11]] : memref<4x256xbf16, strided<[1, ?], offset: ?>> to memref<4x256xbf16> -// CHECK: %[[VAL_12:.*]] = bufferization.to_tensor %[[VAL_11]] restrict writable : memref<4x256xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_12]] in writable %[[VAL_10]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_mul_const_const.mlir b/test/Conversion/TritonToLinalg/addptr_mul_const_const.mlir deleted file mode 100644 index 61ddea4f..00000000 --- a/test/Conversion/TritonToLinalg/addptr_mul_const_const.mlir +++ /dev/null @@ -1,49 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr, - %arg2 : i32 - ) - { - %0 = tt.get_program_id x : i32 - %1 = tt.make_range {end = 1024 : i32, start = 0 : i32}:tensor<1024xi32> - %2 = tt.splat %0 : i32 -> tensor<1024xi32> - %3 = arith.addi %2, %1 : tensor<1024xi32> - //%3: splat(%0) + range(0, 1024) - //%3: offset = %0, size = 1024, stride = 1 - // vector and scalar are both constant - %4 = tt.make_range {end = 3072 : i32, start = 2048 : i32}:tensor<1024xi32> - %c10 = arith.constant 10 : i32 - %5 = tt.splat %c10 : i32 -> tensor<1024xi32> - %6 = arith.muli %5, %4 : tensor<1024xi32> - //%6: splat(%c10)*range(2048, 4096); - //%6: offset = %c10*2048, size = 1024, stride = %c10*1 - %7 = arith.addi %3, %6 : tensor<1024xi32> - //%7: offset = %c10*2048 + %0, size = 1024, stride = %c10*1+1 - %8 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> - %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr>, tensor<1024xi32> - //source=%arg0 offset = %c10*2048 + pid0, size = 1024, stride = %c10*1+1 - %10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> - %11 = tt.addptr %10, %3 : tensor<1024x!tt.ptr>, tensor<1024xi32> - //source=%arg1, offset = pid0, size = 1024, stride = 1 - %16 = tt.load %9 : tensor<1024x!tt.ptr> - tt.store %11, %16 : tensor<1024x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32) { -// CHECK: %[[VAL_6:.*]] = arith.constant 11 : index -// CHECK: %[[VAL_7:.*]] = arith.constant 20480 : index -// CHECK: %[[VAL_8:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_7]] : index -// CHECK: %[[VAL_10:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_9]]], sizes: [1024], strides: {{\[}}%[[VAL_6]]] : memref<*xbf16> to memref<1024xbf16, strided<[?], offset: ?>> -// CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_3]] : i32 to index -// CHECK: %[[VAL_12:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_11]]], sizes: [1024], strides: [1] : memref<*xbf16> to memref<1024xbf16, strided<[1], offset: ?>> -// CHECK: %[[VAL_13:.*]] = memref.alloc() : memref<1024xbf16> -// CHECK: memref.copy %[[VAL_10]], %[[VAL_13]] : memref<1024xbf16, strided<[?], offset: ?>> to memref<1024xbf16> -// CHECK: %[[VAL_14:.*]] = bufferization.to_tensor %[[VAL_13]] restrict writable : memref<1024xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_14]] in writable %[[VAL_12]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_mul_value_const.mlir b/test/Conversion/TritonToLinalg/addptr_mul_value_const.mlir deleted file mode 100644 index 77907e06..00000000 --- a/test/Conversion/TritonToLinalg/addptr_mul_value_const.mlir +++ /dev/null @@ -1,51 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr, - %arg2 : i32 - ) - { - %0 = tt.get_program_id x : i32 - %1 = tt.make_range {end = 1024 : i32, start = 0 : i32}:tensor<1024xi32> - %2 = tt.splat %0 : i32 -> tensor<1024xi32> - %3 = arith.addi %2, %1 : tensor<1024xi32> - //%3: splat(%0) + range(0, 1024) - //%3: offset = %0, size = 1024, stride = 1 - // vector is constant, scalar is value - %4 = tt.make_range {end = 3072 : i32, start = 2048 : i32}:tensor<1024xi32> - %5 = tt.splat %arg2 : i32 -> tensor<1024xi32> - %6 = arith.muli %5, %4 : tensor<1024xi32> - //%6: splat(%arg2)*range(2048, 3072); - //%6: offset = %arg2*2048, size = 1024, stride = %arg2*1 - %7 = arith.addi %3, %6 : tensor<1024xi32> - //%7: offset = %arg2*2048 + %0, size = 1024, stride = %arg2*1+1 - %8 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> - %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr>, tensor<1024xi32> - //source=%arg0: offset = %arg2*2048 + pid0, size = 1024, stride = %arg2*1+1 - %10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> - %11 = tt.addptr %10, %3 : tensor<1024x!tt.ptr>, tensor<1024xi32> - //source=%arg1: offset = pid0, size = 1024, stride = 1 - %16 = tt.load %9 : tensor<1024x!tt.ptr> - tt.store %11, %16 : tensor<1024x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[ARG_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32) { -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 2048 : index -// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[ARG_6]] : i32 to index -// CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_2]] : i32 to index -// CHECK: %[[VAL_11:.*]] = arith.muli %[[VAL_10]], %[[VAL_8]] : index -// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_9]], %[[VAL_11]] : index -// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_10]], %[[VAL_6]] : index -// CHECK: %[[VAL_15:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_13]]], sizes: [1024], strides: {{\[}}%[[VAL_14]]] : memref<*xbf16> to memref<1024xbf16, strided<[?], offset: ?>> -// CHECK: %[[VAL_16:.*]] = arith.index_cast %[[ARG_6]] : i32 to index -// CHECK: %[[VAL_17:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_16]]], sizes: [1024], strides: [1] : memref<*xbf16> to memref<1024xbf16, strided<[1], offset: ?>> -// CHECK: %[[VAL_18:.*]] = memref.alloc() : memref<1024xbf16> -// CHECK: memref.copy %[[VAL_15]], %[[VAL_18]] : memref<1024xbf16, strided<[?], offset: ?>> to memref<1024xbf16> -// CHECK: %[[VAL_19:.*]] = bufferization.to_tensor %[[VAL_18]] restrict writable : memref<1024xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_19]] in writable %[[VAL_17]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_nested.mlir b/test/Conversion/TritonToLinalg/addptr_nested.mlir deleted file mode 100644 index bbbc0b22..00000000 --- a/test/Conversion/TritonToLinalg/addptr_nested.mlir +++ /dev/null @@ -1,73 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : i32 - ) - { - %0 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> - // offset = 0, size = 4, stride = 1 - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - // offset = [0,0], size = [4,1], stride = [1,0] - %2 = tt.broadcast %1 : tensor<4x1xi32> -> tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [1,0] - %arg1splat = tt.splat %arg1 : i32 -> tensor<4x256xi32> - %offset3 = arith.addi %2, %arg1splat : tensor<4x256xi32> - // offset = [%arg1,0], size = [4,256], stride = [1,0] - %3 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32> - // offset = 0, size = 256, stride = 1 - %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - // offset = [0,0], size = [1,256], stride = [0,1] - %5 = tt.broadcast %4 : tensor<1x256xi32> -> tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [0,1] - %6 = arith.constant 5 : i32 - %splat6 = tt.splat %6 : i32 -> tensor<4x256xi32> - %scale5 = arith.muli %5, %splat6 : tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [0,5] - %7 = arith.addi %offset3, %scale5: tensor<4x256xi32> - // offset = [%arg1, 0], size = [4, 256], stride = [1, 5] - %8 = tt.splat %arg0 : !tt.ptr -> tensor<4x256x!tt.ptr> - %9 = tt.addptr %8, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source: %arg0, offset = [%arg1, 0], size = [4, 256], stride = [1, 5] - %10 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256x!tt.ptr> - %12 = tt.addptr %9, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source: %arg0, offset = [%arg1+%arg1, 0], size = [4, 256], stride = [2, 10] - %13 = tt.load %12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256x!tt.ptr> - %14 = arith.addf %10, %13 : tensor<4x256xbf16> - %16 = tt.addptr %12, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source: %arg0, offset = [%arg1+%arg1+%arg1, 0], size = [4, 256], stride = [3, 15] - tt.store %16, %14 : tensor<4x256x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: i32, %[[ARG_2:.*]]: i32, %[[ARG_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32) { -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 15 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 5 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 10 : index -// CHECK: %[[VAL_8:.*]] = arith.index_cast %[[VAL_1]] : i32 to index -// CHECK: %[[VAL_9:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_8]]], sizes: [4, 256], strides: [1, %[[VAL_6]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>> -// CHECK: %[[VAL_10:.*]] = memref.alloc() : memref<4x256xbf16> -// CHECK: memref.copy %[[VAL_9]], %[[VAL_10]] : memref<4x256xbf16, strided<[1, ?], offset: ?>> to memref<4x256xbf16> -// CHECK: %[[VAL_11:.*]] = bufferization.to_tensor %[[VAL_10]] restrict writable : memref<4x256xbf16> -// CHECK: %[[VAL_12:.*]] = arith.index_cast %[[VAL_1]] : i32 to index -// CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_1]] : i32 to index -// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : index -// CHECK: %[[VAL_15:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_14]]], sizes: [4, 256], strides: [2, %[[VAL_7]]] : memref<*xbf16> to memref<4x256xbf16, strided<[2, ?], offset: ?>> -// CHECK: %[[VAL_16:.*]] = memref.alloc() : memref<4x256xbf16> -// CHECK: memref.copy %[[VAL_15]], %[[VAL_16]] : memref<4x256xbf16, strided<[2, ?], offset: ?>> to memref<4x256xbf16> -// CHECK: %[[VAL_17:.*]] = bufferization.to_tensor %[[VAL_16]] restrict writable : memref<4x256xbf16> -// CHECK: %[[VAL_18:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_11]], %[[VAL_17]] : tensor<4x256xbf16>, tensor<4x256xbf16>) outs(%[[VAL_11]] : tensor<4x256xbf16>) { -// CHECK: ^bb0(%[[VAL_19:.*]]: bf16, %[[VAL_20:.*]]: bf16, %[[VAL_21:.*]]: bf16): -// CHECK: %[[VAL_22:.*]] = arith.addf %[[VAL_19]], %[[VAL_20]] : bf16 -// CHECK: linalg.yield %[[VAL_22]] : bf16 -// CHECK: } -> tensor<4x256xbf16> -// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_1]] : i32 to index -// CHECK: %[[VAL_24:.*]] = arith.index_cast %[[VAL_1]] : i32 to index -// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_23]], %[[VAL_24]] : index -// CHECK: %[[VAL_26:.*]] = arith.index_cast %[[VAL_1]] : i32 to index -// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_25]], %[[VAL_26]] : index -// CHECK: %[[VAL_28:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_27]]], sizes: [4, 256], strides: [3, %[[VAL_5]]] : memref<*xbf16> to memref<4x256xbf16, strided<[3, ?], offset: ?>> -// CHECK: bufferization.materialize_in_destination %[[VAL_29:.*]] in writable %[[VAL_28]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_reshape_broadcast.mlir b/test/Conversion/TritonToLinalg/addptr_reshape_broadcast.mlir deleted file mode 100644 index 2f508262..00000000 --- a/test/Conversion/TritonToLinalg/addptr_reshape_broadcast.mlir +++ /dev/null @@ -1,43 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -// TODO: expand this example to 3D -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr - ) - { - %0 = tt.make_range {end = 768 : i32, start = 512 : i32}:tensor<256xi32> - // offset = [512] size = 256, stride = 1 - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<256xi32> -> tensor<256x1xi32> - // offset = [512,0], size = [256,1], stride = [1,0] - %2 = tt.broadcast %1 : tensor<256x1xi32> -> tensor<256x128xi32> - // offset = [512,0], size = [256,128], stride = [1,0] - %5 = tt.make_range {end = 1152 : i32, start = 1024 : i32}:tensor<128xi32> - // offset = 1024, size = 128, stride = 1 - %6 = tt.expand_dims %5 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> - // offset = [0,1024], size = [1,128], stride = [0,1] - %7 = tt.broadcast %6 : tensor<1x128xi32> -> tensor<256x128xi32> - // offset = [0,1024], size = [256,128], stride = [0,1] - %c6 = arith.constant 6 : i32 - %splat6 = tt.splat %c6 : i32 -> tensor<256x128xi32> - %scale7 = arith.muli %7, %splat6 : tensor<256x128xi32> - // offset = [0,6144], size = [256,128], stride = [0,6] - %14 = arith.addi %2, %scale7 : tensor<256x128xi32> - // offset = [512,6144], size = [256,128], stride = [1,6] - %17 = tt.splat %arg1 : !tt.ptr -> tensor<256x128x!tt.ptr> - %18 = tt.addptr %17, %14 : tensor<256x128x!tt.ptr>, tensor<256x128xi32> - %19 = tt.load %18 : tensor<256x128x!tt.ptr> - tt.store %18, %19 : tensor<256x128x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK: %[[VAL_6:.*]] = arith.constant 6 : index -// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}6656], sizes: [256, 128], strides: {{\[}}1, %[[VAL_6]]] : memref<*xbf16> to memref<256x128xbf16, strided<[1, ?], offset: 6656>> -// CHECK: %[[VAL_8:.*]] = memref.alloc() : memref<256x128xbf16> -// CHECK: memref.copy %[[VAL_7]], %[[VAL_8]] : memref<256x128xbf16, strided<[1, ?], offset: 6656>> to memref<256x128xbf16> -// CHECK: %[[VAL_9:.*]] = bufferization.to_tensor %[[VAL_8]] restrict writable : memref<256x128xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_9]] in writable %[[VAL_7]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_scalar_broadcast.mlir b/test/Conversion/TritonToLinalg/addptr_scalar_broadcast.mlir deleted file mode 100644 index 2af087ce..00000000 --- a/test/Conversion/TritonToLinalg/addptr_scalar_broadcast.mlir +++ /dev/null @@ -1,65 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel (%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32) { - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %arg2 : i32 - %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 - // source = arg1, offset = %1, size = 1, strides = 0 - %3 = tt.splat %2 : !tt.ptr -> tensor<1024x!tt.ptr> - // source = arg1, offset = %1, size = 1024, strides = 0 - %4 = tt.expand_dims %3 {axis = 1 : i32} : tensor<1024x!tt.ptr> -> tensor<1024x1x!tt.ptr> - // source = arg1, offset = [%1, 0], size = [1024, 1], strides = [0, 0] - %5 = tt.broadcast %4 : tensor<1024x1x!tt.ptr> -> tensor<1024x1024x!tt.ptr> - // source = arg1, offset = [%1, 0], size = [1024, 1024], strides = [0, 0] - %6 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - // offset = 0, size = 1024, strides = 1 - %7 = tt.expand_dims %6 {axis = 0 : i32} : tensor<1024xi32> -> tensor<1x1024xi32> - // offset = [0, 0], size = [1, 1024], strides = [0, 1] - %8 = tt.broadcast %7 : tensor<1x1024xi32> -> tensor<1024x1024xi32> - // offset = [0, 0], size = [1024, 1024], strides = [0, 1] - %9 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - // offset = 0, size = 1024, strides = 1 - %10 = tt.expand_dims %9 {axis = 1 : i32} : tensor<1024xi32> -> tensor<1024x1xi32> - // offset = [0, 0], size = [1024, 1], strides = [1, 0] - %11 = tt.broadcast %10 : tensor<1024x1xi32> -> tensor<1024x1024xi32> - // offset = [0, 0], size = [1024, 1024], strides = [1, 0] - %12 = arith.addi %8, %11 : tensor<1024x1024xi32> - // offset = [0, 0], size = [1024, 1024], strides = [1, 1] - %13 = tt.addptr %5, %12 : tensor<1024x1024x!tt.ptr>, tensor<1024x1024xi32> - // source = arg1, offset = [pid * %arg2, 0], size = [1024, 1024], strides = [1, 1] - %14 = tt.load %13 : tensor<1024x1024x!tt.ptr> - %17 = math.exp %14 : tensor<1024x1024xf32> - %18 = arith.muli %0, %arg3 : i32 - %19 = tt.addptr %arg0, %18 : !tt.ptr, i32 - // source = arg0, offset = pid+arg3, size = 1, strides = 0 - %20 = tt.splat %19 : !tt.ptr -> tensor<1024x!tt.ptr> - // source = arg0, offset = pid+arg3, size = 1024, strides = 0 - %21 = tt.expand_dims %20 {axis = 1 : i32} : tensor<1024x!tt.ptr> -> tensor<1024x1x!tt.ptr> - // source = arg0, offset = [pid+arg3, 0], size = [1024, 1], strides = [0, 0] - %22 = tt.broadcast %21 : tensor<1024x1x!tt.ptr> -> tensor<1024x1024x!tt.ptr> - // source = arg0, offset = [pid+arg3, 0], size = [1024, 1024], strides = [0, 0] - %23 = tt.addptr %22, %12 : tensor<1024x1024x!tt.ptr>, tensor<1024x1024xi32> - // source = arg0, offset = [pid+arg3, 0], size = [1024, 1024], strides = [1, 1] - tt.store %23, %17 : tensor<1024x1024x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32) { -// CHECK: %[[VAL_8:.*]] = arith.muli %[[ARG_8]], %[[VAL_2]] : i32 -// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : i32 to index -// CHECK: %[[VAL_10:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_9]]], sizes: [1024, 1024], strides: [1, 1] : memref<*xf32> to memref<1024x1024xf32, strided<[1, 1], offset: ?>> -// CHECK: %[[VAL_11:.*]] = memref.alloc() : memref<1024x1024xf32> -// CHECK: memref.copy %[[VAL_10]], %[[VAL_11]] : memref<1024x1024xf32, strided<[1, 1], offset: ?>> to memref<1024x1024xf32> -// CHECK: %[[VAL_12:.*]] = bufferization.to_tensor %[[VAL_11]] restrict writable : memref<1024x1024xf32> -// CHECK: %[[VAL_13:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_12]] : tensor<1024x1024xf32>) outs(%[[VAL_12]] : tensor<1024x1024xf32>) { -// CHECK: ^bb0(%[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32): -// CHECK: %[[VAL_16:.*]] = math.exp %[[VAL_14]] : f32 -// CHECK: linalg.yield %[[VAL_16]] : f32 -// CHECK: } -> tensor<1024x1024xf32> -// CHECK: %[[VAL_17:.*]] = arith.muli %[[ARG_8]], %[[VAL_3]] : i32 -// CHECK: %[[VAL_18:.*]] = arith.index_cast %[[VAL_17]] : i32 to index -// CHECK: %[[VAL_19:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_18]]], sizes: [1024, 1024], strides: [1, 1] : memref<*xf32> to memref<1024x1024xf32, strided<[1, 1], offset: ?>> -// CHECK: bufferization.materialize_in_destination %[[VAL_20:.*]] in writable %[[VAL_19]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_scalar_for.mlir b/test/Conversion/TritonToLinalg/addptr_scalar_for.mlir deleted file mode 100644 index 466778f4..00000000 --- a/test/Conversion/TritonToLinalg/addptr_scalar_for.mlir +++ /dev/null @@ -1,70 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel (%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32) { - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %arg2 : i32 - %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 - // source = %arg1, offset = %1, size = 1, strides = 0 - %cf0 = arith.constant 0.000000e+00 : f32 - %tensor_cf0 = tt.splat %cf0 : f32 -> tensor<1024xf32> - %c0 = arith.constant 0 : index - %c12 = arith.constant 12 : index - %c3 = arith.constant 3 : index - %_ptr, %sum_out = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr_iter = %2, %sum_iter = %tensor_cf0) -> (!tt.ptr, tensor<1024xf32>) { - %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - // offset = 0, size = 1024, strides = 1 - %4 = tt.splat %ptr_iter : !tt.ptr -> tensor<1024x!tt.ptr> - // source = %arg1, offset = %1, size = 1024, strides = 0 - %5 = tt.addptr %4, %3 : tensor<1024x!tt.ptr>, tensor<1024xi32> - // source = %arg1, offset = %1, size = 1024, strides = 1 - %8 = tt.load %5 : tensor<1024x!tt.ptr> - %9 = math.exp %8 : tensor<1024xf32> - %sum_next = arith.addf %sum_iter, %9 : tensor<1024xf32> - %cast_i = arith.index_cast %i : index to i32 - %ptr_next = tt.addptr %ptr_iter, %cast_i : !tt.ptr, i32 - // source = %arg1, offset = %1 + %i, size = 1, strides = 0 - scf.yield %ptr_next, %sum_next : !tt.ptr, tensor<1024xf32> - } - %10 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - %18 = arith.muli %0, %arg3 : i32 - %19 = tt.addptr %arg0, %18 : !tt.ptr, i32 - %20 = tt.splat %19 : !tt.ptr -> tensor<1024x!tt.ptr> - %21 = tt.addptr %20, %10 : tensor<1024x!tt.ptr>, tensor<1024xi32> - tt.store %21, %sum_out : tensor<1024x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32) { -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 12 : index -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_14:.*]] = tensor.empty() : tensor<1024xf32> -// CHECK: %[[VAL_15:.*]] = linalg.fill ins(%[[VAL_11]] : f32) outs(%[[VAL_14]] : tensor<1024xf32>) -> tensor<1024xf32> -// CHECK: %[[VAL_12:.*]] = arith.muli %[[ARG_8]], %[[VAL_2]] : i32 -// CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_12]] : i32 to index -// CHECK: %[[VAL_16:.*]]:2 = scf.for %[[VAL_17:.*]] = %[[VAL_10]] to %[[VAL_9]] step %[[VAL_8]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]], %[[VAL_19:.*]] = %[[VAL_13]]) -> (tensor<1024xf32>, index) { -// CHECK: %[[VAL_20:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_19]]], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> -// CHECK: %[[VAL_21:.*]] = memref.alloc() : memref<1024xf32> -// CHECK: memref.copy %[[VAL_20]], %[[VAL_21]] : memref<1024xf32, strided<[1], offset: ?>> to memref<1024xf32> -// CHECK: %[[VAL_22:.*]] = bufferization.to_tensor %[[VAL_21]] restrict writable : memref<1024xf32> -// CHECK: %[[VAL_23:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[VAL_22]] : tensor<1024xf32>) outs(%[[VAL_22]] : tensor<1024xf32>) { -// CHECK: ^bb0(%[[VAL_24:.*]]: f32, %[[VAL_25:.*]]: f32): -// CHECK: %[[VAL_26:.*]] = math.exp %[[VAL_24]] : f32 -// CHECK: linalg.yield %[[VAL_26]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: %[[VAL_27:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_18]], %[[VAL_28:.*]] : tensor<1024xf32>, tensor<1024xf32>) outs(%[[VAL_18]] : tensor<1024xf32>) { -// CHECK: ^bb0(%[[VAL_29:.*]]: f32, %[[VAL_30:.*]]: f32, %[[VAL_31:.*]]: f32): -// CHECK: %[[VAL_32:.*]] = arith.addf %[[VAL_29]], %[[VAL_30]] : f32 -// CHECK: linalg.yield %[[VAL_32]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: %[[VAL_33:.*]] = arith.addi %[[VAL_19]], %[[VAL_17]] : index -// CHECK: scf.yield %[[VAL_34:.*]], %[[VAL_33]] : tensor<1024xf32>, index -// CHECK: } -// CHECK: %[[VAL_35:.*]] = arith.muli %[[ARG_8]], %[[VAL_3]] : i32 -// CHECK: %[[VAL_36:.*]] = arith.index_cast %[[VAL_35]] : i32 to index -// CHECK: %[[VAL_37:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_36]]], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination %[[VAL_38:.*]]#0 in writable %[[VAL_37]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_scalar_for_2d.mlir b/test/Conversion/TritonToLinalg/addptr_scalar_for_2d.mlir deleted file mode 100644 index 39f3913b..00000000 --- a/test/Conversion/TritonToLinalg/addptr_scalar_for_2d.mlir +++ /dev/null @@ -1,92 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel (%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32) { - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %arg2 : i32 - %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 - %cf0 = arith.constant 0.000000e+00 : f32 - %tensor_cf0 = tt.splat %cf0 : f32 -> tensor<128x128xf32> - %c0 = arith.constant 0 : index - %c12 = arith.constant 12 : index - %c3 = arith.constant 3 : index - %sum_out, %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%sum_iter = %tensor_cf0, %ptr_iter = %2) -> (tensor<128x128xf32>, !tt.ptr ) { - %3 = tt.splat %ptr_iter : !tt.ptr -> tensor<128x128x!tt.ptr> - // source = %arg1, offset = [%1, 0], size = [128, 128], strides = [0, 0] - %4 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> - %6 = tt.broadcast %5 : tensor<1x128xi32> -> tensor<128x128xi32> - // offset = [0, 0], size = [128, 128], strides = [0, 1] - %7 = tt.make_range {end = 256 : i32, start = 128 : i32} : tensor<128xi32> - %8 = tt.expand_dims %7 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> - %9 = tt.broadcast %8 : tensor<128x1xi32> -> tensor<128x128xi32> - // offset = [128, 0], size = [128, 128], strides = [1, 0] - %10 = arith.addi %6, %9 : tensor<128x128xi32> - // offset = [128, 0], size = [128, 128], strides = [1, 1] - %11 = tt.addptr %3, %10 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - // source = %arg1, offset = [%1 + 128, 0], size = [128, 128], strides = [1, 1] - %12 = tt.load %11 : tensor<128x128x!tt.ptr> - %17 = math.exp %12 : tensor<128x128xf32> - %sum_next = arith.addf %sum_iter, %17 : tensor<128x128xf32> - %cast_i = arith.index_cast %i : index to i32 - %ptr_next = tt.addptr %ptr_iter, %cast_i : !tt.ptr, i32 - // source = %arg1, offset = %1 + %i, size = 1, strides = 0 - scf.yield %sum_next, %ptr_next : tensor<128x128xf32>, !tt.ptr - } - %4 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> - %6 = tt.broadcast %5 : tensor<1x128xi32> -> tensor<128x128xi32> - // offset = [0, 0], size = [128, 128], strides = [0, 1] - %7 = tt.make_range {end = 256 : i32, start = 128 : i32} : tensor<128xi32> - %8 = tt.expand_dims %7 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> - %9 = tt.broadcast %8 : tensor<128x1xi32> -> tensor<128x128xi32> - // offset = [128, 0], size = [128, 128], strides = [1, 0] - %10 = arith.addi %6, %9 : tensor<128x128xi32> - // offset = [128, 0], size = [128, 128], strides = [1, 1] - %18 = arith.muli %0, %arg3 : i32 - %19 = tt.addptr %arg0, %18 : !tt.ptr, i32 - // source = arg0, offset = %18, size = 1, strides = 0 - %20 = tt.splat %19 : !tt.ptr -> tensor<128x128x!tt.ptr> - // source = arg0, offset = [%18, 0], size = [128, 128], strides = [0, 0] - %21 = tt.addptr %20, %10 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - // source = %arg0, offset = [%18 + 128, 0], size = [128, 128], strides = [1, 1] - tt.store %21, %sum_out : tensor<128x128x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32) { -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 128 : index -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 12 : index -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_15:.*]] = tensor.empty() : tensor<128x128xf32> -// CHECK: %[[VAL_16:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_15]] : tensor<128x128xf32>) -> tensor<128x128xf32> -// CHECK: %[[VAL_13:.*]] = arith.muli %[[ARG_8]], %[[VAL_2]] : i32 -// CHECK: %[[VAL_14:.*]] = arith.index_cast %[[VAL_13]] : i32 to index -// CHECK: %[[VAL_17:.*]]:2 = scf.for %[[VAL_18:.*]] = %[[VAL_11]] to %[[VAL_10]] step %[[VAL_9]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]], %[[VAL_20:.*]] = %[[VAL_14]]) -> (tensor<128x128xf32>, index) { -// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_8]] : index -// CHECK: %[[VAL_22:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_21]]], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1], offset: ?>> -// CHECK: %[[VAL_23:.*]] = memref.alloc() : memref<128x128xf32> -// CHECK: memref.copy %[[VAL_22]], %[[VAL_23]] : memref<128x128xf32, strided<[1, 1], offset: ?>> to memref<128x128xf32> -// CHECK: %[[VAL_24:.*]] = bufferization.to_tensor %[[VAL_23]] restrict writable : memref<128x128xf32> -// CHECK: %[[VAL_25:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_24]] : tensor<128x128xf32>) outs(%[[VAL_24]] : tensor<128x128xf32>) { -// CHECK: ^bb0(%[[VAL_26:.*]]: f32, %[[VAL_27:.*]]: f32): -// CHECK: %[[VAL_28:.*]] = math.exp %[[VAL_26]] : f32 -// CHECK: linalg.yield %[[VAL_28]] : f32 -// CHECK: } -> tensor<128x128xf32> -// CHECK: %[[VAL_29:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_19]], %[[VAL_30:.*]] : tensor<128x128xf32>, tensor<128x128xf32>) outs(%[[VAL_19]] : tensor<128x128xf32>) { -// CHECK: ^bb0(%[[VAL_31:.*]]: f32, %[[VAL_32:.*]]: f32, %[[VAL_33:.*]]: f32): -// CHECK: %[[VAL_34:.*]] = arith.addf %[[VAL_31]], %[[VAL_32]] : f32 -// CHECK: linalg.yield %[[VAL_34]] : f32 -// CHECK: } -> tensor<128x128xf32> -// CHECK: %[[VAL_35:.*]] = arith.addi %[[VAL_20]], %[[VAL_18]] : index -// CHECK: scf.yield %[[VAL_36:.*]], %[[VAL_35]] : tensor<128x128xf32>, index -// CHECK: } -// CHECK: %[[VAL_37:.*]] = arith.muli %[[ARG_8]], %[[VAL_3]] : i32 -// CHECK: %[[VAL_38:.*]] = arith.index_cast %[[VAL_37]] : i32 to index -// CHECK: %[[VAL_39:.*]] = arith.addi %[[VAL_38]], %[[VAL_8]] : index -// CHECK: %[[VAL_40:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_39]]], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1], offset: ?>> -// CHECK: bufferization.materialize_in_destination %[[VAL_41:.*]]#0 in writable %[[VAL_40]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_scalar_loopback.mlir b/test/Conversion/TritonToLinalg/addptr_scalar_loopback.mlir deleted file mode 100644 index 567dd950..00000000 --- a/test/Conversion/TritonToLinalg/addptr_scalar_loopback.mlir +++ /dev/null @@ -1,27 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s - -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr, - %arg2 : i32 - ) { - %0 = tt.addptr %arg0, %arg2 : !tt.ptr, i32 - %1 = tt.addptr %arg1, %arg2 : !tt.ptr, i32 - %10 = tt.load %0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: !tt.ptr - tt.store %1, %10 : !tt.ptr - tt.return - } -} - -// CHECK: module { -// CHECK: func.func @kernel(%arg0: memref<*xbf16>, %arg1: memref<*xbf16>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) { -// CHECK: %0 = arith.index_cast %arg2 : i32 to index -// CHECK: %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%0], sizes: [1], strides: [1] : memref<*xbf16> to memref<1xbf16, strided<[1], offset: ?>> -// CHECK: %1 = arith.index_cast %arg2 : i32 to index -// CHECK: %reinterpret_cast_0 = memref.reinterpret_cast %arg1 to offset: [%1], sizes: [1], strides: [1] : memref<*xbf16> to memref<1xbf16, strided<[1], offset: ?>> -// CHECK: %2 = affine.load %reinterpret_cast[0] : memref<1xbf16, strided<[1], offset: ?>> -// CHECK: affine.store %2, %reinterpret_cast_0[0] : memref<1xbf16, strided<[1], offset: ?>> -// CHECK: return -// CHECK: } -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_scalar_nested.mlir b/test/Conversion/TritonToLinalg/addptr_scalar_nested.mlir deleted file mode 100644 index 1bf5f031..00000000 --- a/test/Conversion/TritonToLinalg/addptr_scalar_nested.mlir +++ /dev/null @@ -1,57 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel (%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32) { - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %arg2 : i32 - %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 - // source = arg1, offset = %1, size = 1, strides = 0 - %3 = arith.muli %0, %arg3 : i32 - %4 = tt.addptr %2, %3 : !tt.ptr, i32 - // source = arg1, offset = %1+%3, size = 1, strides = 0 - %5 = arith.muli %0, %arg4 : i32 - %6 = tt.addptr %4, %5 : !tt.ptr, i32 - // source = arg1, offset = %1+%3+%5, size = 1, strides = 0 - %7 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - // offset = 0, size = 1024, strides = 1 - %8 = tt.splat %6 : !tt.ptr -> tensor<1024x!tt.ptr> - // source = arg1, offset = %1, size = 1024, strides = 0 - %9 = tt.addptr %8, %7 : tensor<1024x!tt.ptr>, tensor<1024xi32> - // source = arg1, offset = %1+%3+%5, size = 1024, strides = 1 - %10 = tt.load %9 : tensor<1024x!tt.ptr> - %17 = math.exp %10 : tensor<1024xf32> - %18 = arith.muli %0, %arg3 : i32 - %19 = tt.addptr %arg0, %18 : !tt.ptr, i32 - // source = arg0, offset = %18, size = 1, strides = 0 - %20 = tt.splat %19 : !tt.ptr -> tensor<1024x!tt.ptr> - // source = arg0, offset = %18, size = 1024, strides = 0 - %21 = tt.addptr %20, %7 : tensor<1024x!tt.ptr>, tensor<1024xi32> - // source = arg0, offset = %18, size = 1024, strides = 1 - tt.store %21, %17 : tensor<1024x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32) { -// CHECK: %[[VAL_8:.*]] = arith.muli %[[ARG_8]], %[[VAL_2]] : i32 -// CHECK: %[[VAL_9:.*]] = arith.muli %[[ARG_8]], %[[VAL_3]] : i32 -// CHECK: %[[VAL_10:.*]] = arith.muli %[[ARG_8]], %[[VAL_4]] : i32 -// CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_8]] : i32 to index -// CHECK: %[[VAL_12:.*]] = arith.index_cast %[[VAL_9]] : i32 to index -// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : index -// CHECK: %[[VAL_14:.*]] = arith.index_cast %[[VAL_10]] : i32 to index -// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : index -// CHECK: %[[VAL_16:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_15]]], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> -// CHECK: %[[VAL_17:.*]] = memref.alloc() : memref<1024xf32> -// CHECK: memref.copy %[[VAL_16]], %[[VAL_17]] : memref<1024xf32, strided<[1], offset: ?>> to memref<1024xf32> -// CHECK: %[[VAL_18:.*]] = bufferization.to_tensor %[[VAL_17]] restrict writable : memref<1024xf32> -// CHECK: %[[VAL_19:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[VAL_18]] : tensor<1024xf32>) outs(%[[VAL_18]] : tensor<1024xf32>) { -// CHECK: ^bb0(%[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32): -// CHECK: %[[VAL_22:.*]] = math.exp %[[VAL_20]] : f32 -// CHECK: linalg.yield %[[VAL_22]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: %[[VAL_23:.*]] = arith.muli %[[ARG_8]], %[[VAL_3]] : i32 -// CHECK: %[[VAL_24:.*]] = arith.index_cast %[[VAL_23]] : i32 to index -// CHECK: %[[VAL_25:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_24]]], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination %[[VAL_26:.*]] in writable %[[VAL_25]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_scalar_splat.mlir b/test/Conversion/TritonToLinalg/addptr_scalar_splat.mlir deleted file mode 100644 index ccb8ce49..00000000 --- a/test/Conversion/TritonToLinalg/addptr_scalar_splat.mlir +++ /dev/null @@ -1,45 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel (%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32) { - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %arg2 : i32 - %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 - // source = %arg1, offset = %1, size = 1, strides = 0 - %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - // offset = 0, size = 1024, strides = 1 - %4 = tt.splat %2 : !tt.ptr -> tensor<1024x!tt.ptr> - // source = %arg1, offset = %1, size = 1024, strides = 0 - %5 = tt.addptr %4, %3 : tensor<1024x!tt.ptr>, tensor<1024xi32> - // source = %arg1, offset = %1, size = 1024, strides = 1 - %8 = tt.load %5 : tensor<1024x!tt.ptr> - %17 = math.exp %8 : tensor<1024xf32> - %18 = arith.muli %0, %arg3 : i32 - %19 = tt.addptr %arg0, %18 : !tt.ptr, i32 - // source = %arg0, offset = %18, size = 1, strides = 0 - %20 = tt.splat %19 : !tt.ptr -> tensor<1024x!tt.ptr> - // source = %arg0, offset = %18, size = 1024, strides = 0 - %21 = tt.addptr %20, %3 : tensor<1024x!tt.ptr>, tensor<1024xi32> - // source = %arg0, offset = %18, size = 1024, strides = 1 - tt.store %21, %17 : tensor<1024x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32) { -// CHECK: %[[VAL_8:.*]] = arith.muli %[[ARG_8]], %[[VAL_2]] : i32 -// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : i32 to index -// CHECK: %[[VAL_10:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_9]]], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> -// CHECK: %[[VAL_11:.*]] = memref.alloc() : memref<1024xf32> -// CHECK: memref.copy %[[VAL_10]], %[[VAL_11]] : memref<1024xf32, strided<[1], offset: ?>> to memref<1024xf32> -// CHECK: %[[VAL_12:.*]] = bufferization.to_tensor %[[VAL_11]] restrict writable : memref<1024xf32> -// CHECK: %[[VAL_13:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[VAL_12]] : tensor<1024xf32>) outs(%[[VAL_12]] : tensor<1024xf32>) { -// CHECK: ^bb0(%[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32): -// CHECK: %[[VAL_16:.*]] = math.exp %[[VAL_14]] : f32 -// CHECK: linalg.yield %[[VAL_16]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: %[[VAL_17:.*]] = arith.muli %[[ARG_8]], %[[VAL_3]] : i32 -// CHECK: %[[VAL_18:.*]] = arith.index_cast %[[VAL_17]] : i32 to index -// CHECK: %[[VAL_19:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_18]]], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination %[[VAL_20:.*]] in writable %[[VAL_19]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/addptr_scalar_splat_2d.mlir b/test/Conversion/TritonToLinalg/addptr_scalar_splat_2d.mlir deleted file mode 100644 index 122d1c40..00000000 --- a/test/Conversion/TritonToLinalg/addptr_scalar_splat_2d.mlir +++ /dev/null @@ -1,56 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel (%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32) { - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %arg2 : i32 - %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 - %3 = tt.splat %2 : !tt.ptr -> tensor<128x128x!tt.ptr> - // source = %arg1, offset = [%1, 0], size = [128, 128], strides = [0, 0] - %4 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> - %6 = tt.broadcast %5 : tensor<1x128xi32> -> tensor<128x128xi32> - // offset = [0, 0], size = [128, 128], strides = [0, 1] - %7 = tt.make_range {end = 256 : i32, start = 128 : i32} : tensor<128xi32> - // offset = 128, size = 128, strides = 1 - %8 = tt.expand_dims %7 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> - %9 = tt.broadcast %8 : tensor<128x1xi32> -> tensor<128x128xi32> - // offset = [128, 0], size = [128, 128], strides = [1, 0] - %10 = arith.addi %6, %9 : tensor<128x128xi32> - // offset = [128, 0], size = [128, 128], strides = [1, 1] - %11 = tt.addptr %3, %10 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - // source = %arg1, offset = [%1 + 128, 0], size = [128, 128], strides = [1, 1] - %12 = tt.load %11 : tensor<128x128x!tt.ptr> - %17 = math.exp %12 : tensor<128x128xf32> - %18 = arith.muli %0, %arg3 : i32 - %19 = tt.addptr %arg0, %18 : !tt.ptr, i32 - // source = arg0, offset = %18, size = 1, strides = 0 - %20 = tt.splat %19 : !tt.ptr -> tensor<128x128x!tt.ptr> - // source = arg0, offset = [%18, 0], size = [128, 128], strides = [0, 0] - %21 = tt.addptr %20, %10 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - // source = %arg0, offset = [%18 + 128, 0], size = [128, 128], strides = [1, 1] - tt.store %21, %17 : tensor<128x128x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: memref<*xf32> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32, %[[ARG_7:.*]]: i32, %[[ARG_8:.*]]: i32, %[[ARG_9:.*]]: i32, %[[ARG_10:.*]]: i32) { -// CHECK: %[[VAL_8:.*]] = arith.constant 128 : index -// CHECK: %[[VAL_9:.*]] = arith.muli %[[ARG_8]], %[[VAL_2]] : i32 -// CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_9]] : i32 to index -// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_8]] : index -// CHECK: %[[VAL_12:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_11]]], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1], offset: ?>> -// CHECK: %[[VAL_13:.*]] = memref.alloc() : memref<128x128xf32> -// CHECK: memref.copy %[[VAL_12]], %[[VAL_13]] : memref<128x128xf32, strided<[1, 1], offset: ?>> to memref<128x128xf32> -// CHECK: %[[VAL_14:.*]] = bufferization.to_tensor %[[VAL_13]] restrict writable : memref<128x128xf32> -// CHECK: %[[VAL_15:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_14]] : tensor<128x128xf32>) outs(%[[VAL_14]] : tensor<128x128xf32>) { -// CHECK: ^bb0(%[[VAL_16:.*]]: f32, %[[VAL_17:.*]]: f32): -// CHECK: %[[VAL_18:.*]] = math.exp %[[VAL_16]] : f32 -// CHECK: linalg.yield %[[VAL_18]] : f32 -// CHECK: } -> tensor<128x128xf32> -// CHECK: %[[VAL_19:.*]] = arith.muli %[[ARG_8]], %[[VAL_3]] : i32 -// CHECK: %[[VAL_20:.*]] = arith.index_cast %[[VAL_19]] : i32 to index -// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_8]] : index -// CHECK: %[[VAL_22:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: {{\[}}%[[VAL_21]]], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1], offset: ?>> -// CHECK: bufferization.materialize_in_destination %[[VAL_23:.*]] in writable %[[VAL_22]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/arith_not_ptr_arith.mlir b/test/Conversion/TritonToLinalg/arith_not_ptr_arith.mlir deleted file mode 100644 index e0efdde0..00000000 --- a/test/Conversion/TritonToLinalg/arith_not_ptr_arith.mlir +++ /dev/null @@ -1,39 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %a : !tt.ptr, - %b : !tt.ptr - ) -> () { - // offset calculations - %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - // a pointer - %8 = tt.splat %a : !tt.ptr -> tensor<1024x!tt.ptr> - %9 = tt.addptr %8, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> - // b pointer - %18 = tt.splat %b : !tt.ptr -> tensor<1024x!tt.ptr> - %19 = tt.addptr %18, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> - %am = tt.load %9 : tensor<1024x!tt.ptr> - %bm = tt.load %19 : tensor<1024x!tt.ptr> - %5 = arith.addi %am, %bm : tensor<1024xi32> - tt.store %19, %5 : tensor<1024x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xi32>, %[[VAL_1:.*]]: memref<*xi32>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK: %[[VAL_5:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1024], strides: [1] : memref<*xi32> to memref<1024xi32, strided<[1]>> -// CHECK: %[[VAL_6:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [1024], strides: [1] : memref<*xi32> to memref<1024xi32, strided<[1]>> -// CHECK: %[[VAL_7:.*]] = memref.alloc() : memref<1024xi32> -// CHECK: memref.copy %[[VAL_5]], %[[VAL_7]] : memref<1024xi32, strided<[1]>> to memref<1024xi32> -// CHECK: %[[VAL_8:.*]] = bufferization.to_tensor %[[VAL_7]] restrict writable : memref<1024xi32> -// CHECK: %[[VAL_9:.*]] = memref.alloc() : memref<1024xi32> -// CHECK: memref.copy %[[VAL_6]], %[[VAL_9]] : memref<1024xi32, strided<[1]>> to memref<1024xi32> -// CHECK: %[[VAL_10:.*]] = bufferization.to_tensor %[[VAL_9]] restrict writable : memref<1024xi32> -// CHECK: %[[VAL_11:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_8]], %[[VAL_10]] : tensor<1024xi32>, tensor<1024xi32>) outs(%[[VAL_8]] : tensor<1024xi32>) { -// CHECK: ^bb0(%[[VAL_12:.*]]: i32, %[[VAL_13:.*]]: i32, %[[VAL_14:.*]]: i32): -// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : i32 -// CHECK: linalg.yield %[[VAL_15]] : i32 -// CHECK: } -> tensor<1024xi32> -// CHECK: bufferization.materialize_in_destination %[[VAL_16:.*]] in writable %[[VAL_6]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/bitcast.mlir b/test/Conversion/TritonToLinalg/bitcast.mlir deleted file mode 100644 index f838a901..00000000 --- a/test/Conversion/TritonToLinalg/bitcast.mlir +++ /dev/null @@ -1,44 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s - -module { - tt.func @kernel(%a : !tt.ptr, %b : !tt.ptr) -> () { - // offset calculations - %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - - // a pointer - %8 = tt.splat %a : !tt.ptr -> tensor<1024x!tt.ptr> - %9 = tt.addptr %8, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> - - // b pointer - %18 = tt.splat %b : !tt.ptr -> tensor<1024x!tt.ptr> - %19 = tt.addptr %18, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> - - %am = tt.load %9 : tensor<1024x!tt.ptr> - - // cast result before doing float add - %am_bitcast = tt.bitcast %am : tensor<1024xi32> -> tensor<1024xf32> - - - tt.store %19, %am_bitcast : tensor<1024x!tt.ptr> - tt.return - } -} - -// CHECK: module { -// CHECK: func.func @kernel(%arg0: memref<*xi32>, %arg1: memref<*xf32>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) { -// CHECK: [[RC_:%.+]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [1024], strides: [1]{{.*}} : memref<*xi32> to memref<1024xi32, strided<[1]>> -// CHECK: [[RC_0_:%.+]] = memref.reinterpret_cast %arg1 to offset: [0], sizes: [1024], strides: [1]{{.*}} : memref<*xf32> to memref<1024xf32, strided<[1]>> -// CHECK: [[ALLOC_:%.+]] = memref.alloc() : memref<1024xi32> -// CHECK: memref.copy [[RC_]], [[ALLOC_]] : memref<1024xi32, strided<[1]>> to memref<1024xi32> -// CHECK: [[VAR_0_:%.+]] = bufferization.to_tensor [[ALLOC_]] restrict writable : memref<1024xi32> -// CHECK: [[VAR_1_:%.+]] = tensor.empty() : tensor<1024xf32> -// CHECK: [[VAR_2_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_0_]] : tensor<1024xi32>) outs([[VAR_1_]] : tensor<1024xf32>) { -// CHECK: ^bb0(%in: i32, %out: f32): -// CHECK: [[VAR_5_:%.+]] = arith.bitcast %in : i32 to f32 -// CHECK: linalg.yield [[VAR_5_]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: bufferization.materialize_in_destination [[VAR_2_]] in writable [[RC_0_]] -// CHECK: return -// CHECK: } -// CHECK: } - diff --git a/test/Conversion/TritonToLinalg/block_ptr_advance.mlir b/test/Conversion/TritonToLinalg/block_ptr_advance.mlir deleted file mode 100644 index 8cf0fe7d..00000000 --- a/test/Conversion/TritonToLinalg/block_ptr_advance.mlir +++ /dev/null @@ -1,90 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s - -module { - tt.func public @matmul_kernel_with_block_pointers_01234567891011(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32) { - %c64_i32 = arith.constant 64 : i32 - %c0_i32 = arith.constant 0 : i32 - %cst = arith.constant 0.000000e+00 : bf16 - %c256_i32 = arith.constant 256 : i32 - %0 = arith.extsi %arg3 : i32 to i64 - %1 = arith.extsi %arg5 : i32 to i64 - %2 = arith.extsi %arg6 : i32 to i64 - %3 = arith.extsi %arg7 : i32 to i64 - %4 = tt.make_tensor_ptr %arg0, [%0, %1], [%2, %3], [%arg12, %c0_i32] {order = array} : > - %5 = tt.advance %4, [%c0_i32, %c64_i32] : > - %6 = tt.splat %cst : bf16 -> tensor<128x64xbf16> - %7:3 = scf.for %arg14 = %c0_i32 to %arg5 step %c64_i32 iter_args(%arg15 = %6, %arg16 = %5, %arg17 = %4) -> (tensor<128x64xbf16>, !tt.ptr>, !tt.ptr>) : i32 { - %13 = tt.load %arg16 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr> - %14 = tt.load %arg17 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr> - %15 = arith.addf %13, %14 : tensor<128x64xbf16> - %16 = arith.addf %arg15, %15 : tensor<128x64xbf16> - %17 = tt.advance %arg16, [%c0_i32, %c64_i32] : > - %18 = tt.advance %arg17, [%c64_i32, %c0_i32] : > - scf.yield %16, %17, %18 : tensor<128x64xbf16>, !tt.ptr>, !tt.ptr> - } - %8 = arith.extsi %arg10 : i32 to i64 - %9 = arith.extsi %arg11 : i32 to i64 - %10 = arith.extsi %arg4 : i32 to i64 - %11 = arith.muli %arg13, %c256_i32 : i32 - %12 = tt.make_tensor_ptr %arg2, [%0, %10], [%8, %9], [%arg12, %11] {order = array} : > - tt.store %12, %7#0 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr> - tt.return - } -} - -// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: module { -// CHECK: func.func @matmul_kernel_with_block_pointers_01234567891011(%arg0: memref<*xbf16>, %arg1: memref<*xbf16>, %arg2: memref<*xbf16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32, %arg15: i32, %arg16: i32, %arg17: i32, %arg18: i32, %arg19: i32) { -// CHECK: %c64 = arith.constant 64 : index -// CHECK: %c256_i32 = arith.constant 256 : i32 -// CHECK: %c0_i32 = arith.constant 0 : i32 -// CHECK: %c64_i32 = arith.constant 64 : i32 -// CHECK: %cst = arith.constant 0.000000e+00 : bf16 -// CHECK: %0 = tensor.empty() : tensor<128x64xbf16> -// CHECK: %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<128x64xbf16>) -> tensor<128x64xbf16> -// CHECK: %2 = arith.index_cast %arg12 : i32 to index -// CHECK: %3 = arith.index_cast %arg6 : i32 to index -// CHECK: %4 = arith.index_cast %arg7 : i32 to index -// CHECK: %5 = arith.muli %2, %3 : index -// CHECK: %6 = arith.muli %4, %c64 : index -// CHECK: %7 = arith.addi %5, %6 : index -// CHECK: %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%7], sizes: [128, 64], strides: [%3, %4] : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>> -// CHECK: %reinterpret_cast_0 = memref.reinterpret_cast %arg0 to offset: [%5], sizes: [128, 64], strides: [%3, %4] : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>> -// CHECK: %8:5 = scf.for %arg20 = %c0_i32 to %arg5 step %c64_i32 iter_args(%arg21 = %1, %arg22 = %reinterpret_cast, %arg23 = %reinterpret_cast_0, %arg24 = %7, %arg25 = %5) -> (tensor<128x64xbf16>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, index, index) : i32 { -// CHECK: %alloc = memref.alloc() : memref<128x64xbf16> -// CHECK: memref.copy %arg22, %alloc : memref<128x64xbf16, strided<[?, ?], offset: ?>> to memref<128x64xbf16> -// CHECK: %17 = bufferization.to_tensor %alloc restrict writable : memref<128x64xbf16> -// CHECK: %alloc_2 = memref.alloc() : memref<128x64xbf16> -// CHECK: memref.copy %arg23, %alloc_2 : memref<128x64xbf16, strided<[?, ?], offset: ?>> to memref<128x64xbf16> -// CHECK: %18 = bufferization.to_tensor %alloc_2 restrict writable : memref<128x64xbf16> -// CHECK: %19 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%17, %18 : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%17 : tensor<128x64xbf16>) { -// CHECK: ^bb0(%in: bf16, %in_5: bf16, %out: bf16): -// CHECK: %25 = arith.addf %in, %in_5 : bf16 -// CHECK: linalg.yield %25 : bf16 -// CHECK: } -> tensor<128x64xbf16> -// CHECK: %20 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg21, %19 : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%arg21 : tensor<128x64xbf16>) { -// CHECK: ^bb0(%in: bf16, %in_5: bf16, %out: bf16): -// CHECK: %25 = arith.addf %in, %in_5 : bf16 -// CHECK: linalg.yield %25 : bf16 -// CHECK: } -> tensor<128x64xbf16> -// CHECK: %21 = arith.muli %4, %c64 : index -// CHECK: %22 = arith.addi %arg24, %21 : index -// CHECK: %reinterpret_cast_3 = memref.reinterpret_cast %arg0 to offset: [%22], sizes: [128, 64], strides: [%3, %4] : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>> -// CHECK: %23 = arith.muli %3, %c64 : index -// CHECK: %24 = arith.addi %23, %arg25 : index -// CHECK: %reinterpret_cast_4 = memref.reinterpret_cast %arg0 to offset: [%24], sizes: [128, 64], strides: [%3, %4] : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>> -// CHECK: scf.yield %20, %reinterpret_cast_3, %reinterpret_cast_4, %22, %24 : tensor<128x64xbf16>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, index, index -// CHECK: } -// CHECK: %9 = arith.muli %arg13, %c256_i32 : i32 -// CHECK: %10 = arith.index_cast %arg12 : i32 to index -// CHECK: %11 = arith.index_cast %9 : i32 to index -// CHECK: %12 = arith.index_cast %arg10 : i32 to index -// CHECK: %13 = arith.index_cast %arg11 : i32 to index -// CHECK: %14 = arith.muli %10, %12 : index -// CHECK: %15 = arith.muli %11, %13 : index -// CHECK: %16 = arith.addi %14, %15 : index -// CHECK: %reinterpret_cast_1 = memref.reinterpret_cast %arg2 to offset: [%16], sizes: [128, 64], strides: [%12, %13] : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>> -// CHECK: bufferization.materialize_in_destination %8#0 in writable %reinterpret_cast_1 -// CHECK: return -// CHECK: } -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_binary.mlir b/test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_binary.mlir deleted file mode 100644 index 03363130..00000000 --- a/test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_binary.mlir +++ /dev/null @@ -1,72 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %a : !tt.ptr, - %b : !tt.ptr, - %c : tensor<1024x!tt.ptr> - ) -> () { - %cst = arith.constant dense : tensor<1024xi1> - // offset calculations - %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - // a pointer - %8 = tt.splat %a : !tt.ptr -> tensor<1024x!tt.ptr> - %9 = tt.addptr %8, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> - // b pointer - %18 = tt.splat %b : !tt.ptr -> tensor<1024x!tt.ptr> - %19 = tt.addptr %18, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> - %am = tt.load %9 : tensor<1024x!tt.ptr> - %bm = tt.load %19 : tensor<1024x!tt.ptr> - %1 = arith.addf %am, %bm : tensor<1024xf32> - %2 = arith.subf %1, %bm : tensor<1024xf32> - %3 = arith.mulf %2, %bm : tensor<1024xf32> - %4 = arith.divf %3, %bm : tensor<1024xf32> - %5 = arith.cmpf "oeq", %4, %bm : tensor<1024xf32> - %6 = arith.select %5, %am, %bm : tensor<1024xi1>, tensor<1024xf32> - tt.store %c, %6 : tensor<1024x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32>, %[[VAL_1:.*]]: memref<*xf32>, %[[VAL_2:.*]]: memref<1024xf32>, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32) { -// CHECK: %[[VAL_6:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1]>> -// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1]>> -// CHECK: %[[VAL_8:.*]] = memref.alloc() : memref<1024xf32> -// CHECK: memref.copy %[[VAL_6]], %[[VAL_8]] : memref<1024xf32, strided<[1]>> to memref<1024xf32> -// CHECK: %[[VAL_9:.*]] = bufferization.to_tensor %[[VAL_8]] restrict writable : memref<1024xf32> -// CHECK: %[[VAL_10:.*]] = memref.alloc() : memref<1024xf32> -// CHECK: memref.copy %[[VAL_7]], %[[VAL_10]] : memref<1024xf32, strided<[1]>> to memref<1024xf32> -// CHECK: %[[VAL_11:.*]] = bufferization.to_tensor %[[VAL_10]] restrict writable : memref<1024xf32> -// CHECK: %[[VAL_12:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_9]], %[[VAL_11]] : tensor<1024xf32>, tensor<1024xf32>) outs(%[[VAL_9]] : tensor<1024xf32>) { -// CHECK: ^bb0(%[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32): -// CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f32 -// CHECK: linalg.yield %[[VAL_16]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: %[[VAL_17:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_18:.*]], %[[VAL_11]] : tensor<1024xf32>, tensor<1024xf32>) outs(%[[VAL_18]] : tensor<1024xf32>) { -// CHECK: ^bb0(%[[VAL_19:.*]]: f32, %[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32): -// CHECK: %[[VAL_22:.*]] = arith.subf %[[VAL_19]], %[[VAL_20]] : f32 -// CHECK: linalg.yield %[[VAL_22]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: %[[VAL_23:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_24:.*]], %[[VAL_11]] : tensor<1024xf32>, tensor<1024xf32>) outs(%[[VAL_24]] : tensor<1024xf32>) { -// CHECK: ^bb0(%[[VAL_25:.*]]: f32, %[[VAL_26:.*]]: f32, %[[VAL_27:.*]]: f32): -// CHECK: %[[VAL_28:.*]] = arith.mulf %[[VAL_25]], %[[VAL_26]] : f32 -// CHECK: linalg.yield %[[VAL_28]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: %[[VAL_29:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_30:.*]], %[[VAL_11]] : tensor<1024xf32>, tensor<1024xf32>) outs(%[[VAL_30]] : tensor<1024xf32>) { -// CHECK: ^bb0(%[[VAL_31:.*]]: f32, %[[VAL_32:.*]]: f32, %[[VAL_33:.*]]: f32): -// CHECK: %[[VAL_34:.*]] = arith.divf %[[VAL_31]], %[[VAL_32]] : f32 -// CHECK: linalg.yield %[[VAL_34]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: %[[VAL_35:.*]] = tensor.empty() : tensor<1024xi1> -// CHECK: %[[VAL_36:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_37:.*]], %[[VAL_11]] : tensor<1024xf32>, tensor<1024xf32>) outs(%[[VAL_35]] : tensor<1024xi1>) { -// CHECK: ^bb0(%[[VAL_38:.*]]: f32, %[[VAL_39:.*]]: f32, %[[VAL_40:.*]]: i1): -// CHECK: %[[VAL_41:.*]] = arith.cmpf oeq, %[[VAL_38]], %[[VAL_39]] : f32 -// CHECK: linalg.yield %[[VAL_41]] : i1 -// CHECK: } -> tensor<1024xi1> -// CHECK: %[[VAL_42:.*]] = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_43:.*]], %[[VAL_9]], %[[VAL_11]] : tensor<1024xi1>, tensor<1024xf32>, tensor<1024xf32>) outs(%[[VAL_9]] : tensor<1024xf32>) { -// CHECK: ^bb0(%[[VAL_44:.*]]: i1, %[[VAL_45:.*]]: f32, %[[VAL_46:.*]]: f32, %[[VAL_47:.*]]: f32): -// CHECK: %[[VAL_48:.*]] = arith.select %[[VAL_44]], %[[VAL_45]], %[[VAL_46]] : f32 -// CHECK: linalg.yield %[[VAL_48]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: bufferization.materialize_in_destination %[[VAL_49:.*]] in writable %[[VAL_2]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_ternary.mlir b/test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_ternary.mlir deleted file mode 100644 index 39f4d5ca..00000000 --- a/test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_ternary.mlir +++ /dev/null @@ -1,49 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %a : !tt.ptr, - %b : !tt.ptr, - %c : !tt.ptr, - %d : tensor<1024x!tt.ptr> - ) -> () { - // offset calculations - %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - // a pointer - %8 = tt.splat %a : !tt.ptr -> tensor<1024x!tt.ptr> - %9 = tt.addptr %8, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> - // b pointer - %18 = tt.splat %b : !tt.ptr -> tensor<1024x!tt.ptr> - %19 = tt.addptr %18, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> - // c pointer - %28 = tt.splat %c : !tt.ptr -> tensor<1024x!tt.ptr> - %29 = tt.addptr %28, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> - %am = tt.load %9 : tensor<1024x!tt.ptr> - %bm = tt.load %19 : tensor<1024x!tt.ptr> - %cm = tt.load %29 : tensor<1024x!tt.ptr> - %10 = arith.select %am, %bm, %cm : tensor<1024xi1>, tensor<1024xf32> - tt.store %d, %10 : tensor<1024x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xi1>, %[[VAL_1:.*]]: memref<*xf32>, %[[VAL_2:.*]]: memref<*xf32>, %[[VAL_3:.*]]: memref<1024xf32>, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32) { -// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1024], strides: [1] : memref<*xi1> to memref<1024xi1, strided<[1]>> -// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1]>> -// CHECK: %[[VAL_9:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1]>> -// CHECK: %[[VAL_10:.*]] = memref.alloc() : memref<1024xi1> -// CHECK: memref.copy %[[VAL_7]], %[[VAL_10]] : memref<1024xi1, strided<[1]>> to memref<1024xi1> -// CHECK: %[[VAL_11:.*]] = bufferization.to_tensor %[[VAL_10]] restrict writable : memref<1024xi1> -// CHECK: %[[VAL_12:.*]] = memref.alloc() : memref<1024xf32> -// CHECK: memref.copy %[[VAL_8]], %[[VAL_12]] : memref<1024xf32, strided<[1]>> to memref<1024xf32> -// CHECK: %[[VAL_13:.*]] = bufferization.to_tensor %[[VAL_12]] restrict writable : memref<1024xf32> -// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<1024xf32> -// CHECK: memref.copy %[[VAL_9]], %[[VAL_14]] : memref<1024xf32, strided<[1]>> to memref<1024xf32> -// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_14]] restrict writable : memref<1024xf32> -// CHECK: %[[VAL_16:.*]] = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<1024xi1>, tensor<1024xf32>, tensor<1024xf32>) outs(%[[VAL_13]] : tensor<1024xf32>) { -// CHECK: ^bb0(%[[VAL_17:.*]]: i1, %[[VAL_18:.*]]: f32, %[[VAL_19:.*]]: f32, %[[VAL_20:.*]]: f32): -// CHECK: %[[VAL_21:.*]] = arith.select %[[VAL_17]], %[[VAL_18]], %[[VAL_19]] : f32 -// CHECK: linalg.yield %[[VAL_21]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: bufferization.materialize_in_destination %[[VAL_22:.*]] in writable %[[VAL_3]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_unary.mlir b/test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_unary.mlir deleted file mode 100644 index 457647a1..00000000 --- a/test/Conversion/TritonToLinalg/convert_1d_elemwise_arith_unary.mlir +++ /dev/null @@ -1,88 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %f32ptr : !tt.ptr, - %intptr : !tt.ptr, - %f16ptr : !tt.ptr, - %save0 : tensor<1024x!tt.ptr>, - %save1 : tensor<1024x!tt.ptr>, - %save2 : tensor<1024x!tt.ptr>, - %save3 : tensor<1024x!tt.ptr>, - %save4 : tensor<1024x!tt.ptr> - ) -> () { - // offset calculations - %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - // f32ptr pointer - %8 = tt.splat %f32ptr : !tt.ptr -> tensor<1024x!tt.ptr> - %9 = tt.addptr %8, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> - // intptr pointer - %18 = tt.splat %intptr : !tt.ptr -> tensor<1024x!tt.ptr> - %19 = tt.addptr %18, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> - // f32ptr pointer - %28 = tt.splat %f16ptr : !tt.ptr -> tensor<1024x!tt.ptr> - %29 = tt.addptr %28, %0 : tensor<1024x!tt.ptr>, tensor<1024xi32> - %afm = tt.load %9 : tensor<1024x!tt.ptr> - %aim = tt.load %19 : tensor<1024x!tt.ptr> - %bfm = tt.load %29 : tensor<1024x!tt.ptr> - %5 = arith.truncf %afm : tensor<1024xf32> to tensor<1024xbf16> - %6 = math.exp %afm : tensor<1024xf32> - %7 = arith.sitofp %aim : tensor<1024xi32> to tensor<1024xf32> - %10 = arith.extf %bfm : tensor<1024xf16> to tensor<1024xf32> - %11 = math.sqrt %afm : tensor<1024xf32> - tt.store %save0, %5 : tensor<1024x!tt.ptr> - tt.store %save1, %6 : tensor<1024x!tt.ptr> - tt.store %save2, %7 : tensor<1024x!tt.ptr> - tt.store %save3, %10 : tensor<1024x!tt.ptr> - tt.store %save4, %11 : tensor<1024x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32>, %[[VAL_1:.*]]: memref<*xi32>, %[[VAL_2:.*]]: memref<*xf16>, %[[VAL_3:.*]]: memref<1024xbf16>, %[[VAL_4:.*]]: memref<1024xf32>, %[[VAL_5:.*]]: memref<1024xf32>, %[[VAL_6:.*]]: memref<1024xf32>, %[[VAL_7:.*]]: memref<1024xf32>, %[[VAL_8:.*]]: i32, %[[VAL_9:.*]]: i32, %[[VAL_10:.*]]: i32) { -// CHECK: %[[VAL_11:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1]>> -// CHECK: %[[VAL_12:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [1024], strides: [1] : memref<*xi32> to memref<1024xi32, strided<[1]>> -// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf16> to memref<1024xf16, strided<[1]>> -// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<1024xf32> -// CHECK: memref.copy %[[VAL_11]], %[[VAL_14]] : memref<1024xf32, strided<[1]>> to memref<1024xf32> -// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_14]] restrict writable : memref<1024xf32> -// CHECK: %[[VAL_16:.*]] = memref.alloc() : memref<1024xi32> -// CHECK: memref.copy %[[VAL_12]], %[[VAL_16]] : memref<1024xi32, strided<[1]>> to memref<1024xi32> -// CHECK: %[[VAL_17:.*]] = bufferization.to_tensor %[[VAL_16]] restrict writable : memref<1024xi32> -// CHECK: %[[VAL_18:.*]] = memref.alloc() : memref<1024xf16> -// CHECK: memref.copy %[[VAL_13]], %[[VAL_18]] : memref<1024xf16, strided<[1]>> to memref<1024xf16> -// CHECK: %[[VAL_19:.*]] = bufferization.to_tensor %[[VAL_18]] restrict writable : memref<1024xf16> -// CHECK: %[[VAL_20:.*]] = tensor.empty() : tensor<1024xbf16> -// CHECK: %[[VAL_21:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[VAL_15]] : tensor<1024xf32>) outs(%[[VAL_20]] : tensor<1024xbf16>) { -// CHECK: ^bb0(%[[VAL_22:.*]]: f32, %[[VAL_23:.*]]: bf16): -// CHECK: %[[VAL_24:.*]] = arith.truncf %[[VAL_22]] : f32 to bf16 -// CHECK: linalg.yield %[[VAL_24]] : bf16 -// CHECK: } -> tensor<1024xbf16> -// CHECK: %[[VAL_25:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[VAL_15]] : tensor<1024xf32>) outs(%[[VAL_15]] : tensor<1024xf32>) { -// CHECK: ^bb0(%[[VAL_26:.*]]: f32, %[[VAL_27:.*]]: f32): -// CHECK: %[[VAL_28:.*]] = math.exp %[[VAL_26]] : f32 -// CHECK: linalg.yield %[[VAL_28]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: %[[VAL_29:.*]] = tensor.empty() : tensor<1024xf32> -// CHECK: %[[VAL_30:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[VAL_17]] : tensor<1024xi32>) outs(%[[VAL_29]] : tensor<1024xf32>) { -// CHECK: ^bb0(%[[VAL_31:.*]]: i32, %[[VAL_32:.*]]: f32): -// CHECK: %[[VAL_33:.*]] = arith.sitofp %[[VAL_31]] : i32 to f32 -// CHECK: linalg.yield %[[VAL_33]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: %[[VAL_34:.*]] = tensor.empty() : tensor<1024xf32> -// CHECK: %[[VAL_35:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[VAL_19]] : tensor<1024xf16>) outs(%[[VAL_34]] : tensor<1024xf32>) { -// CHECK: ^bb0(%[[VAL_36:.*]]: f16, %[[VAL_37:.*]]: f32): -// CHECK: %[[VAL_38:.*]] = arith.extf %[[VAL_36]] : f16 to f32 -// CHECK: linalg.yield %[[VAL_38]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: %[[VAL_39:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%[[VAL_15]] : tensor<1024xf32>) outs(%[[VAL_15]] : tensor<1024xf32>) { -// CHECK: ^bb0(%[[VAL_40:.*]]: f32, %[[VAL_41:.*]]: f32): -// CHECK: %[[VAL_42:.*]] = math.sqrt %[[VAL_40]] : f32 -// CHECK: linalg.yield %[[VAL_42]] : f32 -// CHECK: } -> tensor<1024xf32> -// CHECK: bufferization.materialize_in_destination %[[VAL_43:.*]] in writable %[[VAL_3]] -// CHECK: bufferization.materialize_in_destination %[[VAL_44:.*]] in writable %[[VAL_4]] -// CHECK: bufferization.materialize_in_destination %[[VAL_45:.*]] in writable %[[VAL_5]] -// CHECK: bufferization.materialize_in_destination %[[VAL_46:.*]] in writable %[[VAL_6]] -// CHECK: bufferization.materialize_in_destination %[[VAL_47:.*]] in writable %[[VAL_7]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_binary.mlir b/test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_binary.mlir deleted file mode 100644 index 0f855fc7..00000000 --- a/test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_binary.mlir +++ /dev/null @@ -1,55 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %a : !tt.ptr, - %b : !tt.ptr, - %c : tensor<128x128x!tt.ptr>, - %d : tensor<128x128x!tt.ptr> - ) -> () { - // offset calculations - %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> - %moff = tt.broadcast %1 : tensor<128x1xi32> -> tensor<128x128xi32> - %3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> - %koff = tt.broadcast %4 : tensor<1x128xi32> -> tensor<128x128xi32> - %mkoff = arith.addi %moff, %koff : tensor<128x128xi32> - // a pointer - %8 = tt.splat %a : !tt.ptr -> tensor<128x128x!tt.ptr> - %9 = tt.addptr %8, %mkoff : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - // b pointer - %18 = tt.splat %b : !tt.ptr -> tensor<128x128x!tt.ptr> - %19 = tt.addptr %18, %mkoff : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - %af = tt.load %9 : tensor<128x128x!tt.ptr> - %bf = tt.load %19 : tensor<128x128x!tt.ptr> - %res0 = arith.addf %af, %bf : tensor<128x128xf32> - %res1 = arith.subf %af, %bf : tensor<128x128xf32> - tt.store %c, %res0 : tensor<128x128x!tt.ptr> - tt.store %d, %res1 : tensor<128x128x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32>, %[[VAL_1:.*]]: memref<*xf32>, %[[VAL_2:.*]]: memref<128x128xf32>, %[[VAL_3:.*]]: memref<128x128xf32>, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32) { -// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1]>> -// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1]>> -// CHECK: %[[VAL_9:.*]] = memref.alloc() : memref<128x128xf32> -// CHECK: memref.copy %[[VAL_7]], %[[VAL_9]] : memref<128x128xf32, strided<[1, 1]>> to memref<128x128xf32> -// CHECK: %[[VAL_10:.*]] = bufferization.to_tensor %[[VAL_9]] restrict writable : memref<128x128xf32> -// CHECK: %[[VAL_11:.*]] = memref.alloc() : memref<128x128xf32> -// CHECK: memref.copy %[[VAL_8]], %[[VAL_11]] : memref<128x128xf32, strided<[1, 1]>> to memref<128x128xf32> -// CHECK: %[[VAL_12:.*]] = bufferization.to_tensor %[[VAL_11]] restrict writable : memref<128x128xf32> -// CHECK: %[[VAL_13:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_10]], %[[VAL_12]] : tensor<128x128xf32>, tensor<128x128xf32>) outs(%[[VAL_10]] : tensor<128x128xf32>) { -// CHECK: ^bb0(%[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32, %[[VAL_16:.*]]: f32): -// CHECK: %[[VAL_17:.*]] = arith.addf %[[VAL_14]], %[[VAL_15]] : f32 -// CHECK: linalg.yield %[[VAL_17]] : f32 -// CHECK: } -> tensor<128x128xf32> -// CHECK: %[[VAL_18:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_10]], %[[VAL_12]] : tensor<128x128xf32>, tensor<128x128xf32>) outs(%[[VAL_10]] : tensor<128x128xf32>) { -// CHECK: ^bb0(%[[VAL_19:.*]]: f32, %[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32): -// CHECK: %[[VAL_22:.*]] = arith.subf %[[VAL_19]], %[[VAL_20]] : f32 -// CHECK: linalg.yield %[[VAL_22]] : f32 -// CHECK: } -> tensor<128x128xf32> -// CHECK: bufferization.materialize_in_destination %[[VAL_23:.*]] in writable %[[VAL_2]] -// CHECK: bufferization.materialize_in_destination %[[VAL_24:.*]] in writable %[[VAL_3]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_ternary.mlir b/test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_ternary.mlir deleted file mode 100644 index f0398736..00000000 --- a/test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_ternary.mlir +++ /dev/null @@ -1,55 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %a : !tt.ptr, - %b : !tt.ptr, - %c : !tt.ptr, - %d : tensor<128x128x!tt.ptr> - ) -> () { - // offset calculations - %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> - %moff = tt.broadcast %1 : tensor<128x1xi32> -> tensor<128x128xi32> - %3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> - %koff = tt.broadcast %4 : tensor<1x128xi32> -> tensor<128x128xi32> - %mkoff = arith.addi %moff, %koff : tensor<128x128xi32> - // a pointer - %8 = tt.splat %a : !tt.ptr -> tensor<128x128x!tt.ptr> - %9 = tt.addptr %8, %mkoff : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - // b pointer - %18 = tt.splat %b : !tt.ptr -> tensor<128x128x!tt.ptr> - %19 = tt.addptr %18, %mkoff : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - // c pointer - %28 = tt.splat %c : !tt.ptr -> tensor<128x128x!tt.ptr> - %29 = tt.addptr %28, %mkoff : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - %am = tt.load %9 : tensor<128x128x!tt.ptr> - %bm = tt.load %19 : tensor<128x128x!tt.ptr> - %cm = tt.load %29 : tensor<128x128x!tt.ptr> - %100 = arith.select %am, %bm, %cm : tensor<128x128xi1>, tensor<128x128xf32> - tt.store %d, %100 : tensor<128x128x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xi1>, %[[VAL_1:.*]]: memref<*xf32>, %[[VAL_2:.*]]: memref<*xf32>, %[[VAL_3:.*]]: memref<128x128xf32>, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32) { -// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref<*xi1> to memref<128x128xi1, strided<[1, 1]>> -// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1]>> -// CHECK: %[[VAL_9:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1]>> -// CHECK: %[[VAL_10:.*]] = memref.alloc() : memref<128x128xi1> -// CHECK: memref.copy %[[VAL_7]], %[[VAL_10]] : memref<128x128xi1, strided<[1, 1]>> to memref<128x128xi1> -// CHECK: %[[VAL_11:.*]] = bufferization.to_tensor %[[VAL_10]] restrict writable : memref<128x128xi1> -// CHECK: %[[VAL_12:.*]] = memref.alloc() : memref<128x128xf32> -// CHECK: memref.copy %[[VAL_8]], %[[VAL_12]] : memref<128x128xf32, strided<[1, 1]>> to memref<128x128xf32> -// CHECK: %[[VAL_13:.*]] = bufferization.to_tensor %[[VAL_12]] restrict writable : memref<128x128xf32> -// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<128x128xf32> -// CHECK: memref.copy %[[VAL_9]], %[[VAL_14]] : memref<128x128xf32, strided<[1, 1]>> to memref<128x128xf32> -// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_14]] restrict writable : memref<128x128xf32> -// CHECK: %[[VAL_16:.*]] = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<128x128xi1>, tensor<128x128xf32>, tensor<128x128xf32>) outs(%[[VAL_13]] : tensor<128x128xf32>) { -// CHECK: ^bb0(%[[VAL_17:.*]]: i1, %[[VAL_18:.*]]: f32, %[[VAL_19:.*]]: f32, %[[VAL_20:.*]]: f32): -// CHECK: %[[VAL_21:.*]] = arith.select %[[VAL_17]], %[[VAL_18]], %[[VAL_19]] : f32 -// CHECK: linalg.yield %[[VAL_21]] : f32 -// CHECK: } -> tensor<128x128xf32> -// CHECK: bufferization.materialize_in_destination %[[VAL_22:.*]] in writable %[[VAL_3]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_unary.mlir b/test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_unary.mlir deleted file mode 100644 index 835e4e18..00000000 --- a/test/Conversion/TritonToLinalg/convert_2d_elemwise_arith_unary.mlir +++ /dev/null @@ -1,94 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %f32ptr : !tt.ptr, - %intptr : !tt.ptr, - %f16ptr : !tt.ptr, - %save0 : tensor<128x128x!tt.ptr>, - %save1 : tensor<128x128x!tt.ptr>, - %save2 : tensor<128x128x!tt.ptr>, - %save3 : tensor<128x128x!tt.ptr>, - %save4 : tensor<128x128x!tt.ptr> - ) -> () { - // offset calculations - %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> - %moff = tt.broadcast %1 : tensor<128x1xi32> -> tensor<128x128xi32> - %3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> - %koff = tt.broadcast %4 : tensor<1x128xi32> -> tensor<128x128xi32> - %mkoff = arith.addi %moff, %koff : tensor<128x128xi32> - // f32ptr pointer - %8 = tt.splat %f32ptr : !tt.ptr -> tensor<128x128x!tt.ptr> - %9 = tt.addptr %8, %mkoff : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - // intptr pointer - %18 = tt.splat %intptr : !tt.ptr -> tensor<128x128x!tt.ptr> - %19 = tt.addptr %18, %mkoff : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - // f16ptr pointer - %28 = tt.splat %f16ptr : !tt.ptr -> tensor<128x128x!tt.ptr> - %29 = tt.addptr %28, %mkoff : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - %afm = tt.load %9 : tensor<128x128x!tt.ptr> - %aim = tt.load %19 : tensor<128x128x!tt.ptr> - %bfm = tt.load %29 : tensor<128x128x!tt.ptr> - %5 = arith.truncf %afm : tensor<128x128xf32> to tensor<128x128xbf16> - %6 = math.exp %afm : tensor<128x128xf32> - %7 = arith.sitofp %aim : tensor<128x128xi32> to tensor<128x128xf32> - %10 = arith.extf %bfm : tensor<128x128xf16> to tensor<128x128xf32> - %11 = math.sqrt %afm : tensor<128x128xf32> - tt.store %save0, %5 : tensor<128x128x!tt.ptr> - tt.store %save1, %6 : tensor<128x128x!tt.ptr> - tt.store %save2, %7 : tensor<128x128x!tt.ptr> - tt.store %save3, %10 : tensor<128x128x!tt.ptr> - tt.store %save4, %11 : tensor<128x128x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32>, %[[VAL_1:.*]]: memref<*xi32>, %[[VAL_2:.*]]: memref<*xf16>, %[[VAL_3:.*]]: memref<128x128xbf16>, %[[VAL_4:.*]]: memref<128x128xf32>, %[[VAL_5:.*]]: memref<128x128xf32>, %[[VAL_6:.*]]: memref<128x128xf32>, %[[VAL_7:.*]]: memref<128x128xf32>, %[[VAL_8:.*]]: i32, %[[VAL_9:.*]]: i32, %[[VAL_10:.*]]: i32) { -// CHECK: %[[VAL_11:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1]>> -// CHECK: %[[VAL_12:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref<*xi32> to memref<128x128xi32, strided<[1, 1]>> -// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref<*xf16> to memref<128x128xf16, strided<[1, 1]>> -// CHECK: %[[VAL_14:.*]] = memref.alloc() : memref<128x128xf32> -// CHECK: memref.copy %[[VAL_11]], %[[VAL_14]] : memref<128x128xf32, strided<[1, 1]>> to memref<128x128xf32> -// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_14]] restrict writable : memref<128x128xf32> -// CHECK: %[[VAL_16:.*]] = memref.alloc() : memref<128x128xi32> -// CHECK: memref.copy %[[VAL_12]], %[[VAL_16]] : memref<128x128xi32, strided<[1, 1]>> to memref<128x128xi32> -// CHECK: %[[VAL_17:.*]] = bufferization.to_tensor %[[VAL_16]] restrict writable : memref<128x128xi32> -// CHECK: %[[VAL_18:.*]] = memref.alloc() : memref<128x128xf16> -// CHECK: memref.copy %[[VAL_13]], %[[VAL_18]] : memref<128x128xf16, strided<[1, 1]>> to memref<128x128xf16> -// CHECK: %[[VAL_19:.*]] = bufferization.to_tensor %[[VAL_18]] restrict writable : memref<128x128xf16> -// CHECK: %[[VAL_20:.*]] = tensor.empty() : tensor<128x128xbf16> -// CHECK: %[[VAL_21:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_15]] : tensor<128x128xf32>) outs(%[[VAL_20]] : tensor<128x128xbf16>) { -// CHECK: ^bb0(%[[VAL_22:.*]]: f32, %[[VAL_23:.*]]: bf16): -// CHECK: %[[VAL_24:.*]] = arith.truncf %[[VAL_22]] : f32 to bf16 -// CHECK: linalg.yield %[[VAL_24]] : bf16 -// CHECK: } -> tensor<128x128xbf16> -// CHECK: %[[VAL_25:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_15]] : tensor<128x128xf32>) outs(%[[VAL_15]] : tensor<128x128xf32>) { -// CHECK: ^bb0(%[[VAL_26:.*]]: f32, %[[VAL_27:.*]]: f32): -// CHECK: %[[VAL_28:.*]] = math.exp %[[VAL_26]] : f32 -// CHECK: linalg.yield %[[VAL_28]] : f32 -// CHECK: } -> tensor<128x128xf32> -// CHECK: %[[VAL_29:.*]] = tensor.empty() : tensor<128x128xf32> -// CHECK: %[[VAL_30:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_17]] : tensor<128x128xi32>) outs(%[[VAL_29]] : tensor<128x128xf32>) { -// CHECK: ^bb0(%[[VAL_31:.*]]: i32, %[[VAL_32:.*]]: f32): -// CHECK: %[[VAL_33:.*]] = arith.sitofp %[[VAL_31]] : i32 to f32 -// CHECK: linalg.yield %[[VAL_33]] : f32 -// CHECK: } -> tensor<128x128xf32> -// CHECK: %[[VAL_34:.*]] = tensor.empty() : tensor<128x128xf32> -// CHECK: %[[VAL_35:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_19]] : tensor<128x128xf16>) outs(%[[VAL_34]] : tensor<128x128xf32>) { -// CHECK: ^bb0(%[[VAL_36:.*]]: f16, %[[VAL_37:.*]]: f32): -// CHECK: %[[VAL_38:.*]] = arith.extf %[[VAL_36]] : f16 to f32 -// CHECK: linalg.yield %[[VAL_38]] : f32 -// CHECK: } -> tensor<128x128xf32> -// CHECK: %[[VAL_39:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_15]] : tensor<128x128xf32>) outs(%[[VAL_15]] : tensor<128x128xf32>) { -// CHECK: ^bb0(%[[VAL_40:.*]]: f32, %[[VAL_41:.*]]: f32): -// CHECK: %[[VAL_42:.*]] = math.sqrt %[[VAL_40]] : f32 -// CHECK: linalg.yield %[[VAL_42]] : f32 -// CHECK: } -> tensor<128x128xf32> -// CHECK: bufferization.materialize_in_destination %[[VAL_43:.*]] in writable %[[VAL_3]] -// CHECK: bufferization.materialize_in_destination %[[VAL_44:.*]] in writable %[[VAL_4]] -// CHECK: bufferization.materialize_in_destination %[[VAL_45:.*]] in writable %[[VAL_5]] -// CHECK: bufferization.materialize_in_destination %[[VAL_46:.*]] in writable %[[VAL_6]] -// CHECK: bufferization.materialize_in_destination %[[VAL_47:.*]] in writable %[[VAL_7]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/convert_addi_reduce.mlir b/test/Conversion/TritonToLinalg/convert_addi_reduce.mlir deleted file mode 100644 index f430ce9c..00000000 --- a/test/Conversion/TritonToLinalg/convert_addi_reduce.mlir +++ /dev/null @@ -1,32 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg --split-input-file %s | FileCheck %s - -module { - tt.func public @addi(%arg0: !tt.ptr) { - %cst_0 = arith.constant dense<0> : tensor<4096xi32> - %63 = "tt.reduce"(%cst_0) ({ - ^bb0(%arg14: i32, %arg15: i32): - %69 = arith.addi %arg14, %arg15 : i32 - tt.reduce.return %69 : i32 - }) {axis = 0 : i32} : (tensor<4096xi32>) -> i32 - tt.store %arg0, %63 : !tt.ptr - tt.return - } -} - -// CHECK-LABEL: func.func @addi( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32) { -// CHECK: %[[VAL_7:.*]] = arith.constant 0 : i32 -// CHECK: %[[VAL_8:.*]] = tensor.empty() : tensor<4096xi32> -// CHECK: %[[VAL_9:.*]] = linalg.fill ins(%[[VAL_7]] : i32) outs(%[[VAL_8]] : tensor<4096xi32>) -> tensor<4096xi32> -// CHECK: %[[VAL_10:.*]] = bufferization.alloc_tensor() : tensor -// CHECK: %[[VAL_11:.*]] = tensor.insert %[[VAL_7]] into %[[VAL_10]][] : tensor -// CHECK: %[[VAL_12:.*]] = linalg.reduce ins(%[[VAL_9]] : tensor<4096xi32>) outs(%[[VAL_11]] : tensor) dimensions = [0] -// CHECK: (%[[VAL_13:.*]]: i32, %[[VAL_14:.*]]: i32) { -// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : i32 -// CHECK: linalg.yield %[[VAL_15]] : i32 -// CHECK: } -// CHECK: %[[VAL_16:.*]] = tensor.extract %[[VAL_12]][] : tensor -// CHECK: %[[VAL_17:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1]>> -// CHECK: affine.store %[[VAL_16]], %[[VAL_17]][0] : memref<1xi32, strided<[1]>> -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/convert_argmin_argmax.mlir b/test/Conversion/TritonToLinalg/convert_argmin_argmax.mlir deleted file mode 100644 index c96738ee..00000000 --- a/test/Conversion/TritonToLinalg/convert_argmin_argmax.mlir +++ /dev/null @@ -1,141 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg --split-input-file %s | FileCheck %s - -module { - tt.func public @argmax_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) { - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %arg2 : i32 - %2 = tt.make_range {end = 4096 : i32, start = 0 : i32} : tensor<4096xi32> - %3 = tt.splat %1 : i32 -> tensor<4096xi32> - %4 = arith.addi %3, %2 : tensor<4096xi32> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<4096x!tt.ptr> - %6 = tt.addptr %5, %4 : tensor<4096x!tt.ptr>, tensor<4096xi32> - %7 = tt.load %6 : tensor<4096x!tt.ptr> - %8:2 = "tt.reduce"(%7, %2) <{axis = 0 : i32}> ({ - ^bb0(%arg9: f32, %arg10: i32, %arg11: f32, %arg12: i32): - %11 = arith.cmpf oeq, %arg9, %arg11 : f32 - %12 = arith.cmpi slt, %arg10, %arg12 : i32 - %13 = arith.andi %11, %12 : i1 - %14 = arith.cmpf ogt, %arg9, %arg11 : f32 - %15 = arith.ori %14, %13 : i1 - %16 = arith.select %15, %arg9, %arg11 : f32 - %17 = arith.select %15, %arg10, %arg12 : i32 - tt.reduce.return %16, %17 : f32, i32 - }) : (tensor<4096xf32>, tensor<4096xi32>) -> (f32, i32) - %9 = tt.addptr %arg1, %0 : !tt.ptr, i32 - tt.store %9, %8#1 : !tt.ptr - tt.return - } -} - -// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @argmax_012 -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xi32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) { -// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32 -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32 -// CHECK-DAG: [[VAR_0_:%.+]] = arith.muli [[PARAM_6_]], [[PARAM_2_]] : i32 -// CHECK-DAG: [[VAR_1_:%.+]] = tensor.empty() : tensor<4096xi32> -// CHECK: [[VAR_2_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_1_]] : tensor<4096xi32>) { -// CHECK: ^bb0([[out_:.+]]: i32): -// CHECK: [[VAR_10_:%.+]] = linalg.index 0 : index -// CHECK: [[VAR_11_:%.+]] = arith.index_cast [[VAR_10_]] : index to i32 -// CHECK: linalg.yield [[VAR_11_]] : i32 -// CHECK: } -> tensor<4096xi32> -// CHECK: [[VAR_3_:%.+]] = arith.index_cast [[VAR_0_]] : i32 to index -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_3_]]{{.}}, sizes: [4096], strides: [1] : memref<*xf32> to memref<4096xf32, strided<[1], offset: ?>> -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4096xf32> -// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4096xf32, strided<[1], offset: ?>> to memref<4096xf32> -// CHECK-DAG: [[VAR_4_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4096xf32> -// CHECK-DAG: [[VAR_5_:%.+]] = tensor.empty() : tensor -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_6_:%.+]] = linalg.fill ins([[CST_0_]] : f32) outs([[VAR_5_]] : tensor) -> tensor -// CHECK-DAG: [[VAR_7_:%.+]] = tensor.empty() : tensor -// CHECK: [[VAR_8_:%.+]] = linalg.fill ins([[CST_minus_1_]] : i32) outs([[VAR_7_]] : tensor) -> tensor -// CHECK: [[VAR_reduced_:%.+]]:2 = linalg.reduce ins([[VAR_4_]], [[VAR_2_]] : tensor<4096xf32>, tensor<4096xi32>) outs([[VAR_6_]], [[VAR_8_]] : tensor, tensor) dimensions = [0] -// CHECK: ([[in:.+]]: f32, [[in_1:.+]]: i32, [[init:.+]]: f32, [[init_2:.+]]: i32) { -// CHECK-DAG: [[VAR_10_1_:%.+]] = arith.cmpf oeq, [[in]], [[init]] : f32 -// CHECK-DAG: [[VAR_11_1_:%.+]] = arith.cmpi slt, [[in_1]], [[init_2]] : i32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_12_:%.+]] = arith.andi [[VAR_10_1_]], [[VAR_11_1_]] : i1 -// CHECK-DAG: [[VAR_13_:%.+]] = arith.cmpf ogt, [[in]], [[init]] : f32 -// CHECK: [[VAR_14_:%.+]] = arith.ori [[VAR_13_]], [[VAR_12_]] : i1 -// CHECK-DAG: [[VAR_15_:%.+]] = arith.select [[VAR_14_]], [[in]], [[init]] : f32 -// CHECK-DAG: [[VAR_16_:%.+]] = arith.select [[VAR_14_]], [[in_1]], [[init_2]] : i32 -// CHECK: linalg.yield [[VAR_15_]], [[VAR_16_]] : f32, i32 -// CHECK: } -// CHECK-DAG: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_reduced_]]#1[] : tensor -// CHECK-DAG: [[VAR_9_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index -// CHECK: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_9_]]{{.}}, sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1], offset: ?>> -// CHECK: affine.store [[VAR_extracted_]], [[VAR_reinterpret_cast_0_]][0] : memref<1xi32, strided<[1], offset: ?>> -// CHECK: return -// CHECK: } - -// ----- - -module { - tt.func public @argmin_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) { - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %arg2 : i32 - %2 = tt.make_range {end = 4096 : i32, start = 0 : i32} : tensor<4096xi32> - %3 = tt.splat %1 : i32 -> tensor<4096xi32> - %4 = arith.addi %3, %2 : tensor<4096xi32> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<4096x!tt.ptr> - %6 = tt.addptr %5, %4 : tensor<4096x!tt.ptr>, tensor<4096xi32> - %7 = tt.load %6 : tensor<4096x!tt.ptr> - %8:2 = "tt.reduce"(%7, %2) <{axis = 0 : i32}> ({ - ^bb0(%arg9: f32, %arg10: i32, %arg11: f32, %arg12: i32): - %11 = arith.cmpf oeq, %arg9, %arg11 : f32 - %12 = arith.cmpi slt, %arg10, %arg12 : i32 - %13 = arith.andi %11, %12 : i1 - %14 = arith.cmpf olt, %arg9, %arg11 : f32 - %15 = arith.ori %14, %13 : i1 - %16 = arith.select %15, %arg9, %arg11 : f32 - %17 = arith.select %15, %arg10, %arg12 : i32 - tt.reduce.return %16, %17 : f32, i32 - }) : (tensor<4096xf32>, tensor<4096xi32>) -> (f32, i32) - %9 = tt.addptr %arg1, %0 : !tt.ptr, i32 - tt.store %9, %8#1 : !tt.ptr - tt.return - } -} - -// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @argmin_012 -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xi32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) { -// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32 -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0x7F800000 : f32 -// CHECK-DAG: [[VAR_0_:%.+]] = arith.muli [[PARAM_6_]], [[PARAM_2_]] : i32 -// CHECK-DAG: [[VAR_1_:%.+]] = tensor.empty() : tensor<4096xi32> -// CHECK: [[VAR_2_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_1_]] : tensor<4096xi32>) { -// CHECK: ^bb0([[out_:.+]]: i32): -// CHECK: [[VAR_10_:%.+]] = linalg.index 0 : index -// CHECK: [[VAR_11_:%.+]] = arith.index_cast [[VAR_10_]] : index to i32 -// CHECK: linalg.yield [[VAR_11_]] : i32 -// CHECK: } -> tensor<4096xi32> -// CHECK: [[VAR_3_:%.+]] = arith.index_cast [[VAR_0_]] : i32 to index -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_3_]]{{.}}, sizes: [4096], strides: [1] : memref<*xf32> to memref<4096xf32, strided<[1], offset: ?>> -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4096xf32> -// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4096xf32, strided<[1], offset: ?>> to memref<4096xf32> -// CHECK-DAG: [[VAR_4_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4096xf32> -// CHECK-DAG: [[VAR_5_:%.+]] = tensor.empty() : tensor -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_6_:%.+]] = linalg.fill ins([[CST_0_]] : f32) outs([[VAR_5_]] : tensor) -> tensor -// CHECK-DAG: [[VAR_7_:%.+]] = tensor.empty() : tensor -// CHECK: [[VAR_8_:%.+]] = linalg.fill ins([[CST_minus_1_]] : i32) outs([[VAR_7_]] : tensor) -> tensor -// CHECK: [[VAR_reduced_:%.+]]:2 = linalg.reduce ins([[VAR_4_]], [[VAR_2_]] : tensor<4096xf32>, tensor<4096xi32>) outs([[VAR_6_]], [[VAR_8_]] : tensor, tensor) dimensions = [0] -// CHECK: ([[in:.+]]: f32, [[in_1:.+]]: i32, [[init:.+]]: f32, [[init_2:.+]]: i32) { -// CHECK-DAG: [[VAR_10_1_:%.+]] = arith.cmpf oeq, [[in]], [[init]] : f32 -// CHECK-DAG: [[VAR_11_1_:%.+]] = arith.cmpi slt, [[in_1]], [[init_2]] : i32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_12_:%.+]] = arith.andi [[VAR_10_1_]], [[VAR_11_1_]] : i1 -// CHECK-DAG: [[VAR_13_:%.+]] = arith.cmpf olt, [[in]], [[init]] : f32 -// CHECK: [[VAR_14_:%.+]] = arith.ori [[VAR_13_]], [[VAR_12_]] : i1 -// CHECK-DAG: [[VAR_15_:%.+]] = arith.select [[VAR_14_]], [[in]], [[init]] : f32 -// CHECK-DAG: [[VAR_16_:%.+]] = arith.select [[VAR_14_]], [[in_1]], [[init_2]] : i32 -// CHECK: linalg.yield [[VAR_15_]], [[VAR_16_]] : f32, i32 -// CHECK: } -// CHECK-DAG: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_reduced_]]#1[] : tensor -// CHECK-DAG: [[VAR_9_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index -// CHECK: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_9_]]{{.}}, sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1], offset: ?>> -// CHECK: affine.store [[VAR_extracted_]], [[VAR_reinterpret_cast_0_]][0] : memref<1xi32, strided<[1], offset: ?>> -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/convert_argmin_argmax_2d.mlir b/test/Conversion/TritonToLinalg/convert_argmin_argmax_2d.mlir deleted file mode 100644 index 73e8fd6e..00000000 --- a/test/Conversion/TritonToLinalg/convert_argmin_argmax_2d.mlir +++ /dev/null @@ -1,215 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg --split-input-file %s | FileCheck %s - -// @triton.jit -// def test( -// a_ptr, c_ptr, stride_am, stride_an -// ): -// offs_am = tl.arange(0, 4) -// offs_an = tl.arange(0, 4) -// a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_an[None, :] * stride_an) -// a = tl.load(a_ptrs) -// m = tl.argmax(a, axis=1) -// tl.store(c_ptr + tl.arange(0, 4), m) -// -// ret = triton.compiler.compile( -// test, -// signature=" *fp32,*fp32,i32,i32", -// print_triton_ir_only=True, -// ) - -module { - tt.func public @test_argmax(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32) attributes {noinline = false} { - %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - %2 = tt.splat %arg2 : i32 -> tensor<4x1xi32> - %3 = arith.muli %1, %2 : tensor<4x1xi32> - %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %5 = tt.splat %arg3 : i32 -> tensor<1x4xi32> - %6 = arith.muli %4, %5 : tensor<1x4xi32> - %7 = tt.broadcast %3 : tensor<4x1xi32> -> tensor<4x4xi32> - %8 = tt.broadcast %6 : tensor<1x4xi32> -> tensor<4x4xi32> - %9 = arith.addi %7, %8 : tensor<4x4xi32> - %10 = tt.splat %arg0 : !tt.ptr -> tensor<4x4x!tt.ptr> - %11 = tt.addptr %10, %9 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - %12 = tt.load %11 : tensor<4x4x!tt.ptr> - %13 = tt.broadcast %4 : tensor<1x4xi32> -> tensor<4x4xi32> - %14:2 = "tt.reduce"(%12, %13) <{axis = 1 : i32}> ({ - ^bb0(%arg4: f32, %arg5: i32, %arg6: f32, %arg7: i32): - %18 = arith.cmpf oeq, %arg4, %arg6 : f32 - %19 = arith.cmpi slt, %arg5, %arg7 : i32 - %20 = arith.andi %18, %19 : i1 - %21 = arith.cmpf ogt, %arg4, %arg6 : f32 - %22 = arith.ori %21, %20 : i1 - %23 = arith.select %22, %arg4, %arg6 : f32 - %24 = arith.select %22, %arg5, %arg7 : i32 - tt.reduce.return %23, %24 : f32, i32 - }) : (tensor<4x4xf32>, tensor<4x4xi32>) -> (tensor<4xf32>, tensor<4xi32>) - %15 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> - %16 = tt.addptr %15, %0 : tensor<4x!tt.ptr>, tensor<4xi32> - %17 = arith.sitofp %14#1 : tensor<4xi32> to tensor<4xf32> - tt.store %16, %17 : tensor<4x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0, d1) -> (0, d1)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-LABEL: func.func @test_argmax -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32) { -// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32 -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0xFF800000 : f32 -// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<4xi32> -// CHECK: [[VAR_1_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_0_]] : tensor<4xi32>) { -// CHECK: ^bb0([[out_:.+]]: i32): -// CHECK: [[VAR_13_:%.+]] = linalg.index 0 : index -// CHECK: [[VAR_14_:%.+]] = arith.index_cast [[VAR_13_]] : index to i32 -// CHECK: linalg.yield [[VAR_14_]] : i32 -// CHECK: } -> tensor<4xi32> -// CHECK-DAG: [[VAR_expanded_:%.+]] = tensor.expand_shape [[VAR_1_]] {{.}}[0, 1]{{.}} output_shape [1, 4] : tensor<4xi32> into tensor<1x4xi32> -// CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index -// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [4, 4], strides: {{.}}[[VAR_2_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref<4x4xf32, strided<[?, ?]>> -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x4xf32> -// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4x4xf32, strided<[?, ?]>> to memref<4x4xf32> -// CHECK-DAG: [[VAR_4_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x4xf32> -// CHECK-DAG: [[VAR_5_:%.+]] = tensor.empty() : tensor<4x4xi32> -// CHECK: [[VAR_6_:%.+]] = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel"]} ins([[VAR_expanded_]] : tensor<1x4xi32>) outs([[VAR_5_]] : tensor<4x4xi32>) attrs = {broadcastDims = array} { -// CHECK: ^bb0([[in_:.+]]: i32, [[out_:.+]]: i32): -// CHECK: linalg.yield [[in_]] : i32 -// CHECK: } -> tensor<4x4xi32> -// CHECK: [[VAR_7_:%.+]] = tensor.empty() : tensor<4xf32> -// CHECK-DAG: [[VAR_8_:%.+]] = linalg.fill ins([[CST_0_]] : f32) outs([[VAR_7_]] : tensor<4xf32>) -> tensor<4xf32> -// CHECK-DAG: [[VAR_9_:%.+]] = tensor.empty() : tensor<4xi32> -// CHECK: [[VAR_10_:%.+]] = linalg.fill ins([[CST_minus_1_]] : i32) outs([[VAR_9_]] : tensor<4xi32>) -> tensor<4xi32> -// CHECK: [[VAR_reduced_:%.+]]:2 = linalg.reduce ins([[VAR_4_]], [[VAR_6_]] : tensor<4x4xf32>, tensor<4x4xi32>) outs([[VAR_8_]], [[VAR_10_]] : tensor<4xf32>, tensor<4xi32>) dimensions = [1] -// CHECK: ([[in_:.+]]: f32, [[in_1_:.+]]: i32, [[init:.+]]: f32, [[init_2:.+]]: i32) { -// CHECK-DAG: [[VAR_13_1_:%.+]] = arith.cmpf oeq, [[in_]], [[init]] : f32 -// CHECK-DAG: [[VAR_14_1_:%.+]] = arith.cmpi slt, [[in_1_]], [[init_2]] : i32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_15_:%.+]] = arith.andi [[VAR_13_1_]], [[VAR_14_1_]] : i1 -// CHECK-DAG: [[VAR_16_:%.+]] = arith.cmpf ogt, [[in_]], [[init]] : f32 -// CHECK: [[VAR_17_:%.+]] = arith.ori [[VAR_16_]], [[VAR_15_]] : i1 -// CHECK-DAG: [[VAR_18_:%.+]] = arith.select [[VAR_17_]], [[in_]], [[init]] : f32 -// CHECK-DAG: [[VAR_19_:%.+]] = arith.select [[VAR_17_]], [[in_1_]], [[init_2]] : i32 -// CHECK: linalg.yield [[VAR_18_]], [[VAR_19_]] : f32, i32 -// CHECK: } -// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [4], strides: [1] : memref<*xf32> to memref<4xf32, strided<[1]>> -// CHECK-DAG: [[VAR_11_:%.+]] = tensor.empty() : tensor<4xf32> -// CHECK: [[VAR_12_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_reduced_]]#1 : tensor<4xi32>) outs([[VAR_11_]] : tensor<4xf32>) { -// CHECK: ^bb0([[in_:.+]]: i32, [[out_:.+]]: f32): -// CHECK: [[VAR_13_2_:%.+]] = arith.sitofp [[in_]] : i32 to f32 -// CHECK: linalg.yield [[VAR_13_2_]] : f32 -// CHECK: } -> tensor<4xf32> -// CHECK: bufferization.materialize_in_destination [[VAR_12_]] in writable [[VAR_reinterpret_cast_0_]] -// CHECK: return -// CHECK: } - -// ----- - -// @triton.jit -// def test( -// a_ptr, c_ptr, stride_am, stride_an -// ): -// offs_am = tl.arange(0, 4) -// offs_an = tl.arange(0, 4) -// a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_an[None, :] * stride_an) -// a = tl.load(a_ptrs) -// m = tl.argmin(a, axis=1) -// tl.store(c_ptr + tl.arange(0, 4), m) -// -// ret = triton.compiler.compile( -// test, -// signature=" *fp32,*fp32,i32,i32", -// print_triton_ir_only=True, -// ) - -module { - tt.func public @test_argmin(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32) attributes {noinline = false} { - %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - %2 = tt.splat %arg2 : i32 -> tensor<4x1xi32> - %3 = arith.muli %1, %2 : tensor<4x1xi32> - %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %5 = tt.splat %arg3 : i32 -> tensor<1x4xi32> - %6 = arith.muli %4, %5 : tensor<1x4xi32> - %7 = tt.broadcast %3 : tensor<4x1xi32> -> tensor<4x4xi32> - %8 = tt.broadcast %6 : tensor<1x4xi32> -> tensor<4x4xi32> - %9 = arith.addi %7, %8 : tensor<4x4xi32> - %10 = tt.splat %arg0 : !tt.ptr -> tensor<4x4x!tt.ptr> - %11 = tt.addptr %10, %9 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - %12 = tt.load %11 : tensor<4x4x!tt.ptr> - %13 = tt.broadcast %4 : tensor<1x4xi32> -> tensor<4x4xi32> - %14:2 = "tt.reduce"(%12, %13) <{axis = 1 : i32}> ({ - ^bb0(%arg4: f32, %arg5: i32, %arg6: f32, %arg7: i32): - %18 = arith.cmpf oeq, %arg4, %arg6 : f32 - %19 = arith.cmpi slt, %arg5, %arg7 : i32 - %20 = arith.andi %18, %19 : i1 - %21 = arith.cmpf olt, %arg4, %arg6 : f32 - %22 = arith.ori %21, %20 : i1 - %23 = arith.select %22, %arg4, %arg6 : f32 - %24 = arith.select %22, %arg5, %arg7 : i32 - tt.reduce.return %23, %24 : f32, i32 - }) : (tensor<4x4xf32>, tensor<4x4xi32>) -> (tensor<4xf32>, tensor<4xi32>) - %15 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> - %16 = tt.addptr %15, %0 : tensor<4x!tt.ptr>, tensor<4xi32> - %17 = arith.sitofp %14#1 : tensor<4xi32> to tensor<4xf32> - tt.store %16, %17 : tensor<4x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0, d1) -> (0, d1)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-LABEL: func.func @test_argmin -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32) { -// CHECK-DAG: [[CST_minus_1_:%.+]] = arith.constant -1 : i32 -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0x7F800000 : f32 -// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<4xi32> -// CHECK: [[VAR_1_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_0_]] : tensor<4xi32>) { -// CHECK: ^bb0([[out_:.+]]: i32): -// CHECK: [[VAR_13_:%.+]] = linalg.index 0 : index -// CHECK: [[VAR_14_:%.+]] = arith.index_cast [[VAR_13_]] : index to i32 -// CHECK: linalg.yield [[VAR_14_]] : i32 -// CHECK: } -> tensor<4xi32> -// CHECK-DAG: [[VAR_expanded_:%.+]] = tensor.expand_shape [[VAR_1_]] {{.}}[0, 1]{{.}} output_shape [1, 4] : tensor<4xi32> into tensor<1x4xi32> -// CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index -// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [4, 4], strides: {{.}}[[VAR_2_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref<4x4xf32, strided<[?, ?]>> -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x4xf32> -// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4x4xf32, strided<[?, ?]>> to memref<4x4xf32> -// CHECK-DAG: [[VAR_4_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x4xf32> -// CHECK-DAG: [[VAR_5_:%.+]] = tensor.empty() : tensor<4x4xi32> -// CHECK: [[VAR_6_:%.+]] = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel"]} ins([[VAR_expanded_]] : tensor<1x4xi32>) outs([[VAR_5_]] : tensor<4x4xi32>) attrs = {broadcastDims = array} { -// CHECK: ^bb0([[in_:.+]]: i32, [[out_:.+]]: i32): -// CHECK: linalg.yield [[in_]] : i32 -// CHECK: } -> tensor<4x4xi32> -// CHECK: [[VAR_7_:%.+]] = tensor.empty() : tensor<4xf32> -// CHECK-DAG: [[VAR_8_:%.+]] = linalg.fill ins([[CST_0_]] : f32) outs([[VAR_7_]] : tensor<4xf32>) -> tensor<4xf32> -// CHECK-DAG: [[VAR_9_:%.+]] = tensor.empty() : tensor<4xi32> -// CHECK: [[VAR_10_:%.+]] = linalg.fill ins([[CST_minus_1_]] : i32) outs([[VAR_9_]] : tensor<4xi32>) -> tensor<4xi32> -// CHECK: [[VAR_reduced_:%.+]]:2 = linalg.reduce ins([[VAR_4_]], [[VAR_6_]] : tensor<4x4xf32>, tensor<4x4xi32>) outs([[VAR_8_]], [[VAR_10_]] : tensor<4xf32>, tensor<4xi32>) dimensions = [1] -// CHECK: ([[in_:.+]]: f32, [[in_1_:.+]]: i32, [[init:.+]]: f32, [[init_2:.+]]: i32) { -// CHECK-DAG: [[VAR_13_1_:%.+]] = arith.cmpf oeq, [[in_]], [[init]] : f32 -// CHECK-DAG: [[VAR_14_1_:%.+]] = arith.cmpi slt, [[in_1_]], [[init_2]] : i32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_15_:%.+]] = arith.andi [[VAR_13_1_]], [[VAR_14_1_]] : i1 -// CHECK-DAG: [[VAR_16_:%.+]] = arith.cmpf olt, [[in_]], [[init]] : f32 -// CHECK: [[VAR_17_:%.+]] = arith.ori [[VAR_16_]], [[VAR_15_]] : i1 -// CHECK-DAG: [[VAR_18_:%.+]] = arith.select [[VAR_17_]], [[in_]], [[init]] : f32 -// CHECK-DAG: [[VAR_19_:%.+]] = arith.select [[VAR_17_]], [[in_1_]], [[init_2]] : i32 -// CHECK: linalg.yield [[VAR_18_]], [[VAR_19_]] : f32, i32 -// CHECK: } -// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [4], strides: [1] : memref<*xf32> to memref<4xf32, strided<[1]>> -// CHECK-DAG: [[VAR_11_:%.+]] = tensor.empty() : tensor<4xf32> -// CHECK: [[VAR_12_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_reduced_]]#1 : tensor<4xi32>) outs([[VAR_11_]] : tensor<4xf32>) { -// CHECK: ^bb0([[in_:.+]]: i32, [[out_:.+]]: f32): -// CHECK: [[VAR_13_2_:%.+]] = arith.sitofp [[in_]] : i32 to f32 -// CHECK: linalg.yield [[VAR_13_2_]] : f32 -// CHECK: } -> tensor<4xf32> -// CHECK: bufferization.materialize_in_destination [[VAR_12_]] in writable [[VAR_reinterpret_cast_0_]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/convert_extern_elementwise.mlir b/test/Conversion/TritonToLinalg/convert_extern_elementwise.mlir deleted file mode 100644 index a7ab57ab..00000000 --- a/test/Conversion/TritonToLinalg/convert_extern_elementwise.mlir +++ /dev/null @@ -1,809 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg --split-input-file %s | FileCheck %s - -module { - tt.func public @atan2_kernel_0123(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg3 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %11 = tt.addptr %10, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %12 = tt.load %11, %6 : tensor<32x!tt.ptr> - %13 = tt.extern_elementwise %9, %12 {libname = "", libpath = "", pure = true, symbol = "__nv_atan2f"} : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> - %14 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr> - %15 = tt.addptr %14, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %15, %13 : tensor<32x!tt.ptr> - tt.return - } -} -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @atan2_kernel_0123 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]], [[VAR_2:%.+]] : tensor<32xf32>, tensor<32xf32>) outs([[VAR_3:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in_1:%.+]]: f32, [[in_2:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_4:%.+]] = math.atan2 [[in_1]], [[in_2]] : f32 -// CHECK: linalg.yield [[VAR_4]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @pow_kernel_0123(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg3 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %11 = tt.addptr %10, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %12 = tt.load %11, %6 : tensor<32x!tt.ptr> - %13 = tt.extern_elementwise %9, %12 {libname = "", libpath = "", pure = true, symbol = "__nv_powf"} : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> - %14 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr> - %15 = tt.addptr %14, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %15, %13 : tensor<32x!tt.ptr> - tt.return - } -} -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @pow_kernel_0123 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]], [[VAR_2:%.+]] : tensor<32xf32>, tensor<32xf32>) outs([[VAR_3:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in_1:%.+]]: f32, [[in_2:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_4:%.+]] = math.powf [[in_1]], [[in_2]] : f32 -// CHECK: linalg.yield [[VAR_4]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @fabs_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_fabsf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @fabs_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.absf [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @sin_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_sinf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @sin_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.sin [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @cos_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_cosf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @cos_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.cos [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @tan_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_tanf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @tan_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.tan [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @asin_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_asinf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @asin_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.asin [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @acos_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_acosf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @acos_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.acos [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @atan_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_atanf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @atan_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.atan [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @sinh_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_sinhf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @sinh_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.sinh [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @cosh_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_coshf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @cosh_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.cosh [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @tanh_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_tanhf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @tanh_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.tanh [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @asinh_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_asinhf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @asinh_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.asinh [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @acosh_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_acoshf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @acosh_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.acosh [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @atanh_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_atanhf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @atanh_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.atanh [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @log_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_logf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @log_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.log [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @log10_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_log10f"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @log10_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.log10 [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @log1p_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_log1pf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @log1p_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.log1p [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @exp_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_expf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @exp_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.exp [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @exp2_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_exp2f"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @exp2_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.exp2 [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @erf_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_erff"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @erf_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.erf [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @sqrt_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_sqrtf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @sqrt_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.sqrt [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @rsqrt_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_rsqrtf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @rsqrt_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.rsqrt [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @ceil_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_ceilf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @ceil_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.ceil [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @floor_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_floorf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @floor_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.floor [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> - -// ----- - -module { - tt.func public @trunc_kernel_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.splat %arg2 : i32 -> tensor<32xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<32xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %9 = tt.load %8, %6 : tensor<32x!tt.ptr> - %10 = tt.extern_elementwise %9 {libname = "", libpath = "", pure = true, symbol = "__nv_truncf"} : (tensor<32xf32>) -> tensor<32xf32> - %11 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %12 = tt.addptr %11, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - tt.store %12, %10 : tensor<32x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @trunc_kernel_012 -// CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [[[MAP]], [[MAP]]], iterator_types = ["parallel"]} ins([[VAR_1:%.+]] : tensor<32xf32>) outs([[VAR_2:%.+]] : tensor<32xf32>) { -// CHECK: ^bb0([[in:%.+]]: f32, [[out:%.+]]: f32): -// CHECK: [[VAR_3:%.+]] = math.trunc [[in:%.+]] : f32 -// CHECK: linalg.yield [[VAR_3:%.+]] : f32 -// CHECK: } -> tensor<32xf32> diff --git a/test/Conversion/TritonToLinalg/convert_minmax.mlir b/test/Conversion/TritonToLinalg/convert_minmax.mlir deleted file mode 100644 index dd7edd3a..00000000 --- a/test/Conversion/TritonToLinalg/convert_minmax.mlir +++ /dev/null @@ -1,50 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg --split-input-file %s | FileCheck %s -module { - tt.func public @minmax_olt(%arg0: !tt.ptr, %arg1: f32, %arg2: f32) { - %0 = arith.cmpf olt, %arg1, %arg2 : f32 - %1 = arith.select %0, %arg1, %arg2 : f32 - tt.store %arg0, %1 : !tt.ptr - tt.return - } -} -// CHECK: func.func @minmax_olt -// CHECK: %[[VAL:.*]] = arith.minimumf %arg1, %arg2 : f32 - -// ----- - -module { - tt.func public @minmax_ole(%arg0: !tt.ptr, %arg1: f32, %arg2: f32) { - %0 = arith.cmpf ole, %arg1, %arg2 : f32 - %1 = arith.select %0, %arg1, %arg2 : f32 - tt.store %arg0, %1 : !tt.ptr - tt.return - } -} -// CHECK: func.func @minmax_ole -// CHECK: %[[VAL:.*]] = arith.minimumf %arg1, %arg2 : f32 - -// ----- - -module { - tt.func public @minmax_ogt(%arg0: !tt.ptr, %arg1: f32, %arg2: f32) { - %0 = arith.cmpf ogt, %arg1, %arg2 : f32 - %1 = arith.select %0, %arg1, %arg2 : f32 - tt.store %arg0, %1 : !tt.ptr - tt.return - } -} -// CHECK: func.func @minmax_ogt -// CHECK: %[[VAL:.*]] = arith.maximumf %arg1, %arg2 : f32 - -// ----- - -module { - tt.func public @minmax_oge(%arg0: !tt.ptr, %arg1: f32, %arg2: f32) { - %0 = arith.cmpf oge, %arg1, %arg2 : f32 - %1 = arith.select %0, %arg1, %arg2 : f32 - tt.store %arg0, %1 : !tt.ptr - tt.return - } -} -// CHECK: func.func @minmax_oge -// CHECK: %[[VAL:.*]] = arith.maximumf %arg1, %arg2 : f32 diff --git a/test/Conversion/TritonToLinalg/convert_minmax_fp_reduce.mlir b/test/Conversion/TritonToLinalg/convert_minmax_fp_reduce.mlir deleted file mode 100644 index 7d915e27..00000000 --- a/test/Conversion/TritonToLinalg/convert_minmax_fp_reduce.mlir +++ /dev/null @@ -1,68 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg --split-input-file %s | FileCheck %s - -module { - tt.func public @maxnumf(%arg0: !tt.ptr) { - %cst_0 = arith.constant dense<0.000000e+00> : tensor<4096xf32> - %63 = "tt.reduce"(%cst_0) ({ - ^bb0(%arg14: f32, %arg15: f32): - %69 = arith.maxnumf %arg14, %arg15 : f32 - tt.reduce.return %69 : f32 - }) {axis = 0 : i32} : (tensor<4096xf32>) -> f32 - tt.store %arg0, %63 : !tt.ptr - tt.return - } -} - -// CHECK-LABEL: func.func @maxnumf( -// CHECK-SAME: %arg0: memref<*xf32>, %[[ARG_1:.*]]: i32, %[[ARG_2:.*]]: i32, %[[ARG_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32) -// CHECK: %[[CST:.*]] = arith.constant 0xFF800000 : f32 -// CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<4096xf32> -// CHECK: %[[VAL_1:.*]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[VAL_0]] : tensor<4096xf32>) -> tensor<4096xf32> -// CHECK: %[[VAL_2:.*]] = bufferization.alloc_tensor() : tensor -// CHECK: %[[VAL_3:.*]] = tensor.insert %[[CST]] into %[[VAL_2]][] : tensor -// CHECK: %[[VAL_4:.*]] = linalg.reduce ins(%[[VAL_1]] : tensor<4096xf32>) outs(%[[VAL_3]] : tensor) dimensions = [0] -// CHECK: (%in: f32, %init: f32) { -// CHECK: %[[VAL_5:.*]] = arith.maxnumf %in, %init : f32 -// CHECK: linalg.yield %[[VAL_5]] : f32 -// CHECK: } -// CHECK: %[[VAL_6:.*]] = tensor.extract %[[VAL_4]][] : tensor -// CHECK: %[[VAL_7:.*]] = memref.[[VAL_7]] %arg0 to offset: [0], sizes: [1], strides: [1] : memref<*xf32> to memref<1xf32, strided<[1]>> -// CHECK: affine.store %[[VAL_6]], %[[VAL_7]][0] : memref<1xf32, strided<[1]>> -// CHECK: return -// CHECK:} - -// ----- - - -module { - tt.func public @minnumf(%arg0: !tt.ptr) { - %cst_0 = arith.constant dense<0.000000e+00> : tensor<4096xf32> - %63 = "tt.reduce"(%cst_0) ({ - ^bb0(%arg14: f32, %arg15: f32): - %69 = arith.minnumf %arg14, %arg15 : f32 - tt.reduce.return %69 : f32 - }) {axis = 0 : i32} : (tensor<4096xf32>) -> f32 - tt.store %arg0, %63 : !tt.ptr - tt.return - } -} - -// CHECK-LABEL: func.func @minnumf( -// CHECK-SAME: %arg0: memref<*xf32>, %[[ARG_1:.*]]: i32, %[[ARG_2:.*]]: i32, %[[ARG_3:.*]]: i32, %[[ARG_4:.*]]: i32, %[[ARG_5:.*]]: i32, %[[ARG_6:.*]]: i32) -// CHECK: %[[CST:.*]] = arith.constant 0x7F800000 : f32 -// CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<4096xf32> -// CHECK: %[[VAL_1:.*]] = linalg.fill ins(%[[CST_0]] : f32) outs(%[[VAL_0]] : tensor<4096xf32>) -> tensor<4096xf32> -// CHECK: %[[VAL_2:.*]] = bufferization.alloc_tensor() : tensor -// CHECK: %[[VAL_3:.*]] = tensor.insert %[[CST]] into %[[VAL_2]][] : tensor -// CHECK: %[[VAL_4:.*]] = linalg.reduce ins(%[[VAL_1]] : tensor<4096xf32>) outs(%[[VAL_3]] : tensor) dimensions = [0] -// CHECK: (%in: f32, %init: f32) { -// CHECK: %[[VAL_5:.*]] = arith.minnumf %in, %init : f32 -// CHECK: linalg.yield %[[VAL_5]] : f32 -// CHECK: } -// CHECK: %[[VAL_6:.*]] = tensor.extract %[[VAL_4]][] : tensor -// CHECK: %[[VAL_7:.*]] = memref.[[VAL_7]] %arg0 to offset: [0], sizes: [1], strides: [1] : memref<*xf32> to memref<1xf32, strided<[1]>> -// CHECK: affine.store %[[VAL_6]], %[[VAL_7]][0] : memref<1xf32, strided<[1]>> -// CHECK: return -// CHECK:} diff --git a/test/Conversion/TritonToLinalg/convert_minmax_reduce.mlir b/test/Conversion/TritonToLinalg/convert_minmax_reduce.mlir deleted file mode 100644 index eaf30630..00000000 --- a/test/Conversion/TritonToLinalg/convert_minmax_reduce.mlir +++ /dev/null @@ -1,126 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg --split-input-file %s | FileCheck %s -module { - tt.func public @minmax_sgt(%arg0: !tt.ptr) { - %cst_0 = arith.constant dense<0> : tensor<4096xi32> - %63 = "tt.reduce"(%cst_0) ({ - ^bb0(%arg14: i32, %arg15: i32): - %69 = arith.cmpi sgt, %arg14, %arg15 : i32 - %70 = arith.select %69, %arg14, %arg15 : i32 - tt.reduce.return %70 : i32 - }) {axis = 0 : i32} : (tensor<4096xi32>) -> i32 - tt.store %arg0, %63 : !tt.ptr - tt.return - } -} - -// CHECK: func.func @minmax_sgt(%[[VAL_0:.*]]: memref<*xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32) { -// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<4096xi32> -// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%c0{{.*}} : i32) outs(%[[VAL_7]] : tensor<4096xi32>) -> tensor<4096xi32> -// CHECK: %[[VAL_9:.*]] = bufferization.alloc_tensor() : tensor -// CHECK: %[[VAL_10:.*]] = tensor.insert %c-2147483648{{.*}} into %[[VAL_9]][] : tensor -// CHECK: %[[VAL_11:.*]] = linalg.reduce ins(%[[VAL_8]] : tensor<4096xi32>) outs(%[[VAL_10]] : tensor) dimensions = [0] -// CHECK: (%in: i32, %init: i32) { -// CHECK: %[[VAL_12:.*]] = arith.maxsi %in, %init : i32 -// CHECK: linalg.yield %[[VAL_12]] : i32 -// CHECK: } -// CHECK: %[[VAL_12:.*]] = tensor.extract %[[VAL_11]][] : tensor -// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1]>> -// CHECK: affine.store %[[VAL_12]], %[[VAL_13]][0] : memref<1xi32, strided<[1]>> -// CHECK: return -// CHECK: } - -// ----- - -module { - tt.func public @minmax_ugt(%arg0: !tt.ptr) { - %cst_0 = arith.constant dense<0> : tensor<4096xi32> - %63 = "tt.reduce"(%cst_0) ({ - ^bb0(%arg14: i32, %arg15: i32): - %69 = arith.cmpi ugt, %arg14, %arg15 : i32 - %70 = arith.select %69, %arg14, %arg15 : i32 - tt.reduce.return %70 : i32 - }) {axis = 0 : i32} : (tensor<4096xi32>) -> i32 - tt.store %arg0, %63 : !tt.ptr - tt.return - } -} - -// CHECK: func.func @minmax_ugt(%[[VAL_0:.*]]: memref<*xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32) { -// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<4096xi32> -// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%c0{{.*}} : i32) outs(%[[VAL_7]] : tensor<4096xi32>) -> tensor<4096xi32> -// CHECK: %[[VAL_9:.*]] = bufferization.alloc_tensor() : tensor -// CHECK: %[[VAL_10:.*]] = tensor.insert %c0{{.*}} into %[[VAL_9]][] : tensor -// CHECK: %[[VAL_11:.*]] = linalg.reduce ins(%[[VAL_8]] : tensor<4096xi32>) outs(%[[VAL_10]] : tensor) dimensions = [0] -// CHECK: (%in: i32, %init: i32) { -// CHECK: %[[VAL_12:.*]] = arith.maxui %in, %init : i32 -// CHECK: linalg.yield %[[VAL_12]] : i32 -// CHECK: } -// CHECK: %[[VAL_12:.*]] = tensor.extract %[[VAL_11]][] : tensor -// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1]>> -// CHECK: affine.store %[[VAL_12]], %[[VAL_13]][0] : memref<1xi32, strided<[1]>> -// CHECK: return -// CHECK: } - -// ----- - -module { - tt.func public @minmax_slt(%arg0: !tt.ptr) { - %cst_0 = arith.constant dense<0> : tensor<4096xi32> - %63 = "tt.reduce"(%cst_0) ({ - ^bb0(%arg14: i32, %arg15: i32): - %69 = arith.cmpi slt, %arg14, %arg15 : i32 - %70 = arith.select %69, %arg14, %arg15 : i32 - tt.reduce.return %70 : i32 - }) {axis = 0 : i32} : (tensor<4096xi32>) -> i32 - tt.store %arg0, %63 : !tt.ptr - tt.return - } -} - -// CHECK: func.func @minmax_slt(%[[VAL_0:.*]]: memref<*xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32) { -// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<4096xi32> -// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%c0{{.*}} : i32) outs(%[[VAL_7]] : tensor<4096xi32>) -> tensor<4096xi32> -// CHECK: %[[VAL_9:.*]] = bufferization.alloc_tensor() : tensor -// CHECK: %[[VAL_10:.*]] = tensor.insert %c2147483647{{.*}} into %[[VAL_9]][] : tensor -// CHECK: %[[VAL_11:.*]] = linalg.reduce ins(%[[VAL_8]] : tensor<4096xi32>) outs(%[[VAL_10]] : tensor) dimensions = [0] -// CHECK: (%in: i32, %init: i32) { -// CHECK: %[[VAL_12:.*]] = arith.minsi %in, %init : i32 -// CHECK: linalg.yield %[[VAL_12]] : i32 -// CHECK: } -// CHECK: %[[VAL_12:.*]] = tensor.extract %[[VAL_11]][] : tensor -// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1]>> -// CHECK: affine.store %[[VAL_12]], %[[VAL_13]][0] : memref<1xi32, strided<[1]>> -// CHECK: return -// CHECK: } - -// ----- - -module { - tt.func public @minmax_ult(%arg0: !tt.ptr) { - %cst_0 = arith.constant dense<0> : tensor<4096xi32> - %63 = "tt.reduce"(%cst_0) ({ - ^bb0(%arg14: i32, %arg15: i32): - %69 = arith.cmpi ult, %arg14, %arg15 : i32 - %70 = arith.select %69, %arg14, %arg15 : i32 - tt.reduce.return %70 : i32 - }) {axis = 0 : i32} : (tensor<4096xi32>) -> i32 - tt.store %arg0, %63 : !tt.ptr - tt.return - } -} - -// CHECK: func.func @minmax_ult(%[[VAL_0:.*]]: memref<*xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32) { -// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<4096xi32> -// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%c0{{.*}} : i32) outs(%[[VAL_7]] : tensor<4096xi32>) -> tensor<4096xi32> -// CHECK: %[[VAL_9:.*]] = bufferization.alloc_tensor() : tensor -// CHECK: %[[VAL_10:.*]] = tensor.insert %c-1{{.*}} into %[[VAL_9]][] : tensor -// CHECK: %[[VAL_11:.*]] = linalg.reduce ins(%[[VAL_8]] : tensor<4096xi32>) outs(%[[VAL_10]] : tensor) dimensions = [0] -// CHECK: (%in: i32, %init: i32) { -// CHECK: %[[VAL_12:.*]] = arith.minui %in, %init : i32 -// CHECK: linalg.yield %[[VAL_12]] : i32 -// CHECK: } -// CHECK: %[[VAL_12:.*]] = tensor.extract %[[VAL_11]][] : tensor -// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1]>> -// CHECK: affine.store %[[VAL_12]], %[[VAL_13]][0] : memref<1xi32, strided<[1]>> -// CHECK: return -// CHECK: } \ No newline at end of file diff --git a/test/Conversion/TritonToLinalg/convert_splat_float.mlir b/test/Conversion/TritonToLinalg/convert_splat_float.mlir deleted file mode 100644 index f37b2107..00000000 --- a/test/Conversion/TritonToLinalg/convert_splat_float.mlir +++ /dev/null @@ -1,23 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel(%fin : f32, - %bin : bf16, - %save0 : tensor<1024x!tt.ptr>, - %save1 : tensor<128x256x!tt.ptr>) -> () { - %0 = tt.splat %fin : f32 -> tensor<1024xf32> - %1 = tt.splat %bin : bf16 -> tensor<128x256xbf16> - tt.store %save0, %0 : tensor<1024x!tt.ptr> - tt.store %save1, %1 : tensor<128x256x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: f32, %[[VAL_1:.*]]: bf16, %[[VAL_2:.*]]: memref<1024xf32>, %[[VAL_3:.*]]: memref<128x256xbf16>, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32, %[[VAL_6:.*]]: i32) { -// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<1024xf32> -// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_0]] : f32) outs(%[[VAL_7]] : tensor<1024xf32>) -> tensor<1024xf32> -// CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<128x256xbf16> -// CHECK: %[[VAL_10:.*]] = linalg.fill ins(%[[VAL_1]] : bf16) outs(%[[VAL_9]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_8]] in writable %[[VAL_2]] -// CHECK: bufferization.materialize_in_destination %[[VAL_10]] in writable %[[VAL_3]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/convert_tensor_reshape.mlir b/test/Conversion/TritonToLinalg/convert_tensor_reshape.mlir deleted file mode 100644 index 33e5e67f..00000000 --- a/test/Conversion/TritonToLinalg/convert_tensor_reshape.mlir +++ /dev/null @@ -1,45 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func public @bcast_kernel_01(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { - %c32_i32 = arith.constant 32 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c32_i32 : i32 - %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %3 = tt.splat %1 : i32 -> tensor<32xi32> - %4 = arith.addi %3, %2 : tensor<32xi32> - %5 = tt.make_range {end = 2048 : i32, start = 0 : i32} : tensor<2048xi32> - %6 = tt.splat %1 : i32 -> tensor<2048xi32> - %7 = arith.addi %6, %5 : tensor<2048xi32> - %8 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> - %9 = tt.addptr %8, %4 : tensor<32x!tt.ptr>, tensor<32xi32> - %10 = tt.load %9 : tensor<32x!tt.ptr> - %11 = tt.reshape %10 allow_reorder : tensor<32xf32> -> tensor<1x32xf32> - %12 = tt.broadcast %11 : tensor<1x32xf32> -> tensor<64x32xf32> - %13 = tt.reshape %12 allow_reorder : tensor<64x32xf32> -> tensor<2048xf32> - %14 = tt.splat %arg1 : !tt.ptr -> tensor<2048x!tt.ptr> - %15 = tt.addptr %14, %7 : tensor<2048x!tt.ptr>, tensor<2048xi32> - tt.store %15, %13 : tensor<2048x!tt.ptr> - tt.return - } -} - - -// CHECK-LABEL: func.func @bcast_kernel_01( -// CHECK: %[[C32_I32:.*]] = arith.constant 32 : i32 -// CHECK: %[[VAR_0:.*]] = arith.muli %arg5, %[[C32_I32]] : i32 -// CHECK: %[[VAR_1:.*]] = arith.index_cast %[[VAR_0]] : i32 to index -// CHECK: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %arg0 to offset: [%[[VAR_1]]], sizes: [32], strides: [1] : memref<*xf32> to memref<32xf32, strided<[1], offset: ?>> -// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<32xf32> -// CHECK: memref.copy %[[REINTERPRET_CAST:.*]], %[[ALLOC]] : memref<32xf32, strided<[1], offset: ?>> to memref<32xf32> -// CHECK: %[[VAR_2:.*]] = bufferization.to_tensor %[[ALLOC]] restrict writable : memref<32xf32> -// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[VAR_2]] {{.}}[0, 1]{{.}} output_shape [1, 32] : tensor<32xf32> into tensor<1x32xf32> -// CHECK: %[[VAR_3:.*]] = tensor.empty() : tensor<64x32xf32> -// CHECK: %[[VAR_4:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[EXPANDED]] : tensor<1x32xf32>) outs(%[[VAR_3:.*]] : tensor<64x32xf32>) attrs = {broadcastDims = array} { -// CHECK: ^bb0(%in: f32, %out: f32): -// CHECK: linalg.yield %in : f32 -// CHECK: } -> tensor<64x32xf32> -// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[VAR_4]] {{.}}[0, 1]{{.}} : tensor<64x32xf32> into tensor<2048xf32> -// CHECK: %[[VAR_7:.*]] = arith.index_cast %[[VAR_0]] : i32 to index -// CHECK: %[[REINTERPRET_CAST_1:.*]] = memref.reinterpret_cast %arg1 to offset: [%[[VAR_7]]], sizes: [2048], strides: [1] : memref<*xf32> to memref<2048xf32, strided<[1], offset: ?>> -// CHECK: bufferization.materialize_in_destination %[[COLLAPSED]] in writable %[[REINTERPRET_CAST_1]] : (tensor<2048xf32>, memref<2048xf32, strided<[1], offset: ?>>) -> () -// CHECK: return diff --git a/test/Conversion/TritonToLinalg/cumsum.mlir b/test/Conversion/TritonToLinalg/cumsum.mlir deleted file mode 100644 index b579517a..00000000 --- a/test/Conversion/TritonToLinalg/cumsum.mlir +++ /dev/null @@ -1,68 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s - -// @triton.jit -// def test_cumsum_op( -// input_ptr, output_ptr, n_columns -// ): -// row = tl.program_id(axis=0) -// row_start = row * n_columns -// columns = tl.arange(0, 4096) -// offsets = row_start + columns -// data = tl.load(input_ptr + offsets) -// result = tl.cumsum(data, axis=0) -// tl.store(output_ptr + offsets, result) -// -// ret = triton.compiler.compile( -// test_cumsum_op, -// signature=" *fp32,*i32,i32", -// print_triton_ir_only=True, -// ) -// print(ret.asm["ttir"]) - -module { - tt.func public @test_cumsum_op_012(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %arg2 : i32 - %2 = tt.make_range {end = 4096 : i32, start = 0 : i32} : tensor<4096xi32> - %3 = tt.splat %1 : i32 -> tensor<4096xi32> - %4 = arith.addi %3, %2 : tensor<4096xi32> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<4096x!tt.ptr> - %6 = tt.addptr %5, %4 : tensor<4096x!tt.ptr>, tensor<4096xi32> - %7 = tt.load %6 : tensor<4096x!tt.ptr> - %8 = "tt.scan"(%7) <{axis = 0 : i32, reverse = false}> ({ - ^bb0(%arg3: f32, %arg4: f32): - %12 = arith.addf %arg3, %arg4 : f32 - tt.scan.return %12 : f32 - }) : (tensor<4096xf32>) -> tensor<4096xf32> - %9 = tt.splat %arg1 : !tt.ptr -> tensor<4096x!tt.ptr> - %10 = tt.addptr %9, %4 : tensor<4096x!tt.ptr>, tensor<4096xi32> - %11 = arith.fptosi %8 : tensor<4096xf32> to tensor<4096xi32> - tt.store %10, %11 : tensor<4096x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @test_cumsum_op_012 -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xi32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) { -// CHECK: [[VAR_0_:%.+]] = arith.muli [[PARAM_6_]], [[PARAM_2_]] : i32 -// CHECK: [[VAR_1_:%.+]] = arith.index_cast [[VAR_0_]] : i32 to index -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_1_]]{{.}}, sizes: [4096], strides: [1] : memref<*xf32> to memref<4096xf32, strided<[1], offset: ?>> -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4096xf32> -// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4096xf32, strided<[1], offset: ?>> to memref<4096xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4096xf32> -// CHECK-DAG: [[VAR_3_:%.+]] = tensor.empty() : tensor<4096xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_4_:%.+]] = ttx.cumsum {axis = 0 : ui32, operandSegmentSizes = array} ins([[VAR_2_]] : tensor<4096xf32>) outs([[VAR_3_]] : tensor<4096xf32>) -> tensor<4096xf32> -// CHECK-DAG: [[VAR_5_:%.+]] = arith.index_cast [[VAR_0_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_5_]]{{.}}, sizes: [4096], strides: [1] : memref<*xi32> to memref<4096xi32, strided<[1], offset: ?>> -// CHECK-DAG: [[VAR_6_:%.+]] = tensor.empty() : tensor<4096xi32> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_4_]] : tensor<4096xf32>) outs([[VAR_6_]] : tensor<4096xi32>) { -// CHECK: ^bb0([[in_:.+]]: f32, [[out_:.+]]: i32): -// CHECK: [[VAR_8_:%.+]] = arith.fptosi [[in_]] : f32 to i32 -// CHECK: linalg.yield [[VAR_8_]] : i32 -// CHECK: } -> tensor<4096xi32> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_0_]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/dot.mlir b/test/Conversion/TritonToLinalg/dot.mlir deleted file mode 100644 index 95cb91b7..00000000 --- a/test/Conversion/TritonToLinalg/dot.mlir +++ /dev/null @@ -1,84 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr, - %arg2 : !tt.ptr - ) - { - %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %c64 = arith.constant 128 : i32 - %1 = tt.splat %c64 : i32 -> tensor<128xi32> - %2 = arith.muli %0, %1 : tensor<128xi32> - %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> - %4 = tt.broadcast %3 : tensor<128x1xi32> -> tensor<128x64xi32> - %5 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %6 = tt.expand_dims %5 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> - %7 = tt.broadcast %6 : tensor<1x64xi32> -> tensor<128x64xi32> - %8 = arith.addi %4, %7 : tensor<128x64xi32> - %10 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> - %11 = tt.expand_dims %10 {axis = 1 : i32} : tensor<256xi32> -> tensor<256x1xi32> - %12 = tt.broadcast %11 : tensor<256x1xi32> -> tensor<256x64xi32> - %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %c256 = arith.constant 256 : i32 - %14 = tt.splat %c256 : i32 -> tensor<64xi32> - %15 = arith.muli %13, %14 : tensor<64xi32> - %16 = tt.expand_dims %15 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> - %17 = tt.broadcast %16 : tensor<1x64xi32> -> tensor<256x64xi32> - %18 = arith.addi %12, %17 : tensor<256x64xi32> - %20 = tt.splat %c256 : i32 -> tensor<128xi32> - %21 = arith.muli %0, %20 : tensor<128xi32> - %22 = tt.expand_dims %21 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> - %23 = tt.broadcast %22 : tensor<128x1xi32> -> tensor<128x256xi32> - %24 = tt.expand_dims %10 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - %25 = tt.broadcast %24 {axis = 0 : i32} : tensor<1x256xi32> -> tensor<128x256xi32> - %26 = arith.addi %23, %25 : tensor<128x256xi32> - %30 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr> - %31 = tt.addptr %30, %8 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> - %32 = tt.load %31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<128x64x!tt.ptr> - %40 = tt.splat %arg1 : !tt.ptr -> tensor<256x64x!tt.ptr> - %41 = tt.addptr %40, %18 : tensor<256x64x!tt.ptr>, tensor<256x64xi32> - %42 = tt.load %41 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x64x!tt.ptr> - %43 = tt.trans %42 {order = array} : tensor<256x64xbf16> -> tensor<64x256xbf16> - %50 = tt.splat %arg2 : !tt.ptr -> tensor<128x256x!tt.ptr> - %51 = tt.addptr %50, %26 : tensor<128x256x!tt.ptr>, tensor<128x256xi32> - %52 = tt.load %51 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<128x256x!tt.ptr> - %60 = tt.dot %32, %43, %52 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xbf16> - tt.store %51, %60 : tensor<128x256x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP_:#.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-LABEL: func.func @kernel -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: memref<*xbf16>, [[PARAM_2_:%.+]]: memref<*xbf16>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) { -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0.000000e+00 : bf16 -// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index -// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [128, 64], strides: {{.}}[[CST_128_]], 1] : memref<*xbf16> to memref<128x64xbf16, strided<[?, 1]>> -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<128x64xbf16> -// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<128x64xbf16, strided<[?, 1]>> to memref<128x64xbf16> -// CHECK-DAG: [[VAR_0_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<128x64xbf16> -// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [256, 64], strides: [1, [[CST_256_]]{{.}} : memref<*xbf16> to memref<256x64xbf16, strided<[1, ?]>> -// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<256x64xbf16> -// CHECK: memref.copy [[VAR_reinterpret_cast_0_]], [[RES_1_]] : memref<256x64xbf16, strided<[1, ?]>> to memref<256x64xbf16> -// CHECK-DAG: [[VAR_1_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<256x64xbf16> -// CHECK-DAG: [[VAR_2_:%.+]] = tensor.empty() : tensor<64x256xbf16> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_transposed_:%.+]] = linalg.transpose ins([[VAR_1_]] : tensor<256x64xbf16>) outs([[VAR_2_]] : tensor<64x256xbf16>) permutation = [1, 0] -// CHECK-DAG: [[VAR_reinterpret_cast_2_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: [0], sizes: [128, 256], strides: {{.}}[[CST_256_]], 1] : memref<*xbf16> to memref<128x256xbf16, strided<[?, 1]>> -// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref<128x256xbf16> -// CHECK: memref.copy [[VAR_reinterpret_cast_2_]], [[RES_2_]] : memref<128x256xbf16, strided<[?, 1]>> to memref<128x256xbf16> -// CHECK-DAG: [[VAR_3_:%.+]] = bufferization.to_tensor [[RES_2_]] restrict writable : memref<128x256xbf16> -// CHECK: [[VAR_4_:%.+]] = tensor.empty() : tensor<128x256xbf16> -// CHECK: [[VAR_5_:%.+]] = linalg.fill ins([[CST_0_]] : bf16) outs([[VAR_4_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> -// CHECK: [[VAR_6_:%.+]] = linalg.matmul ins([[VAR_0_]], [[VAR_transposed_]] : tensor<128x64xbf16>, tensor<64x256xbf16>) outs([[VAR_5_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> -// CHECK: [[VAR_7_:%.+]] = linalg.generic {indexing_maps = [[[MAP_]], [[MAP_]], [[MAP_]]], iterator_types = ["parallel", "parallel"]} ins([[VAR_3_]], [[VAR_6_]] : tensor<128x256xbf16>, tensor<128x256xbf16>) outs([[VAR_3_]] : tensor<128x256xbf16>) { -// CHECK: ^bb0([[VAR_in_1:%.+]]: bf16, [[VAR_in_2:%.+]]: bf16, {{%.+}}: bf16): -// CHECK: [[VAR_8_:%.+]] = arith.addf [[VAR_in_1]], [[VAR_in_2]] : bf16 -// CHECK: linalg.yield [[VAR_8_:%.+]] : bf16 -// CHECK: } -> tensor<128x256xbf16> -// CHECK: bufferization.materialize_in_destination [[VAR_7_]] in writable [[VAR_reinterpret_cast_2_]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/get_num_programs.mlir b/test/Conversion/TritonToLinalg/get_num_programs.mlir deleted file mode 100644 index afda2996..00000000 --- a/test/Conversion/TritonToLinalg/get_num_programs.mlir +++ /dev/null @@ -1,45 +0,0 @@ -// XFAIL: * -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s - -module { - tt.func public @num_programs(%arg0: !tt.ptr) { - %0 = tt.get_num_programs x : i32 - %1 = tt.get_num_programs y : i32 - %2 = tt.get_num_programs z : i32 - %3 = tt.make_range {end = 1 : i32, start = 0 : i32} : tensor<1xi32> - %4 = tt.make_range {end = 2 : i32, start = 1 : i32} : tensor<1xi32> - %5 = tt.make_range {end = 3 : i32, start = 2 : i32} : tensor<1xi32> - %6 = tt.splat %arg0 : !tt.ptr -> tensor<1x!tt.ptr> - %7 = tt.addptr %6, %3 : tensor<1x!tt.ptr>, tensor<1xi32> - %8 = tt.splat %0 : i32 -> tensor<1xi32> - tt.store %7, %8 : tensor<1x!tt.ptr> - %9 = tt.addptr %6, %4 : tensor<1x!tt.ptr>, tensor<1xi32> - %10 = tt.splat %1 : i32 -> tensor<1xi32> - tt.store %9, %10 : tensor<1x!tt.ptr> - %11 = tt.addptr %6, %5 : tensor<1x!tt.ptr>, tensor<1xi32> - %12 = tt.splat %2 : i32 -> tensor<1xi32> - tt.store %11, %12 : tensor<1x!tt.ptr> - tt.return - } -} - -// CHECK: module { -// CHECK: func.func @num_programs(%arg0: memref<*xi32>, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) { -// CHECK: %c2 = arith.constant 2 : index -// CHECK: %c1 = arith.constant 1 : index -// CHECK: %c0 = arith.constant 0 : index -// CHECK: %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%c0], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1], offset: ?>> -// CHECK: %0 = tensor.empty() : tensor<1xi32> -// CHECK: %1 = linalg.fill ins(%arg1 : i32) outs(%0 : tensor<1xi32>) -> tensor<1xi32> -// CHECK: bufferization.materialize_in_destination %1 in writable %reinterpret_cast -// CHECK: %reinterpret_cast_0 = memref.reinterpret_cast %arg0 to offset: [%c1], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1], offset: ?>> -// CHECK: %2 = tensor.empty() : tensor<1xi32> -// CHECK: %3 = linalg.fill ins(%arg2 : i32) outs(%2 : tensor<1xi32>) -> tensor<1xi32> -// CHECK: bufferization.materialize_in_destination %3 in writable %reinterpret_cast_0 -// CHECK: %reinterpret_cast_1 = memref.reinterpret_cast %arg0 to offset: [%c2], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1], offset: ?>> -// CHECK: %4 = tensor.empty() : tensor<1xi32> -// CHECK: %5 = linalg.fill ins(%arg3 : i32) outs(%4 : tensor<1xi32>) -> tensor<1xi32> -// CHECK: bufferization.materialize_in_destination %5 in writable %reinterpret_cast_1 -// CHECK: return -// CHECK: } -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/reducemax_32_256_bf16.mlir b/test/Conversion/TritonToLinalg/reducemax_32_256_bf16.mlir deleted file mode 100644 index 78afe418..00000000 --- a/test/Conversion/TritonToLinalg/reducemax_32_256_bf16.mlir +++ /dev/null @@ -1,58 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel(%afloat : !tt.ptr, - %res : tensor<256x16x!tt.ptr> - ) -> () { - // offset calculations - %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %c256 = arith.constant 256 : i32 - %ct256 = tt.splat %c256 : i32 -> tensor<32xi32> - %ws = arith.muli %ct256, %0 : tensor<32xi32> - %1 = tt.expand_dims %ws {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32> - %m2 = tt.broadcast %1 : tensor<32x1xi32> -> tensor<32x256xi32> - %100 = tt.expand_dims %m2 {axis = 2 : i32} : tensor<32x256xi32> -> tensor<32x256x1xi32> - %moff = tt.broadcast %100 : tensor<32x256x1xi32> -> tensor<32x256x16xi32> - %33 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> - %34 = tt.expand_dims %33 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - %k2 = tt.broadcast %34 : tensor<1x256xi32> -> tensor<32x256xi32> - %200 = tt.expand_dims %k2 {axis = 2 : i32} : tensor<32x256xi32> -> tensor<32x256x1xi32> - %koff = tt.broadcast %200 : tensor<32x256x1xi32> -> tensor<32x256x16xi32> - %23 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> - %24 = tt.expand_dims %23 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> - %n2 = tt.broadcast %24 : tensor<1x16xi32> -> tensor<256x16xi32> - %300 = tt.expand_dims %n2 {axis = 0 : i32} : tensor<256x16xi32> -> tensor<1x256x16xi32> - %noff = tt.broadcast %300 : tensor<1x256x16xi32> -> tensor<32x256x16xi32> - %mkoff = arith.addi %moff, %koff : tensor<32x256x16xi32> - %mknoff = arith.addi %mkoff, %noff : tensor<32x256x16xi32> - // afloat pointer - %8 = tt.splat %afloat : !tt.ptr -> tensor<32x256x16x!tt.ptr> - %9 = tt.addptr %8, %mknoff : tensor<32x256x16x!tt.ptr>, tensor<32x256x16xi32> - %afm = tt.load %9 : tensor<32x256x16x!tt.ptr> - %6 = "tt.reduce"(%afm) ({ - ^bb0(%arg5: bf16, %arg6: bf16): - %21 = arith.cmpf ogt, %arg5, %arg6 : bf16 - %22 = arith.select %21, %arg5, %arg6 : bf16 - tt.reduce.return %22 : bf16 - }) {axis = 0 : i32} : (tensor<32x256x16xbf16>) -> tensor<256x16xbf16> - tt.store %res, %6 : tensor<256x16x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<256x16xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 256 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0xFF80 : bf16 -// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [32, 256, 16], strides: {{\[}}%[[VAL_5]], 1, 1] : memref<*xbf16> to memref<32x256x16xbf16, strided<[?, 1, 1]>> -// CHECK: %[[VAL_8:.*]] = memref.alloc() : memref<32x256x16xbf16> -// CHECK: memref.copy %[[VAL_7]], %[[VAL_8]] : memref<32x256x16xbf16, strided<[?, 1, 1]>> to memref<32x256x16xbf16> -// CHECK: %[[VAL_9:.*]] = bufferization.to_tensor %[[VAL_8]] restrict writable : memref<32x256x16xbf16> -// CHECK: %[[VAL_10:.*]] = tensor.empty() : tensor<256x16xbf16> -// CHECK: %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_6]] : bf16) outs(%[[VAL_10]] : tensor<256x16xbf16>) -> tensor<256x16xbf16> -// CHECK: %[[VAL_12:.*]] = linalg.reduce ins(%[[VAL_9]] : tensor<32x256x16xbf16>) outs(%[[VAL_11]] : tensor<256x16xbf16>) dimensions = [0] -// CHECK: (%[[VAL_13:.*]]: bf16, %[[VAL_14:.*]]: bf16) { -// CHECK: %[[VAL_15:.*]] = arith.maximumf %[[VAL_13]], %[[VAL_14]] : bf16 -// CHECK: linalg.yield %[[VAL_15]] : bf16 -// CHECK: } -// CHECK: bufferization.materialize_in_destination %[[VAL_12]] in writable %[[VAL_1]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/reducesum_512_256_bf16_axis0.mlir b/test/Conversion/TritonToLinalg/reducesum_512_256_bf16_axis0.mlir deleted file mode 100644 index 5726ea0c..00000000 --- a/test/Conversion/TritonToLinalg/reducesum_512_256_bf16_axis0.mlir +++ /dev/null @@ -1,51 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel(%afloat : !tt.ptr, - %res : !tt.ptr - ) -> () { - // offset calculations - %0 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32> - %c256 = arith.constant 256 : i32 - %ct256 = tt.splat %c256 : i32 -> tensor<512xi32> - %ws = arith.muli %ct256, %0 : tensor<512xi32> - %1 = tt.expand_dims %ws {axis = 1 : i32} : tensor<512xi32> -> tensor<512x1xi32> - %moff = tt.broadcast %1 : tensor<512x1xi32> -> tensor<512x256xi32> - %3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> - %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - %koff = tt.broadcast %4 : tensor<1x256xi32> -> tensor<512x256xi32> - %mkoff = arith.addi %moff, %koff : tensor<512x256xi32> - // afloat pointer - %8 = tt.splat %afloat : !tt.ptr -> tensor<512x256x!tt.ptr> - %9 = tt.addptr %8, %mkoff : tensor<512x256x!tt.ptr>, tensor<512x256xi32> - // res pointer - %18 = tt.splat %res : !tt.ptr -> tensor<256x!tt.ptr> - %19 = tt.addptr %18, %3 : tensor<256x!tt.ptr>, tensor<256xi32> - %afm = tt.load %9 : tensor<512x256x!tt.ptr> - %5 = "tt.reduce"(%afm) ({ - ^bb0(%arg5: bf16, %arg6: bf16): - %21 = arith.addf %arg5, %arg6 : bf16 - tt.reduce.return %21 : bf16 - }) {axis = 0 : i32} : (tensor<512x256xbf16>) -> tensor<256xbf16> - tt.store %19, %5 : tensor<256x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 256 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : bf16 -// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [512, 256], strides: {{\[}}%[[VAL_5]], 1] : memref<*xbf16> to memref<512x256xbf16, strided<[?, 1]>> -// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [256], strides: [1] : memref<*xbf16> to memref<256xbf16, strided<[1]>> -// CHECK: %[[VAL_9:.*]] = memref.alloc() : memref<512x256xbf16> -// CHECK: memref.copy %[[VAL_7]], %[[VAL_9]] : memref<512x256xbf16, strided<[?, 1]>> to memref<512x256xbf16> -// CHECK: %[[VAL_10:.*]] = bufferization.to_tensor %[[VAL_9]] restrict writable : memref<512x256xbf16> -// CHECK: %[[VAL_11:.*]] = tensor.empty() : tensor<256xbf16> -// CHECK: %[[VAL_12:.*]] = linalg.fill ins(%[[VAL_6]] : bf16) outs(%[[VAL_11]] : tensor<256xbf16>) -> tensor<256xbf16> -// CHECK: %[[VAL_13:.*]] = linalg.reduce ins(%[[VAL_10]] : tensor<512x256xbf16>) outs(%[[VAL_12]] : tensor<256xbf16>) dimensions = [0] -// CHECK: (%[[VAL_14:.*]]: bf16, %[[VAL_15:.*]]: bf16) { -// CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_14]], %[[VAL_15]] : bf16 -// CHECK: linalg.yield %[[VAL_16]] : bf16 -// CHECK: } -// CHECK: bufferization.materialize_in_destination %[[VAL_13]] in writable %[[VAL_8]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/reducesum_512_256_bf16_axis1.mlir b/test/Conversion/TritonToLinalg/reducesum_512_256_bf16_axis1.mlir deleted file mode 100644 index 7f37a9f7..00000000 --- a/test/Conversion/TritonToLinalg/reducesum_512_256_bf16_axis1.mlir +++ /dev/null @@ -1,53 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel(%afloat : !tt.ptr, - %res : !tt.ptr - ) -> () { - // offset calculations - %0 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32> - %c256 = arith.constant 256 : i32 - %ct256 = tt.splat %c256 : i32 -> tensor<512xi32> - %ws = arith.muli %ct256, %0 : tensor<512xi32> - %1 = tt.expand_dims %ws {axis = 1 : i32} : tensor<512xi32> -> tensor<512x1xi32> - %moff = tt.broadcast %1 : tensor<512x1xi32> -> tensor<512x256xi32> - %3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> - %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - %koff = tt.broadcast %4 : tensor<1x256xi32> -> tensor<512x256xi32> - %mkoff = arith.addi %moff, %koff : tensor<512x256xi32> - // afloat pointer - %8 = tt.splat %afloat : !tt.ptr -> tensor<512x256x!tt.ptr> - %9 = tt.addptr %8, %mkoff : tensor<512x256x!tt.ptr>, tensor<512x256xi32> - // res pointer - %18 = tt.splat %res : !tt.ptr -> tensor<512x!tt.ptr> - %19 = tt.addptr %18, %0 : tensor<512x!tt.ptr>, tensor<512xi32> - %afm = tt.load %9 : tensor<512x256x!tt.ptr> - %5 = "tt.reduce"(%afm) ({ - ^bb0(%arg5: bf16, %arg6: bf16): - %21 = arith.addf %arg5, %arg6 : bf16 - tt.reduce.return %21 : bf16 - }) {axis = 1 : i32} : (tensor<512x256xbf16>) -> tensor<512xbf16> - tt.store %19, %5 : tensor<512x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 256 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : bf16 -// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [512, 256], strides: {{\[}}%[[VAL_5]], 1] : memref<*xbf16> to memref<512x256xbf16, strided<[?, 1]>> -// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [512], strides: [1] : memref<*xbf16> to memref<512xbf16, strided<[1]>> -// CHECK: %[[VAL_9:.*]] = memref.alloc() : memref<512x256xbf16> -// CHECK: memref.copy %[[VAL_7]], %[[VAL_9]] : memref<512x256xbf16, strided<[?, 1]>> to memref<512x256xbf16> -// CHECK: %[[VAL_10:.*]] = bufferization.to_tensor %[[VAL_9]] restrict writable : memref<512x256xbf16> -// CHECK: %[[VAL_11:.*]] = tensor.empty() : tensor<256x512xbf16> -// CHECK: %[[VAL_12:.*]] = linalg.transpose ins(%[[VAL_10]] : tensor<512x256xbf16>) outs(%[[VAL_11]] : tensor<256x512xbf16>) permutation = [1, 0] -// CHECK: %[[VAL_13:.*]] = tensor.empty() : tensor<512xbf16> -// CHECK: %[[VAL_14:.*]] = linalg.fill ins(%[[VAL_6]] : bf16) outs(%[[VAL_13]] : tensor<512xbf16>) -> tensor<512xbf16> -// CHECK: %[[VAL_15:.*]] = linalg.reduce ins(%[[VAL_12]] : tensor<256x512xbf16>) outs(%[[VAL_14]] : tensor<512xbf16>) dimensions = [0] -// CHECK: (%[[VAL_16:.*]]: bf16, %[[VAL_17:.*]]: bf16) { -// CHECK: %[[VAL_18:.*]] = arith.addf %[[VAL_16]], %[[VAL_17]] : bf16 -// CHECK: linalg.yield %[[VAL_18]] : bf16 -// CHECK: } -// CHECK: bufferization.materialize_in_destination %[[VAL_15]] in writable %[[VAL_8]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/reducesum_512_256_f32_axis0.mlir b/test/Conversion/TritonToLinalg/reducesum_512_256_f32_axis0.mlir deleted file mode 100644 index a63270ef..00000000 --- a/test/Conversion/TritonToLinalg/reducesum_512_256_f32_axis0.mlir +++ /dev/null @@ -1,51 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel(%afloat : !tt.ptr, - %res : !tt.ptr - ) -> () { - // offset calculations - %0 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32> - %c256 = arith.constant 256 : i32 - %ct256 = tt.splat %c256 : i32 -> tensor<512xi32> - %ws = arith.muli %ct256, %0 : tensor<512xi32> - %1 = tt.expand_dims %ws {axis = 1 : i32} : tensor<512xi32> -> tensor<512x1xi32> - %moff = tt.broadcast %1 : tensor<512x1xi32> -> tensor<512x256xi32> - %3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> - %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - %koff = tt.broadcast %4 : tensor<1x256xi32> -> tensor<512x256xi32> - %mkoff = arith.addi %moff, %koff : tensor<512x256xi32> - // afloat pointer - %8 = tt.splat %afloat : !tt.ptr -> tensor<512x256x!tt.ptr> - %9 = tt.addptr %8, %mkoff : tensor<512x256x!tt.ptr>, tensor<512x256xi32> - // res pointer - %18 = tt.splat %res : !tt.ptr -> tensor<256x!tt.ptr> - %19 = tt.addptr %18, %3 : tensor<256x!tt.ptr>, tensor<256xi32> - %afm = tt.load %9 : tensor<512x256x!tt.ptr> - %5 = "tt.reduce"(%afm) ({ - ^bb0(%arg5: f32, %arg6: f32): - %21 = arith.addf %arg5, %arg6 : f32 - tt.reduce.return %21 : f32 - }) {axis = 0 : i32} : (tensor<512x256xf32>) -> tensor<256xf32> - tt.store %19, %5 : tensor<256x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32>, %[[VAL_1:.*]]: memref<*xf32>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 256 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [512, 256], strides: {{\[}}%[[VAL_5]], 1] : memref<*xf32> to memref<512x256xf32, strided<[?, 1]>> -// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [256], strides: [1] : memref<*xf32> to memref<256xf32, strided<[1]>> -// CHECK: %[[VAL_9:.*]] = memref.alloc() : memref<512x256xf32> -// CHECK: memref.copy %[[VAL_7]], %[[VAL_9]] : memref<512x256xf32, strided<[?, 1]>> to memref<512x256xf32> -// CHECK: %[[VAL_10:.*]] = bufferization.to_tensor %[[VAL_9]] restrict writable : memref<512x256xf32> -// CHECK: %[[VAL_11:.*]] = tensor.empty() : tensor<256xf32> -// CHECK: %[[VAL_12:.*]] = linalg.fill ins(%[[VAL_6]] : f32) outs(%[[VAL_11]] : tensor<256xf32>) -> tensor<256xf32> -// CHECK: %[[VAL_13:.*]] = linalg.reduce ins(%[[VAL_10]] : tensor<512x256xf32>) outs(%[[VAL_12]] : tensor<256xf32>) dimensions = [0] -// CHECK: (%[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32) { -// CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_14]], %[[VAL_15]] : f32 -// CHECK: linalg.yield %[[VAL_16]] : f32 -// CHECK: } -// CHECK: bufferization.materialize_in_destination %[[VAL_13]] in writable %[[VAL_8]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/reducesum_512_256_f32_axis1.mlir b/test/Conversion/TritonToLinalg/reducesum_512_256_f32_axis1.mlir deleted file mode 100644 index 175d33f6..00000000 --- a/test/Conversion/TritonToLinalg/reducesum_512_256_f32_axis1.mlir +++ /dev/null @@ -1,53 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel(%afloat : !tt.ptr, - %res : !tt.ptr - ) -> () { - // offset calculations - %0 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32> - %c256 = arith.constant 256 : i32 - %ct256 = tt.splat %c256 : i32 -> tensor<512xi32> - %ws = arith.muli %ct256, %0 : tensor<512xi32> - %1 = tt.expand_dims %ws {axis = 1 : i32} : tensor<512xi32> -> tensor<512x1xi32> - %moff = tt.broadcast %1 : tensor<512x1xi32> -> tensor<512x256xi32> - %3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> - %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - %koff = tt.broadcast %4 : tensor<1x256xi32> -> tensor<512x256xi32> - %mkoff = arith.addi %moff, %koff : tensor<512x256xi32> - // afloat pointer - %8 = tt.splat %afloat : !tt.ptr -> tensor<512x256x!tt.ptr> - %9 = tt.addptr %8, %mkoff : tensor<512x256x!tt.ptr>, tensor<512x256xi32> - // res pointer - %18 = tt.splat %res : !tt.ptr -> tensor<512x!tt.ptr> - %19 = tt.addptr %18, %0 : tensor<512x!tt.ptr>, tensor<512xi32> - %afm = tt.load %9 : tensor<512x256x!tt.ptr> - %5 = "tt.reduce"(%afm) ({ - ^bb0(%arg5: f32, %arg6: f32): - %21 = arith.addf %arg5, %arg6 : f32 - tt.reduce.return %21 : f32 - }) {axis = 1 : i32} : (tensor<512x256xf32>) -> tensor<512xf32> - tt.store %19, %5 : tensor<512x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xf32>, %[[VAL_1:.*]]: memref<*xf32>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 256 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_7:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [512, 256], strides: {{\[}}%[[VAL_5]], 1] : memref<*xf32> to memref<512x256xf32, strided<[?, 1]>> -// CHECK: %[[VAL_8:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [512], strides: [1] : memref<*xf32> to memref<512xf32, strided<[1]>> -// CHECK: %[[VAL_9:.*]] = memref.alloc() : memref<512x256xf32> -// CHECK: memref.copy %[[VAL_7]], %[[VAL_9]] : memref<512x256xf32, strided<[?, 1]>> to memref<512x256xf32> -// CHECK: %[[VAL_10:.*]] = bufferization.to_tensor %[[VAL_9]] restrict writable : memref<512x256xf32> -// CHECK: %[[VAL_11:.*]] = tensor.empty() : tensor<256x512xf32> -// CHECK: %[[VAL_12:.*]] = linalg.transpose ins(%[[VAL_10]] : tensor<512x256xf32>) outs(%[[VAL_11]] : tensor<256x512xf32>) permutation = [1, 0] -// CHECK: %[[VAL_13:.*]] = tensor.empty() : tensor<512xf32> -// CHECK: %[[VAL_14:.*]] = linalg.fill ins(%[[VAL_6]] : f32) outs(%[[VAL_13]] : tensor<512xf32>) -> tensor<512xf32> -// CHECK: %[[VAL_15:.*]] = linalg.reduce ins(%[[VAL_12]] : tensor<256x512xf32>) outs(%[[VAL_14]] : tensor<512xf32>) dimensions = [0] -// CHECK: (%[[VAL_16:.*]]: f32, %[[VAL_17:.*]]: f32) { -// CHECK: %[[VAL_18:.*]] = arith.addf %[[VAL_16]], %[[VAL_17]] : f32 -// CHECK: linalg.yield %[[VAL_18]] : f32 -// CHECK: } -// CHECK: bufferization.materialize_in_destination %[[VAL_15]] in writable %[[VAL_8]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/reducesum_middle_dim.mlir b/test/Conversion/TritonToLinalg/reducesum_middle_dim.mlir deleted file mode 100644 index 33b9c7a1..00000000 --- a/test/Conversion/TritonToLinalg/reducesum_middle_dim.mlir +++ /dev/null @@ -1,60 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel(%afloat : !tt.ptr, - %res : !tt.ptr, - %out: tensor<32x16x!tt.ptr> - ) -> () { - // offset calculations - %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %c256 = arith.constant 256 : i32 - %ct256 = tt.splat %c256 : i32 -> tensor<32xi32> - %ws = arith.muli %ct256, %0 : tensor<32xi32> - %1 = tt.expand_dims %ws {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32> - %m2 = tt.broadcast %1 : tensor<32x1xi32> -> tensor<32x256xi32> - %100 = tt.expand_dims %m2 {axis = 2 : i32} : tensor<32x256xi32> -> tensor<32x256x1xi32> - %moff = tt.broadcast %100 : tensor<32x256x1xi32> -> tensor<32x256x16xi32> - %33 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> - %34 = tt.expand_dims %33 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - %k2 = tt.broadcast %34 : tensor<1x256xi32> -> tensor<32x256xi32> - %200 = tt.expand_dims %k2 {axis = 2 : i32} : tensor<32x256xi32> -> tensor<32x256x1xi32> - %koff = tt.broadcast %200 : tensor<32x256x1xi32> -> tensor<32x256x16xi32> - %23 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> - %24 = tt.expand_dims %23 {axis = 0 : i32} : tensor<16xi32> -> tensor<1x16xi32> - %n2 = tt.broadcast %24 : tensor<1x16xi32> -> tensor<256x16xi32> - %300 = tt.expand_dims %n2 {axis = 0 : i32} : tensor<256x16xi32> -> tensor<1x256x16xi32> - %noff = tt.broadcast %300 : tensor<1x256x16xi32> -> tensor<32x256x16xi32> - %mkoff = arith.addi %moff, %koff : tensor<32x256x16xi32> - %mknoff = arith.addi %mkoff, %noff : tensor<32x256x16xi32> - // afloat pointer - %8 = tt.splat %afloat : !tt.ptr -> tensor<32x256x16x!tt.ptr> - %9 = tt.addptr %8, %mknoff : tensor<32x256x16x!tt.ptr>, tensor<32x256x16xi32> - %afm = tt.load %9 : tensor<32x256x16x!tt.ptr> - %5 = "tt.reduce"(%afm) ({ - ^bb0(%arg5: bf16, %arg6: bf16): - %21 = arith.addf %arg5, %arg6 : bf16 - tt.reduce.return %21 : bf16 - }) {axis = 1 : i32} : (tensor<32x256x16xbf16>) -> tensor<32x16xbf16> - tt.store %out, %5 : tensor<32x16x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[ARG0:.*]]: memref<*xbf16>, %[[ARG1:.*]]: memref<*xbf16>, %[[ARG2:.*]]: memref<32x16xbf16>, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32, %[[ARG6:.*]]: i32, %[[ARG7:.*]]: i32, %[[ARG8:.*]]: i32) { -// CHECK-DAG: %[[VAL_0:.*]] = arith.constant 0.000000e+00 : bf16 -// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 256 : index -// CHECK: %[[VAL_2:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [32, 256, 16], strides: {{\[}}%[[VAL_1]], 1, 1] : memref<*xbf16> to memref<32x256x16xbf16, strided<[?, 1, 1]>> -// CHECK: %[[VAL_3:.*]] = memref.alloc() : memref<32x256x16xbf16> -// CHECK: memref.copy %[[VAL_2]], %[[VAL_3]] : memref<32x256x16xbf16, strided<[?, 1, 1]>> to memref<32x256x16xbf16> -// CHECK: %[[VAL_4:.*]] = bufferization.to_tensor %[[VAL_3]] restrict writable : memref<32x256x16xbf16> to tensor<32x256x16xbf16> -// CHECK: %[[VAL_5:.*]] = tensor.empty() : tensor<256x32x16xbf16> -// CHECK: %[[VAL_6:.*]] = linalg.transpose ins(%[[VAL_4]] : tensor<32x256x16xbf16>) outs(%[[VAL_5]] : tensor<256x32x16xbf16>) permutation = [1, 0, 2] -// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<32x16xbf16> -// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_0]] : bf16) outs(%[[VAL_7]] : tensor<32x16xbf16>) -> tensor<32x16xbf16> -// CHECK: %[[VAL_9:.*]] = linalg.reduce ins(%[[VAL_6]] : tensor<256x32x16xbf16>) outs(%[[VAL_8]] : tensor<32x16xbf16>) dimensions = [0] -// CHECK: (%[[VAL_10:.*]]: bf16, %[[VAL_11:.*]]: bf16) { -// CHECK: %[[VAL_12:.*]] = arith.addf %[[VAL_10]], %[[VAL_11]] : bf16 -// CHECK: linalg.yield %[[VAL_12]] : bf16 -// CHECK: } -// CHECK: bufferization.materialize_in_destination %[[VAL_9]] in writable %[[ARG2]] : (tensor<32x16xbf16>, memref<32x16xbf16>) -> () -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/reducesum_scalar.mlir b/test/Conversion/TritonToLinalg/reducesum_scalar.mlir deleted file mode 100644 index a5ca4d0a..00000000 --- a/test/Conversion/TritonToLinalg/reducesum_scalar.mlir +++ /dev/null @@ -1,38 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel(%afloat : !tt.ptr, %res : !tt.ptr) - { - %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %1 = tt.splat %afloat : !tt.ptr -> tensor<128x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<128x!tt.ptr>, tensor<128xi32> - %afm = tt.load %2 : tensor<128x!tt.ptr> - %3 = "tt.reduce"(%afm) ({ - ^bb0(%arg5: bf16, %arg6: bf16): - %21 = arith.addf %arg5, %arg6 : bf16 - tt.reduce.return %21 : bf16 - }) {axis = 0 : i32} : (tensor<128xbf16>) -> bf16 - tt.store %res, %3 : !tt.ptr - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_6:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [128], strides: [1] : memref<*xbf16> to memref<128xbf16, strided<[1]>> -// CHECK: %[[VAL_7:.*]] = memref.alloc() : memref<128xbf16> -// CHECK: memref.copy %[[VAL_6]], %[[VAL_7]] : memref<128xbf16, strided<[1]>> to memref<128xbf16> -// CHECK: %[[VAL_8:.*]] = bufferization.to_tensor %[[VAL_7]] restrict writable : memref<128xbf16> -// CHECK: %[[VAL_9:.*]] = bufferization.alloc_tensor() : tensor -// CHECK: %[[VAL_10:.*]] = tensor.insert %[[VAL_5]] into %[[VAL_9]][] : tensor -// CHECK: %[[VAL_11:.*]] = linalg.reduce ins(%[[VAL_8]] : tensor<128xbf16>) outs(%[[VAL_10]] : tensor) dimensions = [0] -// CHECK: (%[[VAL_12:.*]]: bf16, %[[VAL_13:.*]]: f32) { -// CHECK: %[[VAL_14:.*]] = arith.extf %[[VAL_12]] : bf16 to f32 -// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_14]], %[[VAL_13]] : f32 -// CHECK: linalg.yield %[[VAL_15]] : f32 -// CHECK: } -// CHECK: %[[VAL_16:.*]] = tensor.extract %[[VAL_11]][] : tensor -// CHECK: %[[VAL_17:.*]] = arith.truncf %[[VAL_16]] : f32 to bf16 -// CHECK: %[[VAL_18:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [1], strides: [1] : memref<*xbf16> to memref<1xbf16, strided<[1]>> -// CHECK: affine.store %[[VAL_17]], %[[VAL_18]][0] : memref<1xbf16, strided<[1]>> -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/triton_assert.mlir b/test/Conversion/TritonToLinalg/triton_assert.mlir deleted file mode 100644 index d2ed1e8e..00000000 --- a/test/Conversion/TritonToLinalg/triton_assert.mlir +++ /dev/null @@ -1,50 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s - -// CHECK: #map = affine_map<(d0) -> (d0)> -// CHECK: #map1 = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: #map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> - -tt.func public @assert_tensor_1d() { - %0 = tensor.empty() : tensor<4xi1> - tt.assert %0, "message" : tensor<4xi1> - tt.return -} - -// CHECK-LABEL: func.func @assert_tensor_1d -// CHECK-NOT: tt.assert -// CHECK: linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} ins(%0 : tensor<4xi1>) { -// CHECK: ^bb0(%in: i1): -// CHECK: cf.assert %in, "Assertion `message` failed" -// CHECK: linalg.yield -// CHECK: } -// CHECK-NOT: tt.assert - -tt.func public @assert_tensor_2d() { - %0 = tensor.empty() : tensor<4x4xi1> - tt.assert %0, "message" : tensor<4x4xi1> - tt.return -} - -// CHECK-LABEL: func.func @assert_tensor_2d -// CHECK-NOT: tt.assert -// CHECK: linalg.generic {indexing_maps = [#map1], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<4x4xi1>) { -// CHECK: ^bb0(%in: i1): -// CHECK: cf.assert %in, "Assertion `message` failed" -// CHECK: linalg.yield -// CHECK: } -// CHECK-NOT: tt.assert - -tt.func public @assert_tensor_3d() { - %0 = tensor.empty() : tensor<4x4x4xi1> - tt.assert %0, "message" : tensor<4x4x4xi1> - tt.return -} - -// CHECK-LABEL: func.func @assert_tensor_3d -// CHECK-NOT: tt.assert -// CHECK: linalg.generic {indexing_maps = [#map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<4x4x4xi1>) { -// CHECK: ^bb0(%in: i1): -// CHECK: cf.assert %in, "Assertion `message` failed" -// CHECK: linalg.yield -// CHECK: } -// CHECK-NOT: tt.assert diff --git a/test/Conversion/TritonToLinalg/unsupported_extern_elementwise.mlir b/test/Conversion/TritonToLinalg/unsupported_extern_elementwise.mlir deleted file mode 100644 index 11e588bc..00000000 --- a/test/Conversion/TritonToLinalg/unsupported_extern_elementwise.mlir +++ /dev/null @@ -1,35 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s - -module { - tt.func public @rand(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - %3 = tt.load %2 : tensor<8x!tt.ptr> - %4 = tt.extern_elementwise %3, %0 {libname = "", libpath = "", pure = true, symbol = "some_symbol"} : (tensor<8xi32>, tensor<8xi32>) -> tensor<8xi32> - %5 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> - %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - tt.store %6, %4 : tensor<8x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @rand -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xi32>, [[PARAM_1_:%.+]]: memref<*xi32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { -// CHECK: [[VAR_0_:%.+]] = tensor.empty() : tensor<8xi32> -// CHECK: [[VAR_1_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_0_]] : tensor<8xi32>) { -// CHECK: ^bb0([[out:.+]]: i32): -// CHECK: [[VAR_4_:%.+]] = linalg.index 0 : index -// CHECK: [[VAR_5_:%.+]] = arith.index_cast [[VAR_4_]] : index to i32 -// CHECK: linalg.yield [[VAR_5_]] : i32 -// CHECK: } -> tensor<8xi32> -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [8], strides: [1] : memref<*xi32> to memref<8xi32, strided<[1]>> -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<8xi32> -// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<8xi32, strided<[1]>> to memref<8xi32> -// CHECK: [[VAR_2_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<8xi32> -// CHECK-DAG: [[VAR_3_:%.+]] = tt.extern_elementwise [[VAR_2_]], [[VAR_1_]] {libname = "", libpath = "", pure = true, symbol = "some_symbol"} : (tensor<8xi32>, tensor<8xi32>) -> tensor<8xi32> -// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [8], strides: [1] : memref<*xi32> to memref<8xi32, strided<[1]>> -// CHECK: bufferization.materialize_in_destination [[VAR_3_]] in writable [[VAR_reinterpret_cast_0_]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/use_dot_opc.mlir b/test/Conversion/TritonToLinalg/use_dot_opc.mlir deleted file mode 100644 index df5f2140..00000000 --- a/test/Conversion/TritonToLinalg/use_dot_opc.mlir +++ /dev/null @@ -1,76 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr, - %arg2 : !tt.ptr - ) - { - %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %c64 = arith.constant 128 : i32 - %1 = tt.splat %c64 : i32 -> tensor<128xi32> - %2 = arith.muli %0, %1 : tensor<128xi32> - %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> - %4 = tt.broadcast %3 : tensor<128x1xi32> -> tensor<128x64xi32> - %5 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %6 = tt.expand_dims %5 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> - %7 = tt.broadcast %6 : tensor<1x64xi32> -> tensor<128x64xi32> - %8 = arith.addi %4, %7 : tensor<128x64xi32> - %10 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> - %11 = tt.expand_dims %10 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - %12 = tt.broadcast %11 : tensor<1x256xi32> -> tensor<64x256xi32> - %13 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %c256 = arith.constant 256 : i32 - %14 = tt.splat %c256 : i32 -> tensor<64xi32> - %15 = arith.muli %13, %14 : tensor<64xi32> - %16 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> - %17 = tt.broadcast %16 : tensor<64x1xi32> -> tensor<64x256xi32> - %18 = arith.addi %12, %17 : tensor<64x256xi32> - %20 = tt.splat %c256 : i32 -> tensor<128xi32> - %21 = arith.muli %0, %20 : tensor<128xi32> - %22 = tt.expand_dims %21 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> - %23 = tt.broadcast %22 : tensor<128x1xi32> -> tensor<128x256xi32> - %24 = tt.expand_dims %10 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - %25 = tt.broadcast %24 {axis = 0 : i32} : tensor<1x256xi32> -> tensor<128x256xi32> - %26 = arith.addi %23, %25 : tensor<128x256xi32> - %30 = tt.splat %arg0 : !tt.ptr -> tensor<128x64x!tt.ptr> - %31 = tt.addptr %30, %8 : tensor<128x64x!tt.ptr>, tensor<128x64xi32> - %32 = tt.load %31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<128x64x!tt.ptr> - %40 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr> - %41 = tt.addptr %40, %18 : tensor<64x256x!tt.ptr>, tensor<64x256xi32> - %42 = tt.load %41 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<64x256x!tt.ptr> - %50 = tt.splat %arg2 : !tt.ptr -> tensor<128x256x!tt.ptr> - %51 = tt.addptr %50, %26 : tensor<128x256x!tt.ptr>, tensor<128x256xi32> - %cf0 = arith.constant 0.0 : bf16 - %71 = tt.splat %cf0 : bf16 -> tensor<128x256xbf16> - %60 = tt.dot %32, %42, %71 {inputPrecision = 2 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xbf16> - tt.store %51, %60 : tensor<128x256x!tt.ptr> - tt.store %51, %71 : tensor<128x256x!tt.ptr> - tt.return - } -} - -// CHECK-LABEL: func.func @kernel -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: memref<*xbf16>, [[PARAM_2_:%.+]]: memref<*xbf16>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) { -// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index -// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0.000000e+00 : bf16 -// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<128x256xbf16> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_]] : bf16) outs([[VAR_0_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [128, 64], strides: {{.}}[[CST_128_]], 1] : memref<*xbf16> to memref<128x64xbf16, strided<[?, 1]>> -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<128x64xbf16> -// CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<128x64xbf16, strided<[?, 1]>> to memref<128x64xbf16> -// CHECK-DAG: [[VAR_2_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<128x64xbf16> -// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [64, 256], strides: {{.}}[[CST_256_]], 1] : memref<*xbf16> to memref<64x256xbf16, strided<[?, 1]>> -// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref<64x256xbf16> -// CHECK: memref.copy [[VAR_reinterpret_cast_0_]], [[RES_1_]] : memref<64x256xbf16, strided<[?, 1]>> to memref<64x256xbf16> -// CHECK-DAG: [[VAR_3_:%.+]] = bufferization.to_tensor [[RES_1_]] restrict writable : memref<64x256xbf16> -// CHECK-DAG: [[VAR_reinterpret_cast_2_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: [0], sizes: [128, 256], strides: {{.}}[[CST_256_]], 1] : memref<*xbf16> to memref<128x256xbf16, strided<[?, 1]>> -// CHECK-DAG: [[VAR_4_:%.+]] = tensor.empty() : tensor<128x256xbf16> -// CHECK: [[VAR_5_:%.+]] = linalg.fill ins([[CST_0_]] : bf16) outs([[VAR_4_:%.+]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> -// CHECK: [[VAR_6_:%.+]] = linalg.matmul ins([[VAR_2_]], [[VAR_3_]] : tensor<128x64xbf16>, tensor<64x256xbf16>) outs([[VAR_5_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> -// CHECK: bufferization.materialize_in_destination [[VAR_6_]] in writable [[VAR_reinterpret_cast_2_]] -// CHECK: bufferization.materialize_in_destination [[VAR_1_]] in writable [[VAR_reinterpret_cast_2_]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/use_end_chain.mlir b/test/Conversion/TritonToLinalg/use_end_chain.mlir deleted file mode 100644 index a66116d4..00000000 --- a/test/Conversion/TritonToLinalg/use_end_chain.mlir +++ /dev/null @@ -1,95 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr - ) - { - %0 = tt.make_range {end = 768 : i32, start = 512 : i32}:tensor<256xi32> - // offset = [512] size = 256, stride = 1 - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<256xi32> -> tensor<256x1xi32> - // offset = [512,0], size = [256,1], stride = [1,0] - %2 = tt.broadcast %1 : tensor<256x1xi32> -> tensor<256x128xi32> - // offset = [512,0], size = [256,128], stride = [1,0] - %5 = tt.make_range {end = 1152 : i32, start = 1024 : i32}:tensor<128xi32> - // offset = 1024, size = 128, stride = 1 - %6 = tt.expand_dims %5 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> - // offset = [0,1024], size = [1,128], stride = [0,1] - %7 = tt.broadcast %6 : tensor<1x128xi32> -> tensor<256x128xi32> - // offset = [0,1024], size = [256,128], stride = [0,1] - %c6 = arith.constant 6 : i32 - %splat6 = tt.splat %c6 : i32 -> tensor<256x128xi32> - %scale7 = arith.muli %7, %splat6 : tensor<256x128xi32> - // offset = [0,6144], size = [256,128], stride = [0,6] - %14 = arith.addi %2, %scale7 : tensor<256x128xi32> - // offset = [512,6144], size = [256,128], stride = [1,6] - // mixed use - %17 = tt.splat %arg1 : !tt.ptr -> tensor<256x128x!tt.ptr> - %18 = tt.addptr %17, %14 : tensor<256x128x!tt.ptr>, tensor<256x128xi32> - %19 = tt.load %18 : tensor<256x128x!tt.ptr> - tt.store %18, %19 : tensor<256x128x!tt.ptr> - %20 = arith.sitofp %14 : tensor<256x128xi32> to tensor<256x128xbf16> - tt.store %18, %20 : tensor<256x128x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32) { -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 6 : index -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 6 : i32 -// CHECK-DAG: %[[CST_512_:.*]] = arith.constant 512 : i32 -// CHECK-DAG: %[[CST_1024_:.*]] = arith.constant 1024 : i32 -// CHECK: %[[VAL_30:.*]] = tensor.empty() : tensor<256x128xi32> -// CHECK: %[[VAL_31:.*]] = linalg.fill ins(%[[VAL_7]] : i32) outs(%[[VAL_30]] : tensor<256x128xi32>) -> tensor<256x128xi32> -// CHECK: %[[VAL_8:.*]] = tensor.empty() : tensor<256xi32> -// CHECK: %[[VAL_9:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%[[VAL_8]] : tensor<256xi32>) { -// CHECK: ^bb0(%[[VAL_10:.*]]: i32): -// CHECK: %[[VAL_11:.*]] = linalg.index 0 : index -// CHECK: %[[VAL_12:.*]] = arith.index_cast %[[VAL_11]] : index to i32 -// CHECK: %[[VAL_55:.*]] = arith.addi %[[VAL_12]], %[[CST_512_]] : i32 -// CHECK: linalg.yield %[[VAL_55]] : i32 -// CHECK: } -> tensor<256xi32> -// CHECK: %[[VAL_13:.*]] = tensor.expand_shape %[[VAL_14:.*]] {{\[\[}}0, 1]] output_shape [256, 1] : tensor<256xi32> into tensor<256x1xi32> -// CHECK: %[[VAL_15:.*]] = tensor.empty() : tensor<256x128xi32> -// CHECK: %[[VAL_16:.*]] = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_13]] : tensor<256x1xi32>) outs(%[[VAL_15]] : tensor<256x128xi32>) attrs = {broadcastDims = array} { -// CHECK: ^bb0(%[[VAL_17:.*]]: i32, %[[VAL_18:.*]]: i32): -// CHECK: linalg.yield %[[VAL_17]] : i32 -// CHECK: } -> tensor<256x128xi32> -// CHECK: %[[VAL_19:.*]] = tensor.empty() : tensor<128xi32> -// CHECK: %[[VAL_20:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%[[VAL_19]] : tensor<128xi32>) { -// CHECK: ^bb0(%[[VAL_21:.*]]: i32): -// CHECK: %[[VAL_22:.*]] = linalg.index 0 : index -// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_22]] : index to i32 -// CHECK: %[[VAL_56:.*]] = arith.addi %[[VAL_23]], %[[CST_1024_]] : i32 -// CHECK: linalg.yield %[[VAL_56]] : i32 -// CHECK: } -> tensor<128xi32> -// CHECK: %[[VAL_24:.*]] = tensor.expand_shape %[[VAL_25:.*]] {{\[\[}}0, 1]] output_shape [1, 128] : tensor<128xi32> into tensor<1x128xi32> -// CHECK: %[[VAL_26:.*]] = tensor.empty() : tensor<256x128xi32> -// CHECK: %[[VAL_27:.*]] = linalg.generic {indexing_maps = [#map3, #map2], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_24]] : tensor<1x128xi32>) outs(%[[VAL_26]] : tensor<256x128xi32>) attrs = {broadcastDims = array} { -// CHECK: ^bb0(%[[VAL_28:.*]]: i32, %[[VAL_29:.*]]: i32): -// CHECK: linalg.yield %[[VAL_28]] : i32 -// CHECK: } -> tensor<256x128xi32> -// CHECK: %[[VAL_32:.*]] = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_33:.*]], %[[VAL_31]] : tensor<256x128xi32>, tensor<256x128xi32>) outs(%[[VAL_33]] : tensor<256x128xi32>) { -// CHECK: ^bb0(%[[VAL_34:.*]]: i32, %[[VAL_35:.*]]: i32, %[[VAL_36:.*]]: i32): -// CHECK: %[[VAL_37:.*]] = arith.muli %[[VAL_34]], %[[VAL_35]] : i32 -// CHECK: linalg.yield %[[VAL_37]] : i32 -// CHECK: } -> tensor<256x128xi32> -// CHECK: %[[VAL_38:.*]] = linalg.generic {indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_39:.*]], %[[VAL_40:.*]] : tensor<256x128xi32>, tensor<256x128xi32>) outs(%[[VAL_39]] : tensor<256x128xi32>) { -// CHECK: ^bb0(%[[VAL_41:.*]]: i32, %[[VAL_42:.*]]: i32, %[[VAL_43:.*]]: i32): -// CHECK: %[[VAL_44:.*]] = arith.addi %[[VAL_41]], %[[VAL_42]] : i32 -// CHECK: linalg.yield %[[VAL_44]] : i32 -// CHECK: } -> tensor<256x128xi32> -// CHECK: %[[VAL_45:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}6656], sizes: [256, 128], strides: [1, %[[VAL_6]]] : memref<*xbf16> to memref<256x128xbf16, strided<[1, ?], offset: 6656>> -// CHECK: %[[VAL_46:.*]] = memref.alloc() : memref<256x128xbf16> -// CHECK: memref.copy %[[VAL_45]], %[[VAL_46]] : memref<256x128xbf16, strided<[1, ?], offset: 6656>> to memref<256x128xbf16> -// CHECK: %[[VAL_47:.*]] = bufferization.to_tensor %[[VAL_46]] restrict writable : memref<256x128xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_47]] in writable %[[VAL_45]] -// CHECK: %[[VAL_48:.*]] = tensor.empty() : tensor<256x128xbf16> -// CHECK: %[[VAL_49:.*]] = linalg.generic {indexing_maps = [#map2, #map2], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_50:.*]] : tensor<256x128xi32>) outs(%[[VAL_48]] : tensor<256x128xbf16>) { -// CHECK: ^bb0(%[[VAL_51:.*]]: i32, %[[VAL_52:.*]]: bf16): -// CHECK: %[[VAL_53:.*]] = arith.sitofp %[[VAL_51]] : i32 to bf16 -// CHECK: linalg.yield %[[VAL_53]] : bf16 -// CHECK: } -> tensor<256x128xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_54:.*]] in writable %[[VAL_45]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/use_mid_chain.mlir b/test/Conversion/TritonToLinalg/use_mid_chain.mlir deleted file mode 100644 index f4a855aa..00000000 --- a/test/Conversion/TritonToLinalg/use_mid_chain.mlir +++ /dev/null @@ -1,64 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -module { - tt.func @kernel( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr, - %arg2 : !tt.ptr - ) - { - %0 = tt.make_range {end = 768 : i32, start = 512 : i32}:tensor<256xi32> - // offset = [512] size = 256, stride = 1 - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<256xi32> -> tensor<256x1xi32> - // offset = [512,0], size = [256,1], stride = [1,0] - %2 = tt.broadcast %1 : tensor<256x1xi32> -> tensor<256x128xi32> - // offset = [512,0], size = [256,128], stride = [1,0] - // mixed use - %5 = tt.make_range {end = 1152 : i32, start = 1024 : i32}:tensor<128xi32> - // offset = 1024, size = 128, stride = 1 - %6 = tt.expand_dims %5 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> - // offset = [0,1024], size = [1,128], stride = [0,1] - %7 = tt.broadcast %6 : tensor<1x128xi32> -> tensor<256x128xi32> - // offset = [0,1024], size = [256,128], stride = [0,1] - %c6 = arith.constant 6 : i32 - %splat6 = tt.splat %c6 : i32 -> tensor<256x128xi32> - %scale7 = arith.muli %7, %splat6 : tensor<256x128xi32> - // offset = [0,6144], size = [256,128], stride = [0,6] - %14 = arith.addi %2, %scale7 : tensor<256x128xi32> - // offset = [512,6144], size = [256,128], stride = [1,6] - %17 = tt.splat %arg1 : !tt.ptr -> tensor<256x128x!tt.ptr> - %18 = tt.addptr %17, %14 : tensor<256x128x!tt.ptr>, tensor<256x128xi32> - %19 = tt.load %18 : tensor<256x128x!tt.ptr> - tt.store %18, %19 : tensor<256x128x!tt.ptr> - %20 = tt.splat %arg2 : !tt.ptr -> tensor<256x128x!tt.ptr> - %21 = tt.addptr %20, %14 : tensor<256x128x!tt.ptr>, tensor<256x128xi32> - tt.store %21, %2 : tensor<256x128x!tt.ptr> - tt.return - } -} -// CHECK-LABEL: func.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: memref<*xbf16>, %[[VAL_1:.*]]: memref<*xbf16>, %[[VAL_2:.*]]: memref<*xi32>, %[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i32) { -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 6 : index -// CHECK-DAG: %[[VAL_25:.*]] = arith.constant 512 : i32 -// CHECK: %[[VAL_8:.*]] = tensor.empty() : tensor<256xi32> -// CHECK: %[[VAL_9:.*]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%[[VAL_8]] : tensor<256xi32>) { -// CHECK: ^bb0(%[[VAL_10:.*]]: i32): -// CHECK: %[[VAL_11:.*]] = linalg.index 0 : index -// CHECK: %[[VAL_12:.*]] = arith.index_cast %[[VAL_11]] : index to i32 -// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_12]], %[[VAL_25]] : i32 -// CHECK: linalg.yield %[[VAL_24]] : i32 -// CHECK: } -> tensor<256xi32> -// CHECK: %[[VAL_13:.*]] = tensor.expand_shape %[[VAL_14:.*]] {{\[\[}}0, 1]] output_shape [256, 1] : tensor<256xi32> into tensor<256x1xi32> -// CHECK: %[[VAL_15:.*]] = tensor.empty() : tensor<256x128xi32> -// CHECK: %[[VAL_16:.*]] = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_13]] : tensor<256x1xi32>) outs(%[[VAL_15]] : tensor<256x128xi32>) attrs = {broadcastDims = array} { -// CHECK: ^bb0(%[[VAL_17:.*]]: i32, %[[VAL_18:.*]]: i32): -// CHECK: linalg.yield %[[VAL_17]] : i32 -// CHECK: } -> tensor<256x128xi32> -// CHECK: %[[VAL_19:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}6656], sizes: [256, 128], strides: [1, %[[VAL_7]]] : memref<*xbf16> to memref<256x128xbf16, strided<[1, ?], offset: 6656>> -// CHECK: %[[VAL_20:.*]] = memref.alloc() : memref<256x128xbf16> -// CHECK: memref.copy %[[VAL_19]], %[[VAL_20]] : memref<256x128xbf16, strided<[1, ?], offset: 6656>> to memref<256x128xbf16> -// CHECK: %[[VAL_21:.*]] = bufferization.to_tensor %[[VAL_20]] restrict writable : memref<256x128xbf16> -// CHECK: bufferization.materialize_in_destination %[[VAL_21]] in writable %[[VAL_19]] -// CHECK: %[[VAL_22:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: {{\[}}6656], sizes: [256, 128], strides: [1, %[[VAL_7]]] : memref<*xi32> to memref<256x128xi32, strided<[1, ?], offset: 6656>> -// CHECK: bufferization.materialize_in_destination %[[VAL_23:.*]] in writable %[[VAL_22]] -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/wraparound_side_by_side.mlir b/test/Conversion/TritonToLinalg/wraparound_side_by_side.mlir deleted file mode 100644 index cb27d947..00000000 --- a/test/Conversion/TritonToLinalg/wraparound_side_by_side.mlir +++ /dev/null @@ -1,133 +0,0 @@ -// XFAIL: * -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s - -module { - tt.func public @wrap_side_by_side_masked_loop_01234567(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) { - %cst = arith.constant dense<-9.900000e+01> : tensor<4x4xf32> - %c1_i32 = arith.constant 1 : i32 - %c0_i32 = arith.constant 0 : i32 - %c2_i32 = arith.constant 2 : i32 - %cst_0 = arith.constant dense<2> : tensor<4x1xi32> - %cst_1 = arith.constant dense<6> : tensor<4xi32> - %cst_2 = arith.constant dense<2> : tensor<4xi32> - %c4_i32 = arith.constant 4 : i32 - %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %1 = arith.addi %0, %cst_2 : tensor<4xi32> - %2 = arith.addi %0, %cst_1 : tensor<4xi32> - %3 = tt.splat %arg3 : i32 -> tensor<4xi32> - %4 = arith.remsi %2, %3 : tensor<4xi32> - %5 = tt.expand_dims %1 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - %6 = tt.splat %arg4 : i32 -> tensor<4x1xi32> - %7 = arith.muli %5, %6 : tensor<4x1xi32> - %8 = tt.expand_dims %4 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %9 = tt.splat %arg5 : i32 -> tensor<1x4xi32> - %10 = arith.muli %8, %9 : tensor<1x4xi32> - %11 = tt.broadcast %7 : tensor<4x1xi32> -> tensor<4x4xi32> - %12 = tt.broadcast %10 : tensor<1x4xi32> -> tensor<4x4xi32> - %13 = arith.addi %11, %12 : tensor<4x4xi32> - %14 = tt.splat %arg0 : !tt.ptr -> tensor<4x4x!tt.ptr> - %15 = tt.addptr %14, %13 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - %16 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - %17 = tt.splat %arg6 : i32 -> tensor<4x1xi32> - %18 = arith.muli %17, %16 : tensor<4x1xi32> - %19 = tt.splat %arg1 : !tt.ptr -> tensor<4x1x!tt.ptr> - %20 = tt.addptr %19, %18 : tensor<4x1x!tt.ptr>, tensor<4x1xi32> - %21 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %22 = tt.splat %arg7 : i32 -> tensor<1x4xi32> - %23 = arith.muli %22, %21 : tensor<1x4xi32> - %24 = tt.broadcast %20 : tensor<4x1x!tt.ptr> -> tensor<4x4x!tt.ptr> - %25 = tt.broadcast %23 : tensor<1x4xi32> -> tensor<4x4xi32> - %26 = tt.addptr %24, %25 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - %27 = arith.cmpi slt, %16, %cst_0 : tensor<4x1xi32> - %28 = tt.broadcast %27 : tensor<4x1xi1> -> tensor<4x4xi1> - %29 = arith.muli %arg4, %c4_i32 : i32 - %30 = tt.splat %29 : i32 -> tensor<4x4xi32> - %31 = arith.muli %arg5, %c4_i32 : i32 - %32 = tt.splat %31 : i32 -> tensor<4x4xi32> - %33:2 = scf.for %arg8 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg9 = %15, %arg10 = %26) -> (tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr>) : i32 { - %34 = tt.load %arg9, %28, %cst : tensor<4x4x!tt.ptr> - tt.store %arg10, %34 : tensor<4x4x!tt.ptr> - %35 = tt.addptr %arg9, %30 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - %36 = tt.addptr %arg10, %32 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - scf.yield %35, %36 : tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr> - } - tt.return - } -} - -// CHECK-LABEL: func.func @wrap_side_by_side_masked_loop_01234567 -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32, [[PARAM_12_:%.+]]: i32, [[PARAM_13_:%.+]]: i32) { -// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index -// CHECK-DAG: [[CST_minus_9_dot_900000_:%.+]] = arith.constant -9.900000e+01 : f32 -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index -// CHECK-DAG: [[CST_6_:%.+]] = arith.constant 6 : index -// CHECK-DAG: [[CST_1_1_:%.+]] = arith.constant 1 : i32 -// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : i32 -// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 -// CHECK-DAG: [[CST_4_1_:%.+]] = arith.constant 4 : i32 -// CHECK-DAG: [[CST_2_1_:%.+]] = arith.constant 2 : index -// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_1_:%.+]] = arith.muli [[VAR_0_]], [[CST_2_1_]] : index -// CHECK-DAG: [[VAR_2_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index -// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK: [[VAR_4_:%.+]] = arith.muli [[VAR_3_]], [[CST_6_]] : index -// CHECK: [[VAR_5_:%.+]] = arith.addi [[VAR_1_]], [[VAR_4_]] : index -// CHECK: [[VAR_6_:%.+]] = arith.remsi [[VAR_5_]], [[VAR_2_]] : index -// CHECK-DAG: [[VAR_7_:%.+]] = arith.subi [[VAR_5_]], [[VAR_6_]] : index -// CHECK-DAG: [[VAR_8_:%.+]] = arith.addi [[VAR_6_]], [[CST_4_]] : index -// CHECK: [[VAR_9_:%.+]] = arith.minsi [[VAR_8_]], [[VAR_2_]] : index -// CHECK: [[VAR_10_:%.+]] = arith.subi [[VAR_9_]], [[VAR_6_]] : index -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_5_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_10_]]{{.}}, strides: {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref<4x?xf32, strided<[?, ?], offset: ?>> -// CHECK-DAG: [[VAR_11_:%.+]] = arith.subi [[CST_4_]], [[VAR_10_]] : index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_7_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_11_]]{{.}}, strides: {{.}}[[VAR_0_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref<4x?xf32, strided<[?, ?], offset: ?>> -// CHECK-DAG: [[VAR_12_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index -// CHECK-DAG: [[VAR_13_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index -// CHECK-DAG: [[VAR_14_:%.+]] = arith.muli [[PARAM_4_]], [[CST_4_1_]] : i32 -// CHECK-DAG: [[VAR_15_:%.+]] = arith.muli [[PARAM_5_]], [[CST_4_1_]] : i32 -// CHECK-DAG: [[VAR_16_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_17_:%.+]] = arith.muli [[VAR_16_]], [[CST_2_1_]] : index -// CHECK-DAG: [[VAR_18_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index -// CHECK-DAG: [[VAR_19_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_20_:%.+]] = arith.muli [[VAR_19_]], [[CST_6_]] : index -// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: [4, 4], strides: {{.}}[[VAR_12_]], [[VAR_13_]]{{.}} : memref<*xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_21_:%.+]]:6 = scf.for [[VAR_arg14_:%.+]] = [[CST_0_1_]] to [[CST_2_]] step [[CST_1_1_]] iter_args([[VAR_arg15_:%.+]] = [[VAR_reinterpret_cast_]], [[VAR_arg16_:%.+]] = [[VAR_reinterpret_cast_]]_1, [[VAR_arg17_:%.+]] = [[VAR_17_]], [[VAR_arg18_:%.+]] = [[CST_0_]], [[VAR_arg19_:%.+]] = [[CST_0_]], [[VAR_arg20_:%.+]] = [[VAR_reinterpret_cast_]]_0) -> (memref<4x?xf32, strided<[?, ?], offset: ?>>, memref<4x4xf32, strided<[?, ?], offset: ?>>, index, index, index, memref<4x?xf32, strided<[?, ?], offset: ?>>) : i32 { -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x4xf32> -// CHECK: linalg.fill ins([[CST_minus_9_dot_900000_]] : f32) outs([[RES_]] : memref<4x4xf32>) -// CHECK: [[VAR_dim_:%.+]] = memref.dim [[VAR_arg15_]], [[CST_1_]] : memref<4x?xf32, strided<[?, ?], offset: ?>> -// CHECK: [[VAR_22_:%.+]] = arith.minsi [[VAR_dim_]], [[CST_4_]] : index -// CHECK-DAG: [[VAR_23_:%.+]] = arith.subi [[CST_4_]], [[VAR_22_]] : index -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_arg15_]][0, 0] [2, [[VAR_22_]]{{.}} [1, 1] : memref<4x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_arg20_]][0, 0] [2, [[VAR_23_]]{{.}} [1, 1] : memref<4x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[?, ?], offset: ?>> -// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[RES_]][0, 0] [2, [[VAR_22_]]{{.}} [1, 1] : memref<4x4xf32> to memref<2x?xf32, strided<[4, 1]>> -// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_]][0, [[VAR_22_]]{{.}} [2, [[VAR_23_]]{{.}} [1, 1] : memref<4x4xf32> to memref<2x?xf32, strided<[4, 1], offset: ?>> -// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_]]_3 : memref<2x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[4, 1]>> -// CHECK: memref.copy [[VAR_subview_2_]], [[VAR_subview_4_]] : memref<2x?xf32, strided<[?, ?], offset: ?>> to memref<2x?xf32, strided<[4, 1], offset: ?>> -// CHECK: [[VAR_24_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x4xf32> -// CHECK: bufferization.materialize_in_destination [[VAR_24_]] in writable [[VAR_arg16_]] -// CHECK: [[VAR_25_:%.+]] = arith.index_cast [[VAR_14_]] : i32 to index -// CHECK: [[VAR_26_:%.+]] = arith.addi [[VAR_arg17_]], [[VAR_25_]] : index -// CHECK: [[VAR_27_:%.+]] = arith.addi [[VAR_26_]], [[VAR_20_]] : index -// CHECK: [[VAR_28_:%.+]] = arith.remsi [[VAR_27_]], [[VAR_18_]] : index -// CHECK-DAG: [[VAR_29_:%.+]] = arith.subi [[VAR_27_]], [[VAR_28_]] : index -// CHECK-DAG: [[VAR_30_:%.+]] = arith.addi [[VAR_28_]], [[CST_4_]] : index -// CHECK: [[VAR_31_:%.+]] = arith.minsi [[VAR_30_]], [[VAR_18_]] : index -// CHECK: [[VAR_32_:%.+]] = arith.subi [[VAR_31_]], [[VAR_28_]] : index -// CHECK-DAG: [[VAR_reinterpret_cast_5_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_27_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_32_]]{{.}}, strides: {{.}}[[VAR_16_]], [[VAR_19_]]{{.}} : memref<*xf32> to memref<4x?xf32, strided<[?, ?], offset: ?>> -// CHECK-DAG: [[VAR_33_:%.+]] = arith.subi [[CST_4_]], [[VAR_32_]] : index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_6_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_29_]]{{.}}, sizes: {{.}}[[CST_4_]], [[VAR_33_]]{{.}}, strides: {{.}}[[VAR_16_]], [[VAR_19_]]{{.}} : memref<*xf32> to memref<4x?xf32, strided<[?, ?], offset: ?>> -// CHECK-DAG: [[VAR_34_:%.+]] = arith.index_cast [[VAR_15_]] : i32 to index -// CHECK: [[VAR_35_:%.+]] = arith.addi [[VAR_arg18_]], [[VAR_34_]] : index -// CHECK: [[VAR_36_:%.+]] = arith.addi [[VAR_35_]], [[VAR_arg19_]] : index -// CHECK: [[VAR_reinterpret_cast_7_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_36_]]{{.}}, sizes: [4, 4], strides: {{.}}[[VAR_12_]], [[VAR_13_]]{{.}} : memref<*xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> -// CHECK: scf.yield [[VAR_reinterpret_cast_5_]], [[VAR_reinterpret_cast_7_]], [[VAR_26_]], [[VAR_36_]], [[CST_0_]], [[VAR_reinterpret_cast_6_]] : memref<4x?xf32, strided<[?, ?], offset: ?>>, memref<4x4xf32, strided<[?, ?], offset: ?>>, index, index, index, memref<4x?xf32, strided<[?, ?], offset: ?>> -// CHECK: } -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/wraparound_stacked.mlir b/test/Conversion/TritonToLinalg/wraparound_stacked.mlir deleted file mode 100644 index f0e86002..00000000 --- a/test/Conversion/TritonToLinalg/wraparound_stacked.mlir +++ /dev/null @@ -1,129 +0,0 @@ -// XFAIL: * -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s - -module { - tt.func public @wrap_stacked_masked_loop_01234567(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) { - %cst = arith.constant dense<-9.900000e+01> : tensor<4x4xf32> - %c1_i32 = arith.constant 1 : i32 - %c0_i32 = arith.constant 0 : i32 - %c2_i32 = arith.constant 2 : i32 - %cst_0 = arith.constant dense<3> : tensor<1x4xi32> - %cst_1 = arith.constant dense<3> : tensor<4xi32> - %cst_2 = arith.constant dense<2> : tensor<4xi32> - %c4_i32 = arith.constant 4 : i32 - %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %1 = arith.addi %0, %cst_2 : tensor<4xi32> - %2 = tt.splat %arg2 : i32 -> tensor<4xi32> - %3 = arith.remsi %1, %2 : tensor<4xi32> - %4 = arith.addi %0, %cst_1 : tensor<4xi32> - %5 = tt.expand_dims %3 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - %6 = tt.splat %arg4 : i32 -> tensor<4x1xi32> - %7 = arith.muli %5, %6 : tensor<4x1xi32> - %8 = tt.expand_dims %4 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %9 = tt.splat %arg5 : i32 -> tensor<1x4xi32> - %10 = arith.muli %8, %9 : tensor<1x4xi32> - %11 = tt.broadcast %7 : tensor<4x1xi32> -> tensor<4x4xi32> - %12 = tt.broadcast %10 : tensor<1x4xi32> -> tensor<4x4xi32> - %13 = arith.addi %11, %12 : tensor<4x4xi32> - %14 = tt.splat %arg0 : !tt.ptr -> tensor<4x4x!tt.ptr> - %15 = tt.addptr %14, %13 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - %16 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - %17 = tt.splat %arg6 : i32 -> tensor<4x1xi32> - %18 = arith.muli %17, %16 : tensor<4x1xi32> - %19 = tt.splat %arg1 : !tt.ptr -> tensor<4x1x!tt.ptr> - %20 = tt.addptr %19, %18 : tensor<4x1x!tt.ptr>, tensor<4x1xi32> - %21 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %22 = tt.splat %arg7 : i32 -> tensor<1x4xi32> - %23 = arith.muli %22, %21 : tensor<1x4xi32> - %24 = tt.broadcast %20 : tensor<4x1x!tt.ptr> -> tensor<4x4x!tt.ptr> - %25 = tt.broadcast %23 : tensor<1x4xi32> -> tensor<4x4xi32> - %26 = tt.addptr %24, %25 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - %27 = arith.cmpi slt, %21, %cst_0 : tensor<1x4xi32> - %28 = tt.broadcast %27 : tensor<1x4xi1> -> tensor<4x4xi1> - %29 = arith.muli %arg5, %c4_i32 : i32 - %30 = tt.splat %29 : i32 -> tensor<4x4xi32> - %31:2 = scf.for %arg8 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg9 = %15, %arg10 = %26) -> (tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr>) : i32 { - %32 = tt.load %arg9, %28, %cst : tensor<4x4x!tt.ptr> - tt.store %arg10, %32 : tensor<4x4x!tt.ptr> - %33 = tt.addptr %arg9, %30 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - %34 = tt.addptr %arg10, %30 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - scf.yield %33, %34 : tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr> - } - tt.return - } -} - -// CHECK-LABEL: func.func @wrap_stacked_masked_loop_01234567 -// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32, [[PARAM_12_:%.+]]: i32, [[PARAM_13_:%.+]]: i32) { -// CHECK-DAG: [[CST_minus_9_dot_900000_:%.+]] = arith.constant -9.900000e+01 : f32 -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index -// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index -// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index -// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 -// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : i32 -// CHECK-DAG: [[CST_2_1_:%.+]] = arith.constant 2 : i32 -// CHECK-DAG: [[CST_4_1_:%.+]] = arith.constant 4 : i32 -// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index -// CHECK-DAG: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_2_:%.+]] = arith.muli [[VAR_1_]], [[CST_2_]] : index -// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK: [[VAR_4_:%.+]] = arith.muli [[VAR_3_]], [[CST_3_]] : index -// CHECK: [[VAR_5_:%.+]] = arith.addi [[VAR_2_]], [[VAR_4_]] : index -// CHECK-DAG: [[VAR_6_:%.+]] = arith.remsi [[VAR_5_]], [[VAR_1_]] : index -// CHECK-DAG: [[VAR_7_:%.+]] = arith.muli [[VAR_0_]], [[VAR_1_]] : index -// CHECK: [[VAR_8_:%.+]] = arith.addi [[VAR_7_]], [[VAR_6_]] : index -// CHECK: [[VAR_9_:%.+]] = arith.subi [[VAR_8_]], [[VAR_5_]] : index -// CHECK: [[VAR_10_:%.+]] = arith.divsi [[VAR_9_]], [[VAR_1_]] : index -// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_5_]]{{.}}, sizes: {{.}}[[VAR_10_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref> -// CHECK-DAG: [[VAR_11_:%.+]] = arith.subi [[CST_4_]], [[VAR_10_]] : index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_6_]]{{.}}, sizes: {{.}}[[VAR_11_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_1_]], [[VAR_3_]]{{.}} : memref<*xf32> to memref> -// CHECK-DAG: [[VAR_12_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index -// CHECK-DAG: [[VAR_13_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index -// CHECK-DAG: [[VAR_14_:%.+]] = arith.muli [[PARAM_5_]], [[CST_4_1_]] : i32 -// CHECK-DAG: [[VAR_15_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index -// CHECK-DAG: [[VAR_16_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_17_:%.+]] = arith.muli [[VAR_16_]], [[CST_2_]] : index -// CHECK-DAG: [[VAR_18_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_19_:%.+]] = arith.muli [[VAR_18_]], [[CST_3_]] : index -// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: [4, 4], strides: {{.}}[[VAR_12_]], [[VAR_13_]]{{.}} : memref<*xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_20_:%.+]]:6 = scf.for [[VAR_arg14_:%.+]] = [[CST_0_1_]] to [[CST_2_1_]] step [[CST_1_]] iter_args([[VAR_arg15_:%.+]] = [[VAR_reinterpret_cast_]], [[VAR_arg16_:%.+]] = [[VAR_reinterpret_cast_]]_1, [[VAR_arg17_:%.+]] = [[VAR_17_]], [[VAR_arg18_:%.+]] = [[CST_0_]], [[VAR_arg19_:%.+]] = [[CST_0_]], [[VAR_arg20_:%.+]] = [[VAR_reinterpret_cast_]]_0) -> (memref>, memref<4x4xf32, strided<[?, ?], offset: ?>>, index, index, index, memref>) : i32 { -// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4x4xf32> -// CHECK: linalg.fill ins([[CST_minus_9_dot_900000_]] : f32) outs([[RES_]] : memref<4x4xf32>) -// CHECK: [[VAR_dim_:%.+]] = memref.dim [[VAR_arg15_]], [[CST_0_]] : memref> -// CHECK: [[VAR_21_:%.+]] = arith.minsi [[VAR_dim_]], [[CST_4_]] : index -// CHECK-DAG: [[VAR_22_:%.+]] = arith.subi [[CST_4_]], [[VAR_21_]] : index -// CHECK-DAG: [[VAR_subview_:%.+]] = memref.subview [[VAR_arg15_]][0, 0] {{.}}[[VAR_21_]], 3] [1, 1] : memref> to memref> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_subview_2_:%.+]] = memref.subview [[VAR_arg20_]][0, 0] {{.}}[[VAR_22_]], 3] [1, 1] : memref> to memref> -// CHECK-DAG: [[VAR_subview_3_:%.+]] = memref.subview [[RES_]][0, 0] {{.}}[[VAR_21_]], 3] [1, 1] : memref<4x4xf32> to memref> -// CHECK-DAG: [[VAR_subview_4_:%.+]] = memref.subview [[RES_]]{{.}}[[VAR_21_]], 0] {{.}}[[VAR_22_]], 3] [1, 1] : memref<4x4xf32> to memref> -// CHECK: memref.copy [[VAR_subview_]], [[VAR_subview_]]_3 : memref> to memref> -// CHECK: memref.copy [[VAR_subview_2_]], [[VAR_subview_4_]] : memref> to memref> -// CHECK: [[VAR_23_:%.+]] = bufferization.to_tensor [[RES_]] restrict writable : memref<4x4xf32> -// CHECK: bufferization.materialize_in_destination [[VAR_23_]] in writable [[VAR_arg16_]] -// CHECK: [[VAR_24_:%.+]] = arith.index_cast [[VAR_14_]] : i32 to index -// CHECK: [[VAR_25_:%.+]] = arith.addi [[VAR_arg17_]], [[VAR_24_]] : index -// CHECK: [[VAR_26_:%.+]] = arith.addi [[VAR_25_]], [[VAR_19_]] : index -// CHECK-DAG: [[VAR_27_:%.+]] = arith.remsi [[VAR_26_]], [[VAR_16_]] : index -// CHECK-DAG: [[VAR_28_:%.+]] = arith.muli [[VAR_15_]], [[VAR_16_]] : index -// CHECK: [[VAR_29_:%.+]] = arith.addi [[VAR_28_]], [[VAR_27_]] : index -// CHECK: [[VAR_30_:%.+]] = arith.subi [[VAR_29_]], [[VAR_26_]] : index -// CHECK: [[VAR_31_:%.+]] = arith.divsi [[VAR_30_]], [[VAR_16_]] : index -// CHECK-DAG: [[VAR_reinterpret_cast_5_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_26_]]{{.}}, sizes: {{.}}[[VAR_31_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_16_]], [[VAR_18_]]{{.}} : memref<*xf32> to memref> -// CHECK-DAG: [[VAR_32_:%.+]] = arith.subi [[CST_4_]], [[VAR_31_]] : index -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_reinterpret_cast_6_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_27_]]{{.}}, sizes: {{.}}[[VAR_32_]], [[CST_4_]]{{.}}, strides: {{.}}[[VAR_16_]], [[VAR_18_]]{{.}} : memref<*xf32> to memref> -// CHECK-DAG: [[VAR_33_:%.+]] = arith.index_cast [[VAR_14_]] : i32 to index -// CHECK: [[VAR_34_:%.+]] = arith.addi [[VAR_arg18_]], [[VAR_33_]] : index -// CHECK: [[VAR_35_:%.+]] = arith.addi [[VAR_34_]], [[VAR_arg19_]] : index -// CHECK: [[VAR_reinterpret_cast_7_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_35_]]{{.}}, sizes: [4, 4], strides: {{.}}[[VAR_12_]], [[VAR_13_]]{{.}} : memref<*xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> -// CHECK: scf.yield [[VAR_reinterpret_cast_5_]], [[VAR_reinterpret_cast_7_]], [[VAR_25_]], [[VAR_35_]], [[CST_0_]], [[VAR_reinterpret_cast_6_]] : memref>, memref<4x4xf32, strided<[?, ?], offset: ?>>, index, index, index, memref> -// CHECK: } -// CHECK: return -// CHECK: } diff --git a/test/Conversion/TritonToLinalg/wraparound_unsupported_add_offset.mlir b/test/Conversion/TritonToLinalg/wraparound_unsupported_add_offset.mlir deleted file mode 100644 index 9455f1e8..00000000 --- a/test/Conversion/TritonToLinalg/wraparound_unsupported_add_offset.mlir +++ /dev/null @@ -1,57 +0,0 @@ -// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s -// XFAIL: * -// We currently do not support this kind of modulo pattern: -// (a + arrange(0, K)) % M -module { - tt.func public @wrap_side_by_side_masked_loop_01234567(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) { - %cst = arith.constant dense<-9.900000e+01> : tensor<4x4xf32> - %c1_i32 = arith.constant 1 : i32 - %c0_i32 = arith.constant 0 : i32 - %c2_i32 = arith.constant 2 : i32 - %cst_0 = arith.constant dense<2> : tensor<4x1xi32> - %cst_1 = arith.constant dense<6> : tensor<4xi32> - %cst_2 = arith.constant dense<2> : tensor<4xi32> - %c4_i32 = arith.constant 4 : i32 - %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %1 = arith.addi %0, %cst_2 : tensor<4xi32> - %2 = tt.splat %arg3 : i32 -> tensor<4xi32> - %3 = arith.remsi %0, %2 : tensor<4xi32> - %4 = arith.addi %3, %cst_1 : tensor<4xi32> - %5 = tt.expand_dims %1 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - %6 = tt.splat %arg4 : i32 -> tensor<4x1xi32> - %7 = arith.muli %5, %6 : tensor<4x1xi32> - %8 = tt.expand_dims %4 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %9 = tt.splat %arg5 : i32 -> tensor<1x4xi32> - %10 = arith.muli %8, %9 : tensor<1x4xi32> - %11 = tt.broadcast %7 : tensor<4x1xi32> -> tensor<4x4xi32> - %12 = tt.broadcast %10 : tensor<1x4xi32> -> tensor<4x4xi32> - %13 = arith.addi %11, %12 : tensor<4x4xi32> - %14 = tt.splat %arg0 : !tt.ptr -> tensor<4x4x!tt.ptr> - %15 = tt.addptr %14, %13 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - %16 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - %17 = tt.splat %arg6 : i32 -> tensor<4x1xi32> - %18 = arith.muli %17, %16 : tensor<4x1xi32> - %19 = tt.splat %arg1 : !tt.ptr -> tensor<4x1x!tt.ptr> - %20 = tt.addptr %19, %18 : tensor<4x1x!tt.ptr>, tensor<4x1xi32> - %21 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %22 = tt.splat %arg7 : i32 -> tensor<1x4xi32> - %23 = arith.muli %22, %21 : tensor<1x4xi32> - %24 = tt.broadcast %20 : tensor<4x1x!tt.ptr> -> tensor<4x4x!tt.ptr> - %25 = tt.broadcast %23 : tensor<1x4xi32> -> tensor<4x4xi32> - %26 = tt.addptr %24, %25 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - %27 = arith.cmpi slt, %16, %cst_0 : tensor<4x1xi32> - %28 = tt.broadcast %27 : tensor<4x1xi1> -> tensor<4x4xi1> - %29 = arith.muli %arg4, %c4_i32 : i32 - %30 = tt.splat %29 : i32 -> tensor<4x4xi32> - %31 = arith.muli %arg5, %c4_i32 : i32 - %32 = tt.splat %31 : i32 -> tensor<4x4xi32> - %33:2 = scf.for %arg8 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg9 = %15, %arg10 = %26) -> (tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr>) : i32 { - %34 = tt.load %arg9, %28, %cst : tensor<4x4x!tt.ptr> - tt.store %arg10, %34 : tensor<4x4x!tt.ptr> - %35 = tt.addptr %arg9, %30 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - %36 = tt.addptr %arg10, %32 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - scf.yield %35, %36 : tensor<4x4x!tt.ptr>, tensor<4x4x!tt.ptr> - } - tt.return - } -} diff --git a/tools/triton-shared-opt/RegisterTritonSharedDialects.h b/tools/triton-shared-opt/RegisterTritonSharedDialects.h index 6d92953f..d78f96be 100644 --- a/tools/triton-shared-opt/RegisterTritonSharedDialects.h +++ b/tools/triton-shared-opt/RegisterTritonSharedDialects.h @@ -13,7 +13,6 @@ #include "triton-shared/Conversion/StructuredToMemref/Passes.h" #include "triton-shared/Conversion/TritonArithToLinalg/Passes.h" #include "triton-shared/Conversion/TritonPtrToMemref/Passes.h" -#include "triton-shared/Conversion/TritonToLinalg/Passes.h" #include "triton-shared/Conversion/TritonToLinalgExperimental/Passes.h" #include "triton-shared/Conversion/TritonToStructured/Passes.h" #include "triton-shared/Conversion/TritonToUnstructured/Passes.h" @@ -29,7 +28,6 @@ inline void registerTritonSharedDialects(mlir::DialectRegistry ®istry) { mlir::registerAllPasses(); mlir::registerLinalgPasses(); mlir::triton::registerTritonPasses(); - mlir::triton::registerTritonToLinalgPass(); mlir::triton::registerTritonToLinalgExperimentalPasses(); mlir::triton::registerTritonToStructuredPass(); mlir::triton::registerTritonPtrToMemref();