Skip to content

[RAISE-BP] Add support for arith.remsi|remui as tt.addptr input #1570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
145 changes: 144 additions & 1 deletion test/Triton/raise-block-pointer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,79 @@ tt.func @test_addptr_broadcast_rank_3(%arg0 : !tt.ptr<f32>) -> tensor<128x2x128x
tt.return %3 : tensor<128x2x128xf32>
}


// CHECK: tt.func public @wrap_side_by_side_masked([[PARAM_0_:%.+]]: !tt.ptr<f32>, [[PARAM_1_:%.+]]: !tt.ptr<f32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) {
// CHECK-DAG: [[CST_6_i32:%.+]] = arith.constant 6 : i32
// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64
// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
// CHECK-DAG: [[CST_2_i32:%.+]] = arith.constant 2 : i32
// CHECK: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index
// CHECK: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : index to i64
// CHECK: [[VAR_3_:%.+]] = arith.muli [[PARAM_3_]], [[CST_2_i32]] : i32
// CHECK: [[VAR_4_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index
// CHECK: [[VAR_5_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index
// CHECK: [[VAR_6_:%.+]] = arith.index_cast [[VAR_5_]] : index to i64
// CHECK: [[VAR_7_:%.+]] = arith.muli [[PARAM_4_]], [[CST_6_i32]] : i32
// CHECK: [[VAR_8_:%.+]] = arith.index_cast [[VAR_4_]] : index to i64
// CHECK: [[VAR_9_:%.+]] = arith.muli [[VAR_8_]], [[VAR_6_]] : i64
// CHECK: [[VAR_10:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[VAR_9_]]], {{\[}}[[VAR_2_]], [[VAR_6_]]], {{\[}}[[VAR_3_]], [[VAR_7_]]] {order = array<i32>} : <tensor<4x4xf32>>
// CHECK: [[VAR_11_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index
// CHECK: [[VAR_12_:%.+]] = arith.index_cast [[VAR_11_]] : index to i64
// CHECK: [[VAR_13_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index
// CHECK: [[VAR_14_:%.+]] = arith.index_cast [[VAR_13_]] : index to i64
// CHECK: [[VAR_15:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_12_]], [[VAR_14_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {order = array<i32>} : <tensor<4x4xf32>>
// CHECK: [[VAR_16:%.+]] = tt.load [[VAR_10]] : !tt.ptr<tensor<4x4xf32>>
// CHECK: tt.store [[VAR_15]], [[VAR_16]] : !tt.ptr<tensor<4x4xf32>>
// CHECK: tt.return
module {
tt.func public @wrap_side_by_side_masked(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) {
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%c4_i32 = arith.constant 4 : i32
%cst_0 = arith.constant dense<2> : tensor<4x1xi32>
%cst_1 = arith.constant dense<6> : tensor<4xi32>
%cst_2 = arith.constant dense<2> : tensor<4xi32>
%0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32>
%1 = arith.addi %0, %cst_2 : tensor<4xi32>
%2 = arith.addi %0, %cst_1 : tensor<4xi32>
%3 = tt.splat %arg2 : i32 -> tensor<4xi32>
%4 = arith.remsi %2, %3 : tensor<4xi32>
%5 = tt.expand_dims %1 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32>
%6 = tt.splat %arg3 : i32 -> tensor<4x1xi32>
%7 = arith.muli %5, %6 : tensor<4x1xi32>
%8 = tt.expand_dims %4 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32>
%9 = tt.splat %arg4 : i32 -> tensor<1x4xi32>
%10 = arith.muli %8, %9 : tensor<1x4xi32>
%11 = tt.broadcast %7 : tensor<4x1xi32> -> tensor<4x4xi32>
%12 = tt.broadcast %10 : tensor<1x4xi32> -> tensor<4x4xi32>
%13 = arith.addi %11, %12 : tensor<4x4xi32>
%14 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<4x4x!tt.ptr<f32>>
%15 = tt.addptr %14, %13 : tensor<4x4x!tt.ptr<f32>>, tensor<4x4xi32>
%16 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32>
%17 = tt.splat %arg5 : i32 -> tensor<4x1xi32>
%18 = arith.muli %17, %16 : tensor<4x1xi32>
%19 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<4x1x!tt.ptr<f32>>
%20 = tt.addptr %19, %18 : tensor<4x1x!tt.ptr<f32>>, tensor<4x1xi32>
%21 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32>
%22 = tt.splat %arg6 : i32 -> tensor<1x4xi32>
%23 = arith.muli %22, %21 : tensor<1x4xi32>
%24 = tt.broadcast %20 : tensor<4x1x!tt.ptr<f32>> -> tensor<4x4x!tt.ptr<f32>>
%25 = tt.broadcast %23 : tensor<1x4xi32> -> tensor<4x4xi32>
%26 = tt.addptr %24, %25 : tensor<4x4x!tt.ptr<f32>>, tensor<4x4xi32>
%27 = arith.cmpi slt, %16, %cst_0 : tensor<4x1xi32>
%28 = tt.broadcast %27 : tensor<4x1xi1> -> tensor<4x4xi1>
%29 = arith.muli %arg3, %c4_i32 : i32
%30 = tt.splat %29 : i32 -> tensor<4x4xi32>
%31 = arith.muli %arg4, %c4_i32 : i32
%32 = tt.splat %31 : i32 -> tensor<4x4xi32>
%34 = tt.load %15 : tensor<4x4x!tt.ptr<f32>>
tt.store %26, %34 : tensor<4x4x!tt.ptr<f32>>
tt.return
}
}


// CHECK: tt.func @test_addptr_for_accumulation([[PARAM_0_:%.+]]: !tt.ptr<bf16>, [[PARAM_1_:%.+]]: !tt.ptr<bf16>, [[PARAM_2_:%.+]]: !tt.ptr<bf16>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32) {
// CHECK: [[CST_3_i32:%.+]] = arith.constant 3 : i32
// CHECK: [[CST_3_:%.+]] = arith.constant 3 : index
Expand Down Expand Up @@ -319,6 +392,77 @@ module {
}
}


// CHECK: tt.func public @wrap_stacked_masked_loop([[PARAM_0_:%.+]]: !tt.ptr<f32>, [[PARAM_1_:%.+]]: !tt.ptr<f32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) {
// CHECK-DAG: [[CST_3_i32:%.+]] = arith.constant 3 : i32
// CHECK-DAG: [[CST_2_i32:%.+]] = arith.constant 2 : i32
// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64
// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
// CHECK: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index
// CHECK: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index
// CHECK: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : index to i64
// CHECK: [[VAR_3_:%.+]] = arith.muli [[PARAM_3_]], [[CST_2_i32]] : i32
// CHECK: [[VAR_4_:%.+]] = arith.index_cast [[VAR_0_]] : index to i64
// CHECK: [[VAR_5_:%.+]] = arith.muli [[VAR_4_]], [[VAR_2_]] : i64
// CHECK: [[VAR_6_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index
// CHECK: [[VAR_7_:%.+]] = arith.index_cast [[VAR_6_]] : index to i64
// CHECK: [[VAR_8_:%.+]] = arith.muli [[PARAM_4_]], [[CST_3_i32]] : i32
// CHECK: [[VAR_9:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[VAR_5_]], [[CST_0_i64]]], {{\[}}[[VAR_2_]], [[VAR_7_]]], {{\[}}[[VAR_3_]], [[VAR_8_]]] {order = array<i32>} : <tensor<4x4xf32>>
// CHECK: [[VAR_10_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index
// CHECK: [[VAR_11_:%.+]] = arith.index_cast [[VAR_10_]] : index to i64
// CHECK: [[VAR_12_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index
// CHECK: [[VAR_13_:%.+]] = arith.index_cast [[VAR_12_]] : index to i64
// CHECK: [[VAR_14:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_11_]], [[VAR_13_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {order = array<i32>} : <tensor<4x4xf32>>
// CHECK: [[VAR_15:%.+]] = tt.load [[VAR_9]] : !tt.ptr<tensor<4x4xf32>>
// CHECK: tt.store [[VAR_14]], [[VAR_15]] : !tt.ptr<tensor<4x4xf32>>
// CHECK: tt.return
module {
tt.func public @wrap_stacked_masked_loop(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) {
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%c4_i32 = arith.constant 4 : i32
%cst_0 = arith.constant dense<3> : tensor<1x4xi32>
%cst_1 = arith.constant dense<3> : tensor<4xi32>
%cst_2 = arith.constant dense<2> : tensor<4xi32>
%0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32>
%1 = arith.addi %0, %cst_2 : tensor<4xi32>
%2 = tt.splat %arg2 : i32 -> tensor<4xi32>
%3 = arith.remui %1, %2 : tensor<4xi32>
%4 = arith.addi %0, %cst_1 : tensor<4xi32>
%5 = tt.expand_dims %3 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32>
%6 = tt.splat %arg3 : i32 -> tensor<4x1xi32>
%7 = arith.muli %5, %6 : tensor<4x1xi32>
%8 = tt.expand_dims %4 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32>
%9 = tt.splat %arg4 : i32 -> tensor<1x4xi32>
%10 = arith.muli %8, %9 : tensor<1x4xi32>
%11 = tt.broadcast %7 : tensor<4x1xi32> -> tensor<4x4xi32>
%12 = tt.broadcast %10 : tensor<1x4xi32> -> tensor<4x4xi32>
%13 = arith.addi %11, %12 : tensor<4x4xi32>
%14 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<4x4x!tt.ptr<f32>>
%15 = tt.addptr %14, %13 : tensor<4x4x!tt.ptr<f32>>, tensor<4x4xi32>
%16 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32>
%17 = tt.splat %arg5 : i32 -> tensor<4x1xi32>
%18 = arith.muli %17, %16 : tensor<4x1xi32>
%19 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<4x1x!tt.ptr<f32>>
%20 = tt.addptr %19, %18 : tensor<4x1x!tt.ptr<f32>>, tensor<4x1xi32>
%21 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32>
%22 = tt.splat %arg6 : i32 -> tensor<1x4xi32>
%23 = arith.muli %22, %21 : tensor<1x4xi32>
%24 = tt.broadcast %20 : tensor<4x1x!tt.ptr<f32>> -> tensor<4x4x!tt.ptr<f32>>
%25 = tt.broadcast %23 : tensor<1x4xi32> -> tensor<4x4xi32>
%26 = tt.addptr %24, %25 : tensor<4x4x!tt.ptr<f32>>, tensor<4x4xi32>
%27 = arith.cmpi slt, %21, %cst_0 : tensor<1x4xi32>
%28 = tt.broadcast %27 : tensor<1x4xi1> -> tensor<4x4xi1>
%29 = arith.muli %arg4, %c4_i32 : i32
%30 = tt.splat %29 : i32 -> tensor<4x4xi32>
%32 = tt.load %15 : tensor<4x4x!tt.ptr<f32>>
tt.store %26, %32 : tensor<4x4x!tt.ptr<f32>>
tt.return
}
}


// CHECK: tt.func @test_addptr_for_more_init_args([[PARAM_0_:%.+]]: !tt.ptr<bf16>, [[PARAM_1_:%.+]]: !tt.ptr<bf16>) {
// CHECK: [[CST_3_i32:%.+]] = arith.constant 3 : i32
// CHECK: [[CST_0_i64:%.+]] = arith.constant 0 : i64
Expand Down Expand Up @@ -423,7 +567,6 @@ module {
}



// CHECK: tt.func @test_addptr_for_used_before_update([[PARAM_0_:%.+]]: !tt.ptr<bf16>) {
// CHECK: [[CST_3_i32:%.+]] = arith.constant 3 : i32
// CHECK: [[CST_0_i64:%.+]] = arith.constant 0 : i64
Expand Down
Loading