@@ -237,6 +237,79 @@ tt.func @test_addptr_broadcast_rank_3(%arg0 : !tt.ptr<f32>) -> tensor<128x2x128x
237237 tt.return %3 : tensor <128 x2 x128 xf32 >
238238}
239239
240+
241+ // CHECK: tt.func public @wrap_side_by_side_masked([[PARAM_0_:%.+]]: !tt.ptr<f32>, [[PARAM_1_:%.+]]: !tt.ptr<f32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) {
242+ // CHECK-DAG: [[CST_6_i32:%.+]] = arith.constant 6 : i32
243+ // CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64
244+ // CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
245+ // CHECK-DAG: [[CST_2_i32:%.+]] = arith.constant 2 : i32
246+ // CHECK: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index
247+ // CHECK: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : index to i64
248+ // CHECK: [[VAR_3_:%.+]] = arith.muli [[PARAM_3_]], [[CST_2_i32]] : i32
249+ // CHECK: [[VAR_4_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index
250+ // CHECK: [[VAR_5_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index
251+ // CHECK: [[VAR_6_:%.+]] = arith.index_cast [[VAR_5_]] : index to i64
252+ // CHECK: [[VAR_7_:%.+]] = arith.muli [[PARAM_4_]], [[CST_6_i32]] : i32
253+ // CHECK: [[VAR_8_:%.+]] = arith.index_cast [[VAR_4_]] : index to i64
254+ // CHECK: [[VAR_9_:%.+]] = arith.muli [[VAR_8_]], [[VAR_6_]] : i64
255+ // CHECK: [[VAR_10:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[VAR_9_]]], {{\[}}[[VAR_2_]], [[VAR_6_]]], {{\[}}[[VAR_3_]], [[VAR_7_]]] {order = array<i32>} : <tensor<4x4xf32>>
256+ // CHECK: [[VAR_11_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index
257+ // CHECK: [[VAR_12_:%.+]] = arith.index_cast [[VAR_11_]] : index to i64
258+ // CHECK: [[VAR_13_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index
259+ // CHECK: [[VAR_14_:%.+]] = arith.index_cast [[VAR_13_]] : index to i64
260+ // CHECK: [[VAR_15:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_12_]], [[VAR_14_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {order = array<i32>} : <tensor<4x4xf32>>
261+ // CHECK: [[VAR_16:%.+]] = tt.load [[VAR_10]] : !tt.ptr<tensor<4x4xf32>>
262+ // CHECK: tt.store [[VAR_15]], [[VAR_16]] : !tt.ptr<tensor<4x4xf32>>
263+ // CHECK: tt.return
264+ module {
265+ tt.func public @wrap_side_by_side_masked (%arg0: !tt.ptr <f32 >, %arg1: !tt.ptr <f32 >, %arg2: i32 , %arg3: i32 , %arg4: i32 , %arg5: i32 , %arg6: i32 ) {
266+ %c0_i32 = arith.constant 0 : i32
267+ %c1_i32 = arith.constant 1 : i32
268+ %c2_i32 = arith.constant 2 : i32
269+ %c4_i32 = arith.constant 4 : i32
270+ %cst_0 = arith.constant dense <2 > : tensor <4 x1 xi32 >
271+ %cst_1 = arith.constant dense <6 > : tensor <4 xi32 >
272+ %cst_2 = arith.constant dense <2 > : tensor <4 xi32 >
273+ %0 = tt.make_range {end = 4 : i32 , start = 0 : i32 } : tensor <4 xi32 >
274+ %1 = arith.addi %0 , %cst_2 : tensor <4 xi32 >
275+ %2 = arith.addi %0 , %cst_1 : tensor <4 xi32 >
276+ %3 = tt.splat %arg2 : i32 -> tensor <4 xi32 >
277+ %4 = arith.remsi %2 , %3 : tensor <4 xi32 >
278+ %5 = tt.expand_dims %1 {axis = 1 : i32 } : tensor <4 xi32 > -> tensor <4 x1 xi32 >
279+ %6 = tt.splat %arg3 : i32 -> tensor <4 x1 xi32 >
280+ %7 = arith.muli %5 , %6 : tensor <4 x1 xi32 >
281+ %8 = tt.expand_dims %4 {axis = 0 : i32 } : tensor <4 xi32 > -> tensor <1 x4 xi32 >
282+ %9 = tt.splat %arg4 : i32 -> tensor <1 x4 xi32 >
283+ %10 = arith.muli %8 , %9 : tensor <1 x4 xi32 >
284+ %11 = tt.broadcast %7 : tensor <4 x1 xi32 > -> tensor <4 x4 xi32 >
285+ %12 = tt.broadcast %10 : tensor <1 x4 xi32 > -> tensor <4 x4 xi32 >
286+ %13 = arith.addi %11 , %12 : tensor <4 x4 xi32 >
287+ %14 = tt.splat %arg0 : !tt.ptr <f32 > -> tensor <4 x4 x!tt.ptr <f32 >>
288+ %15 = tt.addptr %14 , %13 : tensor <4 x4 x!tt.ptr <f32 >>, tensor <4 x4 xi32 >
289+ %16 = tt.expand_dims %0 {axis = 1 : i32 } : tensor <4 xi32 > -> tensor <4 x1 xi32 >
290+ %17 = tt.splat %arg5 : i32 -> tensor <4 x1 xi32 >
291+ %18 = arith.muli %17 , %16 : tensor <4 x1 xi32 >
292+ %19 = tt.splat %arg1 : !tt.ptr <f32 > -> tensor <4 x1 x!tt.ptr <f32 >>
293+ %20 = tt.addptr %19 , %18 : tensor <4 x1 x!tt.ptr <f32 >>, tensor <4 x1 xi32 >
294+ %21 = tt.expand_dims %0 {axis = 0 : i32 } : tensor <4 xi32 > -> tensor <1 x4 xi32 >
295+ %22 = tt.splat %arg6 : i32 -> tensor <1 x4 xi32 >
296+ %23 = arith.muli %22 , %21 : tensor <1 x4 xi32 >
297+ %24 = tt.broadcast %20 : tensor <4 x1 x!tt.ptr <f32 >> -> tensor <4 x4 x!tt.ptr <f32 >>
298+ %25 = tt.broadcast %23 : tensor <1 x4 xi32 > -> tensor <4 x4 xi32 >
299+ %26 = tt.addptr %24 , %25 : tensor <4 x4 x!tt.ptr <f32 >>, tensor <4 x4 xi32 >
300+ %27 = arith.cmpi slt , %16 , %cst_0 : tensor <4 x1 xi32 >
301+ %28 = tt.broadcast %27 : tensor <4 x1 xi1 > -> tensor <4 x4 xi1 >
302+ %29 = arith.muli %arg3 , %c4_i32 : i32
303+ %30 = tt.splat %29 : i32 -> tensor <4 x4 xi32 >
304+ %31 = arith.muli %arg4 , %c4_i32 : i32
305+ %32 = tt.splat %31 : i32 -> tensor <4 x4 xi32 >
306+ %34 = tt.load %15 : tensor <4 x4 x!tt.ptr <f32 >>
307+ tt.store %26 , %34 : tensor <4 x4 x!tt.ptr <f32 >>
308+ tt.return
309+ }
310+ }
311+
312+
240313// CHECK: tt.func @test_addptr_for_accumulation([[PARAM_0_:%.+]]: !tt.ptr<bf16>, [[PARAM_1_:%.+]]: !tt.ptr<bf16>, [[PARAM_2_:%.+]]: !tt.ptr<bf16>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32) {
241314// CHECK: [[CST_3_i32:%.+]] = arith.constant 3 : i32
242315// CHECK: [[CST_3_:%.+]] = arith.constant 3 : index
@@ -319,6 +392,77 @@ module {
319392 }
320393}
321394
395+
396+ // CHECK: tt.func public @wrap_stacked_masked_loop([[PARAM_0_:%.+]]: !tt.ptr<f32>, [[PARAM_1_:%.+]]: !tt.ptr<f32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) {
397+ // CHECK-DAG: [[CST_3_i32:%.+]] = arith.constant 3 : i32
398+ // CHECK-DAG: [[CST_2_i32:%.+]] = arith.constant 2 : i32
399+ // CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64
400+ // CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
401+ // CHECK: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index
402+ // CHECK: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index
403+ // CHECK: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : index to i64
404+ // CHECK: [[VAR_3_:%.+]] = arith.muli [[PARAM_3_]], [[CST_2_i32]] : i32
405+ // CHECK: [[VAR_4_:%.+]] = arith.index_cast [[VAR_0_]] : index to i64
406+ // CHECK: [[VAR_5_:%.+]] = arith.muli [[VAR_4_]], [[VAR_2_]] : i64
407+ // CHECK: [[VAR_6_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index
408+ // CHECK: [[VAR_7_:%.+]] = arith.index_cast [[VAR_6_]] : index to i64
409+ // CHECK: [[VAR_8_:%.+]] = arith.muli [[PARAM_4_]], [[CST_3_i32]] : i32
410+ // CHECK: [[VAR_9:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[VAR_5_]], [[CST_0_i64]]], {{\[}}[[VAR_2_]], [[VAR_7_]]], {{\[}}[[VAR_3_]], [[VAR_8_]]] {order = array<i32>} : <tensor<4x4xf32>>
411+ // CHECK: [[VAR_10_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index
412+ // CHECK: [[VAR_11_:%.+]] = arith.index_cast [[VAR_10_]] : index to i64
413+ // CHECK: [[VAR_12_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index
414+ // CHECK: [[VAR_13_:%.+]] = arith.index_cast [[VAR_12_]] : index to i64
415+ // CHECK: [[VAR_14:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_11_]], [[VAR_13_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {order = array<i32>} : <tensor<4x4xf32>>
416+ // CHECK: [[VAR_15:%.+]] = tt.load [[VAR_9]] : !tt.ptr<tensor<4x4xf32>>
417+ // CHECK: tt.store [[VAR_14]], [[VAR_15]] : !tt.ptr<tensor<4x4xf32>>
418+ // CHECK: tt.return
419+ module {
420+ tt.func public @wrap_stacked_masked_loop (%arg0: !tt.ptr <f32 >, %arg1: !tt.ptr <f32 >, %arg2: i32 , %arg3: i32 , %arg4: i32 , %arg5: i32 , %arg6: i32 ) {
421+ %c0_i32 = arith.constant 0 : i32
422+ %c1_i32 = arith.constant 1 : i32
423+ %c2_i32 = arith.constant 2 : i32
424+ %c4_i32 = arith.constant 4 : i32
425+ %cst_0 = arith.constant dense <3 > : tensor <1 x4 xi32 >
426+ %cst_1 = arith.constant dense <3 > : tensor <4 xi32 >
427+ %cst_2 = arith.constant dense <2 > : tensor <4 xi32 >
428+ %0 = tt.make_range {end = 4 : i32 , start = 0 : i32 } : tensor <4 xi32 >
429+ %1 = arith.addi %0 , %cst_2 : tensor <4 xi32 >
430+ %2 = tt.splat %arg2 : i32 -> tensor <4 xi32 >
431+ %3 = arith.remui %1 , %2 : tensor <4 xi32 >
432+ %4 = arith.addi %0 , %cst_1 : tensor <4 xi32 >
433+ %5 = tt.expand_dims %3 {axis = 1 : i32 } : tensor <4 xi32 > -> tensor <4 x1 xi32 >
434+ %6 = tt.splat %arg3 : i32 -> tensor <4 x1 xi32 >
435+ %7 = arith.muli %5 , %6 : tensor <4 x1 xi32 >
436+ %8 = tt.expand_dims %4 {axis = 0 : i32 } : tensor <4 xi32 > -> tensor <1 x4 xi32 >
437+ %9 = tt.splat %arg4 : i32 -> tensor <1 x4 xi32 >
438+ %10 = arith.muli %8 , %9 : tensor <1 x4 xi32 >
439+ %11 = tt.broadcast %7 : tensor <4 x1 xi32 > -> tensor <4 x4 xi32 >
440+ %12 = tt.broadcast %10 : tensor <1 x4 xi32 > -> tensor <4 x4 xi32 >
441+ %13 = arith.addi %11 , %12 : tensor <4 x4 xi32 >
442+ %14 = tt.splat %arg0 : !tt.ptr <f32 > -> tensor <4 x4 x!tt.ptr <f32 >>
443+ %15 = tt.addptr %14 , %13 : tensor <4 x4 x!tt.ptr <f32 >>, tensor <4 x4 xi32 >
444+ %16 = tt.expand_dims %0 {axis = 1 : i32 } : tensor <4 xi32 > -> tensor <4 x1 xi32 >
445+ %17 = tt.splat %arg5 : i32 -> tensor <4 x1 xi32 >
446+ %18 = arith.muli %17 , %16 : tensor <4 x1 xi32 >
447+ %19 = tt.splat %arg1 : !tt.ptr <f32 > -> tensor <4 x1 x!tt.ptr <f32 >>
448+ %20 = tt.addptr %19 , %18 : tensor <4 x1 x!tt.ptr <f32 >>, tensor <4 x1 xi32 >
449+ %21 = tt.expand_dims %0 {axis = 0 : i32 } : tensor <4 xi32 > -> tensor <1 x4 xi32 >
450+ %22 = tt.splat %arg6 : i32 -> tensor <1 x4 xi32 >
451+ %23 = arith.muli %22 , %21 : tensor <1 x4 xi32 >
452+ %24 = tt.broadcast %20 : tensor <4 x1 x!tt.ptr <f32 >> -> tensor <4 x4 x!tt.ptr <f32 >>
453+ %25 = tt.broadcast %23 : tensor <1 x4 xi32 > -> tensor <4 x4 xi32 >
454+ %26 = tt.addptr %24 , %25 : tensor <4 x4 x!tt.ptr <f32 >>, tensor <4 x4 xi32 >
455+ %27 = arith.cmpi slt , %21 , %cst_0 : tensor <1 x4 xi32 >
456+ %28 = tt.broadcast %27 : tensor <1 x4 xi1 > -> tensor <4 x4 xi1 >
457+ %29 = arith.muli %arg4 , %c4_i32 : i32
458+ %30 = tt.splat %29 : i32 -> tensor <4 x4 xi32 >
459+ %32 = tt.load %15 : tensor <4 x4 x!tt.ptr <f32 >>
460+ tt.store %26 , %32 : tensor <4 x4 x!tt.ptr <f32 >>
461+ tt.return
462+ }
463+ }
464+
465+
322466// CHECK: tt.func @test_addptr_for_more_init_args([[PARAM_0_:%.+]]: !tt.ptr<bf16>, [[PARAM_1_:%.+]]: !tt.ptr<bf16>) {
323467// CHECK: [[CST_3_i32:%.+]] = arith.constant 3 : i32
324468// CHECK: [[CST_0_i64:%.+]] = arith.constant 0 : i64
@@ -423,7 +567,6 @@ module {
423567}
424568
425569
426-
427570// CHECK: tt.func @test_addptr_for_used_before_update([[PARAM_0_:%.+]]: !tt.ptr<bf16>) {
428571// CHECK: [[CST_3_i32:%.+]] = arith.constant 3 : i32
429572// CHECK: [[CST_0_i64:%.+]] = arith.constant 0 : i64
0 commit comments