diff --git a/test/Triton/Intel/RaiseToBlockPointers/addptr_dim1.mlir b/test/Triton/Intel/RaiseToBlockPointers/addptr_dim1.mlir index a9549c583f..6258fcb741 100644 --- a/test/Triton/Intel/RaiseToBlockPointers/addptr_dim1.mlir +++ b/test/Triton/Intel/RaiseToBlockPointers/addptr_dim1.mlir @@ -85,7 +85,7 @@ module { // CHECK-DAG: [[VAR_3_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_0_i64]], [[CST_1_i64]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : > // CHECK-DAG: [[VAR_4_:%.+]] = tt.addptr [[VAR_2_]], [[VAR_1_]] : tensor<1x256x!tt.ptr>, tensor<1x256xi32> // CHECK-DAG: [[VAR_5_:%.+]] = tt.load [[VAR_3_]] : !tt.ptr> -// CHECK-DAG: [[VAR_6_:%.+]] = tt.advance [[VAR_3_]], {{\[}}[[CST_0_i32]], [[PARAM_1_]]] : > +// CHECK-DAG: [[VAR_6_:%.+]] = tt.advance [[VAR_3_]], {{\[}}[[PARAM_1_]], [[CST_0_i32]]] : > // CHECK: tt.store [[VAR_6_]], [[VAR_5_]] : !tt.ptr> // CHECK: [[VAR_7_:%.+]]:2 = scf.for [[VAR_arg2_:%.+]] = {{.*}} iter_args([[VAR_arg3_:%.+]] = [[CST_0_]], [[VAR_arg4_:%.+]] = [[VAR_4_]]) -> (tensor<4x256xbf16>, tensor<1x256x!tt.ptr>) { // CHECK: [[VAR_9_:%.+]] = tt.broadcast [[VAR_arg4_]] : tensor<1x256x!tt.ptr> -> tensor<4x256x!tt.ptr> diff --git a/test/Triton/Intel/RaiseToBlockPointers/addptr_for_accumulation.mlir b/test/Triton/Intel/RaiseToBlockPointers/addptr_for_accumulation.mlir index 92b893129f..3322a787e4 100644 --- a/test/Triton/Intel/RaiseToBlockPointers/addptr_for_accumulation.mlir +++ b/test/Triton/Intel/RaiseToBlockPointers/addptr_for_accumulation.mlir @@ -64,20 +64,18 @@ module { // CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : i32 // CHECK-DAG: [[CST_5_:%.+]] = arith.constant 5 : i64 // CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i64 -// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i64 // CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32 -// CHECK: [[VAR_1_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_5_]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {{.*}} : > -// CHECK: [[VAR_2_:%.+]] = tt.load [[VAR_1_]] : !tt.ptr> -// CHECK: [[VAR_3_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_5_]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {{.*}} : > -// CHECK: [[VAR_4_:%.+]]:2 = scf.for [[VAR_arg5_:%.+]] = {{.*}} iter_args([[VAR_arg6_:%.+]] = [[VAR_2_]], [[VAR_arg7_:%.+]] = [[VAR_3_]]) -> (tensor<4x256xbf16>, !tt.ptr>) { -// CHECK: [[VAR_8_:%.+]] = tt.load [[VAR_arg7_]] : !tt.ptr> -// CHECK-DAG: [[VAR_9_:%.+]] = arith.addf [[VAR_arg6_]], [[VAR_8_]] : tensor<4x256xbf16> -// CHECK-DAG: [[VAR_10_:%.+]] = tt.advance [[VAR_arg7_]], {{\[}}[[CST_0_i32]], [[CST_3_]]] : > -// CHECK: scf.yield [[VAR_9_]], [[VAR_10_]] : tensor<4x256xbf16>, !tt.ptr> +// CHECK: [[VAR_0_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_5_]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {{.*}} : > +// CHECK: [[VAR_1_:%.+]] = tt.load [[VAR_0_]] : !tt.ptr> +// CHECK: [[VAR_2_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_5_]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {{.*}} : > +// CHECK: [[VAR_3_:%.+]]:2 = scf.for [[VAR_arg5_:%.+]] = {{.*}} iter_args([[VAR_arg6_:%.+]] = [[VAR_1_]], [[VAR_arg7_:%.+]] = [[VAR_2_]]) -> (tensor<4x256xbf16>, !tt.ptr>) { +// CHECK: [[VAR_5_:%.+]] = tt.load [[VAR_arg7_]] : !tt.ptr> +// CHECK-DAG: [[VAR_6_:%.+]] = arith.addf [[VAR_arg6_]], [[VAR_5_]] : tensor<4x256xbf16> +// CHECK-DAG: [[VAR_7_:%.+]] = tt.advance [[VAR_arg7_]], {{\[}}[[CST_3_]], [[CST_0_i32]]] : > +// CHECK: scf.yield [[VAR_6_]], [[VAR_7_]] : tensor<4x256xbf16>, !tt.ptr> // CHECK: } -// COM: to sizes: [4, 256], strides: [1, [[CST_5_]]{{.}}, offsets: {{.}}[[VAR_5_]], 0], shape: [0, 0], order: [] : to tensor<4x256x!tt.ptr> -// CHECK: [[VAR_6_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_5_]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {{.*}} : > -// CHECK: tt.store [[VAR_6_]], [[VAR_4_]]#0 : !tt.ptr> +// CHECK: [[VAR_4_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_5_]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {{.*}} : > +// CHECK: tt.store [[VAR_4_]], [[VAR_3_]]#0 : !tt.ptr> // CHECK: tt.return // CHECK: } diff --git a/test/Triton/Intel/RaiseToBlockPointers/kernel-03-matrix-multiplication.mlir b/test/Triton/Intel/RaiseToBlockPointers/kernel-03-matrix-multiplication.mlir index c83eaf9fbc..0b0c52f3fb 100644 --- a/test/Triton/Intel/RaiseToBlockPointers/kernel-03-matrix-multiplication.mlir +++ b/test/Triton/Intel/RaiseToBlockPointers/kernel-03-matrix-multiplication.mlir @@ -106,6 +106,8 @@ module { // CHECK-DAG: [[CST_64_i32:%.+]] = arith.constant 64 : i32 // CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32 // CHECK-DAG: [[CST_1_i32:%.+]] = arith.constant 1 : i32 +// CHECK: [[VAR_17_:%.+]] = arith.muli {{.*}}, [[CST_128_i32]] : i32 +// CHECK: [[VAR_18_:%.+]] = arith.muli {{.*}}, [[CST_256_i32]] : i32 // CHECK: [[VAR_20_:%.+]] = arith.extsi [[PARAM_6_]] : i32 to i64 // CHECK: [[VAR_21_:%.+]] = arith.extsi [[PARAM_7_]] : i32 to i64 // CHECK: [[VAR_22_:%.+]] = arith.divui {{.*}}, [[PARAM_6_]] : i32 @@ -122,12 +124,20 @@ module { // CHECK-DAG: [[VAR_41_:%.+]] = tt.load [[VAR_arg15_]] : !tt.ptr> // CHECK: [[VAR_42_:%.+]] = tt.dot [[VAR_40_]], [[VAR_41_]], [[VAR_cst_]], inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xf32> // CHECK-DAG: [[VAR_43_:%.+]] = arith.addf [[VAR_arg13_]], [[VAR_42_]] : tensor<128x256xf32> -// CHECK-DAG: [[VAR_44_:%.+]] = tt.advance [[VAR_arg14_]], {{\[}}[[CST_0_i32]], [[VAR_29_]]] : > -// CHECK-DAG: [[VAR_45_:%.+]] = tt.advance [[VAR_arg15_]], {{\[}}[[CST_0_i32]], [[VAR_30_]]] : > -// CHECK: scf.yield [[VAR_43_]], [[VAR_44_]], [[VAR_45_]] : tensor<128x256xf32>, !tt.ptr>, !tt.ptr> +// CHECK-DAG: [[VAR_44_:%.+]] = arith.divui [[VAR_29_]], [[PARAM_6_]] : i32 +// CHECK-DAG: [[VAR_45_:%.+]] = tt.advance [[VAR_arg14_]], {{\[}}[[VAR_44_]], [[CST_0_i32]]] : > +// CHECK-DAG: [[VAR_46_:%.+]] = arith.divui [[VAR_30_]], [[PARAM_8_]] : i32 +// CHECK-DAG: [[VAR_47_:%.+]] = tt.advance [[VAR_arg15_]], {{\[}}[[VAR_46_]], [[CST_0_i32]]] : > +// CHECK: scf.yield [[VAR_43_]], [[VAR_45_]], [[VAR_47_]] : tensor<128x256xf32>, !tt.ptr>, !tt.ptr> // CHECK: } // CHECK-DAG: [[VAR_32_:%.+]] = arith.truncf [[VAR_31_]]#0 : tensor<128x256xf32> to tensor<128x256xbf16> -// CHECK: [[VAR_39_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}{{.*}}], {{\[}}{{.*}}] {{.*}} : > +// CHECK-DAG: [[VAR_33_:%.+]] = arith.muli [[VAR_17_]], [[PARAM_10_]] : i32 +// CHECK-DAG: [[VAR_34_:%.+]] = arith.extsi [[PARAM_10_]] : i32 to i64 +// CHECK-DAG: [[VAR_35_:%.+]] = arith.muli [[VAR_18_]], [[PARAM_11_]] : i32 +// CHECK-DAG: [[VAR_36_:%.+]] = arith.extsi [[PARAM_11_]] : i32 to i64 +// CHECK-DAG: [[VAR_37_:%.+]] = arith.divui [[VAR_33_]], [[PARAM_10_]] : i32 +// CHECK-DAG: [[VAR_38_:%.+]] = arith.divui [[VAR_35_]], [[PARAM_11_]] : i32 +// CHECK: [[VAR_39_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_34_]], [[VAR_36_]]], {{\[}}[[VAR_37_]], [[VAR_38_]]] {{.*}} : > // CHECK: tt.store [[VAR_39_]], [[VAR_32_]] : !tt.ptr> // CHECK: tt.return // CHECK: } diff --git a/test/Triton/Intel/RaiseToBlockPointers/raise-block-pointer.mlir b/test/Triton/Intel/RaiseToBlockPointers/raise-block-pointer.mlir index 3f6a7ae0fb..f26c2df9c8 100644 --- a/test/Triton/Intel/RaiseToBlockPointers/raise-block-pointer.mlir +++ b/test/Triton/Intel/RaiseToBlockPointers/raise-block-pointer.mlir @@ -342,7 +342,7 @@ tt.func public @wrap_side_by_side_masked(%arg0: !tt.ptr, %arg1: !tt.ptr (tensor<4x256xbf16>, !tt.ptr>) { // CHECK: [[VAR_5_:%.+]] = tt.load [[VAR_arg7_]] : !tt.ptr> // CHECK: [[VAR_6_:%.+]] = arith.addf [[VAR_arg6_]], [[VAR_5_]] : tensor<4x256xbf16> -// CHECK: [[VAR_7_:%.+]] = tt.advance [[VAR_arg7_]], {{\[}}[[CST_0_i32]], [[CST_3_i32]]{{\]}} : > +// CHECK: [[VAR_7_:%.+]] = tt.advance [[VAR_arg7_]], {{\[}}[[CST_3_i32]], [[CST_0_i32]]{{\]}} : > // CHECK: scf.yield [[VAR_6_]], [[VAR_7_]] : tensor<4x256xbf16>, !tt.ptr> // CHECK: } // CHECK: [[VAR_5_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_5_i64]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {order = array} : > diff --git a/test/Triton/Intel/RaiseToBlockPointers/wraparound_side_by_side.mlir b/test/Triton/Intel/RaiseToBlockPointers/wraparound_side_by_side.mlir index 43d5a3964e..d08a0981e2 100644 --- a/test/Triton/Intel/RaiseToBlockPointers/wraparound_side_by_side.mlir +++ b/test/Triton/Intel/RaiseToBlockPointers/wraparound_side_by_side.mlir @@ -79,16 +79,18 @@ module { // CHECK-DAG: [[VAR_14_:%.+]] = arith.extsi [[PARAM_6_]] : i32 to i64 // CHECK-DAG: [[VAR_15_:%.+]] = arith.extsi [[PARAM_7_]] : i32 to i64 // CHECK: [[VAR_16_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_14_]], [[VAR_15_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : > -// CHECK: [[VAR_21_:%.+]] = arith.cmpi slt, [[VAR_13_]], {{.*}} : tensor<4x1xi32> -// CHECK: [[VAR_22_:%.+]] = tt.broadcast [[VAR_21_]] : tensor<4x1xi1> -> tensor<4x4xi1> -// CHECK-DAG: [[VAR_23_:%.+]] = arith.muli [[PARAM_4_]], [[CST_4_i32]] : i32 -// CHECK-DAG: [[VAR_24_:%.+]] = arith.muli [[PARAM_5_]], [[CST_4_i32]] : i32 -// CHECK-DAG: [[VAR_25_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = {{.*}} iter_args([[VAR_arg9_:%.+]] = [[VAR_12_]], [[VAR_arg10_:%.+]] = [[VAR_16_]]) -> (!tt.ptr>, !tt.ptr>) : i32 { -// CHECK: [[VAR_26_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_22_]], [[CST_]] : !tt.ptr> -// CHECK: tt.store [[VAR_arg10_]], [[VAR_26_]] : !tt.ptr> -// CHECK-DAG: [[VAR_27_:%.+]] = tt.advance [[VAR_arg9_]], {{\[}}[[CST_0_i32]], [[VAR_23_]]] : > -// CHECK-DAG: [[VAR_28_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[CST_0_i32]], [[VAR_24_]]] : > -// CHECK: scf.yield [[VAR_27_]], [[VAR_28_]] : !tt.ptr>, !tt.ptr> +// CHECK: [[VAR_17_:%.+]] = arith.cmpi slt, [[VAR_13_]], {{.*}} : tensor<4x1xi32> +// CHECK: [[VAR_18_:%.+]] = tt.broadcast [[VAR_17_]] : tensor<4x1xi1> -> tensor<4x4xi1> +// CHECK-DAG: [[VAR_19_:%.+]] = arith.muli [[PARAM_4_]], [[CST_4_i32]] : i32 +// CHECK-DAG: [[VAR_20_:%.+]] = arith.muli [[PARAM_5_]], [[CST_4_i32]] : i32 +// CHECK-DAG: [[VAR_21_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = {{.*}} iter_args([[VAR_arg9_:%.+]] = [[VAR_12_]], [[VAR_arg10_:%.+]] = [[VAR_16_]]) -> (!tt.ptr>, !tt.ptr>) : i32 { +// CHECK: [[VAR_22_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_18_]], [[CST_]] : !tt.ptr> +// CHECK: tt.store [[VAR_arg10_]], [[VAR_22_]] : !tt.ptr> +// CHECK: [[VAR_23_:%.+]] = arith.divui [[VAR_19_]], [[PARAM_4_]] : i32 +// CHECK: [[VAR_24_:%.+]] = tt.advance [[VAR_arg9_]], {{\[}}[[VAR_23_]], [[CST_0_i32]]] : > +// CHECK: [[VAR_25_:%.+]] = arith.divui [[VAR_20_]], [[PARAM_6_]] : i32 +// CHECK: [[VAR_26_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[VAR_25_]], [[CST_0_i32]]] : > +// CHECK: scf.yield [[VAR_24_]], [[VAR_26_]] : !tt.ptr>, !tt.ptr> // CHECK: } // CHECK: tt.return // CHECK: } diff --git a/test/Triton/Intel/RaiseToBlockPointers/wraparound_stacked.mlir b/test/Triton/Intel/RaiseToBlockPointers/wraparound_stacked.mlir index 5c210f3ac6..b8a5374fd7 100644 --- a/test/Triton/Intel/RaiseToBlockPointers/wraparound_stacked.mlir +++ b/test/Triton/Intel/RaiseToBlockPointers/wraparound_stacked.mlir @@ -83,9 +83,11 @@ module { // CHECK: [[VAR_20_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = {{.*}} iter_args([[VAR_arg9_:%.+]] = [[VAR_12_]], [[VAR_arg10_:%.+]] = [[VAR_16_]]) -> (!tt.ptr>, !tt.ptr>) : i32 { // CHECK: [[VAR_21_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_18_]], [[CST_]] : !tt.ptr> // CHECK: tt.store [[VAR_arg10_]], [[VAR_21_]] : !tt.ptr> -// CHECK-DAG: [[VAR_22_:%.+]] = tt.advance [[VAR_arg9_]], {{\[}}[[CST_0_i32]], [[VAR_19_]]] : > -// CHECK-DAG: [[VAR_23_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[CST_0_i32]], [[VAR_19_]]] : > -// CHECK: scf.yield [[VAR_22_]], [[VAR_23_]] : !tt.ptr>, !tt.ptr> +// CHECK: [[VAR_22_:%.+]] = arith.divui [[VAR_19_]], [[PARAM_4_]] : i32 +// CHECK: [[VAR_23_:%.+]] = tt.advance [[VAR_arg9_]], {{\[}}[[VAR_22_]], [[CST_0_i32]]] : > +// CHECK: [[VAR_24_:%.+]] = arith.divui [[VAR_19_]], [[PARAM_6_]] : i32 +// CHECK-DAG: [[VAR_25_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[VAR_24_]], [[CST_0_i32]]] : > +// CHECK: scf.yield [[VAR_23_]], [[VAR_25_]] : !tt.ptr>, !tt.ptr> // CHECK: } // CHECK: tt.return // CHECK: } diff --git a/test/Triton/Intel/RaiseToBlockPointers/wraparound_unsupported_add_offset.mlir b/test/Triton/Intel/RaiseToBlockPointers/wraparound_unsupported_add_offset.mlir index 805bbb9ccb..9ce098f09a 100644 --- a/test/Triton/Intel/RaiseToBlockPointers/wraparound_unsupported_add_offset.mlir +++ b/test/Triton/Intel/RaiseToBlockPointers/wraparound_unsupported_add_offset.mlir @@ -91,20 +91,21 @@ module { // CHECK-DAG: [[VAR_16_:%.+]] = tt.expand_dims [[VAR_0_]] {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> // CHECK-DAG: [[VAR_17_:%.+]] = arith.extsi [[arg6_]] : i32 to i64 // CHECK-DAG: [[VAR_18_:%.+]] = arith.extsi [[arg7_]] : i32 to i64 -// CHECK: [[VAR_25_:%.+]] = tt.make_tensor_ptr [[arg1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_17_]], [[VAR_18_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : > -// CHECK: [[VAR_26_:%.+]] = arith.cmpi slt, [[VAR_16_]], [[VAR_cst_0_]] : tensor<4x1xi32> -// CHECK-DAG: [[VAR_27_:%.+]] = tt.broadcast [[VAR_26_]] : tensor<4x1xi1> -> tensor<4x4xi1> -// CHECK-DAG: [[VAR_28_:%.+]] = arith.muli [[arg4_]], [[CST_4_i32]] : i32 +// CHECK: [[VAR_19_:%.+]] = tt.make_tensor_ptr [[arg1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_17_]], [[VAR_18_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : > +// CHECK: [[VAR_20_:%.+]] = arith.cmpi slt, [[VAR_16_]], [[VAR_cst_0_]] : tensor<4x1xi32> +// CHECK-DAG: [[VAR_21_:%.+]] = tt.broadcast [[VAR_20_]] : tensor<4x1xi1> -> tensor<4x4xi1> +// CHECK-DAG: [[VAR_22_:%.+]] = arith.muli [[arg4_]], [[CST_4_i32]] : i32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_29_:%.+]] = tt.splat [[VAR_28_]] : i32 -> tensor<4x4xi32> -// CHECK-DAG: [[VAR_30_:%.+]] = arith.muli [[arg5_]], [[CST_4_i32]] : i32 +// CHECK-DAG: [[VAR_23_:%.+]] = tt.splat [[VAR_22_]] : i32 -> tensor<4x4xi32> +// CHECK-DAG: [[VAR_24_:%.+]] = arith.muli [[arg5_]], [[CST_4_i32]] : i32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_31_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = [[CST_0_i32]] to [[CST_2_i32]] step [[CST_1_i32]] iter_args([[VAR_arg9_:%.+]] = [[VAR_15_]], [[VAR_arg10_:%.+]] = [[VAR_25_]]) -> (tensor<4x4x!tt.ptr>, !tt.ptr>) -// CHECK: [[VAR_32_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_27_]], [[VAR_cst_]] : tensor<4x4x!tt.ptr> -// CHECK: tt.store [[VAR_arg10_]], [[VAR_32_]] : !tt.ptr> -// CHECK-DAG: [[VAR_33_:%.+]] = tt.addptr [[VAR_arg9_]], [[VAR_29_]] : tensor<4x4x!tt.ptr>, tensor<4x4xi32> -// CHECK-DAG: [[VAR_34_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[CST_0_i32]], [[VAR_30_]]] : > -// CHECK: scf.yield [[VAR_33_]], [[VAR_34_]] : tensor<4x4x!tt.ptr>, !tt.ptr> +// CHECK-DAG: [[VAR_25_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = [[CST_0_i32]] to [[CST_2_i32]] step [[CST_1_i32]] iter_args([[VAR_arg9_:%.+]] = [[VAR_15_]], [[VAR_arg10_:%.+]] = [[VAR_19_]]) -> (tensor<4x4x!tt.ptr>, !tt.ptr>) +// CHECK: [[VAR_26_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_21_]], [[VAR_cst_]] : tensor<4x4x!tt.ptr> +// CHECK: tt.store [[VAR_arg10_]], [[VAR_26_]] : !tt.ptr> +// CHECK-DAG: [[VAR_27_:%.+]] = tt.addptr [[VAR_arg9_]], [[VAR_23_]] : tensor<4x4x!tt.ptr>, tensor<4x4xi32> +// CHECK-DAG: [[VAR_28_:%.+]] = arith.divui [[VAR_24_]], [[arg6_]] : i32 +// CHECK-DAG: [[VAR_29_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[VAR_28_]], [[CST_0_i32]]] : > +// CHECK: scf.yield [[VAR_27_]], [[VAR_29_]] : tensor<4x4x!tt.ptr>, !tt.ptr> // CHECK: } // CHECK: tt.return // CHECK: } diff --git a/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp b/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp index 92913b44ce..b5d2935883 100644 --- a/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp +++ b/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp @@ -217,7 +217,6 @@ struct PtrState { source = lhsState.source ? lhsState.source : rhsState.source; Location loc = op->getLoc(); - ArithBuilder abuilder(builder, loc); if (lhsState.scalar && rhsState.scalar) { scalar = @@ -309,7 +308,6 @@ struct PtrState { std::swap(lhs, rhs); Location loc = op->getLoc(); - ArithBuilder abuilder(builder, loc); for (const auto &[offset, stride, dim, size] : llvm::zip(lhs->offsets, lhs->strides, lhs->shape, lhs->sizes)) { @@ -361,20 +359,7 @@ struct PtrState { for (const auto &[offset, stride, dim] : llvm::zip(offsets, strides, shape)) { - if (ttgi::isConstant(stride, 0)) { - newOffsets.push_back(findOrCreateCast( - loc, offset, builder.getIntegerType(offsetBitwidth), builder)); - } else { - Value divOffset = builder.create( - loc, builder.getIntegerType(offsetBitwidth), - findOrCreateCast(loc, offset, - builder.getIntegerType(offsetBitwidth), builder), - findOrCreateCast(loc, stride, - builder.getIntegerType(offsetBitwidth), builder)); - newOffsets.push_back( - findOrCreateCast(loc, getFinalValue(divOffset), - builder.getIntegerType(offsetBitwidth), builder)); - } + newOffsets.push_back(computeOffset(offset, stride, builder, loc)); newStrides.push_back(findOrCreateCast( loc, stride, builder.getIntegerType(shapeAndStridesBitwidth), builder)); @@ -385,6 +370,34 @@ struct PtrState { return findOrCreateMakeTensorPtr(loc, source, newShape, newStrides, newOffsets, order, sizes, builder); } + + Value createTTAdvanceOp(Value ptr, tt::MakeTensorPtrOp makeTPtrOp, + OpBuilder &builder, Location loc) const { + SmallVector newOffsets; + for (const auto &[offset, stride] : + llvm::zip(offsets, makeTPtrOp.getStrides())) + newOffsets.push_back(computeOffset(offset, stride, builder, loc)); + + return builder.createOrFold(loc, ptr.getType(), ptr, + newOffsets); + } + +private: + Value computeOffset(Value offset, Value stride, OpBuilder &builder, + Location loc) const { + if (ttgi::isConstant(stride, 0)) + return findOrCreateCast(loc, offset, + builder.getIntegerType(offsetBitwidth), builder); + + Value divOffset = builder.create( + loc, builder.getIntegerType(offsetBitwidth), + findOrCreateCast(loc, offset, builder.getIntegerType(offsetBitwidth), + builder), + findOrCreateCast(loc, stride, builder.getIntegerType(offsetBitwidth), + builder)); + return findOrCreateCast(loc, getFinalValue(divOffset), + builder.getIntegerType(offsetBitwidth), builder); + } }; #ifndef NDEBUG @@ -698,36 +711,17 @@ struct TritonRaiseBlockPointer return val; }; - // If the ptr has already been mapped (i.e. rewritten into a block pointer), - // rewrite the AddPtrOp using and AdvanceOp. + // If the ptr has already been mapped (i.e. rewritten into a block + // pointer), rewrite the AddPtrOp using and AdvanceOp. if (Value mappedV = ptrMap.lookupOrNull(ptr)) { if (auto makeTPtrOp = mappedV.getDefiningOp()) { - auto offsetType = cast(op.getOffset().getType()); - unsigned rank = offsetType.getRank(); - - SmallVector offsets; - TypeSwitch(op.getOffset().getDefiningOp()) - .Case([&](tt::SplatOp splatOp) { - fillOffsets(splatOp.getSrc(), rank, offsets); - }) - .Case([&](arith::ConstantOp cstOp) { - APInt val = getConstantValue(cstOp); - - fillOffsets(findOrCreateConstant(loc, val.getZExtValue(), - offsetBitwidth, builder), - rank, offsets); - }) - .Default([](Operation *op) { - llvm::errs() << "Operation: " << *op << "\n"; - llvm_unreachable("Unhandled operation"); - }); - - assert(!offsets.empty() && offsets.size() == rank && - "unexpected number of offsets"); + PtrState state; + if (failed(visitOperand(op.getOffset(), state, loc, builder))) + return failure(); Value basePtr = tt::isTensorPointerType(ptr.getType()) ? ptr : mappedV; - auto advanceOp = builder.createOrFold( - loc, basePtr.getType(), basePtr, offsets); + auto advanceOp = + state.createTTAdvanceOp(basePtr, makeTPtrOp, builder, loc); cleanUp.insert(op); ptrMap.map(op.getResult(), advanceOp); @@ -791,7 +785,6 @@ struct TritonRaiseBlockPointer auto resType = cast(makeTPtrOp.getResult().getType()); auto pointeeType = cast(resType.getPointeeType()); ArrayRef shape = pointeeType.getShape(); - ArithBuilder abuilder(builder, loc); for (int i = 0; i < pointeeType.getRank(); i++) { state.sizes.push_back(shape[i]);