Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[triton-raise-block-ptr]: Fix tt.advance offset computation #3287

Merged
merged 8 commits into from
Jan 29, 2025
2 changes: 1 addition & 1 deletion test/Triton/Intel/RaiseToBlockPointers/addptr_dim1.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ module {
// CHECK-DAG: [[VAR_3_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_0_i64]], [[CST_1_i64]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : <tensor<1x256xbf16>>
// CHECK-DAG: [[VAR_4_:%.+]] = tt.addptr [[VAR_2_]], [[VAR_1_]] : tensor<1x256x!tt.ptr<bf16>>, tensor<1x256xi32>
// CHECK-DAG: [[VAR_5_:%.+]] = tt.load [[VAR_3_]] : !tt.ptr<tensor<1x256xbf16>>
// CHECK-DAG: [[VAR_6_:%.+]] = tt.advance [[VAR_3_]], {{\[}}[[CST_0_i32]], [[PARAM_1_]]] : <tensor<1x256xbf16>>
// CHECK-DAG: [[VAR_6_:%.+]] = tt.advance [[VAR_3_]], {{\[}}[[PARAM_1_]], [[CST_0_i32]]] : <tensor<1x256xbf16>>
// CHECK: tt.store [[VAR_6_]], [[VAR_5_]] : !tt.ptr<tensor<1x256xbf16>>
// CHECK: [[VAR_7_:%.+]]:2 = scf.for [[VAR_arg2_:%.+]] = {{.*}} iter_args([[VAR_arg3_:%.+]] = [[CST_0_]], [[VAR_arg4_:%.+]] = [[VAR_4_]]) -> (tensor<4x256xbf16>, tensor<1x256x!tt.ptr<bf16>>) {
// CHECK: [[VAR_9_:%.+]] = tt.broadcast [[VAR_arg4_]] : tensor<1x256x!tt.ptr<bf16>> -> tensor<4x256x!tt.ptr<bf16>>
Expand Down
22 changes: 10 additions & 12 deletions test/Triton/Intel/RaiseToBlockPointers/addptr_for_accumulation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,18 @@ module {
// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : i32
// CHECK-DAG: [[CST_5_:%.+]] = arith.constant 5 : i64
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i64
// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i64
// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
// CHECK: [[VAR_1_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_5_]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {{.*}} : <tensor<4x256xbf16>>
// CHECK: [[VAR_2_:%.+]] = tt.load [[VAR_1_]] : !tt.ptr<tensor<4x256xbf16>>
// CHECK: [[VAR_3_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_5_]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {{.*}} : <tensor<4x256xbf16>>
// CHECK: [[VAR_4_:%.+]]:2 = scf.for [[VAR_arg5_:%.+]] = {{.*}} iter_args([[VAR_arg6_:%.+]] = [[VAR_2_]], [[VAR_arg7_:%.+]] = [[VAR_3_]]) -> (tensor<4x256xbf16>, !tt.ptr<tensor<4x256xbf16>>) {
// CHECK: [[VAR_8_:%.+]] = tt.load [[VAR_arg7_]] : !tt.ptr<tensor<4x256xbf16>>
// CHECK-DAG: [[VAR_9_:%.+]] = arith.addf [[VAR_arg6_]], [[VAR_8_]] : tensor<4x256xbf16>
// CHECK-DAG: [[VAR_10_:%.+]] = tt.advance [[VAR_arg7_]], {{\[}}[[CST_0_i32]], [[CST_3_]]] : <tensor<4x256xbf16>>
// CHECK: scf.yield [[VAR_9_]], [[VAR_10_]] : tensor<4x256xbf16>, !tt.ptr<tensor<4x256xbf16>>
// CHECK: [[VAR_0_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_5_]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {{.*}} : <tensor<4x256xbf16>>
// CHECK: [[VAR_1_:%.+]] = tt.load [[VAR_0_]] : !tt.ptr<tensor<4x256xbf16>>
// CHECK: [[VAR_2_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_5_]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {{.*}} : <tensor<4x256xbf16>>
// CHECK: [[VAR_3_:%.+]]:2 = scf.for [[VAR_arg5_:%.+]] = {{.*}} iter_args([[VAR_arg6_:%.+]] = [[VAR_1_]], [[VAR_arg7_:%.+]] = [[VAR_2_]]) -> (tensor<4x256xbf16>, !tt.ptr<tensor<4x256xbf16>>) {
// CHECK: [[VAR_5_:%.+]] = tt.load [[VAR_arg7_]] : !tt.ptr<tensor<4x256xbf16>>
// CHECK-DAG: [[VAR_6_:%.+]] = arith.addf [[VAR_arg6_]], [[VAR_5_]] : tensor<4x256xbf16>
// CHECK-DAG: [[VAR_7_:%.+]] = tt.advance [[VAR_arg7_]], {{\[}}[[CST_3_]], [[CST_0_i32]]] : <tensor<4x256xbf16>>
// CHECK: scf.yield [[VAR_6_]], [[VAR_7_]] : tensor<4x256xbf16>, !tt.ptr<tensor<4x256xbf16>>
// CHECK: }
// COM: to sizes: [4, 256], strides: [1, [[CST_5_]]{{.}}, offsets: {{.}}[[VAR_5_]], 0], shape: [0, 0], order: [] : <bf16> to tensor<4x256x!tt.ptr<bf16>>
// CHECK: [[VAR_6_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_5_]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {{.*}} : <tensor<4x256xbf16>>
// CHECK: tt.store [[VAR_6_]], [[VAR_4_]]#0 : !tt.ptr<tensor<4x256xbf16>>
// CHECK: [[VAR_4_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_5_]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {{.*}} : <tensor<4x256xbf16>>
// CHECK: tt.store [[VAR_4_]], [[VAR_3_]]#0 : !tt.ptr<tensor<4x256xbf16>>
// CHECK: tt.return
// CHECK: }
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ module {
// CHECK-DAG: [[CST_64_i32:%.+]] = arith.constant 64 : i32
// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
// CHECK-DAG: [[CST_1_i32:%.+]] = arith.constant 1 : i32
// CHECK: [[VAR_17_:%.+]] = arith.muli {{.*}}, [[CST_128_i32]] : i32
// CHECK: [[VAR_18_:%.+]] = arith.muli {{.*}}, [[CST_256_i32]] : i32
// CHECK: [[VAR_20_:%.+]] = arith.extsi [[PARAM_6_]] : i32 to i64
// CHECK: [[VAR_21_:%.+]] = arith.extsi [[PARAM_7_]] : i32 to i64
// CHECK: [[VAR_22_:%.+]] = arith.divui {{.*}}, [[PARAM_6_]] : i32
Expand All @@ -122,12 +124,20 @@ module {
// CHECK-DAG: [[VAR_41_:%.+]] = tt.load [[VAR_arg15_]] : !tt.ptr<tensor<64x256xbf16>>
// CHECK: [[VAR_42_:%.+]] = tt.dot [[VAR_40_]], [[VAR_41_]], [[VAR_cst_]], inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xf32>
// CHECK-DAG: [[VAR_43_:%.+]] = arith.addf [[VAR_arg13_]], [[VAR_42_]] : tensor<128x256xf32>
// CHECK-DAG: [[VAR_44_:%.+]] = tt.advance [[VAR_arg14_]], {{\[}}[[CST_0_i32]], [[VAR_29_]]] : <tensor<128x64xbf16>>
// CHECK-DAG: [[VAR_45_:%.+]] = tt.advance [[VAR_arg15_]], {{\[}}[[CST_0_i32]], [[VAR_30_]]] : <tensor<64x256xbf16>>
// CHECK: scf.yield [[VAR_43_]], [[VAR_44_]], [[VAR_45_]] : tensor<128x256xf32>, !tt.ptr<tensor<128x64xbf16>>, !tt.ptr<tensor<64x256xbf16>>
// CHECK-DAG: [[VAR_44_:%.+]] = arith.divui [[VAR_29_]], [[PARAM_6_]] : i32
// CHECK-DAG: [[VAR_45_:%.+]] = tt.advance [[VAR_arg14_]], {{\[}}[[VAR_44_]], [[CST_0_i32]]] : <tensor<128x64xbf16>>
// CHECK-DAG: [[VAR_46_:%.+]] = arith.divui [[VAR_30_]], [[PARAM_8_]] : i32
// CHECK-DAG: [[VAR_47_:%.+]] = tt.advance [[VAR_arg15_]], {{\[}}[[VAR_46_]], [[CST_0_i32]]] : <tensor<64x256xbf16>>
// CHECK: scf.yield [[VAR_43_]], [[VAR_45_]], [[VAR_47_]] : tensor<128x256xf32>, !tt.ptr<tensor<128x64xbf16>>, !tt.ptr<tensor<64x256xbf16>>
// CHECK: }
// CHECK-DAG: [[VAR_32_:%.+]] = arith.truncf [[VAR_31_]]#0 : tensor<128x256xf32> to tensor<128x256xbf16>
// CHECK: [[VAR_39_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}{{.*}}], {{\[}}{{.*}}] {{.*}} : <tensor<128x256xbf16>>
// CHECK-DAG: [[VAR_33_:%.+]] = arith.muli [[VAR_17_]], [[PARAM_10_]] : i32
// CHECK-DAG: [[VAR_34_:%.+]] = arith.extsi [[PARAM_10_]] : i32 to i64
// CHECK-DAG: [[VAR_35_:%.+]] = arith.muli [[VAR_18_]], [[PARAM_11_]] : i32
// CHECK-DAG: [[VAR_36_:%.+]] = arith.extsi [[PARAM_11_]] : i32 to i64
// CHECK-DAG: [[VAR_37_:%.+]] = arith.divui [[VAR_33_]], [[PARAM_10_]] : i32
// CHECK-DAG: [[VAR_38_:%.+]] = arith.divui [[VAR_35_]], [[PARAM_11_]] : i32
// CHECK: [[VAR_39_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_34_]], [[VAR_36_]]], {{\[}}[[VAR_37_]], [[VAR_38_]]] {{.*}} : <tensor<128x256xbf16>>
// CHECK: tt.store [[VAR_39_]], [[VAR_32_]] : !tt.ptr<tensor<128x256xbf16>>
// CHECK: tt.return
// CHECK: }
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ tt.func public @wrap_side_by_side_masked(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32
// CHECK: [[VAR_4_:%.+]]:2 = scf.for [[VAR_arg5_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg6_:%.+]] = [[VAR_2_]], [[VAR_arg7_:%.+]] = [[VAR_3_]]) -> (tensor<4x256xbf16>, !tt.ptr<tensor<4x256xbf16>>) {
// CHECK: [[VAR_5_:%.+]] = tt.load [[VAR_arg7_]] : !tt.ptr<tensor<4x256xbf16>>
// CHECK: [[VAR_6_:%.+]] = arith.addf [[VAR_arg6_]], [[VAR_5_]] : tensor<4x256xbf16>
// CHECK: [[VAR_7_:%.+]] = tt.advance [[VAR_arg7_]], {{\[}}[[CST_0_i32]], [[CST_3_i32]]{{\]}} : <tensor<4x256xbf16>>
// CHECK: [[VAR_7_:%.+]] = tt.advance [[VAR_arg7_]], {{\[}}[[CST_3_i32]], [[CST_0_i32]]{{\]}} : <tensor<4x256xbf16>>
// CHECK: scf.yield [[VAR_6_]], [[VAR_7_]] : tensor<4x256xbf16>, !tt.ptr<tensor<4x256xbf16>>
// CHECK: }
// CHECK: [[VAR_5_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_5_i64]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {order = array<i32>} : <tensor<4x256xbf16>>
Expand Down
22 changes: 12 additions & 10 deletions test/Triton/Intel/RaiseToBlockPointers/wraparound_side_by_side.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,18 @@ module {
// CHECK-DAG: [[VAR_14_:%.+]] = arith.extsi [[PARAM_6_]] : i32 to i64
// CHECK-DAG: [[VAR_15_:%.+]] = arith.extsi [[PARAM_7_]] : i32 to i64
// CHECK: [[VAR_16_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_14_]], [[VAR_15_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : <tensor<4x4xf32>>
// CHECK: [[VAR_21_:%.+]] = arith.cmpi slt, [[VAR_13_]], {{.*}} : tensor<4x1xi32>
// CHECK: [[VAR_22_:%.+]] = tt.broadcast [[VAR_21_]] : tensor<4x1xi1> -> tensor<4x4xi1>
// CHECK-DAG: [[VAR_23_:%.+]] = arith.muli [[PARAM_4_]], [[CST_4_i32]] : i32
// CHECK-DAG: [[VAR_24_:%.+]] = arith.muli [[PARAM_5_]], [[CST_4_i32]] : i32
// CHECK-DAG: [[VAR_25_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = {{.*}} iter_args([[VAR_arg9_:%.+]] = [[VAR_12_]], [[VAR_arg10_:%.+]] = [[VAR_16_]]) -> (!tt.ptr<tensor<4x4xf32>>, !tt.ptr<tensor<4x4xf32>>) : i32 {
// CHECK: [[VAR_26_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_22_]], [[CST_]] : !tt.ptr<tensor<4x4xf32>>
// CHECK: tt.store [[VAR_arg10_]], [[VAR_26_]] : !tt.ptr<tensor<4x4xf32>>
// CHECK-DAG: [[VAR_27_:%.+]] = tt.advance [[VAR_arg9_]], {{\[}}[[CST_0_i32]], [[VAR_23_]]] : <tensor<4x4xf32>>
// CHECK-DAG: [[VAR_28_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[CST_0_i32]], [[VAR_24_]]] : <tensor<4x4xf32>>
// CHECK: scf.yield [[VAR_27_]], [[VAR_28_]] : !tt.ptr<tensor<4x4xf32>>, !tt.ptr<tensor<4x4xf32>>
// CHECK: [[VAR_17_:%.+]] = arith.cmpi slt, [[VAR_13_]], {{.*}} : tensor<4x1xi32>
// CHECK: [[VAR_18_:%.+]] = tt.broadcast [[VAR_17_]] : tensor<4x1xi1> -> tensor<4x4xi1>
// CHECK-DAG: [[VAR_19_:%.+]] = arith.muli [[PARAM_4_]], [[CST_4_i32]] : i32
// CHECK-DAG: [[VAR_20_:%.+]] = arith.muli [[PARAM_5_]], [[CST_4_i32]] : i32
// CHECK-DAG: [[VAR_21_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = {{.*}} iter_args([[VAR_arg9_:%.+]] = [[VAR_12_]], [[VAR_arg10_:%.+]] = [[VAR_16_]]) -> (!tt.ptr<tensor<4x4xf32>>, !tt.ptr<tensor<4x4xf32>>) : i32 {
// CHECK: [[VAR_22_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_18_]], [[CST_]] : !tt.ptr<tensor<4x4xf32>>
// CHECK: tt.store [[VAR_arg10_]], [[VAR_22_]] : !tt.ptr<tensor<4x4xf32>>
// CHECK: [[VAR_23_:%.+]] = arith.divui [[VAR_19_]], [[PARAM_4_]] : i32
// CHECK: [[VAR_24_:%.+]] = tt.advance [[VAR_arg9_]], {{\[}}[[VAR_23_]], [[CST_0_i32]]] : <tensor<4x4xf32>>
// CHECK: [[VAR_25_:%.+]] = arith.divui [[VAR_20_]], [[PARAM_6_]] : i32
// CHECK: [[VAR_26_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[VAR_25_]], [[CST_0_i32]]] : <tensor<4x4xf32>>
// CHECK: scf.yield [[VAR_24_]], [[VAR_26_]] : !tt.ptr<tensor<4x4xf32>>, !tt.ptr<tensor<4x4xf32>>
// CHECK: }
// CHECK: tt.return
// CHECK: }
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,11 @@ module {
// CHECK: [[VAR_20_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = {{.*}} iter_args([[VAR_arg9_:%.+]] = [[VAR_12_]], [[VAR_arg10_:%.+]] = [[VAR_16_]]) -> (!tt.ptr<tensor<4x4xf32>>, !tt.ptr<tensor<4x4xf32>>) : i32 {
// CHECK: [[VAR_21_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_18_]], [[CST_]] : !tt.ptr<tensor<4x4xf32>>
// CHECK: tt.store [[VAR_arg10_]], [[VAR_21_]] : !tt.ptr<tensor<4x4xf32>>
// CHECK-DAG: [[VAR_22_:%.+]] = tt.advance [[VAR_arg9_]], {{\[}}[[CST_0_i32]], [[VAR_19_]]] : <tensor<4x4xf32>>
// CHECK-DAG: [[VAR_23_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[CST_0_i32]], [[VAR_19_]]] : <tensor<4x4xf32>>
// CHECK: scf.yield [[VAR_22_]], [[VAR_23_]] : !tt.ptr<tensor<4x4xf32>>, !tt.ptr<tensor<4x4xf32>>
// CHECK: [[VAR_22_:%.+]] = arith.divui [[VAR_19_]], [[PARAM_4_]] : i32
// CHECK: [[VAR_23_:%.+]] = tt.advance [[VAR_arg9_]], {{\[}}[[VAR_22_]], [[CST_0_i32]]] : <tensor<4x4xf32>>
// CHECK: [[VAR_24_:%.+]] = arith.divui [[VAR_19_]], [[PARAM_6_]] : i32
// CHECK-DAG: [[VAR_25_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[VAR_24_]], [[CST_0_i32]]] : <tensor<4x4xf32>>
// CHECK: scf.yield [[VAR_23_]], [[VAR_25_]] : !tt.ptr<tensor<4x4xf32>>, !tt.ptr<tensor<4x4xf32>>
// CHECK: }
// CHECK: tt.return
// CHECK: }
Original file line number Diff line number Diff line change
Expand Up @@ -91,20 +91,21 @@ module {
// CHECK-DAG: [[VAR_16_:%.+]] = tt.expand_dims [[VAR_0_]] {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32>
// CHECK-DAG: [[VAR_17_:%.+]] = arith.extsi [[arg6_]] : i32 to i64
// CHECK-DAG: [[VAR_18_:%.+]] = arith.extsi [[arg7_]] : i32 to i64
// CHECK: [[VAR_25_:%.+]] = tt.make_tensor_ptr [[arg1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_17_]], [[VAR_18_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : <tensor<4x4xf32>>
// CHECK: [[VAR_26_:%.+]] = arith.cmpi slt, [[VAR_16_]], [[VAR_cst_0_]] : tensor<4x1xi32>
// CHECK-DAG: [[VAR_27_:%.+]] = tt.broadcast [[VAR_26_]] : tensor<4x1xi1> -> tensor<4x4xi1>
// CHECK-DAG: [[VAR_28_:%.+]] = arith.muli [[arg4_]], [[CST_4_i32]] : i32
// CHECK: [[VAR_19_:%.+]] = tt.make_tensor_ptr [[arg1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_17_]], [[VAR_18_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : <tensor<4x4xf32>>
// CHECK: [[VAR_20_:%.+]] = arith.cmpi slt, [[VAR_16_]], [[VAR_cst_0_]] : tensor<4x1xi32>
// CHECK-DAG: [[VAR_21_:%.+]] = tt.broadcast [[VAR_20_]] : tensor<4x1xi1> -> tensor<4x4xi1>
// CHECK-DAG: [[VAR_22_:%.+]] = arith.muli [[arg4_]], [[CST_4_i32]] : i32
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_29_:%.+]] = tt.splat [[VAR_28_]] : i32 -> tensor<4x4xi32>
// CHECK-DAG: [[VAR_30_:%.+]] = arith.muli [[arg5_]], [[CST_4_i32]] : i32
// CHECK-DAG: [[VAR_23_:%.+]] = tt.splat [[VAR_22_]] : i32 -> tensor<4x4xi32>
// CHECK-DAG: [[VAR_24_:%.+]] = arith.muli [[arg5_]], [[CST_4_i32]] : i32
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_31_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = [[CST_0_i32]] to [[CST_2_i32]] step [[CST_1_i32]] iter_args([[VAR_arg9_:%.+]] = [[VAR_15_]], [[VAR_arg10_:%.+]] = [[VAR_25_]]) -> (tensor<4x4x!tt.ptr<f32>>, !tt.ptr<tensor<4x4xf32>>)
// CHECK: [[VAR_32_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_27_]], [[VAR_cst_]] : tensor<4x4x!tt.ptr<f32>>
// CHECK: tt.store [[VAR_arg10_]], [[VAR_32_]] : !tt.ptr<tensor<4x4xf32>>
// CHECK-DAG: [[VAR_33_:%.+]] = tt.addptr [[VAR_arg9_]], [[VAR_29_]] : tensor<4x4x!tt.ptr<f32>>, tensor<4x4xi32>
// CHECK-DAG: [[VAR_34_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[CST_0_i32]], [[VAR_30_]]] : <tensor<4x4xf32>>
// CHECK: scf.yield [[VAR_33_]], [[VAR_34_]] : tensor<4x4x!tt.ptr<f32>>, !tt.ptr<tensor<4x4xf32>>
// CHECK-DAG: [[VAR_25_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = [[CST_0_i32]] to [[CST_2_i32]] step [[CST_1_i32]] iter_args([[VAR_arg9_:%.+]] = [[VAR_15_]], [[VAR_arg10_:%.+]] = [[VAR_19_]]) -> (tensor<4x4x!tt.ptr<f32>>, !tt.ptr<tensor<4x4xf32>>)
// CHECK: [[VAR_26_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_21_]], [[VAR_cst_]] : tensor<4x4x!tt.ptr<f32>>
// CHECK: tt.store [[VAR_arg10_]], [[VAR_26_]] : !tt.ptr<tensor<4x4xf32>>
// CHECK-DAG: [[VAR_27_:%.+]] = tt.addptr [[VAR_arg9_]], [[VAR_23_]] : tensor<4x4x!tt.ptr<f32>>, tensor<4x4xi32>
// CHECK-DAG: [[VAR_28_:%.+]] = arith.divui [[VAR_24_]], [[arg6_]] : i32
// CHECK-DAG: [[VAR_29_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[VAR_28_]], [[CST_0_i32]]] : <tensor<4x4xf32>>
// CHECK: scf.yield [[VAR_27_]], [[VAR_29_]] : tensor<4x4x!tt.ptr<f32>>, !tt.ptr<tensor<4x4xf32>>
// CHECK: }
// CHECK: tt.return
// CHECK: }
Loading