Skip to content

Commit 57a10af

Browse files
author
Maxime France-Pillois
authored
[RAISE-BP] Add support for arith.remsi|remui as tt.addptr input (#1570)
- Add minimal support for handling `arith.remsi|remui` as `tt.addptr` input. - Improve handling of unfolded arithmetic operations when evaluating the modulo property and constant values. Closes Issue: #1436 and #1482 --------- Signed-off-by: Maxime France-Pillois <[email protected]>
1 parent 75eee6f commit 57a10af

File tree

2 files changed

+333
-15
lines changed

2 files changed

+333
-15
lines changed

test/Triton/raise-block-pointer.mlir

Lines changed: 144 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,79 @@ tt.func @test_addptr_broadcast_rank_3(%arg0 : !tt.ptr<f32>) -> tensor<128x2x128x
237237
tt.return %3 : tensor<128x2x128xf32>
238238
}
239239

240+
241+
// CHECK: tt.func public @wrap_side_by_side_masked([[PARAM_0_:%.+]]: !tt.ptr<f32>, [[PARAM_1_:%.+]]: !tt.ptr<f32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) {
242+
// CHECK-DAG: [[CST_6_i32:%.+]] = arith.constant 6 : i32
243+
// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64
244+
// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
245+
// CHECK-DAG: [[CST_2_i32:%.+]] = arith.constant 2 : i32
246+
// CHECK: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index
247+
// CHECK: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : index to i64
248+
// CHECK: [[VAR_3_:%.+]] = arith.muli [[PARAM_3_]], [[CST_2_i32]] : i32
249+
// CHECK: [[VAR_4_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index
250+
// CHECK: [[VAR_5_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index
251+
// CHECK: [[VAR_6_:%.+]] = arith.index_cast [[VAR_5_]] : index to i64
252+
// CHECK: [[VAR_7_:%.+]] = arith.muli [[PARAM_4_]], [[CST_6_i32]] : i32
253+
// CHECK: [[VAR_8_:%.+]] = arith.index_cast [[VAR_4_]] : index to i64
254+
// CHECK: [[VAR_9_:%.+]] = arith.muli [[VAR_8_]], [[VAR_6_]] : i64
255+
// CHECK: [[VAR_10:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[VAR_9_]]], {{\[}}[[VAR_2_]], [[VAR_6_]]], {{\[}}[[VAR_3_]], [[VAR_7_]]] {order = array<i32>} : <tensor<4x4xf32>>
256+
// CHECK: [[VAR_11_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index
257+
// CHECK: [[VAR_12_:%.+]] = arith.index_cast [[VAR_11_]] : index to i64
258+
// CHECK: [[VAR_13_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index
259+
// CHECK: [[VAR_14_:%.+]] = arith.index_cast [[VAR_13_]] : index to i64
260+
// CHECK: [[VAR_15:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_12_]], [[VAR_14_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {order = array<i32>} : <tensor<4x4xf32>>
261+
// CHECK: [[VAR_16:%.+]] = tt.load [[VAR_10]] : !tt.ptr<tensor<4x4xf32>>
262+
// CHECK: tt.store [[VAR_15]], [[VAR_16]] : !tt.ptr<tensor<4x4xf32>>
263+
// CHECK: tt.return
264+
module {
265+
tt.func public @wrap_side_by_side_masked(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) {
266+
%c0_i32 = arith.constant 0 : i32
267+
%c1_i32 = arith.constant 1 : i32
268+
%c2_i32 = arith.constant 2 : i32
269+
%c4_i32 = arith.constant 4 : i32
270+
%cst_0 = arith.constant dense<2> : tensor<4x1xi32>
271+
%cst_1 = arith.constant dense<6> : tensor<4xi32>
272+
%cst_2 = arith.constant dense<2> : tensor<4xi32>
273+
%0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32>
274+
%1 = arith.addi %0, %cst_2 : tensor<4xi32>
275+
%2 = arith.addi %0, %cst_1 : tensor<4xi32>
276+
%3 = tt.splat %arg2 : i32 -> tensor<4xi32>
277+
%4 = arith.remsi %2, %3 : tensor<4xi32>
278+
%5 = tt.expand_dims %1 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32>
279+
%6 = tt.splat %arg3 : i32 -> tensor<4x1xi32>
280+
%7 = arith.muli %5, %6 : tensor<4x1xi32>
281+
%8 = tt.expand_dims %4 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32>
282+
%9 = tt.splat %arg4 : i32 -> tensor<1x4xi32>
283+
%10 = arith.muli %8, %9 : tensor<1x4xi32>
284+
%11 = tt.broadcast %7 : tensor<4x1xi32> -> tensor<4x4xi32>
285+
%12 = tt.broadcast %10 : tensor<1x4xi32> -> tensor<4x4xi32>
286+
%13 = arith.addi %11, %12 : tensor<4x4xi32>
287+
%14 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<4x4x!tt.ptr<f32>>
288+
%15 = tt.addptr %14, %13 : tensor<4x4x!tt.ptr<f32>>, tensor<4x4xi32>
289+
%16 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32>
290+
%17 = tt.splat %arg5 : i32 -> tensor<4x1xi32>
291+
%18 = arith.muli %17, %16 : tensor<4x1xi32>
292+
%19 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<4x1x!tt.ptr<f32>>
293+
%20 = tt.addptr %19, %18 : tensor<4x1x!tt.ptr<f32>>, tensor<4x1xi32>
294+
%21 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32>
295+
%22 = tt.splat %arg6 : i32 -> tensor<1x4xi32>
296+
%23 = arith.muli %22, %21 : tensor<1x4xi32>
297+
%24 = tt.broadcast %20 : tensor<4x1x!tt.ptr<f32>> -> tensor<4x4x!tt.ptr<f32>>
298+
%25 = tt.broadcast %23 : tensor<1x4xi32> -> tensor<4x4xi32>
299+
%26 = tt.addptr %24, %25 : tensor<4x4x!tt.ptr<f32>>, tensor<4x4xi32>
300+
%27 = arith.cmpi slt, %16, %cst_0 : tensor<4x1xi32>
301+
%28 = tt.broadcast %27 : tensor<4x1xi1> -> tensor<4x4xi1>
302+
%29 = arith.muli %arg3, %c4_i32 : i32
303+
%30 = tt.splat %29 : i32 -> tensor<4x4xi32>
304+
%31 = arith.muli %arg4, %c4_i32 : i32
305+
%32 = tt.splat %31 : i32 -> tensor<4x4xi32>
306+
%34 = tt.load %15 : tensor<4x4x!tt.ptr<f32>>
307+
tt.store %26, %34 : tensor<4x4x!tt.ptr<f32>>
308+
tt.return
309+
}
310+
}
311+
312+
240313
// CHECK: tt.func @test_addptr_for_accumulation([[PARAM_0_:%.+]]: !tt.ptr<bf16>, [[PARAM_1_:%.+]]: !tt.ptr<bf16>, [[PARAM_2_:%.+]]: !tt.ptr<bf16>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32) {
241314
// CHECK: [[CST_3_i32:%.+]] = arith.constant 3 : i32
242315
// CHECK: [[CST_3_:%.+]] = arith.constant 3 : index
@@ -319,6 +392,77 @@ module {
319392
}
320393
}
321394

395+
396+
// CHECK: tt.func public @wrap_stacked_masked_loop([[PARAM_0_:%.+]]: !tt.ptr<f32>, [[PARAM_1_:%.+]]: !tt.ptr<f32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) {
397+
// CHECK-DAG: [[CST_3_i32:%.+]] = arith.constant 3 : i32
398+
// CHECK-DAG: [[CST_2_i32:%.+]] = arith.constant 2 : i32
399+
// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64
400+
// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
401+
// CHECK: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index
402+
// CHECK: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index
403+
// CHECK: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : index to i64
404+
// CHECK: [[VAR_3_:%.+]] = arith.muli [[PARAM_3_]], [[CST_2_i32]] : i32
405+
// CHECK: [[VAR_4_:%.+]] = arith.index_cast [[VAR_0_]] : index to i64
406+
// CHECK: [[VAR_5_:%.+]] = arith.muli [[VAR_4_]], [[VAR_2_]] : i64
407+
// CHECK: [[VAR_6_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index
408+
// CHECK: [[VAR_7_:%.+]] = arith.index_cast [[VAR_6_]] : index to i64
409+
// CHECK: [[VAR_8_:%.+]] = arith.muli [[PARAM_4_]], [[CST_3_i32]] : i32
410+
// CHECK: [[VAR_9:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[VAR_5_]], [[CST_0_i64]]], {{\[}}[[VAR_2_]], [[VAR_7_]]], {{\[}}[[VAR_3_]], [[VAR_8_]]] {order = array<i32>} : <tensor<4x4xf32>>
411+
// CHECK: [[VAR_10_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index
412+
// CHECK: [[VAR_11_:%.+]] = arith.index_cast [[VAR_10_]] : index to i64
413+
// CHECK: [[VAR_12_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index
414+
// CHECK: [[VAR_13_:%.+]] = arith.index_cast [[VAR_12_]] : index to i64
415+
// CHECK: [[VAR_14:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_11_]], [[VAR_13_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {order = array<i32>} : <tensor<4x4xf32>>
416+
// CHECK: [[VAR_15:%.+]] = tt.load [[VAR_9]] : !tt.ptr<tensor<4x4xf32>>
417+
// CHECK: tt.store [[VAR_14]], [[VAR_15]] : !tt.ptr<tensor<4x4xf32>>
418+
// CHECK: tt.return
419+
module {
420+
tt.func public @wrap_stacked_masked_loop(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) {
421+
%c0_i32 = arith.constant 0 : i32
422+
%c1_i32 = arith.constant 1 : i32
423+
%c2_i32 = arith.constant 2 : i32
424+
%c4_i32 = arith.constant 4 : i32
425+
%cst_0 = arith.constant dense<3> : tensor<1x4xi32>
426+
%cst_1 = arith.constant dense<3> : tensor<4xi32>
427+
%cst_2 = arith.constant dense<2> : tensor<4xi32>
428+
%0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32>
429+
%1 = arith.addi %0, %cst_2 : tensor<4xi32>
430+
%2 = tt.splat %arg2 : i32 -> tensor<4xi32>
431+
%3 = arith.remui %1, %2 : tensor<4xi32>
432+
%4 = arith.addi %0, %cst_1 : tensor<4xi32>
433+
%5 = tt.expand_dims %3 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32>
434+
%6 = tt.splat %arg3 : i32 -> tensor<4x1xi32>
435+
%7 = arith.muli %5, %6 : tensor<4x1xi32>
436+
%8 = tt.expand_dims %4 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32>
437+
%9 = tt.splat %arg4 : i32 -> tensor<1x4xi32>
438+
%10 = arith.muli %8, %9 : tensor<1x4xi32>
439+
%11 = tt.broadcast %7 : tensor<4x1xi32> -> tensor<4x4xi32>
440+
%12 = tt.broadcast %10 : tensor<1x4xi32> -> tensor<4x4xi32>
441+
%13 = arith.addi %11, %12 : tensor<4x4xi32>
442+
%14 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<4x4x!tt.ptr<f32>>
443+
%15 = tt.addptr %14, %13 : tensor<4x4x!tt.ptr<f32>>, tensor<4x4xi32>
444+
%16 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32>
445+
%17 = tt.splat %arg5 : i32 -> tensor<4x1xi32>
446+
%18 = arith.muli %17, %16 : tensor<4x1xi32>
447+
%19 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<4x1x!tt.ptr<f32>>
448+
%20 = tt.addptr %19, %18 : tensor<4x1x!tt.ptr<f32>>, tensor<4x1xi32>
449+
%21 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32>
450+
%22 = tt.splat %arg6 : i32 -> tensor<1x4xi32>
451+
%23 = arith.muli %22, %21 : tensor<1x4xi32>
452+
%24 = tt.broadcast %20 : tensor<4x1x!tt.ptr<f32>> -> tensor<4x4x!tt.ptr<f32>>
453+
%25 = tt.broadcast %23 : tensor<1x4xi32> -> tensor<4x4xi32>
454+
%26 = tt.addptr %24, %25 : tensor<4x4x!tt.ptr<f32>>, tensor<4x4xi32>
455+
%27 = arith.cmpi slt, %21, %cst_0 : tensor<1x4xi32>
456+
%28 = tt.broadcast %27 : tensor<1x4xi1> -> tensor<4x4xi1>
457+
%29 = arith.muli %arg4, %c4_i32 : i32
458+
%30 = tt.splat %29 : i32 -> tensor<4x4xi32>
459+
%32 = tt.load %15 : tensor<4x4x!tt.ptr<f32>>
460+
tt.store %26, %32 : tensor<4x4x!tt.ptr<f32>>
461+
tt.return
462+
}
463+
}
464+
465+
322466
// CHECK: tt.func @test_addptr_for_more_init_args([[PARAM_0_:%.+]]: !tt.ptr<bf16>, [[PARAM_1_:%.+]]: !tt.ptr<bf16>) {
323467
// CHECK: [[CST_3_i32:%.+]] = arith.constant 3 : i32
324468
// CHECK: [[CST_0_i64:%.+]] = arith.constant 0 : i64
@@ -423,7 +567,6 @@ module {
423567
}
424568

425569

426-
427570
// CHECK: tt.func @test_addptr_for_used_before_update([[PARAM_0_:%.+]]: !tt.ptr<bf16>) {
428571
// CHECK: [[CST_3_i32:%.+]] = arith.constant 3 : i32
429572
// CHECK: [[CST_0_i64:%.+]] = arith.constant 0 : i64

0 commit comments

Comments
 (0)