diff --git a/include/triton-shared/Dialect/TPtr/IR/TPtrDialect.td b/include/triton-shared/Dialect/TPtr/IR/TPtrDialect.td index 8333aa38..f02806e0 100644 --- a/include/triton-shared/Dialect/TPtr/IR/TPtrDialect.td +++ b/include/triton-shared/Dialect/TPtr/IR/TPtrDialect.td @@ -20,36 +20,9 @@ def TPtr_Dialect : Dialect { Typed Pointer Dialect. }]; - let extraClassDeclaration = [{ - void registerTypes(); - }]; - let dependentDialects = [ "mlir::ptr::PtrDialect" ]; - - let useDefaultAttributePrinterParser = 1; - let usePropertiesForAttributes = 1; -} - -class TPtrTypeDef traits = []> - : TypeDef { - // Used by printer/parser - let mnemonic = _mnemonic; -} - -class TPtr_Attr traits = []> - : AttrDef { - let mnemonic = _mnemonic; -} - -// Memory space attr is required for building Ptr ops -// This acts as default memory space since there is -// no such default implemented upstream -def DefaultMemorySpaceAttr - : TPtr_Attr<"DefaultMemorySpace", "default_memory_space", - [DeclareAttrInterfaceMethods]> { } // diff --git a/lib/Conversion/TritonToLinalgExperimental/ReconcilePtrCastsPass.cpp b/lib/Conversion/TritonToLinalgExperimental/ReconcilePtrCastsPass.cpp index e8a3b7f6..1c49a69f 100644 --- a/lib/Conversion/TritonToLinalgExperimental/ReconcilePtrCastsPass.cpp +++ b/lib/Conversion/TritonToLinalgExperimental/ReconcilePtrCastsPass.cpp @@ -16,6 +16,7 @@ #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Ptr/IR/PtrAttrs.h" #include "mlir/Dialect/Ptr/IR/PtrTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinDialect.h" @@ -95,9 +96,8 @@ struct FromMemrefConverter input); auto memrefToPtr = rewriter.create( op->getLoc(), - ptr::PtrType::get( - rewriter.getContext(), - tptr::DefaultMemorySpaceAttr::get(rewriter.getContext())), + ptr::PtrType::get(rewriter.getContext(), + ptr::GenericSpaceAttr::get(rewriter.getContext())), rankedMemref); rewriter.replaceAllUsesWith(output, memrefToPtr); diff --git a/lib/Conversion/TritonToLinalgExperimental/TritonToPtrPass.cpp b/lib/Conversion/TritonToLinalgExperimental/TritonToPtrPass.cpp index 8234360d..1cf6894d 100644 --- a/lib/Conversion/TritonToLinalgExperimental/TritonToPtrPass.cpp +++ b/lib/Conversion/TritonToLinalgExperimental/TritonToPtrPass.cpp @@ -23,6 +23,7 @@ #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Ptr/IR/PtrAttrs.h" #include "mlir/Dialect/Ptr/IR/PtrDialect.h" #include "mlir/Dialect/Ptr/IR/PtrTypes.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -98,9 +99,8 @@ struct EmptyTensorConverter : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, op.getType().getShape(), - ptr::PtrType::get( - rewriter.getContext(), - tptr::DefaultMemorySpaceAttr::get(rewriter.getContext()))); + ptr::PtrType::get(rewriter.getContext(), + ptr::GenericSpaceAttr::get(rewriter.getContext()))); return success(); } }; @@ -187,9 +187,8 @@ struct AddPtrConverter : public OpConversionPattern { rewriter.create(loc, op.getOffset(), pointeeSizeInBytes); rewriter.replaceOpWithNewOp( op, - ptr::PtrType::get( - rewriter.getContext(), - tptr::DefaultMemorySpaceAttr::get(rewriter.getContext())), + ptr::PtrType::get(rewriter.getContext(), + ptr::GenericSpaceAttr::get(rewriter.getContext())), adaptor.getPtr(), scaledOffset); return success(); } @@ -325,9 +324,8 @@ struct IntToPtrConverter : public OpConversionPattern { } rewriter.replaceOpWithNewOp( op, - ptr::PtrType::get( - rewriter.getContext(), - tptr::DefaultMemorySpaceAttr::get(rewriter.getContext())), + ptr::PtrType::get(rewriter.getContext(), + ptr::GenericSpaceAttr::get(rewriter.getContext())), adaptor.getSrc()); return success(); } @@ -417,15 +415,13 @@ class TritonPtrTypeConverter : public TypeConverter { TritonPtrTypeConverter(MLIRContext *context) { addConversion([](Type type) { return type; }); addConversion([context](triton::PointerType ptrType) { - return ptr::PtrType::get(context, - tptr::DefaultMemorySpaceAttr::get(context)); + return ptr::PtrType::get(context, ptr::GenericSpaceAttr::get(context)); }); addConversion([context](RankedTensorType tensorType) { if (isa(tensorType.getElementType())) { return RankedTensorType::get( tensorType.getShape(), - ptr::PtrType::get(context, - tptr::DefaultMemorySpaceAttr::get(context))); + ptr::PtrType::get(context, ptr::GenericSpaceAttr::get(context))); } return tensorType; }); diff --git a/lib/Dialect/TPtr/IR/TPtrDialect.cpp b/lib/Dialect/TPtr/IR/TPtrDialect.cpp index 3235a1e0..5d49c452 100644 --- a/lib/Dialect/TPtr/IR/TPtrDialect.cpp +++ b/lib/Dialect/TPtr/IR/TPtrDialect.cpp @@ -29,64 +29,16 @@ void printIntType(OpAsmPrinter &p, Operation *op, Type ty) { //===----------------------------------------------------------------------===// // Dialect //===----------------------------------------------------------------------===// -void mlir::tptr::TPtrDialect::registerTypes() { - addTypes< -#define GET_TYPEDEF_LIST -#include "triton-shared/Dialect/TPtr/IR/TPtrTypes.cpp.inc" - >(); -} /// Dialect creation, the instance will be owned by the context. This is the /// point of registration of custom types and operations for the dialect. void mlir::tptr::TPtrDialect::initialize() { - addAttributes< -#define GET_ATTRDEF_LIST -#include "triton-shared/Dialect/TPtr/IR/TPtrAttributes.cpp.inc" - >(); - registerTypes(); addOperations< #define GET_OP_LIST #include "triton-shared/Dialect/TPtr/IR/TPtrOps.cpp.inc" >(); } -bool tptr::DefaultMemorySpaceAttr::isValidLoad( - Type type, mlir::ptr::AtomicOrdering ordering, IntegerAttr alignment, - llvm::function_ref emitError) const { - return true; -} - -bool tptr::DefaultMemorySpaceAttr::isValidStore( - Type type, mlir::ptr::AtomicOrdering ordering, IntegerAttr alignment, - llvm::function_ref emitError) const { - return true; -} - -bool tptr::DefaultMemorySpaceAttr::isValidAtomicOp( - mlir::ptr::AtomicBinOp binOp, Type type, mlir::ptr::AtomicOrdering ordering, - IntegerAttr alignment, - llvm::function_ref emitError) const { - return true; -} - -bool tptr::DefaultMemorySpaceAttr::isValidAtomicXchg( - Type type, mlir::ptr::AtomicOrdering successOrdering, - mlir::ptr::AtomicOrdering failureOrdering, IntegerAttr alignment, - llvm::function_ref emitError) const { - return true; -} - -bool tptr::DefaultMemorySpaceAttr::isValidAddrSpaceCast( - Type tgt, Type src, - llvm::function_ref emitError) const { - return true; -} - -bool tptr::DefaultMemorySpaceAttr::isValidPtrIntCast( - Type intLikeTy, Type ptrLikeTy, - llvm::function_ref emitError) const { - return true; -} //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/test/Conversion/ReconcilePtrCasts/ptr_into_memref.mlir b/test/Conversion/ReconcilePtrCasts/ptr_into_memref.mlir index decf2293..15a8652f 100644 --- a/test/Conversion/ReconcilePtrCasts/ptr_into_memref.mlir +++ b/test/Conversion/ReconcilePtrCasts/ptr_into_memref.mlir @@ -6,19 +6,19 @@ module { %c1_i32 = arith.constant 1 : i32 %c2 = arith.constant 2 : index %1 = builtin.unrealized_conversion_cast %arg1 : memref<*xi32> to !tt.ptr - %2 = builtin.unrealized_conversion_cast %1 : !tt.ptr to !ptr.ptr<#tptr.default_memory_space> + %2 = builtin.unrealized_conversion_cast %1 : !tt.ptr to !ptr.ptr<#ptr.generic_space> %3 = builtin.unrealized_conversion_cast %arg0 : memref<*xi32> to !tt.ptr - %4 = builtin.unrealized_conversion_cast %3 : !tt.ptr to !ptr.ptr<#tptr.default_memory_space> + %4 = builtin.unrealized_conversion_cast %3 : !tt.ptr to !ptr.ptr<#ptr.generic_space> %5 = arith.muli %c1_i32, %0 : i32 - %6 = tptr.ptradd %4 %5 : !ptr.ptr<#tptr.default_memory_space>, i32 to !ptr.ptr<#tptr.default_memory_space> - %7 = builtin.unrealized_conversion_cast %6 : !ptr.ptr<#tptr.default_memory_space> to !tt.ptr + %6 = tptr.ptradd %4 %5 : !ptr.ptr<#ptr.generic_space>, i32 to !ptr.ptr<#ptr.generic_space> + %7 = builtin.unrealized_conversion_cast %6 : !ptr.ptr<#ptr.generic_space> to !tt.ptr %8 = builtin.unrealized_conversion_cast %7 : !tt.ptr to memref<*xi64> %reinterpret_cast = memref.reinterpret_cast %8 to offset: [%c2], sizes: [16], strides: [1] : memref<*xi64> to memref<16xi64, strided<[1], offset: ?>> %alloc = memref.alloc() : memref<16xi64> memref.copy %reinterpret_cast, %alloc : memref<16xi64, strided<[1], offset: ?>> to memref<16xi64> %9 = bufferization.to_tensor %alloc restrict writable : memref<16xi64> to tensor<16xi64> - %10 = tptr.ptradd %2 %5 : !ptr.ptr<#tptr.default_memory_space>, i32 to !ptr.ptr<#tptr.default_memory_space> - %11 = builtin.unrealized_conversion_cast %10 : !ptr.ptr<#tptr.default_memory_space> to !tt.ptr + %10 = tptr.ptradd %2 %5 : !ptr.ptr<#ptr.generic_space>, i32 to !ptr.ptr<#ptr.generic_space> + %11 = builtin.unrealized_conversion_cast %10 : !ptr.ptr<#ptr.generic_space> to !tt.ptr %12 = builtin.unrealized_conversion_cast %11 : !tt.ptr to memref<*xi64> %reinterpret_cast_0 = memref.reinterpret_cast %12 to offset: [%c2], sizes: [16], strides: [1] : memref<*xi64> to memref<16xi64, strided<[1], offset: ?>> bufferization.materialize_in_destination %9 in writable %reinterpret_cast_0 : (tensor<16xi64>, memref<16xi64, strided<[1], offset: ?>>) -> () diff --git a/test/Conversion/TritonToLinalgExperimental/conditional_ptr_as_src.mlir b/test/Conversion/TritonToLinalgExperimental/conditional_ptr_as_src.mlir index d8c015c7..70cc270e 100644 --- a/test/Conversion/TritonToLinalgExperimental/conditional_ptr_as_src.mlir +++ b/test/Conversion/TritonToLinalgExperimental/conditional_ptr_as_src.mlir @@ -34,20 +34,20 @@ module { // CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 // CHECK-DAG: [[VAR_cast_:%.+]] = memref.cast [[PARAM_0_]] : memref<*xf32> to memref<1xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_1_:%.+]] = tptr.from_memref [[VAR_cast_]] : memref<1xf32> to <#tptr.default_memory_space> +// CHECK-DAG: [[VAR_1_:%.+]] = tptr.from_memref [[VAR_cast_]] : memref<1xf32> to <#ptr.generic_space> // CHECK-DAG: [[VAR_2_:%.+]] = arith.cmpi eq, [[PARAM_2_]], [[CST_1_]] : i32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_3_:%.+]] = scf.if [[VAR_2_]] -> (!ptr.ptr<#tptr.default_memory_space>) { +// CHECK-DAG: [[VAR_3_:%.+]] = scf.if [[VAR_2_]] -> (!ptr.ptr<#ptr.generic_space>) { // CHECK-DAG: [[VAR_6_:%.+]] = arith.muli [[PARAM_2_]], [[CST_2_]] : i32 // CHECK: [[VAR_7_:%.+]] = arith.muli [[VAR_6_]], [[VAR_0_]] : i32 -// CHECK: [[VAR_8_:%.+]] = tptr.ptradd [[VAR_1_]] [[VAR_7_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> -// CHECK: scf.yield [[VAR_8_]] : !ptr.ptr<#tptr.default_memory_space> +// CHECK: [[VAR_8_:%.+]] = tptr.ptradd [[VAR_1_]] [[VAR_7_]] : <#ptr.generic_space>, i32 to <#ptr.generic_space> +// CHECK: scf.yield [[VAR_8_]] : !ptr.ptr<#ptr.generic_space> // CHECK: } else { // CHECK: [[VAR_6_1_:%.+]] = arith.muli [[PARAM_2_]], [[VAR_0_]] : i32 -// CHECK: [[VAR_7_1_:%.+]] = tptr.ptradd [[VAR_1_]] [[VAR_6_1_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> -// CHECK: scf.yield [[VAR_7_1_]] : !ptr.ptr<#tptr.default_memory_space> +// CHECK: [[VAR_7_1_:%.+]] = tptr.ptradd [[VAR_1_]] [[VAR_6_1_]] : <#ptr.generic_space>, i32 to <#ptr.generic_space> +// CHECK: scf.yield [[VAR_7_1_]] : !ptr.ptr<#ptr.generic_space> // CHECK: } -// CHECK: [[VAR_4_:%.+]] = tptr.to_memref [[VAR_3_]] : <#tptr.default_memory_space> to memref<1xf32> +// CHECK: [[VAR_4_:%.+]] = tptr.to_memref [[VAR_3_]] : <#ptr.generic_space> to memref<1xf32> // CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[VAR_4_]] to offset: {{.}}[[CST_6_]]{{.}}, sizes: [4], strides: [1] : memref<1xf32> to memref<4xf32, strided<[1], offset: ?>> // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<4xf32> // CHECK: memref.copy [[VAR_reinterpret_cast_]], [[RES_]] : memref<4xf32, strided<[1], offset: ?>> to memref<4xf32> diff --git a/test/Conversion/TritonToPtr/cast_with_int_ptr.mlir b/test/Conversion/TritonToPtr/cast_with_int_ptr.mlir index a64c185c..143cd618 100644 --- a/test/Conversion/TritonToPtr/cast_with_int_ptr.mlir +++ b/test/Conversion/TritonToPtr/cast_with_int_ptr.mlir @@ -55,49 +55,49 @@ module { // CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 // CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : i32 // CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : i32 -// CHECK-DAG: [[VAR_5_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_1_]] : !tt.ptr to !ptr.ptr<#tptr.default_memory_space> -// CHECK-DAG: [[VAR_6_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_0_]] : !tt.ptr to !ptr.ptr<#tptr.default_memory_space> +// CHECK-DAG: [[VAR_5_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_1_]] : !tt.ptr to !ptr.ptr<#ptr.generic_space> +// CHECK-DAG: [[VAR_6_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_0_]] : !tt.ptr to !ptr.ptr<#ptr.generic_space> // CHECK: [[VAR_7_:%.+]] = arith.muli [[CST_111_]], [[VAR_4_]] : i32 -// CHECK-DAG: [[VAR_8_:%.+]] = tptr.ptradd [[VAR_6_]] [[VAR_7_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK-DAG: [[VAR_8_:%.+]] = tptr.ptradd [[VAR_6_]] [[VAR_7_]] : <#ptr.generic_space>, i32 to <#ptr.generic_space> // CHECK-DAG: [[VAR_9_:%.+]] = arith.muli [[CST_10_]], [[VAR_3_]] : i32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_10_:%.+]] = tptr.ptradd [[VAR_8_]] [[VAR_9_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> -// CHECK-DAG: [[VAR_11_:%.+]] = tptr.ptrtoint [[VAR_5_]] : <#tptr.default_memory_space> to i64 +// CHECK-DAG: [[VAR_10_:%.+]] = tptr.ptradd [[VAR_8_]] [[VAR_9_]] : <#ptr.generic_space>, i32 to <#ptr.generic_space> +// CHECK-DAG: [[VAR_11_:%.+]] = tptr.ptrtoint [[VAR_5_]] : <#ptr.generic_space> to i64 // CHECK: [[VAR_12_:%.+]] = arith.muli [[VAR_11_]], [[VAR_2_]] : i64 -// CHECK-DAG: [[VAR_13_:%.+]] = tptr.ptradd [[VAR_5_]] [[VAR_12_]] : <#tptr.default_memory_space>, i64 to <#tptr.default_memory_space> +// CHECK-DAG: [[VAR_13_:%.+]] = tptr.ptradd [[VAR_5_]] [[VAR_12_]] : <#ptr.generic_space>, i64 to <#ptr.generic_space> // CHECK-DAG: [[VAR_14_:%.+]] = arith.muli [[CST_9_]], [[VAR_4_]] : i32 -// CHECK: [[VAR_15_:%.+]] = tptr.ptradd [[VAR_13_]] [[VAR_14_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> -// CHECK: [[VAR_16_:%.+]] = tptr.ptrtoint [[VAR_15_]] : <#tptr.default_memory_space> to i64 +// CHECK: [[VAR_15_:%.+]] = tptr.ptradd [[VAR_13_]] [[VAR_14_]] : <#ptr.generic_space>, i32 to <#ptr.generic_space> +// CHECK: [[VAR_16_:%.+]] = tptr.ptrtoint [[VAR_15_]] : <#ptr.generic_space> to i64 // CHECK-DAG: [[VAR_17_:%.+]] = arith.remsi [[VAR_16_]], [[CST_10_1_]] : i64 // CHECK-DAG: [[VAR_18_:%.+]] = arith.muli [[CST_1_]], [[VAR_4_]] : i32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_19_:%.+]] = tptr.ptradd [[VAR_10_]] [[VAR_18_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK-DAG: [[VAR_19_:%.+]] = tptr.ptradd [[VAR_10_]] [[VAR_18_]] : <#ptr.generic_space>, i32 to <#ptr.generic_space> // CHECK-DAG: [[VAR_20_:%.+]] = arith.muli [[VAR_17_]], [[VAR_2_]] : i64 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_21_:%.+]] = tptr.ptradd [[VAR_19_]] [[VAR_20_]] : <#tptr.default_memory_space>, i64 to <#tptr.default_memory_space> +// CHECK-DAG: [[VAR_21_:%.+]] = tptr.ptradd [[VAR_19_]] [[VAR_20_]] : <#ptr.generic_space>, i64 to <#ptr.generic_space> // CHECK-DAG: [[VAR_22_:%.+]] = arith.muli [[CST_2_]], [[VAR_1_]] : i32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_23_:%.+]] = tptr.ptradd [[VAR_21_]] [[VAR_22_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK-DAG: [[VAR_23_:%.+]] = tptr.ptradd [[VAR_21_]] [[VAR_22_]] : <#ptr.generic_space>, i32 to <#ptr.generic_space> // CHECK-DAG: [[VAR_24_:%.+]] = arith.muli [[PARAM_2_]], [[VAR_1_]] : i32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_25_:%.+]] = tptr.ptradd [[VAR_23_]] [[VAR_24_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK-DAG: [[VAR_25_:%.+]] = tptr.ptradd [[VAR_23_]] [[VAR_24_]] : <#ptr.generic_space>, i32 to <#ptr.generic_space> // CHECK-DAG: [[VAR_26_:%.+]] = arith.muli [[CST_3_]], [[VAR_1_]] : i32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_27_:%.+]] = tptr.ptradd [[VAR_25_]] [[VAR_26_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK-DAG: [[VAR_27_:%.+]] = tptr.ptradd [[VAR_25_]] [[VAR_26_]] : <#ptr.generic_space>, i32 to <#ptr.generic_space> // CHECK-DAG: [[VAR_28_:%.+]] = arith.muli [[CST_4_]], [[VAR_0_]] : i32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_29_:%.+]] = tptr.ptradd [[VAR_27_]] [[VAR_28_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK-DAG: [[VAR_29_:%.+]] = tptr.ptradd [[VAR_27_]] [[VAR_28_]] : <#ptr.generic_space>, i32 to <#ptr.generic_space> // CHECK-DAG: [[VAR_30_:%.+]] = arith.muli [[PARAM_2_]], [[VAR_0_]] : i32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_31_:%.+]] = tptr.ptradd [[VAR_29_]] [[VAR_30_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> +// CHECK-DAG: [[VAR_31_:%.+]] = tptr.ptradd [[VAR_29_]] [[VAR_30_]] : <#ptr.generic_space>, i32 to <#ptr.generic_space> // CHECK-DAG: [[VAR_32_:%.+]] = arith.muli [[CST_3_]], [[VAR_0_]] : i32 -// CHECK: [[VAR_33_:%.+]] = tptr.ptradd [[VAR_31_]] [[VAR_32_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> -// CHECK: [[VAR_34_:%.+]] = tptr.to_memref [[VAR_33_]] : <#tptr.default_memory_space> to memref<1xi32> +// CHECK: [[VAR_33_:%.+]] = tptr.ptradd [[VAR_31_]] [[VAR_32_]] : <#ptr.generic_space>, i32 to <#ptr.generic_space> +// CHECK: [[VAR_34_:%.+]] = tptr.to_memref [[VAR_33_]] : <#ptr.generic_space> to memref<1xi32> // CHECK-DAG: [[LOAD_VAR_34_MEM_:%.+]] = memref.load [[VAR_34_]]{{.}}[[CST_0_]]{{.}} : memref<1xi32> // CHECK-DAG: [[VAR_36_:%.+]] = arith.extsi [[PARAM_2_]] : i32 to i64 // CHECK: [[VAR_37_:%.+]] = arith.addi [[VAR_17_]], [[VAR_36_]] : i64 -// CHECK: [[VAR_38_:%.+]] = tptr.inttoptr [[VAR_37_]] : i64 to <#tptr.default_memory_space> -// CHECK: [[VAR_39_:%.+]] = tptr.to_memref [[VAR_38_]] : <#tptr.default_memory_space> to memref<1xi32> +// CHECK: [[VAR_38_:%.+]] = tptr.inttoptr [[VAR_37_]] : i64 to <#ptr.generic_space> +// CHECK: [[VAR_39_:%.+]] = tptr.to_memref [[VAR_38_]] : <#ptr.generic_space> to memref<1xi32> // CHECK: memref.store [[LOAD_VAR_34_MEM_]], [[VAR_39_]]{{.}}[[CST_0_]]{{.}} : memref<1xi32> // CHECK: return // CHECK: } diff --git a/test/Conversion/TritonToPtr/cat_and_where_on_ptrs.mlir b/test/Conversion/TritonToPtr/cat_and_where_on_ptrs.mlir index b8a37cbc..f93e2327 100644 --- a/test/Conversion/TritonToPtr/cat_and_where_on_ptrs.mlir +++ b/test/Conversion/TritonToPtr/cat_and_where_on_ptrs.mlir @@ -56,9 +56,9 @@ module { // CHECK-DAG: [[CST_0_2_:%.+]] = arith.constant 0 : i8 // CHECK-DAG: [[CST_5_:%.+]] = arith.constant 5 : i32 // CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : i32 -// CHECK-DAG: [[VAR_2_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_2_]] : !tt.ptr to !ptr.ptr<#tptr.default_memory_space> -// CHECK-DAG: [[VAR_3_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_1_]] : !tt.ptr to !ptr.ptr<#tptr.default_memory_space> -// CHECK-DAG: [[VAR_4_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_0_]] : !tt.ptr to !ptr.ptr<#tptr.default_memory_space> +// CHECK-DAG: [[VAR_2_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_2_]] : !tt.ptr to !ptr.ptr<#ptr.generic_space> +// CHECK-DAG: [[VAR_3_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_1_]] : !tt.ptr to !ptr.ptr<#ptr.generic_space> +// CHECK-DAG: [[VAR_4_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_0_]] : !tt.ptr to !ptr.ptr<#ptr.generic_space> // CHECK-DAG: [[VAR_5_:%.+]] = tensor.empty() : tensor<16xi8> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_6_:%.+]] = linalg.fill ins([[CST_0_2_]] : i8) outs([[VAR_5_]] : tensor<16xi8>) -> tensor<16xi8> @@ -79,33 +79,33 @@ module { // CHECK: [[VAR_36_1_:%.+]] = arith.index_cast [[VAR_35_1_]] : index to i32 // CHECK: linalg.yield [[VAR_36_1_]] : i32 // CHECK: } -> tensor<8xi32> -// CHECK: [[VAR_13_:%.+]] = tensor.empty() : tensor<8x!ptr.ptr<#tptr.default_memory_space>> -// CHECK: [[VAR_14_:%.+]] = linalg.fill ins([[VAR_4_]] : !ptr.ptr<#tptr.default_memory_space>) outs([[VAR_13_]] : tensor<8x!ptr.ptr<#tptr.default_memory_space>>) -> tensor<8x!ptr.ptr<#tptr.default_memory_space>> -// CHECK: [[VAR_15_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_14_]], [[VAR_12_]] : tensor<8x!ptr.ptr<#tptr.default_memory_space>>, tensor<8xi32>) outs([[VAR_14_]] : tensor<8x!ptr.ptr<#tptr.default_memory_space>>) { -// CHECK: ^bb0([[IN_2_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_3_:%.+]]: i32, [[IN_4_:%.+]]: !ptr.ptr<#tptr.default_memory_space>): +// CHECK: [[VAR_13_:%.+]] = tensor.empty() : tensor<8x!ptr.ptr<#ptr.generic_space>> +// CHECK: [[VAR_14_:%.+]] = linalg.fill ins([[VAR_4_]] : !ptr.ptr<#ptr.generic_space>) outs([[VAR_13_]] : tensor<8x!ptr.ptr<#ptr.generic_space>>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> +// CHECK: [[VAR_15_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_14_]], [[VAR_12_]] : tensor<8x!ptr.ptr<#ptr.generic_space>>, tensor<8xi32>) outs([[VAR_14_]] : tensor<8x!ptr.ptr<#ptr.generic_space>>) { +// CHECK: ^bb0([[IN_2_:%.+]]: !ptr.ptr<#ptr.generic_space>, [[IN_3_:%.+]]: i32, [[IN_4_:%.+]]: !ptr.ptr<#ptr.generic_space>): // CHECK: [[VAR_35_2_:%.+]] = arith.muli [[IN_3_]], [[VAR_1_]] : i32 -// CHECK: [[VAR_36_2_:%.+]] = tptr.ptradd [[IN_2_]] [[VAR_35_2_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> -// CHECK: linalg.yield [[VAR_36_2_]] : !ptr.ptr<#tptr.default_memory_space> -// CHECK: } -> tensor<8x!ptr.ptr<#tptr.default_memory_space>> -// CHECK: [[VAR_16_:%.+]] = linalg.fill ins([[VAR_3_]] : !ptr.ptr<#tptr.default_memory_space>) outs([[VAR_13_]] : tensor<8x!ptr.ptr<#tptr.default_memory_space>>) -> tensor<8x!ptr.ptr<#tptr.default_memory_space>> -// CHECK: [[VAR_17_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_16_]], [[VAR_12_]] : tensor<8x!ptr.ptr<#tptr.default_memory_space>>, tensor<8xi32>) outs([[VAR_16_]] : tensor<8x!ptr.ptr<#tptr.default_memory_space>>) { -// CHECK: ^bb0([[IN_5_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_6_:%.+]]: i32, [[IN_7_:%.+]]: !ptr.ptr<#tptr.default_memory_space>): +// CHECK: [[VAR_36_2_:%.+]] = tptr.ptradd [[IN_2_]] [[VAR_35_2_]] : <#ptr.generic_space>, i32 to <#ptr.generic_space> +// CHECK: linalg.yield [[VAR_36_2_]] : !ptr.ptr<#ptr.generic_space> +// CHECK: } -> tensor<8x!ptr.ptr<#ptr.generic_space>> +// CHECK: [[VAR_16_:%.+]] = linalg.fill ins([[VAR_3_]] : !ptr.ptr<#ptr.generic_space>) outs([[VAR_13_]] : tensor<8x!ptr.ptr<#ptr.generic_space>>) -> tensor<8x!ptr.ptr<#ptr.generic_space>> +// CHECK: [[VAR_17_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_16_]], [[VAR_12_]] : tensor<8x!ptr.ptr<#ptr.generic_space>>, tensor<8xi32>) outs([[VAR_16_]] : tensor<8x!ptr.ptr<#ptr.generic_space>>) { +// CHECK: ^bb0([[IN_5_:%.+]]: !ptr.ptr<#ptr.generic_space>, [[IN_6_:%.+]]: i32, [[IN_7_:%.+]]: !ptr.ptr<#ptr.generic_space>): // CHECK: [[VAR_35_3_:%.+]] = arith.muli [[IN_6_]], [[VAR_1_]] : i32 -// CHECK: [[VAR_36_3_:%.+]] = tptr.ptradd [[IN_5_]] [[VAR_35_3_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> -// CHECK: linalg.yield [[VAR_36_3_]] : !ptr.ptr<#tptr.default_memory_space> -// CHECK: } -> tensor<8x!ptr.ptr<#tptr.default_memory_space>> -// CHECK: [[VAR_18_:%.+]] = tensor.empty() : tensor<16x!ptr.ptr<#tptr.default_memory_space>> -// CHECK: [[VAR_inserted_slice_:%.+]] = tensor.insert_slice [[VAR_15_]] into [[VAR_18_]][0] [8] [1] : tensor<8x!ptr.ptr<#tptr.default_memory_space>> into tensor<16x!ptr.ptr<#tptr.default_memory_space>> -// CHECK: [[VAR_inserted_slice_0_:%.+]] = tensor.insert_slice [[VAR_17_]] into [[VAR_inserted_slice_]][8] [8] [1] : tensor<8x!ptr.ptr<#tptr.default_memory_space>> into tensor<16x!ptr.ptr<#tptr.default_memory_space>> -// CHECK: [[VAR_19_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_inserted_slice_0_]], [[VAR_10_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>, tensor<16xi32>) outs([[VAR_inserted_slice_0_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) { -// CHECK: ^bb0([[IN_8_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_9_:%.+]]: i32, [[IN_10_:%.+]]: !ptr.ptr<#tptr.default_memory_space>): +// CHECK: [[VAR_36_3_:%.+]] = tptr.ptradd [[IN_5_]] [[VAR_35_3_]] : <#ptr.generic_space>, i32 to <#ptr.generic_space> +// CHECK: linalg.yield [[VAR_36_3_]] : !ptr.ptr<#ptr.generic_space> +// CHECK: } -> tensor<8x!ptr.ptr<#ptr.generic_space>> +// CHECK: [[VAR_18_:%.+]] = tensor.empty() : tensor<16x!ptr.ptr<#ptr.generic_space>> +// CHECK: [[VAR_inserted_slice_:%.+]] = tensor.insert_slice [[VAR_15_]] into [[VAR_18_]][0] [8] [1] : tensor<8x!ptr.ptr<#ptr.generic_space>> into tensor<16x!ptr.ptr<#ptr.generic_space>> +// CHECK: [[VAR_inserted_slice_0_:%.+]] = tensor.insert_slice [[VAR_17_]] into [[VAR_inserted_slice_]][8] [8] [1] : tensor<8x!ptr.ptr<#ptr.generic_space>> into tensor<16x!ptr.ptr<#ptr.generic_space>> +// CHECK: [[VAR_19_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_inserted_slice_0_]], [[VAR_10_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>, tensor<16xi32>) outs([[VAR_inserted_slice_0_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>) { +// CHECK: ^bb0([[IN_8_:%.+]]: !ptr.ptr<#ptr.generic_space>, [[IN_9_:%.+]]: i32, [[IN_10_:%.+]]: !ptr.ptr<#ptr.generic_space>): // CHECK: [[VAR_35_4_:%.+]] = arith.muli [[IN_9_]], [[VAR_1_]] : i32 -// CHECK: [[VAR_36_4_:%.+]] = tptr.ptradd [[IN_8_]] [[VAR_35_4_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> -// CHECK: linalg.yield [[VAR_36_4_]] : !ptr.ptr<#tptr.default_memory_space> -// CHECK: } -> tensor<16x!ptr.ptr<#tptr.default_memory_space>> -// CHECK: [[VAR_20_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_19_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) outs([[VAR_7_]] : tensor<16xi32>) { -// CHECK: ^bb0([[IN_11_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_12_:%.+]]: i32): -// CHECK: [[VAR_35_5_:%.+]] = tptr.to_memref [[IN_11_]] : <#tptr.default_memory_space> to memref<1xi32> +// CHECK: [[VAR_36_4_:%.+]] = tptr.ptradd [[IN_8_]] [[VAR_35_4_]] : <#ptr.generic_space>, i32 to <#ptr.generic_space> +// CHECK: linalg.yield [[VAR_36_4_]] : !ptr.ptr<#ptr.generic_space> +// CHECK: } -> tensor<16x!ptr.ptr<#ptr.generic_space>> +// CHECK: [[VAR_20_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_19_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>) outs([[VAR_7_]] : tensor<16xi32>) { +// CHECK: ^bb0([[IN_11_:%.+]]: !ptr.ptr<#ptr.generic_space>, [[IN_12_:%.+]]: i32): +// CHECK: [[VAR_35_5_:%.+]] = tptr.to_memref [[IN_11_]] : <#ptr.generic_space> to memref<1xi32> // CHECK: [[VAR_36_4_:%.+]] = memref.load [[VAR_35_5_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xi32> // CHECK: linalg.yield [[VAR_36_4_]] : i32 // CHECK: } -> tensor<16xi32> @@ -114,33 +114,33 @@ module { // CHECK: [[VAR_35_6_:%.+]] = arith.muli [[IN_13_]], [[IN_14_]] : i32 // CHECK: linalg.yield [[VAR_35_6_]] : i32 // CHECK: } -> tensor<16xi32> -// CHECK: [[VAR_22_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_inserted_slice_0_]], [[VAR_21_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>, tensor<16xi32>) outs([[VAR_inserted_slice_0_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) { -// CHECK: ^bb0([[IN_16_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_17_:%.+]]: i32, [[IN_18_:%.+]]: !ptr.ptr<#tptr.default_memory_space>): +// CHECK: [[VAR_22_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_inserted_slice_0_]], [[VAR_21_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>, tensor<16xi32>) outs([[VAR_inserted_slice_0_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>) { +// CHECK: ^bb0([[IN_16_:%.+]]: !ptr.ptr<#ptr.generic_space>, [[IN_17_:%.+]]: i32, [[IN_18_:%.+]]: !ptr.ptr<#ptr.generic_space>): // CHECK: [[VAR_35_7_:%.+]] = arith.muli [[IN_17_]], [[VAR_1_]] : i32 -// CHECK: [[VAR_36_5_:%.+]] = tptr.ptradd [[IN_16_]] [[VAR_35_7_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> -// CHECK: linalg.yield [[VAR_36_5_]] : !ptr.ptr<#tptr.default_memory_space> -// CHECK: } -> tensor<16x!ptr.ptr<#tptr.default_memory_space>> +// CHECK: [[VAR_36_5_:%.+]] = tptr.ptradd [[IN_16_]] [[VAR_35_7_]] : <#ptr.generic_space>, i32 to <#ptr.generic_space> +// CHECK: linalg.yield [[VAR_36_5_]] : !ptr.ptr<#ptr.generic_space> +// CHECK: } -> tensor<16x!ptr.ptr<#ptr.generic_space>> // CHECK: [[VAR_23_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_20_]], [[VAR_8_]] : tensor<16xi32>, tensor<16xi32>) outs([[VAR_20_]] : tensor<16xi32>) { // CHECK: ^bb0([[IN_19_:%.+]]: i32, [[IN_20_:%.+]]: i32, [[IN_21_:%.+]]: i32): // CHECK: [[VAR_35_8_:%.+]] = arith.muli [[IN_19_]], [[IN_20_]] : i32 // CHECK: linalg.yield [[VAR_35_8_]] : i32 // CHECK: } -> tensor<16xi32> -// CHECK: [[VAR_24_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_inserted_slice_0_]], [[VAR_23_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>, tensor<16xi32>) outs([[VAR_inserted_slice_0_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) { -// CHECK: ^bb0([[IN_22_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_23_:%.+]]: i32, [[IN_24_:%.+]]: !ptr.ptr<#tptr.default_memory_space>): +// CHECK: [[VAR_24_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_inserted_slice_0_]], [[VAR_23_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>, tensor<16xi32>) outs([[VAR_inserted_slice_0_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>) { +// CHECK: ^bb0([[IN_22_:%.+]]: !ptr.ptr<#ptr.generic_space>, [[IN_23_:%.+]]: i32, [[IN_24_:%.+]]: !ptr.ptr<#ptr.generic_space>): // CHECK: [[VAR_35_9_:%.+]] = arith.muli [[IN_23_]], [[VAR_1_]] : i32 -// CHECK: [[VAR_36_6_:%.+]] = tptr.ptradd [[IN_22_]] [[VAR_35_9_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> -// CHECK: linalg.yield [[VAR_36_6_]] : !ptr.ptr<#tptr.default_memory_space> -// CHECK: } -> tensor<16x!ptr.ptr<#tptr.default_memory_space>> -// CHECK: [[VAR_25_:%.+]] = linalg.fill ins([[VAR_2_]] : !ptr.ptr<#tptr.default_memory_space>) outs([[VAR_18_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) -> tensor<16x!ptr.ptr<#tptr.default_memory_space>> -// CHECK: [[VAR_26_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_25_]], [[VAR_10_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>, tensor<16xi32>) outs([[VAR_25_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) { -// CHECK: ^bb0([[IN_25_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_26_:%.+]]: i32, [[IN_27_:%.+]]: !ptr.ptr<#tptr.default_memory_space>): +// CHECK: [[VAR_36_6_:%.+]] = tptr.ptradd [[IN_22_]] [[VAR_35_9_]] : <#ptr.generic_space>, i32 to <#ptr.generic_space> +// CHECK: linalg.yield [[VAR_36_6_]] : !ptr.ptr<#ptr.generic_space> +// CHECK: } -> tensor<16x!ptr.ptr<#ptr.generic_space>> +// CHECK: [[VAR_25_:%.+]] = linalg.fill ins([[VAR_2_]] : !ptr.ptr<#ptr.generic_space>) outs([[VAR_18_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>) -> tensor<16x!ptr.ptr<#ptr.generic_space>> +// CHECK: [[VAR_26_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_25_]], [[VAR_10_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>, tensor<16xi32>) outs([[VAR_25_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>) { +// CHECK: ^bb0([[IN_25_:%.+]]: !ptr.ptr<#ptr.generic_space>, [[IN_26_:%.+]]: i32, [[IN_27_:%.+]]: !ptr.ptr<#ptr.generic_space>): // CHECK: [[VAR_35_10_:%.+]] = arith.muli [[IN_26_]], [[VAR_0_]] : i32 -// CHECK: [[VAR_36_7_:%.+]] = tptr.ptradd [[IN_25_]] [[VAR_35_10_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> -// CHECK: linalg.yield [[VAR_36_7_]] : !ptr.ptr<#tptr.default_memory_space> -// CHECK: } -> tensor<16x!ptr.ptr<#tptr.default_memory_space>> -// CHECK: [[VAR_27_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_26_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) outs([[VAR_5_]] : tensor<16xi8>) { -// CHECK: ^bb0([[IN_28_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_29_:%.+]]: i8): -// CHECK: [[VAR_35_11_:%.+]] = tptr.to_memref [[IN_28_]] : <#tptr.default_memory_space> to memref<1xi8> +// CHECK: [[VAR_36_7_:%.+]] = tptr.ptradd [[IN_25_]] [[VAR_35_10_]] : <#ptr.generic_space>, i32 to <#ptr.generic_space> +// CHECK: linalg.yield [[VAR_36_7_]] : !ptr.ptr<#ptr.generic_space> +// CHECK: } -> tensor<16x!ptr.ptr<#ptr.generic_space>> +// CHECK: [[VAR_27_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_26_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>) outs([[VAR_5_]] : tensor<16xi8>) { +// CHECK: ^bb0([[IN_28_:%.+]]: !ptr.ptr<#ptr.generic_space>, [[IN_29_:%.+]]: i8): +// CHECK: [[VAR_35_11_:%.+]] = tptr.to_memref [[IN_28_]] : <#ptr.generic_space> to memref<1xi8> // CHECK: [[VAR_36_7_:%.+]] = memref.load [[VAR_35_11_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xi8> // CHECK: linalg.yield [[VAR_36_7_]] : i8 // CHECK: } -> tensor<16xi8> @@ -150,20 +150,20 @@ module { // CHECK: [[VAR_35_12_:%.+]] = arith.cmpi ne, [[IN_30_]], [[IN_31_]] : i8 // CHECK: linalg.yield [[VAR_35_12_]] : i1 // CHECK: } -> tensor<16xi1> -// CHECK: [[VAR_30_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_29_]], [[VAR_22_]], [[VAR_24_]] : tensor<16xi1>, tensor<16x!ptr.ptr<#tptr.default_memory_space>>, tensor<16x!ptr.ptr<#tptr.default_memory_space>>) outs([[VAR_22_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) { -// CHECK: ^bb0([[IN_33_:%.+]]: i1, [[IN_34_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_35_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_36_:%.+]]: !ptr.ptr<#tptr.default_memory_space>): -// CHECK: [[VAR_35_13_:%.+]] = arith.select [[IN_33_]], [[IN_34_]], [[IN_35_]] : !ptr.ptr<#tptr.default_memory_space> -// CHECK: linalg.yield [[VAR_35_13_]] : !ptr.ptr<#tptr.default_memory_space> -// CHECK: } -> tensor<16x!ptr.ptr<#tptr.default_memory_space>> -// CHECK: [[VAR_31_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_30_]], [[VAR_10_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>, tensor<16xi32>) outs([[VAR_30_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) { -// CHECK: ^bb0([[IN_37_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_38_:%.+]]: i32, [[IN_39_:%.+]]: !ptr.ptr<#tptr.default_memory_space>): +// CHECK: [[VAR_30_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_29_]], [[VAR_22_]], [[VAR_24_]] : tensor<16xi1>, tensor<16x!ptr.ptr<#ptr.generic_space>>, tensor<16x!ptr.ptr<#ptr.generic_space>>) outs([[VAR_22_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>) { +// CHECK: ^bb0([[IN_33_:%.+]]: i1, [[IN_34_:%.+]]: !ptr.ptr<#ptr.generic_space>, [[IN_35_:%.+]]: !ptr.ptr<#ptr.generic_space>, [[IN_36_:%.+]]: !ptr.ptr<#ptr.generic_space>): +// CHECK: [[VAR_35_13_:%.+]] = arith.select [[IN_33_]], [[IN_34_]], [[IN_35_]] : !ptr.ptr<#ptr.generic_space> +// CHECK: linalg.yield [[VAR_35_13_]] : !ptr.ptr<#ptr.generic_space> +// CHECK: } -> tensor<16x!ptr.ptr<#ptr.generic_space>> +// CHECK: [[VAR_31_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_30_]], [[VAR_10_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>, tensor<16xi32>) outs([[VAR_30_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>) { +// CHECK: ^bb0([[IN_37_:%.+]]: !ptr.ptr<#ptr.generic_space>, [[IN_38_:%.+]]: i32, [[IN_39_:%.+]]: !ptr.ptr<#ptr.generic_space>): // CHECK: [[VAR_35_14_:%.+]] = arith.muli [[IN_38_]], [[VAR_1_]] : i32 -// CHECK: [[VAR_36_8_:%.+]] = tptr.ptradd [[IN_37_]] [[VAR_35_14_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> -// CHECK: linalg.yield [[VAR_36_8_]] : !ptr.ptr<#tptr.default_memory_space> -// CHECK: } -> tensor<16x!ptr.ptr<#tptr.default_memory_space>> -// CHECK: [[VAR_32_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_31_]], [[VAR_29_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>, tensor<16xi1>) outs([[VAR_7_]] : tensor<16xi32>) { -// CHECK: ^bb0([[IN_40_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_41_:%.+]]: i1, [[IN_42_:%.+]]: i32): -// CHECK-DAG: [[VAR_35_15_:%.+]] = tptr.to_memref [[IN_40_]] : <#tptr.default_memory_space> to memref<1xi32> +// CHECK: [[VAR_36_8_:%.+]] = tptr.ptradd [[IN_37_]] [[VAR_35_14_]] : <#ptr.generic_space>, i32 to <#ptr.generic_space> +// CHECK: linalg.yield [[VAR_36_8_]] : !ptr.ptr<#ptr.generic_space> +// CHECK: } -> tensor<16x!ptr.ptr<#ptr.generic_space>> +// CHECK: [[VAR_32_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_31_]], [[VAR_29_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>, tensor<16xi1>) outs([[VAR_7_]] : tensor<16xi32>) { +// CHECK: ^bb0([[IN_40_:%.+]]: !ptr.ptr<#ptr.generic_space>, [[IN_41_:%.+]]: i1, [[IN_42_:%.+]]: i32): +// CHECK-DAG: [[VAR_35_15_:%.+]] = tptr.to_memref [[IN_40_]] : <#ptr.generic_space> to memref<1xi32> // CHECK-DAG: [[VAR_36_9_:%.+]] = scf.if [[IN_41_]] -> (i32) { // CHECK: [[LOAD_VAR_35_15_MEM_:%.+]] = memref.load [[VAR_35_15_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xi32> // CHECK: scf.yield [[LOAD_VAR_35_15_MEM_]] : i32 @@ -172,17 +172,17 @@ module { // CHECK: } // CHECK: linalg.yield [[VAR_36_9_]] : i32 // CHECK: } -> tensor<16xi32> -// CHECK: [[VAR_33_:%.+]] = linalg.fill ins([[VAR_3_]] : !ptr.ptr<#tptr.default_memory_space>) outs([[VAR_18_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) -> tensor<16x!ptr.ptr<#tptr.default_memory_space>> -// CHECK: [[VAR_34_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_33_]], [[VAR_10_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>, tensor<16xi32>) outs([[VAR_33_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) { -// CHECK: ^bb0([[IN_43_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_44_:%.+]]: i32, [[IN_45_:%.+]]: !ptr.ptr<#tptr.default_memory_space>): +// CHECK: [[VAR_33_:%.+]] = linalg.fill ins([[VAR_3_]] : !ptr.ptr<#ptr.generic_space>) outs([[VAR_18_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>) -> tensor<16x!ptr.ptr<#ptr.generic_space>> +// CHECK: [[VAR_34_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_33_]], [[VAR_10_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>, tensor<16xi32>) outs([[VAR_33_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>) { +// CHECK: ^bb0([[IN_43_:%.+]]: !ptr.ptr<#ptr.generic_space>, [[IN_44_:%.+]]: i32, [[IN_45_:%.+]]: !ptr.ptr<#ptr.generic_space>): // CHECK: [[VAR_35_16_:%.+]] = arith.muli [[IN_44_]], [[VAR_1_]] : i32 -// CHECK: [[VAR_36_10_:%.+]] = tptr.ptradd [[IN_43_]] [[VAR_35_16_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> -// CHECK: linalg.yield [[VAR_36_10_]] : !ptr.ptr<#tptr.default_memory_space> -// CHECK: } -> tensor<16x!ptr.ptr<#tptr.default_memory_space>> -// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_34_]], [[VAR_32_]], [[VAR_29_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>, tensor<16xi32>, tensor<16xi1>) { -// CHECK: ^bb0([[IN_46_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_47_:%.+]]: i32, [[IN_48_:%.+]]: i1): +// CHECK: [[VAR_36_10_:%.+]] = tptr.ptradd [[IN_43_]] [[VAR_35_16_]] : <#ptr.generic_space>, i32 to <#ptr.generic_space> +// CHECK: linalg.yield [[VAR_36_10_]] : !ptr.ptr<#ptr.generic_space> +// CHECK: } -> tensor<16x!ptr.ptr<#ptr.generic_space>> +// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_34_]], [[VAR_32_]], [[VAR_29_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>, tensor<16xi32>, tensor<16xi1>) { +// CHECK: ^bb0([[IN_46_:%.+]]: !ptr.ptr<#ptr.generic_space>, [[IN_47_:%.+]]: i32, [[IN_48_:%.+]]: i1): // CHECK: scf.if [[IN_48_]] { -// CHECK: [[VAR_35_17_:%.+]] = tptr.to_memref [[IN_46_]] : <#tptr.default_memory_space> to memref<1xi32> +// CHECK: [[VAR_35_17_:%.+]] = tptr.to_memref [[IN_46_]] : <#ptr.generic_space> to memref<1xi32> // CHECK: memref.store [[IN_47_]], [[VAR_35_17_]]{{.}}[[CST_0_1_]]{{.}} : memref<1xi32> // CHECK: } // CHECK: linalg.yield diff --git a/test/Conversion/TritonToPtr/dynamic_masked_load_store.mlir b/test/Conversion/TritonToPtr/dynamic_masked_load_store.mlir index 810056cc..eed6c612 100644 --- a/test/Conversion/TritonToPtr/dynamic_masked_load_store.mlir +++ b/test/Conversion/TritonToPtr/dynamic_masked_load_store.mlir @@ -40,9 +40,9 @@ module { // CHECK: linalg.yield [[YIELD]] : i1 // CHECK: } -> tensor<16xi1> -// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%{{.+}}, [[MASK]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>, tensor<16xi1>) -// CHECK: ^bb0(%in: !ptr.ptr<#tptr.default_memory_space>, %in_0: i1, %out: i32): -// CHECK: [[MEMREF:%.+]] = tptr.to_memref %in : <#tptr.default_memory_space> to memref<1xi32> +// CHECK: linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%{{.+}}, [[MASK]] : tensor<16x!ptr.ptr<#ptr.generic_space>>, tensor<16xi1>) +// CHECK: ^bb0(%in: !ptr.ptr<#ptr.generic_space>, %in_0: i1, %out: i32): +// CHECK: [[MEMREF:%.+]] = tptr.to_memref %in : <#ptr.generic_space> to memref<1xi32> // CHECK: [[SCF_IF:%.+]] = scf.if %in_0 -> (i32) { // CHECK: [[LOAD:%.+]] = memref.load [[MEMREF]][%c0] : memref<1xi32> // CHECK: scf.yield [[LOAD]] : i32 @@ -53,9 +53,9 @@ module { // CHECK: } -> tensor<16xi32> // CHECK: linalg.generic -// CHECK: ^bb0(%in: !ptr.ptr<#tptr.default_memory_space>, %in_0: i32, %in_1: i1): +// CHECK: ^bb0(%in: !ptr.ptr<#ptr.generic_space>, %in_0: i32, %in_1: i1): // CHECK: scf.if %in_1 { -// CHECK: [[MEMREF:%.+]] = tptr.to_memref %in : <#tptr.default_memory_space> to memref<1xi32> +// CHECK: [[MEMREF:%.+]] = tptr.to_memref %in : <#ptr.generic_space> to memref<1xi32> // CHECK: memref.store %in_0, [[MEMREF]][%c0] : memref<1xi32> // CHECK: } // CHECK: linalg.yield diff --git a/test/Conversion/TritonToPtr/masked_load_store.mlir b/test/Conversion/TritonToPtr/masked_load_store.mlir index 0180f4b7..099fb659 100644 --- a/test/Conversion/TritonToPtr/masked_load_store.mlir +++ b/test/Conversion/TritonToPtr/masked_load_store.mlir @@ -29,8 +29,8 @@ module { } // CHECK: linalg.generic -// CHECK: ^bb0(%in: !ptr.ptr<#tptr.default_memory_space>, %in_0: i1, %out: i32): -// CHECK: [[MEMREF:%.+]] = tptr.to_memref %in : <#tptr.default_memory_space> to memref<1xi32> +// CHECK: ^bb0(%in: !ptr.ptr<#ptr.generic_space>, %in_0: i1, %out: i32): +// CHECK: [[MEMREF:%.+]] = tptr.to_memref %in : <#ptr.generic_space> to memref<1xi32> // CHECK: [[SCF_IF:%.+]] = scf.if %in_0 -> (i32) { // CHECK: [[LOAD:%.+]] = memref.load [[MEMREF]][%c0] : memref<1xi32> // CHECK: scf.yield [[LOAD]] : i32 @@ -41,9 +41,9 @@ module { // CHECK: } -> tensor<16xi32> // CHECK: linalg.generic -// CHECK: ^bb0(%in: !ptr.ptr<#tptr.default_memory_space>, %in_0: i32, %in_1: i1): +// CHECK: ^bb0(%in: !ptr.ptr<#ptr.generic_space>, %in_0: i32, %in_1: i1): // CHECK: scf.if %in_1 { -// CHECK: [[MEMREF:%.+]] = tptr.to_memref %in : <#tptr.default_memory_space> to memref<1xi32> +// CHECK: [[MEMREF:%.+]] = tptr.to_memref %in : <#ptr.generic_space> to memref<1xi32> // CHECK: memref.store %in_0, [[MEMREF]][%c0] : memref<1xi32> // CHECK: } // CHECK: linalg.yield diff --git a/test/Conversion/TritonToPtr/regular_load_store.mlir b/test/Conversion/TritonToPtr/regular_load_store.mlir index d1612e92..04a901d4 100644 --- a/test/Conversion/TritonToPtr/regular_load_store.mlir +++ b/test/Conversion/TritonToPtr/regular_load_store.mlir @@ -27,8 +27,8 @@ module { // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[VAR_0_:%.+]] = tptr.type_offset f32 : i32 // CHECK-DAG: [[VAR_1_:%.+]] = tptr.type_offset i32 : i32 -// CHECK-DAG: [[VAR_2_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_1_]] : !tt.ptr to !ptr.ptr<#tptr.default_memory_space> -// CHECK-DAG: [[VAR_3_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_0_]] : !tt.ptr to !ptr.ptr<#tptr.default_memory_space> +// CHECK-DAG: [[VAR_2_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_1_]] : !tt.ptr to !ptr.ptr<#ptr.generic_space> +// CHECK-DAG: [[VAR_3_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_0_]] : !tt.ptr to !ptr.ptr<#ptr.generic_space> // CHECK-DAG: [[VAR_4_:%.+]] = tensor.empty() : tensor<1024xi32> // CHECK: [[VAR_5_:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs([[VAR_4_]] : tensor<1024xi32>) { // CHECK: ^bb0([[IN_0_:%.+]]: i32): @@ -36,24 +36,24 @@ module { // CHECK: [[VAR_15_:%.+]] = arith.index_cast [[VAR_14_]] : index to i32 // CHECK: linalg.yield [[VAR_15_]] : i32 // CHECK: } -> tensor<1024xi32> -// CHECK: [[VAR_6_:%.+]] = tensor.empty() : tensor<1024x!ptr.ptr<#tptr.default_memory_space>> -// CHECK: [[VAR_7_:%.+]] = linalg.fill ins([[VAR_3_]] : !ptr.ptr<#tptr.default_memory_space>) outs([[VAR_6_]] : tensor<1024x!ptr.ptr<#tptr.default_memory_space>>) -> tensor<1024x!ptr.ptr<#tptr.default_memory_space>> -// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]], [[VAR_5_]] : tensor<1024x!ptr.ptr<#tptr.default_memory_space>>, tensor<1024xi32>) outs([[VAR_7_]] : tensor<1024x!ptr.ptr<#tptr.default_memory_space>>) { -// CHECK: ^bb0([[IN_1_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_2_:%.+]]: i32, [[IN_3_:%.+]]: !ptr.ptr<#tptr.default_memory_space>): +// CHECK: [[VAR_6_:%.+]] = tensor.empty() : tensor<1024x!ptr.ptr<#ptr.generic_space>> +// CHECK: [[VAR_7_:%.+]] = linalg.fill ins([[VAR_3_]] : !ptr.ptr<#ptr.generic_space>) outs([[VAR_6_]] : tensor<1024x!ptr.ptr<#ptr.generic_space>>) -> tensor<1024x!ptr.ptr<#ptr.generic_space>> +// CHECK: [[VAR_8_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_7_]], [[VAR_5_]] : tensor<1024x!ptr.ptr<#ptr.generic_space>>, tensor<1024xi32>) outs([[VAR_7_]] : tensor<1024x!ptr.ptr<#ptr.generic_space>>) { +// CHECK: ^bb0([[IN_1_:%.+]]: !ptr.ptr<#ptr.generic_space>, [[IN_2_:%.+]]: i32, [[IN_3_:%.+]]: !ptr.ptr<#ptr.generic_space>): // CHECK: [[VAR_14_1_:%.+]] = arith.muli [[IN_2_]], [[VAR_1_]] : i32 -// CHECK: [[VAR_15_1_:%.+]] = tptr.ptradd [[IN_1_]] [[VAR_14_1_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> -// CHECK: linalg.yield [[VAR_15_1_]] : !ptr.ptr<#tptr.default_memory_space> -// CHECK: } -> tensor<1024x!ptr.ptr<#tptr.default_memory_space>> -// CHECK: [[VAR_9_:%.+]] = linalg.fill ins([[VAR_2_]] : !ptr.ptr<#tptr.default_memory_space>) outs([[VAR_6_]] : tensor<1024x!ptr.ptr<#tptr.default_memory_space>>) -> tensor<1024x!ptr.ptr<#tptr.default_memory_space>> -// CHECK: [[VAR_10_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_9_]], [[VAR_5_]] : tensor<1024x!ptr.ptr<#tptr.default_memory_space>>, tensor<1024xi32>) outs([[VAR_9_]] : tensor<1024x!ptr.ptr<#tptr.default_memory_space>>) { -// CHECK: ^bb0([[IN_4_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_5_:%.+]]: i32, [[IN_6_:%.+]]: !ptr.ptr<#tptr.default_memory_space>): +// CHECK: [[VAR_15_1_:%.+]] = tptr.ptradd [[IN_1_]] [[VAR_14_1_]] : <#ptr.generic_space>, i32 to <#ptr.generic_space> +// CHECK: linalg.yield [[VAR_15_1_]] : !ptr.ptr<#ptr.generic_space> +// CHECK: } -> tensor<1024x!ptr.ptr<#ptr.generic_space>> +// CHECK: [[VAR_9_:%.+]] = linalg.fill ins([[VAR_2_]] : !ptr.ptr<#ptr.generic_space>) outs([[VAR_6_]] : tensor<1024x!ptr.ptr<#ptr.generic_space>>) -> tensor<1024x!ptr.ptr<#ptr.generic_space>> +// CHECK: [[VAR_10_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_9_]], [[VAR_5_]] : tensor<1024x!ptr.ptr<#ptr.generic_space>>, tensor<1024xi32>) outs([[VAR_9_]] : tensor<1024x!ptr.ptr<#ptr.generic_space>>) { +// CHECK: ^bb0([[IN_4_:%.+]]: !ptr.ptr<#ptr.generic_space>, [[IN_5_:%.+]]: i32, [[IN_6_:%.+]]: !ptr.ptr<#ptr.generic_space>): // CHECK: [[VAR_14_2_:%.+]] = arith.muli [[IN_5_]], [[VAR_0_]] : i32 -// CHECK: [[VAR_15_2_:%.+]] = tptr.ptradd [[IN_4_]] [[VAR_14_2_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> -// CHECK: linalg.yield [[VAR_15_2_]] : !ptr.ptr<#tptr.default_memory_space> -// CHECK: } -> tensor<1024x!ptr.ptr<#tptr.default_memory_space>> -// CHECK: [[VAR_11_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_8_]] : tensor<1024x!ptr.ptr<#tptr.default_memory_space>>) outs([[VAR_4_]] : tensor<1024xi32>) { -// CHECK: ^bb0([[IN_7_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_8_:%.+]]: i32): -// CHECK: [[VAR_14_3_:%.+]] = tptr.to_memref [[IN_7_]] : <#tptr.default_memory_space> to memref<1xi32> +// CHECK: [[VAR_15_2_:%.+]] = tptr.ptradd [[IN_4_]] [[VAR_14_2_]] : <#ptr.generic_space>, i32 to <#ptr.generic_space> +// CHECK: linalg.yield [[VAR_15_2_]] : !ptr.ptr<#ptr.generic_space> +// CHECK: } -> tensor<1024x!ptr.ptr<#ptr.generic_space>> +// CHECK: [[VAR_11_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_8_]] : tensor<1024x!ptr.ptr<#ptr.generic_space>>) outs([[VAR_4_]] : tensor<1024xi32>) { +// CHECK: ^bb0([[IN_7_:%.+]]: !ptr.ptr<#ptr.generic_space>, [[IN_8_:%.+]]: i32): +// CHECK: [[VAR_14_3_:%.+]] = tptr.to_memref [[IN_7_]] : <#ptr.generic_space> to memref<1xi32> // CHECK: [[VAR_15_2_:%.+]] = memref.load [[VAR_14_3_]]{{.}}[[CST_0_]]{{.}} : memref<1xi32> // CHECK: linalg.yield [[VAR_15_2_]] : i32 // CHECK: } -> tensor<1024xi32> @@ -63,9 +63,9 @@ module { // CHECK: [[VAR_14_4_:%.+]] = arith.bitcast [[IN_9_]] : i32 to f32 // CHECK: linalg.yield [[VAR_14_4_]] : f32 // CHECK: } -> tensor<1024xf32> -// CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_10_]], [[VAR_13_]] : tensor<1024x!ptr.ptr<#tptr.default_memory_space>>, tensor<1024xf32>) { -// CHECK: ^bb0([[IN_11_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_12_:%.+]]: f32): -// CHECK: [[VAR_14_5_:%.+]] = tptr.to_memref [[IN_11_]] : <#tptr.default_memory_space> to memref<1xf32> +// CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_10_]], [[VAR_13_]] : tensor<1024x!ptr.ptr<#ptr.generic_space>>, tensor<1024xf32>) { +// CHECK: ^bb0([[IN_11_:%.+]]: !ptr.ptr<#ptr.generic_space>, [[IN_12_:%.+]]: f32): +// CHECK: [[VAR_14_5_:%.+]] = tptr.to_memref [[IN_11_]] : <#ptr.generic_space> to memref<1xf32> // CHECK: memref.store [[IN_12_]], [[VAR_14_5_]]{{.}}[[CST_0_]]{{.}} : memref<1xf32> // CHECK: linalg.yield // CHECK: } diff --git a/test/Conversion/TritonToPtr/triton_tensor_ptr_ops.mlir b/test/Conversion/TritonToPtr/triton_tensor_ptr_ops.mlir index dfcab2b2..5462b2bf 100644 --- a/test/Conversion/TritonToPtr/triton_tensor_ptr_ops.mlir +++ b/test/Conversion/TritonToPtr/triton_tensor_ptr_ops.mlir @@ -47,8 +47,8 @@ module { // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[VAR_1_:%.+]] = tptr.type_offset i32 : i32 // CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 -// CHECK-DAG: [[VAR_2_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_1_]] : !tt.ptr to !ptr.ptr<#tptr.default_memory_space> -// CHECK-DAG: [[VAR_3_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_0_]] : !tt.ptr to !ptr.ptr<#tptr.default_memory_space> +// CHECK-DAG: [[VAR_2_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_1_]] : !tt.ptr to !ptr.ptr<#ptr.generic_space> +// CHECK-DAG: [[VAR_3_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_0_]] : !tt.ptr to !ptr.ptr<#ptr.generic_space> // CHECK-DAG: [[VAR_4_:%.+]] = tensor.empty() : tensor<16xi32> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_5_:%.+]] = linalg.fill ins([[CST_2_]] : i32) outs([[VAR_4_]] : tensor<16xi32>) -> tensor<16xi32> @@ -58,17 +58,17 @@ module { // CHECK: [[VAR_23_:%.+]] = arith.index_cast [[VAR_22_]] : index to i32 // CHECK: linalg.yield [[VAR_23_]] : i32 // CHECK: } -> tensor<16xi32> -// CHECK: [[VAR_7_:%.+]] = tensor.empty() : tensor<16x!ptr.ptr<#tptr.default_memory_space>> -// CHECK: [[VAR_8_:%.+]] = linalg.fill ins([[VAR_3_]] : !ptr.ptr<#tptr.default_memory_space>) outs([[VAR_7_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) -> tensor<16x!ptr.ptr<#tptr.default_memory_space>> -// CHECK: [[VAR_9_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_8_]], [[VAR_6_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>, tensor<16xi32>) outs([[VAR_8_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) { -// CHECK: ^bb0([[IN_1_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_2_:%.+]]: i32, [[IN_3_:%.+]]: !ptr.ptr<#tptr.default_memory_space>): +// CHECK: [[VAR_7_:%.+]] = tensor.empty() : tensor<16x!ptr.ptr<#ptr.generic_space>> +// CHECK: [[VAR_8_:%.+]] = linalg.fill ins([[VAR_3_]] : !ptr.ptr<#ptr.generic_space>) outs([[VAR_7_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>) -> tensor<16x!ptr.ptr<#ptr.generic_space>> +// CHECK: [[VAR_9_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_8_]], [[VAR_6_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>, tensor<16xi32>) outs([[VAR_8_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>) { +// CHECK: ^bb0([[IN_1_:%.+]]: !ptr.ptr<#ptr.generic_space>, [[IN_2_:%.+]]: i32, [[IN_3_:%.+]]: !ptr.ptr<#ptr.generic_space>): // CHECK: [[VAR_22_1_:%.+]] = arith.muli [[IN_2_]], [[VAR_1_]] : i32 -// CHECK: [[VAR_23_1_:%.+]] = tptr.ptradd [[IN_1_]] [[VAR_22_1_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> -// CHECK: linalg.yield [[VAR_23_1_]] : !ptr.ptr<#tptr.default_memory_space> -// CHECK: } -> tensor<16x!ptr.ptr<#tptr.default_memory_space>> -// CHECK: [[VAR_10_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_9_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) outs([[VAR_4_]] : tensor<16xi32>) { -// CHECK: ^bb0([[IN_4_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_5_:%.+]]: i32): -// CHECK: [[VAR_22_2_:%.+]] = tptr.to_memref [[IN_4_]] : <#tptr.default_memory_space> to memref<1xi32> +// CHECK: [[VAR_23_1_:%.+]] = tptr.ptradd [[IN_1_]] [[VAR_22_1_]] : <#ptr.generic_space>, i32 to <#ptr.generic_space> +// CHECK: linalg.yield [[VAR_23_1_]] : !ptr.ptr<#ptr.generic_space> +// CHECK: } -> tensor<16x!ptr.ptr<#ptr.generic_space>> +// CHECK: [[VAR_10_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_9_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>) outs([[VAR_4_]] : tensor<16xi32>) { +// CHECK: ^bb0([[IN_4_:%.+]]: !ptr.ptr<#ptr.generic_space>, [[IN_5_:%.+]]: i32): +// CHECK: [[VAR_22_2_:%.+]] = tptr.to_memref [[IN_4_]] : <#ptr.generic_space> to memref<1xi32> // CHECK: [[VAR_23_1_:%.+]] = memref.load [[VAR_22_2_]]{{.}}[[CST_0_]]{{.}} : memref<1xi32> // CHECK: linalg.yield [[VAR_23_1_]] : i32 // CHECK: } -> tensor<16xi32> @@ -78,27 +78,27 @@ module { // CHECK: [[VAR_22_3_:%.+]] = arith.extsi [[IN_6_]] : i32 to i64 // CHECK: linalg.yield [[VAR_22_3_]] : i64 // CHECK: } -> tensor<16xi64> -// CHECK: [[VAR_13_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_12_]] : tensor<16xi64>) outs([[VAR_7_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) { -// CHECK: ^bb0([[IN_8_:%.+]]: i64, [[IN_9_:%.+]]: !ptr.ptr<#tptr.default_memory_space>): -// CHECK: [[VAR_22_4_:%.+]] = tptr.inttoptr [[IN_8_]] : i64 to <#tptr.default_memory_space> -// CHECK: linalg.yield [[VAR_22_4_]] : !ptr.ptr<#tptr.default_memory_space> -// CHECK: } -> tensor<16x!ptr.ptr<#tptr.default_memory_space>> -// CHECK: [[VAR_14_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_13_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) outs([[VAR_4_]] : tensor<16xi32>) { -// CHECK: ^bb0([[IN_10_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_11_:%.+]]: i32): -// CHECK: [[VAR_22_5_:%.+]] = tptr.to_memref [[IN_10_]] : <#tptr.default_memory_space> to memref<1xi32> +// CHECK: [[VAR_13_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_12_]] : tensor<16xi64>) outs([[VAR_7_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>) { +// CHECK: ^bb0([[IN_8_:%.+]]: i64, [[IN_9_:%.+]]: !ptr.ptr<#ptr.generic_space>): +// CHECK: [[VAR_22_4_:%.+]] = tptr.inttoptr [[IN_8_]] : i64 to <#ptr.generic_space> +// CHECK: linalg.yield [[VAR_22_4_]] : !ptr.ptr<#ptr.generic_space> +// CHECK: } -> tensor<16x!ptr.ptr<#ptr.generic_space>> +// CHECK: [[VAR_14_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_13_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>) outs([[VAR_4_]] : tensor<16xi32>) { +// CHECK: ^bb0([[IN_10_:%.+]]: !ptr.ptr<#ptr.generic_space>, [[IN_11_:%.+]]: i32): +// CHECK: [[VAR_22_5_:%.+]] = tptr.to_memref [[IN_10_]] : <#ptr.generic_space> to memref<1xi32> // CHECK: [[VAR_23_1_1_:%.+]] = memref.load [[VAR_22_5_]]{{.}}[[CST_0_]]{{.}} : memref<1xi32> // CHECK: linalg.yield [[VAR_23_1_1_]] : i32 // CHECK: } -> tensor<16xi32> -// CHECK: [[VAR_15_:%.+]] = linalg.fill ins([[VAR_2_]] : !ptr.ptr<#tptr.default_memory_space>) outs([[VAR_7_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) -> tensor<16x!ptr.ptr<#tptr.default_memory_space>> -// CHECK: [[VAR_16_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_15_]], [[VAR_6_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>, tensor<16xi32>) outs([[VAR_15_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) { -// CHECK: ^bb0([[IN_12_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_13_:%.+]]: i32, [[IN_14_:%.+]]: !ptr.ptr<#tptr.default_memory_space>): +// CHECK: [[VAR_15_:%.+]] = linalg.fill ins([[VAR_2_]] : !ptr.ptr<#ptr.generic_space>) outs([[VAR_7_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>) -> tensor<16x!ptr.ptr<#ptr.generic_space>> +// CHECK: [[VAR_16_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_15_]], [[VAR_6_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>, tensor<16xi32>) outs([[VAR_15_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>) { +// CHECK: ^bb0([[IN_12_:%.+]]: !ptr.ptr<#ptr.generic_space>, [[IN_13_:%.+]]: i32, [[IN_14_:%.+]]: !ptr.ptr<#ptr.generic_space>): // CHECK: [[VAR_22_6_:%.+]] = arith.muli [[IN_13_]], [[VAR_1_]] : i32 -// CHECK: [[VAR_23_2_:%.+]] = tptr.ptradd [[IN_12_]] [[VAR_22_6_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> -// CHECK: linalg.yield [[VAR_23_2_]] : !ptr.ptr<#tptr.default_memory_space> -// CHECK: } -> tensor<16x!ptr.ptr<#tptr.default_memory_space>> -// CHECK: [[VAR_17_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_16_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) outs([[VAR_11_]] : tensor<16xi64>) { -// CHECK: ^bb0([[IN_15_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_16_:%.+]]: i64): -// CHECK: [[VAR_22_7_:%.+]] = tptr.ptrtoint [[IN_15_]] : <#tptr.default_memory_space> to i64 +// CHECK: [[VAR_23_2_:%.+]] = tptr.ptradd [[IN_12_]] [[VAR_22_6_]] : <#ptr.generic_space>, i32 to <#ptr.generic_space> +// CHECK: linalg.yield [[VAR_23_2_]] : !ptr.ptr<#ptr.generic_space> +// CHECK: } -> tensor<16x!ptr.ptr<#ptr.generic_space>> +// CHECK: [[VAR_17_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_16_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>) outs([[VAR_11_]] : tensor<16xi64>) { +// CHECK: ^bb0([[IN_15_:%.+]]: !ptr.ptr<#ptr.generic_space>, [[IN_16_:%.+]]: i64): +// CHECK: [[VAR_22_7_:%.+]] = tptr.ptrtoint [[IN_15_]] : <#ptr.generic_space> to i64 // CHECK: linalg.yield [[VAR_22_7_]] : i64 // CHECK: } -> tensor<16xi64> // CHECK: [[VAR_18_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_14_]] : tensor<16xi32>) outs([[VAR_11_]] : tensor<16xi64>) { @@ -111,20 +111,20 @@ module { // CHECK: [[VAR_22_9_:%.+]] = arith.addi [[IN_19_]], [[IN_20_]] : i64 // CHECK: linalg.yield [[VAR_22_9_]] : i64 // CHECK: } -> tensor<16xi64> -// CHECK: [[VAR_20_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_16_]], [[VAR_5_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>, tensor<16xi32>) outs([[VAR_16_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>) { -// CHECK: ^bb0([[IN_22_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_23_:%.+]]: i32, [[IN_24_:%.+]]: !ptr.ptr<#tptr.default_memory_space>): +// CHECK: [[VAR_20_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_16_]], [[VAR_5_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>, tensor<16xi32>) outs([[VAR_16_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>) { +// CHECK: ^bb0([[IN_22_:%.+]]: !ptr.ptr<#ptr.generic_space>, [[IN_23_:%.+]]: i32, [[IN_24_:%.+]]: !ptr.ptr<#ptr.generic_space>): // CHECK: [[VAR_22_10_:%.+]] = arith.muli [[IN_23_]], [[VAR_0_]] : i32 -// CHECK: [[VAR_23_3_:%.+]] = tptr.ptradd [[IN_22_]] [[VAR_22_10_]] : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> -// CHECK: linalg.yield [[VAR_23_3_]] : !ptr.ptr<#tptr.default_memory_space> -// CHECK: } -> tensor<16x!ptr.ptr<#tptr.default_memory_space>> +// CHECK: [[VAR_23_3_:%.+]] = tptr.ptradd [[IN_22_]] [[VAR_22_10_]] : <#ptr.generic_space>, i32 to <#ptr.generic_space> +// CHECK: linalg.yield [[VAR_23_3_]] : !ptr.ptr<#ptr.generic_space> +// CHECK: } -> tensor<16x!ptr.ptr<#ptr.generic_space>> // CHECK: [[VAR_21_:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_19_]] : tensor<16xi64>) outs([[VAR_4_]] : tensor<16xi32>) { // CHECK: ^bb0([[IN_25_:%.+]]: i64, [[IN_26_:%.+]]: i32): // CHECK: [[VAR_22_11_:%.+]] = arith.trunci [[IN_25_]] : i64 to i32 // CHECK: linalg.yield [[VAR_22_11_]] : i32 // CHECK: } -> tensor<16xi32> -// CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_20_]], [[VAR_21_]] : tensor<16x!ptr.ptr<#tptr.default_memory_space>>, tensor<16xi32>) { -// CHECK: ^bb0([[IN_27_:%.+]]: !ptr.ptr<#tptr.default_memory_space>, [[IN_28_:%.+]]: i32): -// CHECK: [[VAR_22_12_:%.+]] = tptr.to_memref [[IN_27_]] : <#tptr.default_memory_space> to memref<1xi32> +// CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_20_]], [[VAR_21_]] : tensor<16x!ptr.ptr<#ptr.generic_space>>, tensor<16xi32>) { +// CHECK: ^bb0([[IN_27_:%.+]]: !ptr.ptr<#ptr.generic_space>, [[IN_28_:%.+]]: i32): +// CHECK: [[VAR_22_12_:%.+]] = tptr.to_memref [[IN_27_]] : <#ptr.generic_space> to memref<1xi32> // CHECK: memref.store [[IN_28_]], [[VAR_22_12_]]{{.}}[[CST_0_]]{{.}} : memref<1xi32> // CHECK: linalg.yield // CHECK: }