From d66e96530192662cb6527d5774d55023bdae0cb6 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Wed, 22 Jan 2025 22:34:11 +0000 Subject: [PATCH 1/4] Add tests to ensure scalar ptr init arg. to scf.for handled correctly Signed-off-by: Tiotto, Ettore --- .../addptr_scalar_for.mlir | 56 ++ .../addptr_scalar_for_2d.mlir | 77 ++ test/Triton/raise-block-pointer.mlir | 746 ------------------ .../TritonRaiseBlockPointer.cpp | 26 +- 4 files changed, 151 insertions(+), 754 deletions(-) create mode 100644 test/Triton/Intel/RaiseToBlockPointers/addptr_scalar_for.mlir create mode 100644 test/Triton/Intel/RaiseToBlockPointers/addptr_scalar_for_2d.mlir delete mode 100644 test/Triton/raise-block-pointer.mlir diff --git a/test/Triton/Intel/RaiseToBlockPointers/addptr_scalar_for.mlir b/test/Triton/Intel/RaiseToBlockPointers/addptr_scalar_for.mlir new file mode 100644 index 0000000000..c6beb87cb5 --- /dev/null +++ b/test/Triton/Intel/RaiseToBlockPointers/addptr_scalar_for.mlir @@ -0,0 +1,56 @@ +// RUN: triton-opt %s -triton-raise-block-pointer -canonicalize | FileCheck %s + +module { + tt.func @kernel (%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32) { + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %arg2 : i32 + %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 + // source = %arg1, offset = %1, size = 1, strides = 0 + %cf0 = arith.constant 0.000000e+00 : f32 + %tensor_cf0 = tt.splat %cf0 : f32 -> tensor<1024xf32> + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c3 = arith.constant 3 : index + %_ptr, %sum_out = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr_iter = %2, %sum_iter = %tensor_cf0) -> (!tt.ptr, tensor<1024xf32>) { + %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + // offset = 0, size = 1024, strides = 1 + %4 = tt.splat %ptr_iter : !tt.ptr -> tensor<1024x!tt.ptr> + // source = %arg1, offset = %1, size = 1024, strides = 0 + %5 = tt.addptr %4, %3 : tensor<1024x!tt.ptr>, tensor<1024xi32> + // source = %arg1, offset = %1, size = 1024, strides = 1 + %8 = tt.load %5 : tensor<1024x!tt.ptr> + %9 = math.exp %8 : tensor<1024xf32> + %sum_next = arith.addf %sum_iter, %9 : tensor<1024xf32> + %cast_i = arith.index_cast %i : index to i32 + %ptr_next = tt.addptr %ptr_iter, %cast_i : !tt.ptr, i32 + // source = %arg1, offset = %1 + %i, size = 1, strides = 0 + scf.yield %ptr_next, %sum_next : !tt.ptr, tensor<1024xf32> + } + %10 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %18 = arith.muli %0, %arg3 : i32 + %19 = tt.addptr %arg0, %18 : !tt.ptr, i32 + %20 = tt.splat %19 : !tt.ptr -> tensor<1024x!tt.ptr> + %21 = tt.addptr %20, %10 : tensor<1024x!tt.ptr>, tensor<1024xi32> + tt.store %21, %sum_out : tensor<1024x!tt.ptr> + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr {tt.divisibility = 16 : i32}, [[PARAM_1_:%.+]]: !tt.ptr {tt.divisibility = 16 : i32}, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32) { +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant dense<0.000000e+00> : tensor<1024xf32> +// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64 +// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64 +// CHECK-DAG: [[VAR_0_:%.+]] = tt.get_program_id x : i32 +// CHECK: [[VAR_1_:%.+]] = arith.muli [[VAR_0_]], [[PARAM_2_]] : i32 +// CHECK: [[VAR_2_:%.+]] = tt.addptr [[PARAM_1_]], [[VAR_1_]] : !tt.ptr, i32 +// CHECK-DAG: [[VAR_3_:%.+]]:2 = scf.for [[VAR_arg5_:%.+]] = {{.*}} iter_args([[VAR_arg6_:%.+]] = [[VAR_2_]], [[VAR_arg7_:%.+]] = [[CST_0_]]) -> (!tt.ptr, tensor<1024xf32>) { +// CHECK-NOT: tt.make_tensor_ptr +// CHECK-NOT: tt.advance +// CHECK-DAG: [[VAR_13_:%.+]] = tt.addptr [[VAR_arg6_]], {{.*}} : !tt.ptr, i32 +// CHECK: scf.yield [[VAR_13_]], {{.*}} : !tt.ptr, tensor<1024xf32> +// CHECK: } +// CHECK: [[VAR_4_:%.+]] = arith.muli [[VAR_0_]], [[PARAM_3_]] : i32 +// CHECK: [[VAR_5_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[VAR_4_]]] {{.*}} : > +// CHECK: tt.store [[VAR_5_]], [[VAR_3_]]#1 : !tt.ptr> +// CHECK: tt.return +// CHECK: } diff --git a/test/Triton/Intel/RaiseToBlockPointers/addptr_scalar_for_2d.mlir b/test/Triton/Intel/RaiseToBlockPointers/addptr_scalar_for_2d.mlir new file mode 100644 index 0000000000..701045c961 --- /dev/null +++ b/test/Triton/Intel/RaiseToBlockPointers/addptr_scalar_for_2d.mlir @@ -0,0 +1,77 @@ +// RUN: triton-opt %s -triton-raise-block-pointer -canonicalize | FileCheck %s + +module { + tt.func @kernel (%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32) { + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %arg2 : i32 + %2 = tt.addptr %arg1, %1 : !tt.ptr, i32 + %cf0 = arith.constant 0.000000e+00 : f32 + %tensor_cf0 = tt.splat %cf0 : f32 -> tensor<128x128xf32> + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c3 = arith.constant 3 : index + %sum_out, %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%sum_iter = %tensor_cf0, %ptr_iter = %2) -> (tensor<128x128xf32>, !tt.ptr ) { + %3 = tt.splat %ptr_iter : !tt.ptr -> tensor<128x128x!tt.ptr> + // source = %arg1, offset = [%1, 0], size = [128, 128], strides = [0, 0] + %4 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> + %6 = tt.broadcast %5 : tensor<1x128xi32> -> tensor<128x128xi32> + // offset = [0, 0], size = [128, 128], strides = [0, 1] + %7 = tt.make_range {end = 256 : i32, start = 128 : i32} : tensor<128xi32> + %8 = tt.expand_dims %7 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> + %9 = tt.broadcast %8 : tensor<128x1xi32> -> tensor<128x128xi32> + // offset = [128, 0], size = [128, 128], strides = [1, 0] + %10 = arith.addi %6, %9 : tensor<128x128xi32> + // offset = [128, 0], size = [128, 128], strides = [1, 1] + %11 = tt.addptr %3, %10 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + // source = %arg1, offset = [%1 + 128, 0], size = [128, 128], strides = [1, 1] + %12 = tt.load %11 : tensor<128x128x!tt.ptr> + %17 = math.exp %12 : tensor<128x128xf32> + %sum_next = arith.addf %sum_iter, %17 : tensor<128x128xf32> + %cast_i = arith.index_cast %i : index to i32 + %ptr_next = tt.addptr %ptr_iter, %cast_i : !tt.ptr, i32 + // source = %arg1, offset = %1 + %i, size = 1, strides = 0 + scf.yield %sum_next, %ptr_next : tensor<128x128xf32>, !tt.ptr + } + %4 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> + %6 = tt.broadcast %5 : tensor<1x128xi32> -> tensor<128x128xi32> + // offset = [0, 0], size = [128, 128], strides = [0, 1] + %7 = tt.make_range {end = 256 : i32, start = 128 : i32} : tensor<128xi32> + %8 = tt.expand_dims %7 {axis = 1 : i32} : tensor<128xi32> -> tensor<128x1xi32> + %9 = tt.broadcast %8 : tensor<128x1xi32> -> tensor<128x128xi32> + // offset = [128, 0], size = [128, 128], strides = [1, 0] + %10 = arith.addi %6, %9 : tensor<128x128xi32> + // offset = [128, 0], size = [128, 128], strides = [1, 1] + %18 = arith.muli %0, %arg3 : i32 + %19 = tt.addptr %arg0, %18 : !tt.ptr, i32 + // source = arg0, offset = %18, size = 1, strides = 0 + %20 = tt.splat %19 : !tt.ptr -> tensor<128x128x!tt.ptr> + // source = arg0, offset = [%18, 0], size = [128, 128], strides = [0, 0] + %21 = tt.addptr %20, %10 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> + // source = %arg0, offset = [%18 + 128, 0], size = [128, 128], strides = [1, 1] + tt.store %21, %sum_out : tensor<128x128x!tt.ptr> + tt.return + } +} +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr {tt.divisibility = 16 : i32}, [[PARAM_1_:%.+]]: !tt.ptr {tt.divisibility = 16 : i32}, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32) { +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant dense<0.000000e+00> : tensor<128x128xf32> +// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64 +// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64 +// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_128_i32:%.+]] = arith.constant 128 : i32 +// CHECK-DAG: [[VAR_0_:%.+]] = tt.get_program_id x : i32 +// CHECK: [[VAR_1_:%.+]] = arith.muli [[VAR_0_]], [[PARAM_2_]] : i32 +// CHECK: [[VAR_2_:%.+]] = tt.addptr [[PARAM_1_]], [[VAR_1_]] : !tt.ptr, i32 +// CHECK-DAG: [[VAR_3_:%.+]]:2 = scf.for [[VAR_arg5_:%.+]] = {{.*}} iter_args([[VAR_arg6_:%.+]] = [[CST_0_]], [[VAR_arg7_:%.+]] = [[VAR_2_]]) -> (tensor<128x128xf32>, !tt.ptr) { +// CHECK-NOT: tt.make_tensor_ptr +// CHECK-NOT: tt.advance +// CHECK-DAG: [[VAR_20_:%.+]] = tt.addptr [[VAR_arg7_]], {{.*}} : !tt.ptr, i32 +// CHECK: scf.yield {{.*}}, [[VAR_20_]] : tensor<128x128xf32>, !tt.ptr +// CHECK: } +// CHECK: [[VAR_4_:%.+]] = arith.muli [[VAR_0_]], [[PARAM_3_]] : i32 +// CHECK: [[VAR_5_:%.+]] = arith.addi [[VAR_4_]], [[CST_128_i32]] : i32 +// CHECK: [[VAR_6_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_1_i64]]], {{\[}}[[VAR_5_]], [[CST_0_i32]]] {{.*}} : > +// CHECK: tt.store [[VAR_6_]], [[VAR_3_]]#0 : !tt.ptr> +// CHECK: tt.return +// CHECK: } diff --git a/test/Triton/raise-block-pointer.mlir b/test/Triton/raise-block-pointer.mlir deleted file mode 100644 index 579aa09db1..0000000000 --- a/test/Triton/raise-block-pointer.mlir +++ /dev/null @@ -1,746 +0,0 @@ -// RUN: triton-opt %s -triton-raise-block-pointer -canonicalize | FileCheck %s - -// CHECK-LABEL: tt.func @test_addptr_splat_make_range( -// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) -> tensor<128xf32> { -// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : i64 -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 128 : i32 -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : i64 -// CHECK: %[[VAL_4:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]]], {{\[}}%[[VAL_3]]], {{\[}}%[[VAL_2]]] {order = array} : > -// CHECK: %[[VAL_5:.*]] = tt.load %[[VAL_4]] : !tt.ptr> -// CHECK: tt.return %[[VAL_5]] : tensor<128xf32> -tt.func @test_addptr_splat_make_range(%arg0 : !tt.ptr) -> tensor<128xf32> { - %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> - %1 = tt.make_range {start = 128 : i32, end = 256 : i32} : tensor<128xi32> - %2 = tt.addptr %0, %1 : tensor<128x!tt.ptr>, tensor<128xi32> - %3 = tt.load %2 : tensor<128x!tt.ptr> - tt.return %3 : tensor<128xf32> -} - -// CHECK-LABEL: tt.func @test_addptr_load_with_mask( -// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, -// CHECK-SAME: %[[VAL_1:.*]]: i32, -// CHECK-SAME: %[[VAL_2:.*]]: tensor<128xi1>) -> tensor<128xf32> { -// CHECK: %[[VAL_3:.*]] = tt.addptr -// CHECK: %[[VAL_4:.*]] = tt.load %[[VAL_3]], %[[VAL_2]] cacheModifier = ca evictionPolicy = evict_first {isVolatile = true} : tensor<128x!tt.ptr> -// CHECK: tt.return %[[VAL_4]] : tensor<128xf32> -tt.func @test_addptr_load_with_mask(%arg0 : !tt.ptr, %arg1: i32, %arg2: tensor<128xi1>) -> tensor<128xf32> { - %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> - %1 = tt.splat %arg1 : i32 -> tensor<128xi32> - %2 = tt.addptr %0, %1 : tensor<128x!tt.ptr>, tensor<128xi32> - %3 = tt.load %2, %arg2 cacheModifier = ca evictionPolicy = evict_first {isVolatile = true} : tensor<128x!tt.ptr> - tt.return %3 : tensor<128xf32> -} - -// CHECK-LABEL: tt.func @test_addptr_splat_splat_i32( -// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, -// CHECK-SAME: %[[VAL_1:.*]]: i32) -> tensor<128xf32> { -// CHECK: %[[VAL_3:.*]] = arith.constant 0 : i64 -// CHECK: %[[VAL_4:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_3]]], {{\[}}%[[VAL_3]]], {{\[}}%[[VAL_1]]] {order = array} : > -// CHECK: %[[VAL_5:.*]] = tt.load %[[VAL_4]] cacheModifier = ca evictionPolicy = evict_first {isVolatile = true} : !tt.ptr> -// CHECK: tt.return %[[VAL_5]] : tensor<128xf32> -tt.func @test_addptr_splat_splat_i32(%arg0 : !tt.ptr, %arg1: i32) -> tensor<128xf32> { - %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> - %1 = tt.splat %arg1 : i32 -> tensor<128xi32> - %2 = tt.addptr %0, %1 : tensor<128x!tt.ptr>, tensor<128xi32> - %3 = tt.load %2 cacheModifier = ca evictionPolicy = evict_first {isVolatile = true} : tensor<128x!tt.ptr> - tt.return %3 : tensor<128xf32> -} - -// CHECK-LABEL: tt.func @test_addptr_splat_splat_i64( -// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, -// CHECK-SAME: %[[VAL_1:.*]]: i64) -> tensor<128xf32> { -// CHECK: %[[VAL_2:.*]] = arith.constant 0 : i64 -// CHECK: %[[VAL_3:.*]] = arith.index_cast %[[VAL_1]] : i64 to index -// CHECK: %[[VAL_4:.*]] = arith.index_cast %[[VAL_3]] : index to i32 -// CHECK: %[[VAL_5:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_2]]], {{\[}}%[[VAL_2]]], {{\[}}%[[VAL_4]]] {order = array} : > -// CHECK: %[[VAL_6:.*]] = tt.load %[[VAL_5]] : !tt.ptr> -// CHECK: tt.return %[[VAL_6]] : tensor<128xf32> -tt.func @test_addptr_splat_splat_i64(%arg0 : !tt.ptr, %arg1: i64) -> tensor<128xf32> { - %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> - %1 = tt.splat %arg1 : i64 -> tensor<128xi64> - %2 = tt.addptr %0, %1 : tensor<128x!tt.ptr>, tensor<128xi64> - %3 = tt.load %2 : tensor<128x!tt.ptr> - tt.return %3 : tensor<128xf32> -} - -// CHECK-LABEL: tt.func @test_addptr_splat_splat_2d( -// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, -// CHECK-SAME: %[[VAL_1:.*]]: i64) -> tensor<2x128xf32> { -// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i64 -// CHECK: %[[VAL_5:.*]] = arith.constant 0 : i32 -// CHECK: %[[VAL_6:.*]] = arith.index_cast %[[VAL_1]] : i64 to index -// CHECK: %[[VAL_7:.*]] = arith.index_cast %[[VAL_6]] : index to i32 -// CHECK: %[[VAL_8:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_4]], %[[VAL_4]]], {{\[}}%[[VAL_4]], %[[VAL_4]]], {{\[}}%[[VAL_7]], %[[VAL_5]]] {order = array} : > -// CHECK: %[[VAL_9:.*]] = tt.load %[[VAL_8]] : !tt.ptr> -// CHECK: tt.return %[[VAL_9]] : tensor<2x128xf32> -tt.func @test_addptr_splat_splat_2d(%arg0 : !tt.ptr, %arg1: i64) -> tensor<2x128xf32> { - %0 = tt.splat %arg0 : !tt.ptr -> tensor<2x128x!tt.ptr> - %1 = tt.splat %arg1 : i64 -> tensor<2x128xi64> - %2 = tt.addptr %0, %1 : tensor<2x128x!tt.ptr>, tensor<2x128xi64> - %3 = tt.load %2 : tensor<2x128x!tt.ptr> - tt.return %3 : tensor<2x128xf32> -} - -// CHECK-LABEL: tt.func @test_addptr_splat_splat_2d_store( -// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, -// CHECK-SAME: %[[VAL_1:.*]]: i64, -// CHECK-SAME: %[[VAL_2:.*]]: tensor<2x128xf32>) { -// CHECK: %[[VAL_3:.*]] = arith.constant 0 : i64 -// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i32 -// CHECK: %[[VAL_5:.*]] = arith.index_cast %[[VAL_1]] : i64 to index -// CHECK: %[[VAL_6:.*]] = arith.index_cast %[[VAL_5]] : index to i32 -// CHECK: %[[VAL_7:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_3]], %[[VAL_3]]], {{\[}}%[[VAL_3]], %[[VAL_3]]], {{\[}}%[[VAL_6]], %[[VAL_4]]] {order = array} : > -// CHECK: tt.store %[[VAL_7]], %[[VAL_2]] : !tt.ptr> -tt.func @test_addptr_splat_splat_2d_store(%arg0 : !tt.ptr, %arg1: i64, %arg2: tensor<2x128xf32>) { - %0 = tt.splat %arg0 : !tt.ptr -> tensor<2x128x!tt.ptr> - %1 = tt.splat %arg1 : i64 -> tensor<2x128xi64> - %2 = tt.addptr %0, %1 : tensor<2x128x!tt.ptr>, tensor<2x128xi64> - tt.store %2, %arg2 : tensor<2x128x!tt.ptr> - tt.return -} - -// CHECK-LABEL: tt.func @test_addptr_splat_make_range_add( -// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) -> tensor<128xf32> { -// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : i64 -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : i32 -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 2 : i64 -// CHECK: %[[VAL_4:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]]], {{\[}}%[[VAL_3]]], {{\[}}%[[VAL_2]]] {order = array} : > -// CHECK: %[[VAL_5:.*]] = tt.load %[[VAL_4]] : !tt.ptr> -// CHECK: tt.return %[[VAL_5]] : tensor<128xf32> -tt.func @test_addptr_splat_make_range_add(%arg0 : !tt.ptr) -> tensor<128xf32> { - %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> - %1 = tt.make_range {start = 0 : i32, end = 128 : i32} : tensor<128xi32> - %2 = tt.make_range {start = 0 : i32, end = 128 : i32} : tensor<128xi32> - %3 = arith.addi %1, %2 : tensor<128xi32> - %4 = tt.addptr %0, %3 : tensor<128x!tt.ptr>, tensor<128xi32> - %5 = tt.load %4 : tensor<128x!tt.ptr> - tt.return %5 : tensor<128xf32> -} - -// CHECK-LABEL: tt.func @test_addptr_splat_make_range_mul( -// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, -// CHECK-SAME: %[[VAL_1:.*]]: i32) -> tensor<128xf32> { -// CHECK: %[[VAL_2:.*]] = arith.constant 0 : i64 -// CHECK: %[[VAL_3:.*]] = arith.constant 0 : i32 -// CHECK: %[[VAL_4:.*]] = arith.index_cast %[[VAL_1]] : i32 to index -// CHECK: %[[VAL_5:.*]] = arith.index_cast %[[VAL_4]] : index to i64 -// CHECK: %[[VAL_6:.*]] = arith.trunci %[[VAL_5]] : i64 to i32 -// CHECK: %[[VAL_7:.*]] = arith.divui %[[VAL_3]], %[[VAL_6]] : i32 -// CHECK: %[[VAL_8:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_2]]], {{\[}}%[[VAL_5]]], {{\[}}%[[VAL_7]]] {order = array} : > -// CHECK: %[[VAL_9:.*]] = tt.load %[[VAL_8]] : !tt.ptr> -// CHECK: tt.return %[[VAL_9]] : tensor<128xf32> -tt.func @test_addptr_splat_make_range_mul(%arg0 : !tt.ptr, %arg1: i32) -> tensor<128xf32> { - %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> - %1 = tt.splat %arg1 : i32 -> tensor<128xi32> - %2 = tt.make_range {start = 0 : i32, end = 128 : i32} : tensor<128xi32> - %3 = arith.muli %1, %2 : tensor<128xi32> - %4 = tt.addptr %0, %3 : tensor<128x!tt.ptr>, tensor<128xi32> - %5 = tt.load %4 : tensor<128x!tt.ptr> - tt.return %5 : tensor<128xf32> -} - -// CHECK-LABEL: tt.func @test_const_splat_addptr( -// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) -> tensor<128xf32> { -// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : i64 -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 512 : i32 -// CHECK: %[[VAL_3:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]]], {{\[}}%[[VAL_1]]], {{\[}}%[[VAL_2]]] {order = array} : > -// CHECK: %[[VAL_4:.*]] = tt.load %[[VAL_3]] : !tt.ptr> -// CHECK: tt.return %[[VAL_4]] : tensor<128xf32> -tt.func @test_const_splat_addptr(%arg0 : !tt.ptr) -> tensor<128xf32> { - %cst = arith.constant dense<512> : tensor<128xi32> - %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> - %1 = tt.addptr %0, %cst : tensor<128x!tt.ptr>, tensor<128xi32> - %2 = tt.load %1 : tensor<128x!tt.ptr> - tt.return %2 : tensor<128xf32> -} - -// CHECK-LABEL: tt.func @test_expand_dims( -// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) -> tensor<1x128xf32> { -// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : i64 -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : i32 -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 512 : i32 -// CHECK: %[[VAL_4:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_2]], %[[VAL_3]]] {order = array} : > -// CHECK: %[[VAL_5:.*]] = tt.load %[[VAL_4]] : !tt.ptr> -// CHECK: tt.return %[[VAL_5]] : tensor<1x128xf32> -tt.func @test_expand_dims(%arg0 : !tt.ptr) -> tensor<1x128xf32> { - %cst = arith.constant dense<512> : tensor<128xi32> - %0 = tt.expand_dims %cst {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<1x128x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<1x128x!tt.ptr>, tensor<1x128xi32> - %3 = tt.load %2 : tensor<1x128x!tt.ptr> - tt.return %3 : tensor<1x128xf32> -} - -// CHECK-LABEL: tt.func @test_const_splat_addptr_2d( -// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) -> tensor<2x128xf32> { -// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : i64 -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 512 : i32 -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : i32 -// CHECK: %[[VAL_4:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_2]], %[[VAL_3]]] {order = array} : > -// CHECK: %[[VAL_5:.*]] = tt.load %[[VAL_4]] : !tt.ptr> -// CHECK: tt.return %[[VAL_5]] : tensor<2x128xf32> -tt.func @test_const_splat_addptr_2d(%arg0 : !tt.ptr) -> tensor<2x128xf32> { - %cst = arith.constant dense<512> : tensor<2x128xi32> - %0 = tt.splat %arg0 : !tt.ptr -> tensor<2x128x!tt.ptr> - %1 = tt.addptr %0, %cst : tensor<2x128x!tt.ptr>, tensor<2x128xi32> - %2 = tt.load %1 : tensor<2x128x!tt.ptr> - tt.return %2 : tensor<2x128xf32> -} - -// CHECK-LABEL: tt.func @test_addptr_broadcast( -// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) -> tensor<2x128xf32> { -// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : i64 -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : i32 -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : i32 -// CHECK: %[[VAL_4:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_3]], %[[VAL_2]]] {order = array} : > -// CHECK: %[[VAL_5:.*]] = tt.load %[[VAL_4]] : !tt.ptr> -// CHECK: tt.return %[[VAL_5]] : tensor<2x128xf32> -tt.func @test_addptr_broadcast(%arg0 : !tt.ptr) -> tensor<2x128xf32> { - %cst = arith.constant dense<1> : tensor<1x128xi32> - %0 = tt.splat %arg0 : !tt.ptr -> tensor<2x128x!tt.ptr> - %1 = tt.broadcast %cst : tensor<1x128xi32> -> tensor<2x128xi32> - %2 = tt.addptr %0, %1 : tensor<2x128x!tt.ptr>, tensor<2x128xi32> - %3 = tt.load %2 : tensor<2x128x!tt.ptr> - tt.return %3 : tensor<2x128xf32> -} - -// CHECK-LABEL: tt.func @test_addptr_broadcast_rank( -// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) -> tensor<2x128xf32> { -// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : i64 -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : i32 -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : i32 -// CHECK: %[[VAL_4:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_3]], %[[VAL_2]]] {order = array} : > -// CHECK: %[[VAL_5:.*]] = tt.load %[[VAL_4]] : !tt.ptr> -// CHECK: tt.return %[[VAL_5]] : tensor<2x128xf32> -tt.func @test_addptr_broadcast_rank(%arg0 : !tt.ptr) -> tensor<2x128xf32> { - %cst = arith.constant dense<1> : tensor<1x128xi32> - %0 = tt.splat %arg0 : !tt.ptr -> tensor<2x128x!tt.ptr> - %1 = tt.broadcast %cst : tensor<1x128xi32> -> tensor<2x128xi32> - %2 = tt.addptr %0, %1 : tensor<2x128x!tt.ptr>, tensor<2x128xi32> - %3 = tt.load %2 : tensor<2x128x!tt.ptr> - tt.return %3 : tensor<2x128xf32> -} - -// CHECK-LABEL: tt.func @test_addptr_broadcast_rank_2( -// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) -> tensor<128x2x128xf32> { -// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : i64 -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : i32 -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : i32 -// CHECK: %[[VAL_4:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]], %[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_1]], %[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_3]], %[[VAL_2]], %[[VAL_2]]] {order = array} : > -// CHECK: %[[VAL_5:.*]] = tt.load %[[VAL_4]] : !tt.ptr> -// CHECK: tt.return %[[VAL_5]] : tensor<128x2x128xf32> -tt.func @test_addptr_broadcast_rank_2(%arg0 : !tt.ptr) -> tensor<128x2x128xf32> { - %cst = arith.constant dense<1> : tensor<128x1x128xi32> - %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x2x128x!tt.ptr> - %1 = tt.broadcast %cst : tensor<128x1x128xi32> -> tensor<128x2x128xi32> - %2 = tt.addptr %0, %1 : tensor<128x2x128x!tt.ptr>, tensor<128x2x128xi32> - %3 = tt.load %2 : tensor<128x2x128x!tt.ptr> - tt.return %3 : tensor<128x2x128xf32> -} - -// CHECK-LABEL: tt.func @test_addptr_broadcast_rank_3( -// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) -> tensor<128x2x128xf32> { -// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : i64 -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : i32 -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : i32 -// CHECK: %[[VAL_4:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]], %[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_1]], %[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_3]], %[[VAL_2]], %[[VAL_2]]] {order = array} : > -// CHECK: %[[VAL_5:.*]] = tt.load %[[VAL_4]] : !tt.ptr> -// CHECK: tt.return %[[VAL_5]] : tensor<128x2x128xf32> -tt.func @test_addptr_broadcast_rank_3(%arg0 : !tt.ptr) -> tensor<128x2x128xf32> { - %cst = arith.constant dense<1> : tensor<128x1x1xi32> - %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x2x128x!tt.ptr> - %1 = tt.broadcast %cst : tensor<128x1x1xi32> -> tensor<128x2x128xi32> - %2 = tt.addptr %0, %1 : tensor<128x2x128x!tt.ptr>, tensor<128x2x128xi32> - %3 = tt.load %2 : tensor<128x2x128x!tt.ptr> - tt.return %3 : tensor<128x2x128xf32> -} - - -// CHECK: tt.func public @wrap_side_by_side_masked([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) { -// CHECK-DAG: [[CST_6_i32:%.+]] = arith.constant 6 : i32 -// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64 -// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32 -// CHECK-DAG: [[CST_2_i32:%.+]] = arith.constant 2 : i32 -// CHECK: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index -// CHECK: [[VAR_3_:%.+]] = arith.muli [[PARAM_3_]], [[CST_2_i32]] : i32 -// CHECK: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : index to i64 -// CHECK: [[VAR_4_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index -// CHECK: [[VAR_5_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index -// CHECK: [[VAR_6_:%.+]] = arith.index_cast [[VAR_5_]] : index to i64 -// CHECK: [[VAR_7_:%.+]] = arith.muli [[PARAM_4_]], [[CST_6_i32]] : i32 -// CHECK: [[VAR_8_:%.+]] = arith.index_cast [[VAR_4_]] : index to i64 -// CHECK: [[VAR_9_:%.+]] = arith.muli [[VAR_8_]], [[VAR_6_]] : i64 -// CHECK: [[VAR_10_:%.+]] = arith.trunci [[VAR_2_]] : i64 to i32 -// CHECK: [[VAR_11_:%.+]] = arith.divui [[VAR_3_]], [[VAR_10_]] : i32 -// CHECK: [[VAR_12_:%.+]] = arith.trunci [[VAR_6_]] : i64 to i32 -// CHECK: [[VAR_13_:%.+]] = arith.divui [[VAR_7_]], [[VAR_12_]] : i32 -// CHECK: [[VAR_14_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[VAR_9_]]], {{\[}}[[VAR_2_]], [[VAR_6_]]], {{\[}}[[VAR_11_]], [[VAR_13_]]] {order = array} : > -// CHECK: [[VAR_15_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK: [[VAR_16_:%.+]] = arith.index_cast [[VAR_15_]] : index to i64 -// CHECK: [[VAR_17_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index -// CHECK: [[VAR_18_:%.+]] = arith.index_cast [[VAR_17_]] : index to i64 -// CHECK: [[VAR_19_:%.+]] = arith.trunci [[VAR_16_]] : i64 to i32 -// CHECK: [[VAR_20_:%.+]] = arith.divui [[CST_0_i32]], [[VAR_19_]] : i32 -// CHECK: [[VAR_21_:%.+]] = arith.trunci [[VAR_18_]] : i64 to i32 -// CHECK: [[VAR_22_:%.+]] = arith.divui [[CST_0_i32]], [[VAR_21_]] : i32 -// CHECK: [[VAR_23:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_16_]], [[VAR_18_]]], {{\[}}[[VAR_20_]], [[VAR_22_]]] {order = array} : > -// CHECK: [[VAR_24:%.+]] = tt.load [[VAR_14_]] {boundaryCheck = array} : !tt.ptr> -// CHECK: tt.store [[VAR_23]], [[VAR_24]] : !tt.ptr> -// CHECK: tt.return -module { -tt.func public @wrap_side_by_side_masked(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) { - %c0_i32 = arith.constant 0 : i32 - %c1_i32 = arith.constant 1 : i32 - %c2_i32 = arith.constant 2 : i32 - %c4_i32 = arith.constant 4 : i32 - %cst_0 = arith.constant dense<2> : tensor<4x1xi32> - %cst_1 = arith.constant dense<6> : tensor<4xi32> - %cst_2 = arith.constant dense<2> : tensor<4xi32> - %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %1 = arith.addi %0, %cst_2 : tensor<4xi32> - %2 = arith.addi %0, %cst_1 : tensor<4xi32> - %3 = tt.splat %arg2 : i32 -> tensor<4xi32> - %4 = arith.remsi %2, %3 : tensor<4xi32> - %5 = tt.expand_dims %1 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - %6 = tt.splat %arg3 : i32 -> tensor<4x1xi32> - %7 = arith.muli %5, %6 : tensor<4x1xi32> - %8 = tt.expand_dims %4 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %9 = tt.splat %arg4 : i32 -> tensor<1x4xi32> - %10 = arith.muli %8, %9 : tensor<1x4xi32> - %11 = tt.broadcast %7 : tensor<4x1xi32> -> tensor<4x4xi32> - %12 = tt.broadcast %10 : tensor<1x4xi32> -> tensor<4x4xi32> - %13 = arith.addi %11, %12 : tensor<4x4xi32> - %14 = tt.splat %arg0 : !tt.ptr -> tensor<4x4x!tt.ptr> - %15 = tt.addptr %14, %13 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - %16 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - %17 = tt.splat %arg5 : i32 -> tensor<4x1xi32> - %18 = arith.muli %17, %16 : tensor<4x1xi32> - %19 = tt.splat %arg1 : !tt.ptr -> tensor<4x1x!tt.ptr> - %20 = tt.addptr %19, %18 : tensor<4x1x!tt.ptr>, tensor<4x1xi32> - %21 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %22 = tt.splat %arg6 : i32 -> tensor<1x4xi32> - %23 = arith.muli %22, %21 : tensor<1x4xi32> - %24 = tt.broadcast %20 : tensor<4x1x!tt.ptr> -> tensor<4x4x!tt.ptr> - %25 = tt.broadcast %23 : tensor<1x4xi32> -> tensor<4x4xi32> - %26 = tt.addptr %24, %25 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - %27 = arith.cmpi slt, %16, %cst_0 : tensor<4x1xi32> - %28 = tt.broadcast %27 : tensor<4x1xi1> -> tensor<4x4xi1> - %29 = arith.muli %arg3, %c4_i32 : i32 - %30 = tt.splat %29 : i32 -> tensor<4x4xi32> - %31 = arith.muli %arg4, %c4_i32 : i32 - %32 = tt.splat %31 : i32 -> tensor<4x4xi32> - %34 = tt.load %15 : tensor<4x4x!tt.ptr> - tt.store %26, %34 : tensor<4x4x!tt.ptr> - tt.return - } -} - - -// CHECK: tt.func @test_addptr_for_accumulation([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: !tt.ptr, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32) { -// CHECK-DAG: [[CST_3_i32:%.+]] = arith.constant 3 : i32 -// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index -// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64 -// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64 -// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32 -// CHECK-DAG: [[CST_5_i64:%.+]] = arith.constant 5 : i64 -// CHECK: [[VAR_1_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_5_i64]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {order = array} : > -// CHECK: [[VAR_2_:%.+]] = tt.load [[VAR_1_]] : !tt.ptr> -// CHECK: [[VAR_3_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_5_i64]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {order = array} : > -// CHECK: [[VAR_4_:%.+]]:2 = scf.for [[VAR_arg5_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg6_:%.+]] = [[VAR_2_]], [[VAR_arg7_:%.+]] = [[VAR_3_]]) -> (tensor<4x256xbf16>, !tt.ptr>) { -// CHECK: [[VAR_5_:%.+]] = tt.load [[VAR_arg7_]] : !tt.ptr> -// CHECK: [[VAR_6_:%.+]] = arith.addf [[VAR_arg6_]], [[VAR_5_]] : tensor<4x256xbf16> -// CHECK: [[VAR_7_:%.+]] = tt.advance [[VAR_arg7_]], {{\[}}[[CST_0_i32]], [[CST_3_i32]]{{\]}} : > -// CHECK: scf.yield [[VAR_6_]], [[VAR_7_]] : tensor<4x256xbf16>, !tt.ptr> -// CHECK: } -// CHECK: [[VAR_5_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_5_i64]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {order = array} : > -// CHECK: tt.store [[VAR_5_]], [[VAR_4_]]#0 : !tt.ptr> -// CHECK: tt.return -// CHECK: } -module { - tt.func @test_addptr_for_accumulation( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr, - %arg2 : !tt.ptr, - %arg3 : i32, - %arg4 : i32 - ) - { - %0 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> - // offset = 0, size = 4, stride = 1 - %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - // offset = [0,0], size = [4,1], stride = [1,0] - %2 = tt.broadcast %1 : tensor<4x1xi32> -> tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [1,0] - %arg3splat = tt.splat %arg3 : i32 -> tensor<4x256xi32> - %offset3 = arith.addi %2, %arg3splat : tensor<4x256xi32> - // offset = [%arg3,0], size = [4,256], stride = [1,0] - %3 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32> - // offset = 0, size = 256, stride = 1 - %4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - // offset = [0,0], size = [1,256], stride = [0,1] - %5 = tt.broadcast %4 : tensor<1x256xi32> -> tensor<4x256xi32> - // offset = [0,0], size = [4,256], stride = [0,1] - %c5 = arith.constant 5 : i32 - %splat6 = tt.splat %c5 : i32 -> tensor<4x256xi32> - // scalar = 5 - %scale5 = arith.muli %5, %splat6 : tensor<4x256xi32> // Why we never called the conversion function for the inputs here? - // offset = [0,0], size = [4,256], stride = [0,5] - %7 = arith.addi %offset3, %scale5: tensor<4x256xi32> // Why we never called the conversion function for the inputs here? - // offset = [%arg3, 0], size = [4, 256], stride = [1, 5] - %8 = tt.splat %arg0 : !tt.ptr -> tensor<4x256x!tt.ptr> // Why is the input unknown - %9 = tt.addptr %8, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source: %arg0, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] - %19 = tt.load %9 : tensor<4x256x!tt.ptr> // this will be replaced with a memref.copy - %11 = tt.splat %arg1 : !tt.ptr -> tensor<4x256x!tt.ptr> - %12 = tt.addptr %11, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source: %arg1, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] - %c0 = arith.constant 0 : index - %c12 = arith.constant 12 : index - %c3 = arith.constant 3 : index - %i_c3 = arith.constant 3 : i32 - %sum_out, %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%sum_iter = %19, %ptr_iter = %12) -> (tensor<4x256xbf16>, tensor<4x256x!tt.ptr>) { - %20 = tt.load %ptr_iter : tensor<4x256x!tt.ptr> - %sum = arith.addf %sum_iter, %20 : tensor<4x256xbf16> - // pointer updates - %17 = tt.splat %i_c3 : i32 -> tensor<4x256xi32> - // offset: [3, 0], size = [4, 256], stride [0, 0] - %ptr = tt.addptr %ptr_iter, %17 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source: %arg1, offset = [%arg3+%i, 0], size = [4, 256], stride = [1, 5] - scf.yield %sum, %ptr : tensor<4x256xbf16>, tensor<4x256x!tt.ptr> - } - %15 = tt.splat %arg2 : !tt.ptr -> tensor<4x256x!tt.ptr> - %16 = tt.addptr %15, %7 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> - // source: %arg2, offset = [%arg3, 0], size = [4, 256], stride = [1, 5] - tt.store %16, %sum_out : tensor<4x256x!tt.ptr> - tt.return - } -} - - -// CHECK: tt.func public @wrap_stacked_masked_loop([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32) { -// CHECK-DAG: [[CST_3_i32:%.+]] = arith.constant 3 : i32 -// CHECK-DAG: [[CST_2_i32:%.+]] = arith.constant 2 : i32 -// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64 -// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32 -// CHECK: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index -// CHECK: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index -// CHECK: [[VAR_3_:%.+]] = arith.muli [[PARAM_3_]], [[CST_2_i32]] : i32 -// CHECK: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : index to i64 -// CHECK: [[VAR_4_:%.+]] = arith.index_cast [[VAR_0_]] : index to i64 -// CHECK: [[VAR_5_:%.+]] = arith.muli [[VAR_4_]], [[VAR_2_]] : i64 -// CHECK: [[VAR_6_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index -// CHECK: [[VAR_7_:%.+]] = arith.index_cast [[VAR_6_]] : index to i64 -// CHECK: [[VAR_8_:%.+]] = arith.muli [[PARAM_4_]], [[CST_3_i32]] : i32 -// CHECK: [[VAR_9_:%.+]] = arith.trunci [[VAR_2_]] : i64 to i32 -// CHECK: [[VAR_10_:%.+]] = arith.divui [[VAR_3_]], [[VAR_9_]] : i32 -// CHECK: [[VAR_11_:%.+]] = arith.trunci [[VAR_7_]] : i64 to i32 -// CHECK: [[VAR_12_:%.+]] = arith.divui [[VAR_8_]], [[VAR_11_]] : i32 -// CHECK: [[VAR_13:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[VAR_5_]], [[CST_0_i64]]], {{\[}}[[VAR_2_]], [[VAR_7_]]], {{\[}}[[VAR_10_]], [[VAR_12_]]] {order = array} : > -// CHECK: [[VAR_14_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK: [[VAR_15_:%.+]] = arith.index_cast [[VAR_14_]] : index to i64 -// CHECK: [[VAR_16_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index -// CHECK: [[VAR_17_:%.+]] = arith.index_cast [[VAR_16_]] : index to i64 -// CHECK: [[VAR_18_:%.+]] = arith.trunci [[VAR_15_]] : i64 to i32 -// CHECK: [[VAR_19_:%.+]] = arith.divui [[CST_0_i32]], [[VAR_18_]] : i32 -// CHECK: [[VAR_20_:%.+]] = arith.trunci [[VAR_17_]] : i64 to i32 -// CHECK: [[VAR_21_:%.+]] = arith.divui [[CST_0_i32]], [[VAR_20_]] : i32 -// CHECK: [[VAR_22:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_15_]], [[VAR_17_]]], {{\[}}[[VAR_19_]], [[VAR_21_]]] {order = array} : > -// CHECK: [[VAR_23:%.+]] = tt.load [[VAR_13]] {boundaryCheck = array} : !tt.ptr> -// CHECK: tt.store [[VAR_22]], [[VAR_23]] : !tt.ptr> -// CHECK: tt.return -module { - tt.func public @wrap_stacked_masked_loop(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) { - %c0_i32 = arith.constant 0 : i32 - %c1_i32 = arith.constant 1 : i32 - %c2_i32 = arith.constant 2 : i32 - %c4_i32 = arith.constant 4 : i32 - %cst_0 = arith.constant dense<3> : tensor<1x4xi32> - %cst_1 = arith.constant dense<3> : tensor<4xi32> - %cst_2 = arith.constant dense<2> : tensor<4xi32> - %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %1 = arith.addi %0, %cst_2 : tensor<4xi32> - %2 = tt.splat %arg2 : i32 -> tensor<4xi32> - %3 = arith.remui %1, %2 : tensor<4xi32> - %4 = arith.addi %0, %cst_1 : tensor<4xi32> - %5 = tt.expand_dims %3 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - %6 = tt.splat %arg3 : i32 -> tensor<4x1xi32> - %7 = arith.muli %5, %6 : tensor<4x1xi32> - %8 = tt.expand_dims %4 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %9 = tt.splat %arg4 : i32 -> tensor<1x4xi32> - %10 = arith.muli %8, %9 : tensor<1x4xi32> - %11 = tt.broadcast %7 : tensor<4x1xi32> -> tensor<4x4xi32> - %12 = tt.broadcast %10 : tensor<1x4xi32> -> tensor<4x4xi32> - %13 = arith.addi %11, %12 : tensor<4x4xi32> - %14 = tt.splat %arg0 : !tt.ptr -> tensor<4x4x!tt.ptr> - %15 = tt.addptr %14, %13 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - %16 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> - %17 = tt.splat %arg5 : i32 -> tensor<4x1xi32> - %18 = arith.muli %17, %16 : tensor<4x1xi32> - %19 = tt.splat %arg1 : !tt.ptr -> tensor<4x1x!tt.ptr> - %20 = tt.addptr %19, %18 : tensor<4x1x!tt.ptr>, tensor<4x1xi32> - %21 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %22 = tt.splat %arg6 : i32 -> tensor<1x4xi32> - %23 = arith.muli %22, %21 : tensor<1x4xi32> - %24 = tt.broadcast %20 : tensor<4x1x!tt.ptr> -> tensor<4x4x!tt.ptr> - %25 = tt.broadcast %23 : tensor<1x4xi32> -> tensor<4x4xi32> - %26 = tt.addptr %24, %25 : tensor<4x4x!tt.ptr>, tensor<4x4xi32> - %27 = arith.cmpi slt, %21, %cst_0 : tensor<1x4xi32> - %28 = tt.broadcast %27 : tensor<1x4xi1> -> tensor<4x4xi1> - %29 = arith.muli %arg4, %c4_i32 : i32 - %30 = tt.splat %29 : i32 -> tensor<4x4xi32> - %32 = tt.load %15 : tensor<4x4x!tt.ptr> - tt.store %26, %32 : tensor<4x4x!tt.ptr> - tt.return - } -} - - -// CHECK: tt.func @test_addptr_for_more_init_args([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr) { -// CHECK-DAG: [[CST_3_i32:%.+]] = arith.constant 3 : i32 -// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64 -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index -// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index -// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index -// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index -// CHECK-DAG: [[CST_1024:%.+]] = arith.constant 1024 : i32 -// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64 -// CHECK-DAG: [[VAR_1_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[CST_1024]]] {order = array} : > -// CHECK-DAG: [[VAR_2_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[CST_1024]]] {order = array} : > -// CHECK: [[VAR_0_:%.+]]:5 = scf.for [[VAR_arg2_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg3_:%.+]] = [[CST_1_]], [[VAR_arg4_:%.+]] = [[VAR_1_]], [[VAR_arg5_:%.+]] = [[CST_2_]], [[VAR_arg6_:%.+]] = [[VAR_2_]], [[VAR_arg7_:%.+]] = [[CST_3_]]) -> (index, !tt.ptr>, index, !tt.ptr>, index) { -// CHECK: [[VAR_3_:%.+]] = tt.load [[VAR_arg4_]] : !tt.ptr> -// CHECK: tt.store [[VAR_arg6_]], [[VAR_3_]] : !tt.ptr> -// CHECK: [[VAR_4_:%.+]] = tt.advance [[VAR_arg4_]], {{\[}}[[CST_3_i32]]{{\]}} : > -// CHECK: [[VAR_5_:%.+]] = arith.addi [[VAR_arg3_]], [[CST_3_]] : index -// CHECK: [[VAR_6_:%.+]] = arith.addi [[VAR_arg5_]], [[CST_3_]] : index -// CHECK: [[VAR_7_:%.+]] = arith.addi [[VAR_arg7_]], [[CST_3_]] : index -// CHECK: [[VAR_8_:%.+]] = arith.addi [[VAR_5_]], [[VAR_6_]] : index -// CHECK: [[VAR_9_:%.+]] = arith.addi [[VAR_8_]], [[VAR_7_]] : index -// CHECK: [[VAR_10_:%.+]] = arith.index_cast [[VAR_9_]] : index to i32 -// CHECK: [[VAR_11_:%.+]] = tt.advance [[VAR_arg6_]], {{\[}}[[VAR_10_]]{{\]}} : > -// CHECK: scf.yield [[VAR_5_]], [[VAR_4_]], [[VAR_6_]], [[VAR_11_]], [[VAR_7_]] : index, !tt.ptr>, index, !tt.ptr>, index -// CHECK: } -// CHECK: tt.return -// CHECK: } -module { - tt.func @test_addptr_for_more_init_args( - %arg0 : !tt.ptr, - %arg1 : !tt.ptr - ) - { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %c12 = arith.constant 12 : index - %0 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr> - %1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32> - %2 = tt.addptr %0, %1 : tensor<256x!tt.ptr>, tensor<256xi32> - %3 = tt.splat %arg1 : !tt.ptr -> tensor<256x!tt.ptr> - %4 = tt.addptr %3, %1 : tensor<256x!tt.ptr>, tensor<256xi32> - %_arg2, %_ptr_ld, %_arg3, %_ptr_st, %_arg4 = scf.for %i = %c0 to %c12 step %c3 iter_args(%arg2 = %c1, %ptr_ld = %2, %arg3 = %c2, %ptr_st = %4, %arg4 = %c3) -> (index, tensor<256x!tt.ptr>, index, tensor<256x!tt.ptr>, index) { - %5 = tt.load %ptr_ld {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x!tt.ptr> - tt.store %ptr_st, %5 : tensor<256x!tt.ptr> - %cast3 = arith.index_cast %c3 : index to i32 - %6 = tt.splat %cast3 : i32 -> tensor<256xi32> - %ptr_ld_iter = tt.addptr %ptr_ld, %6 : tensor<256x!tt.ptr>, tensor<256xi32> - %arg2_iter = arith.addi %arg2, %c3 : index - %arg3_iter = arith.addi %arg3, %c3 : index - %arg4_iter = arith.addi %arg4, %c3 : index - %7 = arith.addi %arg2_iter, %arg3_iter : index - %8 = arith.addi %7, %arg4_iter : index - %cast8 = arith.index_cast %8 : index to i32 - %9 = tt.splat %cast8 : i32 -> tensor<256xi32> - %ptr_st_iter = tt.addptr %ptr_st, %9 : tensor<256x!tt.ptr>, tensor<256xi32> - scf.yield %arg2_iter, %ptr_ld_iter, %arg3_iter, %ptr_st_iter, %arg4_iter : index, tensor<256x!tt.ptr>, index, tensor<256x!tt.ptr>, index - } - tt.return - } -} - - -// CHECK: tt.func @test_addptr_for_used_after_update([[PARAM_0_:%.+]]: !tt.ptr) { -// CHECK-DAG: [[CST_3_i32:%.+]] = arith.constant 3 : i32 -// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64 -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index -// CHECK-DAG: [[CST_1024_i32:%.+]] = arith.constant 1024 : i32 -// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64 -// CHECK: [[VAR_0_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[CST_1024_i32]]] {order = array} : > -// CHECK: [[VAR_1_:%.+]] = scf.for [[VAR_arg1_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg2_:%.+]] = [[VAR_0_]]) -> (!tt.ptr>) { -// CHECK: [[VAR_2_:%.+]] = tt.advance [[VAR_arg2_]], {{\[}}[[CST_3_i32]]{{\]}} : > -// CHECK: [[VAR_3_:%.+]] = tt.load [[VAR_2_]] : !tt.ptr> -// CHECK: tt.store [[VAR_2_]], [[VAR_3_]] : !tt.ptr> -// CHECK: scf.yield [[VAR_2_]] : !tt.ptr> -// CHECK: } -// CHECK: tt.return -// CHECK: } -module { - tt.func @test_addptr_for_used_after_update( - %arg0 : !tt.ptr - ) - { - %c0 = arith.constant 0 : index - %c12 = arith.constant 12 : index - %c3 = arith.constant 3 : index - %i_c3 = arith.constant 3 : i32 - %0 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr> - %1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32> - %2 = tt.addptr %0, %1 : tensor<256x!tt.ptr>, tensor<256xi32> - %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %2) -> (tensor<256x!tt.ptr>) { - %4 = tt.splat %i_c3 : i32 -> tensor<256xi32> - %ptr_iter = tt.addptr %ptr, %4 : tensor<256x!tt.ptr>, tensor<256xi32> - %3 = tt.load %ptr_iter {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x!tt.ptr> - tt.store %ptr_iter, %3 : tensor<256x!tt.ptr> - scf.yield %ptr_iter : tensor<256x!tt.ptr> - } - tt.return - } -} - - -// CHECK: tt.func @test_addptr_for_used_before_update([[PARAM_0_:%.+]]: !tt.ptr) { -// CHECK-DAG: [[CST_3_i32:%.+]] = arith.constant 3 : i32 -// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64 -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index -// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index -// CHECK-DAG: [[CST_1024_i32:%.+]] = arith.constant 1024 : i32 -// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64 -// CHECK: [[VAR_0_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[CST_1024_i32]]] {order = array} : > -// CHECK: [[VAR_1_:%.+]] = scf.for [[VAR_arg1_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg2_:%.+]] = [[VAR_0_]]) -> (!tt.ptr>) { -// CHECK: [[VAR_2_:%.+]] = tt.load [[VAR_arg2_]] : !tt.ptr> -// CHECK: tt.store [[VAR_arg2_]], [[VAR_2_]] : !tt.ptr> -// CHECK: [[VAR_3_:%.+]] = tt.advance [[VAR_arg2_]], {{\[}}[[CST_3_i32]]{{\]}} : > -// CHECK: scf.yield [[VAR_3_]] : !tt.ptr> -// CHECK: } -// CHECK: tt.return -// CHECK: } -module { - tt.func @test_addptr_for_used_before_update( - %arg0 : !tt.ptr - ) - { - %c0 = arith.constant 0 : index - %c12 = arith.constant 12 : index - %c3 = arith.constant 3 : index - %i_c3 = arith.constant 3 : i32 - %0 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr> - %1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32> - %2 = tt.addptr %0, %1 : tensor<256x!tt.ptr>, tensor<256xi32> - %_ptr2 = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %2) -> (tensor<256x!tt.ptr>) { - %3 = tt.load %ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x!tt.ptr> - tt.store %ptr, %3 : tensor<256x!tt.ptr> - %4 = tt.splat %i_c3 : i32 -> tensor<256xi32> - %ptr_iter = tt.addptr %ptr, %4 : tensor<256x!tt.ptr>, tensor<256xi32> - scf.yield %ptr_iter : tensor<256x!tt.ptr> - } - tt.return - } -} - -// CHECK: tt.func @matmul_kernel -// CHECK: tt.make_tensor_ptr %arg0 -// CHECK: tt.make_tensor_ptr %arg1 -// CHECK: scf.for -// CHECK: [[LOAD1:%.*]] = tt.load [[ARG10:%.*]], {{.*}}, {{.*}} : !tt.ptr> -// CHECK: [[LOAD2:%.*]] = tt.load [[ARG11:%.*]], {{.*}}, {{.*}} : !tt.ptr> -// CHECK: [[DOT:%.*]] = tt.dot [[LOAD1]], [[LOAD2]] -// CHECK: [[ADV1:%.*]] = tt.advance [[ARG10]], {{.*}} : > -// CHECK: [[ADV2:%.*]] = tt.advance [[ARG11]], {{.*}} : > -// CHECK: scf.yield [[DOT]], [[ADV1]], [[ADV2]] -module { - tt.func @matmul_kernel(%arg0: !tt.ptr, %arg1: !tt.ptr , %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) -> tensor<64x128xf16> { - %cst = arith.constant dense<0.000000e+00> : tensor<64x128xf32> - %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x128xf16> - %cst_1 = arith.constant dense<0.000000e+00> : tensor<64x32xf16> - %c1_i32 = arith.constant 1 : i32 - %c0_i32 = arith.constant 0 : i32 - %cst_2 = arith.constant dense<32> : tensor<64x32xi32> - %c32_i32 = arith.constant 32 : i32 - %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %18 = tt.splat %arg3 : i32 -> tensor<64xi32> - %21 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - %22 = tt.splat %arg3 : i32 -> tensor<128xi32> - %23 = arith.addi %22, %21 : tensor<128xi32> - %24 = tt.splat %arg4 : i32 -> tensor<128xi32> - %25 = arith.remsi %23, %24 : tensor<128xi32> - %26 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %27 = tt.expand_dims %18 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> - %28 = tt.splat %arg6 : i32 -> tensor<64x1xi32> - %30 = tt.expand_dims %26 {axis = 0 : i32} : tensor<32xi32> -> tensor<1x32xi32> - %31 = tt.broadcast %28 : tensor<64x1xi32> -> tensor<64x32xi32> - %32 = tt.broadcast %30 : tensor<1x32xi32> -> tensor<64x32xi32> - %33 = arith.addi %31, %32 : tensor<64x32xi32> - %34 = tt.splat %arg0 : !tt.ptr -> tensor<64x32x!tt.ptr> - %35 = tt.addptr %34, %33 : tensor<64x32x!tt.ptr>, tensor<64x32xi32> - %36 = tt.expand_dims %26 {axis = 1 : i32} : tensor<32xi32> -> tensor<32x1xi32> - %37 = tt.splat %arg7 : i32 -> tensor<32x1xi32> - %38 = arith.muli %36, %37 : tensor<32x1xi32> - %39 = tt.expand_dims %25 {axis = 0 : i32} : tensor<128xi32> -> tensor<1x128xi32> - %40 = tt.broadcast %38 : tensor<32x1xi32> -> tensor<32x128xi32> - %41 = tt.broadcast %39 : tensor<1x128xi32> -> tensor<32x128xi32> - %42 = arith.addi %40, %41 : tensor<32x128xi32> - %43 = tt.splat %arg1 : !tt.ptr -> tensor<32x128x!tt.ptr> - %44 = tt.addptr %43, %42 : tensor<32x128x!tt.ptr>, tensor<32x128xi32> - %47 = arith.muli %arg7, %c32_i32 : i32 - %48 = tt.splat %47 : i32 -> tensor<32x128xi32> - %49:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c1_i32 iter_args(%arg10 = %cst, %arg11 = %35, %arg12 = %44) -> (tensor<64x128xf32>, tensor<64x32x!tt.ptr>, tensor<32x128x!tt.ptr>) : i32 { - %67 = arith.muli %arg9, %c32_i32 : i32 - %68 = arith.subi %arg5, %67 : i32 - %69 = tt.splat %68 : i32 -> tensor<1x32xi32> - %70 = arith.cmpi slt, %30, %69 : tensor<1x32xi32> - %71 = tt.broadcast %70 : tensor<1x32xi1> -> tensor<64x32xi1> - %72 = tt.load %arg11, %71, %cst_1 : tensor<64x32x!tt.ptr> - %73 = tt.splat %68 : i32 -> tensor<32x1xi32> - %74 = arith.cmpi slt, %36, %73 : tensor<32x1xi32> - %75 = tt.broadcast %74 : tensor<32x1xi1> -> tensor<32x128xi1> - %76 = tt.load %arg12, %75, %cst_0 : tensor<32x128x!tt.ptr> - %77 = tt.dot %72, %76, %arg10, inputPrecision = tf32 : tensor<64x32xf16> * tensor<32x128xf16> -> tensor<64x128xf32> - %78 = tt.addptr %arg11, %cst_2 : tensor<64x32x!tt.ptr>, tensor<64x32xi32> - %79 = tt.addptr %arg12, %48 : tensor<32x128x!tt.ptr>, tensor<32x128xi32> - scf.yield %77, %78, %79 : tensor<64x128xf32>, tensor<64x32x!tt.ptr>, tensor<32x128x!tt.ptr> - } - %50 = arith.truncf %49#0 : tensor<64x128xf32> to tensor<64x128xf16> - tt.return %50 : tensor<64x128xf16> - } -} - -// `triton::ExpandDims` ops on tensor of pointers are currently not supported in for loops. -// Consequently, the pass should fail cleanly. -// CHECK: tt.func @test_fail_addptr_for_expand_ptr([[PARAM_0_:%.+]]: !tt.ptr) { -// CHECK-NOT: tt.make_tensor_ptr -module { - tt.func @test_fail_addptr_for_expand_ptr( - %arg0 : !tt.ptr - ) - { - %c0 = arith.constant 0 : index - %c12 = arith.constant 12 : index - %c3 = arith.constant 3 : index - %i_c3 = arith.constant 3 : i32 - %0 = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr> - %1 = tt.make_range {end = 1280 : i32, start = 1024 : i32}:tensor<256xi32> - %2 = tt.addptr %0, %1 : tensor<256x!tt.ptr>, tensor<256xi32> - %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%ptr = %2) -> (tensor<256x!tt.ptr>) { - %6 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> - %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<256xi32> -> tensor<256x1xi32> - %8 = tt.broadcast %7 : tensor<256x1xi32> -> tensor<256x256xi32> - %9 = tt.make_range {end = 512 : i32, start = 256 : i32} : tensor<256xi32> - %10 = tt.expand_dims %9 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> - %11 = tt.broadcast %10 : tensor<1x256xi32> -> tensor<256x256xi32> - %12 = arith.addi %8, %11 : tensor<256x256xi32> - %13 = tt.expand_dims %ptr {axis = 1 : i32} : tensor<256x!tt.ptr> -> tensor<256x1x!tt.ptr> - %14 = tt.broadcast %13 : tensor<256x1x!tt.ptr> -> tensor<256x256x!tt.ptr> - %15 = tt.addptr %14, %12 : tensor<256x256x!tt.ptr>, tensor<256x256xi32> - %16 = tt.load %15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<256x256x!tt.ptr> - tt.store %15, %16 : tensor<256x256x!tt.ptr> - %17 = tt.splat %i_c3 : i32 -> tensor<256xi32> - %ptr_iter = tt.addptr %ptr, %17 : tensor<256x!tt.ptr>, tensor<256xi32> - scf.yield %ptr_iter : tensor<256x!tt.ptr> - } - tt.return - } -} diff --git a/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp b/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp index 63aba540b0..e78cecbdf7 100644 --- a/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp +++ b/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp @@ -379,6 +379,11 @@ static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const PtrState &state) { + if (state.source) + os << " "; + if (state.scalar) + os << " "; + return os << " "; @@ -732,22 +737,23 @@ struct TritonRaiseBlockPointer } } + // If the addptr operation increments a scalar pointer, give up. + Value result = op.getResult(); + if (!isa(result.getType())) + return failure(); + // Otherwise, rewrite the AddPtrOp using PtrState. PtrState state; if (failed(visitOperandAddptr(op, state, loc, builder))) return failure(); - Value result = op.getResult(); knownPtrs[result] = state; - Value mapped = result; - if (isa(result.getType())) { - Value makePtrOp = state.createTTMakeTensorPtrOp(builder, loc); - knownPtrs[makePtrOp] = std::move(state); - mapped = makePtrOp; - } + assert(isa(result.getType())); + Value makePtrOp = state.createTTMakeTensorPtrOp(builder, loc); + knownPtrs[makePtrOp] = std::move(state); - ptrMap.map(result, mapped); + ptrMap.map(result, makePtrOp); // AddPtrOps that have been rewritten and no longer used in the code must // be removed in the pass to avoid type matching issue. @@ -889,6 +895,10 @@ struct TritonRaiseBlockPointer "Unexpected operand defining operation tt.make_tensor_ptr"); llvm_unreachable("Unexpected operand defining operation"); } else { + // If the operand is an iter-arg of an for loop, give up. + if (isa(operand.getParentBlock()->getParentOp())) + return failure(); + state.source = operand; return success(); } From 1277edffd347659545cefbe943deb1e838c51e04 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Thu, 23 Jan 2025 20:31:19 +0000 Subject: [PATCH 2/4] Fix issue #3245 Signed-off-by: Tiotto, Ettore --- .../RaiseToBlockPointers/addptr_dim1.mlir | 100 ++++++++++++++++++ .../TritonRaiseBlockPointer.cpp | 38 +++++-- 2 files changed, 127 insertions(+), 11 deletions(-) create mode 100644 test/Triton/Intel/RaiseToBlockPointers/addptr_dim1.mlir diff --git a/test/Triton/Intel/RaiseToBlockPointers/addptr_dim1.mlir b/test/Triton/Intel/RaiseToBlockPointers/addptr_dim1.mlir new file mode 100644 index 0000000000..43181d0d9b --- /dev/null +++ b/test/Triton/Intel/RaiseToBlockPointers/addptr_dim1.mlir @@ -0,0 +1,100 @@ +// RUN: triton-opt %s -triton-raise-block-pointer -canonicalize | FileCheck %s + +module { + tt.func @kernel( + %arg0 : !tt.ptr, + %arg1 : i32 + ) + { + %0 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32> + %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> + + %splat_arg0 = tt.splat %arg0 : !tt.ptr -> tensor<1x256x!tt.ptr> + %2 = tt.addptr %splat_arg0, %1 : tensor<1x256x!tt.ptr>, tensor<1x256xi32> + // source = %arg0, offset = [0, 0], size = [1, 256], stride = [0, 1] + + // 1x256 pointer should have meaningful stride in outer dimension + %3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<1x256x!tt.ptr> + + %4 = tt.splat %arg1 : i32 -> tensor<1x256xi32> + // 1x256 pointer should have meaningful stride in outer dimension + %5 = tt.addptr %2, %4 : tensor<1x256x!tt.ptr>, tensor<1x256xi32> + // source = %arg0, offset = [%arg1, 0], size = [1, 256], stride = [0, 1] + + tt.store %5, %3 : tensor<1x256x!tt.ptr> + + %10 = arith.constant 0.0 : bf16 + %11 = tt.splat %10 : bf16 -> tensor<4x256xbf16> + + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c3 = arith.constant 3 : index + %i_c3 = arith.constant 3 : i32 + %c256 = arith.constant 256 : i32 + %sum_out, %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%sum_iter = %11, %ptr = %2) -> (tensor<4x256xbf16>, tensor<1x256x!tt.ptr>) { + %bptr = tt.broadcast %ptr : tensor<1x256x!tt.ptr> -> tensor<4x256x!tt.ptr> + // source = %arg0, offset = [0, 0], size = [4, 256], stride = [0, 1] + + %20 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> + %i_i32 = arith.index_cast %i : index to i32 + %21 = arith.muli %c256, %i_i32 : i32 + %22 = tt.splat %21 : i32 -> tensor<4xi32> + %23 = arith.muli %20, %22 : tensor<4xi32> + %24 = tt.expand_dims %23 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> + %25 = tt.broadcast %24 : tensor<4x1xi32> -> tensor<4x256xi32> + // offset = [0, 0], size = [4, 256], stride = [i*256, 1] + + // %bptr should have zero stride and %30 should have correct stride + %30 = tt.addptr %bptr, %25 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + // source = %arg0, offset = [0, 0], size = [4, 256], stride = [i*256, 1] + + %31 = tt.load %30 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256x!tt.ptr> + %32 = arith.addf %sum_iter, %31 : tensor<4x256xbf16> + + %40 = tt.splat %c256 : i32 -> tensor<1x256xi32> + %41 = tt.addptr %ptr, %40 : tensor<1x256x!tt.ptr>, tensor<1x256xi32> + // source = %arg0, offset = [i*256, 0], size = [4, 256], stride = [i*256, 1] + + scf.yield %32, %41 : tensor<4x256xbf16>, tensor<1x256x!tt.ptr> + } + + %31 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32> + %splat_c256 = tt.splat %c256 : i32 -> tensor<4xi32> + %32 = arith.muli %31, %splat_c256 : tensor<4xi32> + %33 = tt.expand_dims %32 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> + %34 = tt.broadcast %33 : tensor<4x1xi32> -> tensor<4x256xi32> + %35 = tt.broadcast %2 : tensor<1x256x!tt.ptr> -> tensor<4x256x!tt.ptr> + %36 = tt.addptr %35, %34 : tensor<4x256x!tt.ptr>, tensor<4x256xi32> + tt.store %36, %sum_out : tensor<4x256x!tt.ptr> + tt.return + } +} + +// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: i32) { +// CHECK-DAG: [[CST_:%.+]] = arith.constant dense<256> : tensor<1x256xi32> +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant dense<0.000000e+00> : tensor<4x256xbf16> +// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64 +// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64 +// CHECK-DAG: [[CST_256_i64:%.+]] = arith.constant 256 : i64 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_0_:%.+]] = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> +// CHECK-DAG: [[VAR_1_:%.+]] = tt.expand_dims [[VAR_0_]] {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32> +// CHECK-DAG: [[VAR_2_:%.+]] = tt.splat [[PARAM_0_]] : !tt.ptr -> tensor<1x256x!tt.ptr> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_3_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_0_i64]], [[CST_1_i64]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : > +// CHECK-DAG: [[VAR_4_:%.+]] = tt.addptr [[VAR_2_]], [[VAR_1_]] : tensor<1x256x!tt.ptr>, tensor<1x256xi32> +// CHECK-DAG: [[VAR_5_:%.+]] = tt.load [[VAR_3_]] : !tt.ptr> +// CHECK-DAG: [[VAR_6_:%.+]] = tt.advance [[VAR_3_]], {{\[}}[[CST_0_i32]], [[PARAM_1_]]] : > +// CHECK: tt.store [[VAR_6_]], [[VAR_5_]] : !tt.ptr> +// CHECK: [[VAR_7_:%.+]]:2 = scf.for [[VAR_arg2_:%.+]] = {{.*}} iter_args([[VAR_arg3_:%.+]] = [[CST_0_]], [[VAR_arg4_:%.+]] = [[VAR_4_]]) -> (tensor<4x256xbf16>, tensor<1x256x!tt.ptr>) { +// CHECK: [[VAR_9_:%.+]] = tt.broadcast [[VAR_arg4_]] : tensor<1x256x!tt.ptr> -> tensor<4x256x!tt.ptr> +// CHECK-NOT: tt.make_tensor_ptr +// CHECK-NOT: tt.advance +// CHECK: [[VAR_20_:%.+]] = tt.addptr [[VAR_arg4_]], [[CST_]] : tensor<1x256x!tt.ptr>, tensor<1x256xi32> +// CHECK: scf.yield {{.*}}, [[VAR_20_]] : tensor<4x256xbf16>, tensor<1x256x!tt.ptr> +// CHECK: } +// CHECK: [[VAR_8_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_256_i64]], [[CST_1_i64]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : > +// CHECK: tt.store [[VAR_8_]], [[VAR_7_]]#0 : !tt.ptr> +// CHECK: tt.return +// CHECK: } diff --git a/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp b/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp index e78cecbdf7..0b4c4ff5c5 100644 --- a/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp +++ b/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Verifier.h" +#include "mlir/Support/LLVM.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" @@ -435,7 +436,7 @@ struct TritonRaiseBlockPointer if (failed(rewriteForOp(forOp))) { forOp->emitRemark( "TritonRaiseToBlockPointer: Failed to rewrite ForOp"); - return WalkResult::interrupt(); + return WalkResult::advance(); } return WalkResult::skip(); }) @@ -452,17 +453,24 @@ struct TritonRaiseBlockPointer SmallVector> initArgIndex; OpBuilder builder(op); + auto canBeRewrittenUsingBlockPtr = [&](Operation *op) { + return TypeSwitch(op) + .Case( + [](auto) { return true; }) + .Default([](auto) { return false; }); + }; + // Create a new list of init args for (auto [i, arg] : llvm::enumerate(op.getInitArgs())) { if (Value mappedV = ptrMap.lookupOrNull(arg)) { if (auto makeTensorPtrOp = mappedV.getDefiningOp()) { if (llvm::any_of(op.getRegionIterArgs()[i].getUsers(), - [](Operation *user) { - return isa(user); + [&](Operation *user) { + return !canBeRewrittenUsingBlockPtr(user); })) { - op->emitRemark("TritonRaiseToBlockPointer: ExpandDims Ops in loops " - "are currently not supported"); + op->emitRemark("TritonRaiseToBlockPointer: Loop contains ops that " + "cannot be rewritten using a block ptr"); return failure(); } @@ -668,7 +676,7 @@ struct TritonRaiseBlockPointer OpBuilder builder(op); Location loc = op.getLoc(); - auto ptr = op.getPtr(); + Value ptr = op.getPtr(); auto fillOffsets = [&](Value offset, unsigned rank, SmallVector &offsets) { @@ -726,11 +734,16 @@ struct TritonRaiseBlockPointer assert(!offsets.empty() && offsets.size() == rank && "unexpected number of offsets"); - auto advanceOp = builder.createOrFold(loc, ptr.getType(), - ptr, offsets); - cleanUp.push_back(op); + + Value basePtr = tt::isTensorPointerType(ptr.getType()) ? ptr : mappedV; + auto advanceOp = builder.createOrFold( + loc, basePtr.getType(), basePtr, offsets); + + cleanUp.insert(op); ptrMap.map(op.getResult(), advanceOp); + LLVM_DEBUG(llvm::dbgs() + << "Rewrote:\n\t" << op << "to:\n\t" << advanceOp << "\n"); return success(); } else { llvm_unreachable("Did not find tt::MakeTensorPtrOp"); @@ -755,9 +768,12 @@ struct TritonRaiseBlockPointer ptrMap.map(result, makePtrOp); + LLVM_DEBUG(llvm::dbgs() + << "Rewrote:\n\t" << op << "\nto:\n\t" << makePtrOp << "\n"); + // AddPtrOps that have been rewritten and no longer used in the code must // be removed in the pass to avoid type matching issue. - cleanUp.push_back(op); + cleanUp.insert(op); LLVM_DEBUG({ auto modOp = @@ -1039,7 +1055,7 @@ struct TritonRaiseBlockPointer } private: - SmallVector cleanUp; + SmallPtrSet cleanUp; llvm::SmallDenseMap knownPtrs; IRMapping ptrMap; }; From d9fdda8d620ddf8fd3841570fc3d8dfda9826b43 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Mon, 27 Jan 2025 17:19:08 +0000 Subject: [PATCH 3/4] Improve getFinalValue and remove no longer needed folding helper functions Signed-off-by: Tiotto, Ettore --- .../addptr_mul_value_const.mlir | 21 +- .../kernel-03-matrix-multiplication.mlir | 40 ++-- .../raise-block-pointer.mlir | 102 ++++------ .../wraparound_side_by_side.mlir | 61 +++--- .../wraparound_stacked.mlir | 50 ++--- .../wraparound_unsupported_add_offset.mlir | 12 +- .../TritonRaiseBlockPointer.cpp | 183 +++++++++--------- 7 files changed, 209 insertions(+), 260 deletions(-) diff --git a/test/Triton/Intel/RaiseToBlockPointers/addptr_mul_value_const.mlir b/test/Triton/Intel/RaiseToBlockPointers/addptr_mul_value_const.mlir index 96303f29c1..6a2d61931b 100644 --- a/test/Triton/Intel/RaiseToBlockPointers/addptr_mul_value_const.mlir +++ b/test/Triton/Intel/RaiseToBlockPointers/addptr_mul_value_const.mlir @@ -38,16 +38,15 @@ module { // CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64 // CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64 // CHECK-DAG: [[VAR_0_:%.+]] = tt.get_program_id x : i32 -// CHECK-DAG: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index -// CHECK: [[VAR_2_:%.+]] = arith.muli [[PARAM_2_]], [[CST_2048_i32]] : i32 -// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[VAR_1_]] : index to i64 -// CHECK-DAG: [[VAR_4_:%.+]] = arith.addi [[VAR_0_]], [[VAR_2_]] : i32 -// CHECK-DAG: [[VAR_5_:%.+]] = arith.addi [[VAR_3_]], [[CST_1_i64]] : i64 -// CHECK-DAG: [[VAR_6_:%.+]] = arith.trunci [[VAR_5_]] : i64 to i32 -// CHECK-DAG: [[VAR_7_:%.+]] = arith.divui [[VAR_4_]], [[VAR_6_]] : i32 -// CHECK-DAG: [[VAR_8_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]]], {{\[}}[[VAR_5_]]], {{\[}}[[VAR_7_]]] {{.*}} : > -// CHECK-DAG: [[VAR_9_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[VAR_0_]]] {{.*}} : > -// CHECK: [[VAR_10_:%.+]] = tt.load [[VAR_8_]] : !tt.ptr> -// CHECK: tt.store [[VAR_9_]], [[VAR_10_]] : !tt.ptr> +// CHECK: [[VAR_1_:%.+]] = arith.muli [[PARAM_2_]], [[CST_2048_i32]] : i32 +// CHECK-DAG: [[VAR_2_:%.+]] = arith.extsi [[PARAM_2_]] : i32 to i64 +// CHECK-DAG: [[VAR_3_:%.+]] = arith.addi [[VAR_0_]], [[VAR_1_]] : i32 +// CHECK-DAG: [[VAR_4_:%.+]] = arith.addi [[VAR_2_]], [[CST_1_i64]] : i64 +// CHECK-DAG: [[VAR_5_:%.+]] = arith.trunci [[VAR_4_]] : i64 to i32 +// CHECK-DAG: [[VAR_6_:%.+]] = arith.divui [[VAR_3_]], [[VAR_5_]] : i32 +// CHECK-DAG: [[VAR_7_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]]], {{\[}}[[VAR_4_]]], {{\[}}[[VAR_6_]]] {{.*}} : > +// CHECK-DAG: [[VAR_8_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]]], {{\[}}[[CST_1_i64]]], {{\[}}[[VAR_0_]]] {{.*}} : > +// CHECK: [[VAR_9_:%.+]] = tt.load [[VAR_7_]] : !tt.ptr> +// CHECK: tt.store [[VAR_8_]], [[VAR_9_]] : !tt.ptr> // CHECK: tt.return // CHECK: } diff --git a/test/Triton/Intel/RaiseToBlockPointers/kernel-03-matrix-multiplication.mlir b/test/Triton/Intel/RaiseToBlockPointers/kernel-03-matrix-multiplication.mlir index d54084829e..c83eaf9fbc 100644 --- a/test/Triton/Intel/RaiseToBlockPointers/kernel-03-matrix-multiplication.mlir +++ b/test/Triton/Intel/RaiseToBlockPointers/kernel-03-matrix-multiplication.mlir @@ -106,24 +106,28 @@ module { // CHECK-DAG: [[CST_64_i32:%.+]] = arith.constant 64 : i32 // CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32 // CHECK-DAG: [[CST_1_i32:%.+]] = arith.constant 1 : i32 -// CHECK: [[VAR_28_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}{{.*}}], {{\[}}{{.*}}] {{.*}} : > -// CHECK: [[VAR_31_:%.+]] = arith.index_cast [[PARAM_9_]] : i32 to index -// CHECK: [[VAR_32_:%.+]] = arith.index_cast [[VAR_31_]] : index to i64 -// CHECK: [[VAR_38_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}{{.*}}], {{\[}}{{.*}}] {{.*}} : > -// CHECK-DAG: [[VAR_39_:%.+]] = arith.muli [[PARAM_7_]], [[CST_64_i32]] : i32 -// CHECK-DAG: [[VAR_40_:%.+]] = arith.muli [[PARAM_8_]], [[CST_64_i32]] : i32 -// CHECK: [[VAR_41_:%.+]]:3 = scf.for {{.*}} iter_args([[VAR_arg13_:%.+]] = [[VAR_cst_]], [[VAR_arg14_:%.+]] = [[VAR_28_]], [[VAR_arg15_:%.+]] = [[VAR_38_]]) -> (tensor<128x256xf32>, !tt.ptr>, !tt.ptr>) : i32 { -// CHECK-DAG: [[VAR_54_:%.+]] = tt.load [[VAR_arg14_]] : !tt.ptr> -// CHECK-DAG: [[VAR_55_:%.+]] = tt.load [[VAR_arg15_]] : !tt.ptr> -// CHECK: [[VAR_56_:%.+]] = tt.dot [[VAR_54_]], [[VAR_55_]], [[VAR_cst_]], inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xf32> -// CHECK-DAG: [[VAR_57_:%.+]] = arith.addf [[VAR_arg13_]], [[VAR_56_]] : tensor<128x256xf32> -// CHECK-DAG: [[VAR_58_:%.+]] = tt.advance [[VAR_arg14_]], {{\[}}[[CST_0_i32]], [[VAR_39_]]] : > -// CHECK-DAG: [[VAR_59_:%.+]] = tt.advance [[VAR_arg15_]], {{\[}}[[CST_0_i32]], [[VAR_40_]]] : > -// CHECK: scf.yield [[VAR_57_]], [[VAR_58_]], [[VAR_59_]] : tensor<128x256xf32>, !tt.ptr>, !tt.ptr> +// CHECK: [[VAR_20_:%.+]] = arith.extsi [[PARAM_6_]] : i32 to i64 +// CHECK: [[VAR_21_:%.+]] = arith.extsi [[PARAM_7_]] : i32 to i64 +// CHECK: [[VAR_22_:%.+]] = arith.divui {{.*}}, [[PARAM_6_]] : i32 +// CHECK: [[VAR_23_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_20_]], [[VAR_21_]]], {{\[}}[[VAR_22_]], [[CST_0_i32]]] {{.*}} : > +// CHECK: [[VAR_24_:%.+]] = arith.extsi [[PARAM_8_]] : i32 to i64 +// CHECK: [[VAR_25_:%.+]] = arith.muli {{.*}}, [[PARAM_9_]] : i32 +// CHECK: [[VAR_26_:%.+]] = arith.extsi [[PARAM_9_]] : i32 to i64 +// CHECK: [[VAR_27_:%.+]] = arith.divui [[VAR_25_]], [[PARAM_9_]] : i32 +// CHECK: [[VAR_28_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_24_]], [[VAR_26_]]], {{\[}}[[CST_0_i32]], [[VAR_27_]]] {{.*}} : > +// CHECK-DAG: [[VAR_29_:%.+]] = arith.muli [[PARAM_7_]], [[CST_64_i32]] : i32 +// CHECK-DAG: [[VAR_30_:%.+]] = arith.muli [[PARAM_8_]], [[CST_64_i32]] : i32 +// CHECK: [[VAR_31_:%.+]]:3 = scf.for {{.*}} iter_args([[VAR_arg13_:%.+]] = [[VAR_cst_]], [[VAR_arg14_:%.+]] = [[VAR_23_]], [[VAR_arg15_:%.+]] = [[VAR_28_]]) -> (tensor<128x256xf32>, !tt.ptr>, !tt.ptr>) : i32 { +// CHECK-DAG: [[VAR_40_:%.+]] = tt.load [[VAR_arg14_]] : !tt.ptr> +// CHECK-DAG: [[VAR_41_:%.+]] = tt.load [[VAR_arg15_]] : !tt.ptr> +// CHECK: [[VAR_42_:%.+]] = tt.dot [[VAR_40_]], [[VAR_41_]], [[VAR_cst_]], inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xf32> +// CHECK-DAG: [[VAR_43_:%.+]] = arith.addf [[VAR_arg13_]], [[VAR_42_]] : tensor<128x256xf32> +// CHECK-DAG: [[VAR_44_:%.+]] = tt.advance [[VAR_arg14_]], {{\[}}[[CST_0_i32]], [[VAR_29_]]] : > +// CHECK-DAG: [[VAR_45_:%.+]] = tt.advance [[VAR_arg15_]], {{\[}}[[CST_0_i32]], [[VAR_30_]]] : > +// CHECK: scf.yield [[VAR_43_]], [[VAR_44_]], [[VAR_45_]] : tensor<128x256xf32>, !tt.ptr>, !tt.ptr> // CHECK: } -// CHECK-DAG: [[VAR_42_:%.+]] = arith.truncf [[VAR_41_]]#0 : tensor<128x256xf32> to tensor<128x256xbf16> -// CHECK-DAG: [[VAR_43_:%.+]] = arith.index_cast [[PARAM_10_]] : i32 to index -// CHECK: [[VAR_53_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}{{.*}}], {{\[}}{{.*}}] {{.*}} : > -// CHECK: tt.store [[VAR_53_]], [[VAR_42_]] : !tt.ptr> +// CHECK-DAG: [[VAR_32_:%.+]] = arith.truncf [[VAR_31_]]#0 : tensor<128x256xf32> to tensor<128x256xbf16> +// CHECK: [[VAR_39_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}{{.*}}], {{\[}}{{.*}}] {{.*}} : > +// CHECK: tt.store [[VAR_39_]], [[VAR_32_]] : !tt.ptr> // CHECK: tt.return // CHECK: } diff --git a/test/Triton/Intel/RaiseToBlockPointers/raise-block-pointer.mlir b/test/Triton/Intel/RaiseToBlockPointers/raise-block-pointer.mlir index dea636345c..3f6a7ae0fb 100644 --- a/test/Triton/Intel/RaiseToBlockPointers/raise-block-pointer.mlir +++ b/test/Triton/Intel/RaiseToBlockPointers/raise-block-pointer.mlir @@ -54,9 +54,8 @@ tt.func @test_addptr_splat_splat_i32(%arg0 : !tt.ptr, %arg1: i32) -> tensor // CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr, // CHECK-SAME: %[[VAL_1:.*]]: i64) -> tensor<128xf32> { // CHECK: %[[VAL_2:.*]] = arith.constant 0 : i64 -// CHECK: %[[VAL_3:.*]] = arith.index_cast %[[VAL_1]] : i64 to index -// CHECK: %[[VAL_4:.*]] = arith.index_cast %[[VAL_3]] : index to i32 -// CHECK: %[[VAL_5:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_2]]], {{\[}}%[[VAL_2]]], {{\[}}%[[VAL_4]]] {order = array} : > +// CHECK: %[[VAL_3:.*]] = arith.trunci %[[VAL_1]] : i64 to i32 +// CHECK: %[[VAL_5:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_2]]], {{\[}}%[[VAL_2]]], {{\[}}%[[VAL_3]]] {order = array} : > // CHECK: %[[VAL_6:.*]] = tt.load %[[VAL_5]] : !tt.ptr> // CHECK: tt.return %[[VAL_6]] : tensor<128xf32> tt.func @test_addptr_splat_splat_i64(%arg0 : !tt.ptr, %arg1: i64) -> tensor<128xf32> { @@ -73,9 +72,8 @@ tt.func @test_addptr_splat_splat_i64(%arg0 : !tt.ptr, %arg1: i64) -> tensor // CHECK-SAME: %[[VAL_1:.*]]: i64) -> tensor<2x128xf32> { // CHECK: %[[VAL_4:.*]] = arith.constant 0 : i64 // CHECK: %[[VAL_5:.*]] = arith.constant 0 : i32 -// CHECK: %[[VAL_6:.*]] = arith.index_cast %[[VAL_1]] : i64 to index -// CHECK: %[[VAL_7:.*]] = arith.index_cast %[[VAL_6]] : index to i32 -// CHECK: %[[VAL_8:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_4]], %[[VAL_4]]], {{\[}}%[[VAL_4]], %[[VAL_4]]], {{\[}}%[[VAL_7]], %[[VAL_5]]] {order = array} : > +// CHECK: %[[VAL_6:.*]] = arith.trunci %[[VAL_1]] : i64 to i32 +// CHECK: %[[VAL_8:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_4]], %[[VAL_4]]], {{\[}}%[[VAL_4]], %[[VAL_4]]], {{\[}}%[[VAL_6]], %[[VAL_5]]] {order = array} : > // CHECK: %[[VAL_9:.*]] = tt.load %[[VAL_8]] : !tt.ptr> // CHECK: tt.return %[[VAL_9]] : tensor<2x128xf32> tt.func @test_addptr_splat_splat_2d(%arg0 : !tt.ptr, %arg1: i64) -> tensor<2x128xf32> { @@ -92,9 +90,8 @@ tt.func @test_addptr_splat_splat_2d(%arg0 : !tt.ptr, %arg1: i64) -> tensor< // CHECK-SAME: %[[VAL_2:.*]]: tensor<2x128xf32>) { // CHECK: %[[VAL_3:.*]] = arith.constant 0 : i64 // CHECK: %[[VAL_4:.*]] = arith.constant 0 : i32 -// CHECK: %[[VAL_5:.*]] = arith.index_cast %[[VAL_1]] : i64 to index -// CHECK: %[[VAL_6:.*]] = arith.index_cast %[[VAL_5]] : index to i32 -// CHECK: %[[VAL_7:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_3]], %[[VAL_3]]], {{\[}}%[[VAL_3]], %[[VAL_3]]], {{\[}}%[[VAL_6]], %[[VAL_4]]] {order = array} : > +// CHECK: %[[VAL_5:.*]] = arith.trunci %[[VAL_1]] : i64 to i32 +// CHECK: %[[VAL_7:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_3]], %[[VAL_3]]], {{\[}}%[[VAL_3]], %[[VAL_3]]], {{\[}}%[[VAL_5]], %[[VAL_4]]] {order = array} : > // CHECK: tt.store %[[VAL_7]], %[[VAL_2]] : !tt.ptr> tt.func @test_addptr_splat_splat_2d_store(%arg0 : !tt.ptr, %arg1: i64, %arg2: tensor<2x128xf32>) { %0 = tt.splat %arg0 : !tt.ptr -> tensor<2x128x!tt.ptr> @@ -127,11 +124,8 @@ tt.func @test_addptr_splat_make_range_add(%arg0 : !tt.ptr) -> tensor<128xf3 // CHECK-SAME: %[[VAL_1:.*]]: i32) -> tensor<128xf32> { // CHECK: %[[VAL_2:.*]] = arith.constant 0 : i64 // CHECK: %[[VAL_3:.*]] = arith.constant 0 : i32 -// CHECK: %[[VAL_4:.*]] = arith.index_cast %[[VAL_1]] : i32 to index -// CHECK: %[[VAL_5:.*]] = arith.index_cast %[[VAL_4]] : index to i64 -// CHECK: %[[VAL_6:.*]] = arith.trunci %[[VAL_5]] : i64 to i32 -// CHECK: %[[VAL_7:.*]] = arith.divui %[[VAL_3]], %[[VAL_6]] : i32 -// CHECK: %[[VAL_8:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_2]]], {{\[}}%[[VAL_5]]], {{\[}}%[[VAL_7]]] {order = array} : > +// CHECK: %[[VAL_4:.*]] = arith.extsi %[[VAL_1]] : i32 to i64 +// CHECK: %[[VAL_8:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_2]]], {{\[}}%[[VAL_4]]], {{\[}}%[[VAL_3]]] {order = array} : > // CHECK: %[[VAL_9:.*]] = tt.load %[[VAL_8]] : !tt.ptr> // CHECK: tt.return %[[VAL_9]] : tensor<128xf32> tt.func @test_addptr_splat_make_range_mul(%arg0 : !tt.ptr, %arg1: i32) -> tensor<128xf32> { @@ -266,31 +260,23 @@ tt.func @test_addptr_broadcast_rank_3(%arg0 : !tt.ptr) -> tensor<128x2x128x // CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64 // CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32 // CHECK-DAG: [[CST_2_i32:%.+]] = arith.constant 2 : i32 -// CHECK: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index -// CHECK: [[VAR_3_:%.+]] = arith.muli [[PARAM_3_]], [[CST_2_i32]] : i32 -// CHECK: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : index to i64 -// CHECK: [[VAR_4_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index -// CHECK: [[VAR_5_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index -// CHECK: [[VAR_6_:%.+]] = arith.index_cast [[VAR_5_]] : index to i64 -// CHECK: [[VAR_7_:%.+]] = arith.muli [[PARAM_4_]], [[CST_6_i32]] : i32 -// CHECK: [[VAR_8_:%.+]] = arith.index_cast [[VAR_4_]] : index to i64 -// CHECK: [[VAR_9_:%.+]] = arith.muli [[VAR_8_]], [[VAR_6_]] : i64 -// CHECK: [[VAR_10_:%.+]] = arith.trunci [[VAR_2_]] : i64 to i32 -// CHECK: [[VAR_11_:%.+]] = arith.divui [[VAR_3_]], [[VAR_10_]] : i32 -// CHECK: [[VAR_12_:%.+]] = arith.trunci [[VAR_6_]] : i64 to i32 -// CHECK: [[VAR_13_:%.+]] = arith.divui [[VAR_7_]], [[VAR_12_]] : i32 -// CHECK: [[VAR_14_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[VAR_9_]]], {{\[}}[[VAR_2_]], [[VAR_6_]]], {{\[}}[[VAR_11_]], [[VAR_13_]]] {order = array} : > -// CHECK: [[VAR_15_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK: [[VAR_16_:%.+]] = arith.index_cast [[VAR_15_]] : index to i64 -// CHECK: [[VAR_17_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index -// CHECK: [[VAR_18_:%.+]] = arith.index_cast [[VAR_17_]] : index to i64 -// CHECK: [[VAR_19_:%.+]] = arith.trunci [[VAR_16_]] : i64 to i32 -// CHECK: [[VAR_20_:%.+]] = arith.divui [[CST_0_i32]], [[VAR_19_]] : i32 -// CHECK: [[VAR_21_:%.+]] = arith.trunci [[VAR_18_]] : i64 to i32 -// CHECK: [[VAR_22_:%.+]] = arith.divui [[CST_0_i32]], [[VAR_21_]] : i32 -// CHECK: [[VAR_23:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_16_]], [[VAR_18_]]], {{\[}}[[VAR_20_]], [[VAR_22_]]] {order = array} : > -// CHECK: [[VAR_24:%.+]] = tt.load [[VAR_14_]] {boundaryCheck = array} : !tt.ptr> -// CHECK: tt.store [[VAR_23]], [[VAR_24]] : !tt.ptr> +// CHECK: [[VAR_0_:%.+]] = arith.muli [[PARAM_3_]], [[CST_2_i32]] : i32 +// CHECK: [[VAR_1_:%.+]] = arith.extsi [[PARAM_3_]] : i32 to i64 +// CHECK: [[VAR_2_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index +// CHECK: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index +// CHECK: [[VAR_4_:%.+]] = arith.index_cast [[VAR_3_]] : index to i64 +// CHECK: [[VAR_5_:%.+]] = arith.muli [[PARAM_4_]], [[CST_6_i32]] : i32 +// CHECK: [[VAR_6_:%.+]] = arith.extsi [[PARAM_4_]] : i32 to i64 +// CHECK: [[VAR_7_:%.+]] = arith.index_cast [[VAR_2_]] : index to i64 +// CHECK: [[VAR_8_:%.+]] = arith.muli [[VAR_7_]], [[VAR_4_]] : i64 +// CHECK: [[VAR_9_:%.+]] = arith.divui [[VAR_0_]], [[PARAM_3_]] : i32 +// CHECK: [[VAR_10_:%.+]] = arith.divui [[VAR_5_]], [[PARAM_4_]] : i32 +// CHECK: [[VAR_11_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[VAR_8_]]], {{\[}}[[VAR_1_]], [[VAR_6_]]], {{\[}}[[VAR_9_]], [[VAR_10_]]] {order = array} : > +// CHECK: [[VAR_12_:%.+]] = arith.extsi [[PARAM_5_]] : i32 to i64 +// CHECK: [[VAR_13_:%.+]] = arith.extsi [[PARAM_6_]] : i32 to i64 +// CHECK: [[VAR_14_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_12_]], [[VAR_13_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {order = array} : > +// CHECK: [[VAR_15_:%.+]] = tt.load [[VAR_11_]] {boundaryCheck = array} : !tt.ptr> +// CHECK: tt.store [[VAR_14_]], [[VAR_15_]] : !tt.ptr> // CHECK: tt.return module { tt.func public @wrap_side_by_side_masked(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) { @@ -431,29 +417,21 @@ module { // CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32 // CHECK: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_2_]] : i32 to index // CHECK: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index -// CHECK: [[VAR_3_:%.+]] = arith.muli [[PARAM_3_]], [[CST_2_i32]] : i32 -// CHECK: [[VAR_2_:%.+]] = arith.index_cast [[VAR_1_]] : index to i64 -// CHECK: [[VAR_4_:%.+]] = arith.index_cast [[VAR_0_]] : index to i64 -// CHECK: [[VAR_5_:%.+]] = arith.muli [[VAR_4_]], [[VAR_2_]] : i64 -// CHECK: [[VAR_6_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index -// CHECK: [[VAR_7_:%.+]] = arith.index_cast [[VAR_6_]] : index to i64 -// CHECK: [[VAR_8_:%.+]] = arith.muli [[PARAM_4_]], [[CST_3_i32]] : i32 -// CHECK: [[VAR_9_:%.+]] = arith.trunci [[VAR_2_]] : i64 to i32 -// CHECK: [[VAR_10_:%.+]] = arith.divui [[VAR_3_]], [[VAR_9_]] : i32 -// CHECK: [[VAR_11_:%.+]] = arith.trunci [[VAR_7_]] : i64 to i32 -// CHECK: [[VAR_12_:%.+]] = arith.divui [[VAR_8_]], [[VAR_11_]] : i32 -// CHECK: [[VAR_13:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[VAR_5_]], [[CST_0_i64]]], {{\[}}[[VAR_2_]], [[VAR_7_]]], {{\[}}[[VAR_10_]], [[VAR_12_]]] {order = array} : > -// CHECK: [[VAR_14_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK: [[VAR_15_:%.+]] = arith.index_cast [[VAR_14_]] : index to i64 -// CHECK: [[VAR_16_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index -// CHECK: [[VAR_17_:%.+]] = arith.index_cast [[VAR_16_]] : index to i64 -// CHECK: [[VAR_18_:%.+]] = arith.trunci [[VAR_15_]] : i64 to i32 -// CHECK: [[VAR_19_:%.+]] = arith.divui [[CST_0_i32]], [[VAR_18_]] : i32 -// CHECK: [[VAR_20_:%.+]] = arith.trunci [[VAR_17_]] : i64 to i32 -// CHECK: [[VAR_21_:%.+]] = arith.divui [[CST_0_i32]], [[VAR_20_]] : i32 -// CHECK: [[VAR_22:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_15_]], [[VAR_17_]]], {{\[}}[[VAR_19_]], [[VAR_21_]]] {order = array} : > -// CHECK: [[VAR_23:%.+]] = tt.load [[VAR_13]] {boundaryCheck = array} : !tt.ptr> -// CHECK: tt.store [[VAR_22]], [[VAR_23]] : !tt.ptr> +// CHECK: [[VAR_2_:%.+]] = arith.muli [[PARAM_3_]], [[CST_2_i32]] : i32 +// CHECK: [[VAR_3_:%.+]] = arith.index_cast [[VAR_1_]] : index to i64 +// CHECK: [[VAR_4_:%.+]] = arith.extsi [[PARAM_3_]] : i32 to i64 +// CHECK: [[VAR_5_:%.+]] = arith.index_cast [[VAR_0_]] : index to i64 +// CHECK: [[VAR_6_:%.+]] = arith.muli [[VAR_5_]], [[VAR_3_]] : i64 +// CHECK: [[VAR_7_:%.+]] = arith.muli [[PARAM_4_]], [[CST_3_i32]] : i32 +// CHECK: [[VAR_8_:%.+]] = arith.extsi [[PARAM_4_]] : i32 to i64 +// CHECK: [[VAR_9_:%.+]] = arith.divui [[VAR_2_]], [[PARAM_3_]] : i32 +// CHECK: [[VAR_10_:%.+]] = arith.divui [[VAR_7_]], [[PARAM_4_]] : i32 +// CHECK: [[VAR_11:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[VAR_6_]], [[CST_0_i64]]], {{\[}}[[VAR_4_]], [[VAR_8_]]], {{\[}}[[VAR_9_]], [[VAR_10_]]] {order = array} : > +// CHECK: [[VAR_12_:%.+]] = arith.extsi [[PARAM_5_]] : i32 to i64 +// CHECK: [[VAR_13_:%.+]] = arith.extsi [[PARAM_6_]] : i32 to i64 +// CHECK: [[VAR_14:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_12_]], [[VAR_13_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {order = array} : > +// CHECK: [[VAR_15:%.+]] = tt.load [[VAR_11]] {boundaryCheck = array} : !tt.ptr> +// CHECK: tt.store [[VAR_14]], [[VAR_15]] : !tt.ptr> // CHECK: tt.return module { tt.func public @wrap_stacked_masked_loop(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) { diff --git a/test/Triton/Intel/RaiseToBlockPointers/wraparound_side_by_side.mlir b/test/Triton/Intel/RaiseToBlockPointers/wraparound_side_by_side.mlir index e4bcc887c3..43d5a3964e 100644 --- a/test/Triton/Intel/RaiseToBlockPointers/wraparound_side_by_side.mlir +++ b/test/Triton/Intel/RaiseToBlockPointers/wraparound_side_by_side.mlir @@ -63,41 +63,32 @@ module { // CHECK-DAG: [[CST_4_i32:%.+]] = arith.constant 4 : i32 // CHECK-DAG: [[CST_6_i32:%.+]] = arith.constant 6 : i32 // CHECK: [[VAR_0_:%.+]] = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> -// CHECK: [[VAR_1_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index -// CHECK-DAG: [[VAR_2_:%.+]] = arith.muli [[PARAM_4_]], {{.*}} : i32 -// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[VAR_1_]] : index to i64 -// CHECK-DAG: [[VAR_4_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index -// CHECK-DAG: [[VAR_5_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK-DAG: [[VAR_6_:%.+]] = arith.index_cast [[VAR_5_]] : index to i64 -// CHECK-DAG: [[VAR_7_:%.+]] = arith.muli [[PARAM_5_]], [[CST_6_i32]] : i32 -// CHECK-DAG: [[VAR_8_:%.+]] = arith.index_cast [[VAR_4_]] : index to i64 -// CHECK-DAG: [[VAR_9_:%.+]] = arith.muli [[VAR_8_]], [[VAR_6_]] : i64 -// CHECK: [[VAR_10_:%.+]] = arith.trunci [[VAR_3_]] : i64 to i32 -// CHECK: [[VAR_11_:%.+]] = arith.divui [[VAR_2_]], [[VAR_10_]] : i32 -// CHECK: [[VAR_12_:%.+]] = arith.trunci [[VAR_6_]] : i64 to i32 -// CHECK: [[VAR_13_:%.+]] = arith.divui [[VAR_7_]], [[VAR_12_]] : i32 -// CHECK: [[VAR_14_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[VAR_9_]]], {{\[}}[[VAR_3_]], [[VAR_6_]]], {{\[}}[[VAR_11_]], [[VAR_13_]]] {{.*}} : > -// CHECK-DAG: [[VAR_15_:%.+]] = tt.expand_dims [[VAR_0_]] {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> -// CHECK-DAG: [[VAR_16_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index -// CHECK-DAG: [[VAR_17_:%.+]] = arith.index_cast [[VAR_16_]] : index to i64 -// CHECK-DAG: [[VAR_18_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index -// CHECK-DAG: [[VAR_19_:%.+]] = arith.index_cast [[VAR_18_]] : index to i64 -// CHECK-DAG: [[VAR_20_:%.+]] = arith.trunci [[VAR_17_]] : i64 to i32 -// CHECK-DAG: [[VAR_21_:%.+]] = arith.divui [[CST_0_i32]], [[VAR_20_]] : i32 -// CHECK-DAG: [[VAR_22_:%.+]] = arith.trunci [[VAR_19_]] : i64 to i32 -// CHECK-DAG: [[VAR_23_:%.+]] = arith.divui [[CST_0_i32]], [[VAR_22_]] : i32 -// CHECK: [[VAR_24_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_17_]], [[VAR_19_]]], {{\[}}[[VAR_21_]], [[VAR_23_]]] {{.*}} : > -// CHECK: [[VAR_25_:%.+]] = arith.cmpi slt, [[VAR_15_]], {{.*}} : tensor<4x1xi32> -// CHECK: [[VAR_26_:%.+]] = tt.broadcast [[VAR_25_]] : tensor<4x1xi1> -> tensor<4x4xi1> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_27_:%.+]] = arith.muli [[PARAM_4_]], [[CST_4_i32]] : i32 -// CHECK-DAG: [[VAR_28_:%.+]] = arith.muli [[PARAM_5_]], [[CST_4_i32]] : i32 -// CHECK-DAG: [[VAR_12_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = {{.*}} iter_args([[VAR_arg9_:%.+]] = [[VAR_14_]], [[VAR_arg10_:%.+]] = [[VAR_24_]]) -> (!tt.ptr>, !tt.ptr>) : i32 { -// CHECK: [[VAR_30_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_26_]], [[CST_]] : !tt.ptr> -// CHECK: tt.store [[VAR_arg10_]], [[VAR_30_]] : !tt.ptr> -// CHECK-DAG: [[VAR_31_:%.+]] = tt.advance [[VAR_arg9_]], {{\[}}[[CST_0_i32]], [[VAR_27_]]] : > -// CHECK-DAG: [[VAR_32_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[CST_0_i32]], [[VAR_28_]]] : > -// CHECK: scf.yield [[VAR_31_]], [[VAR_32_]] : !tt.ptr>, !tt.ptr> +// CHECK-DAG: [[VAR_1_:%.+]] = arith.muli [[PARAM_4_]], {{.*}} : i32 +// CHECK-DAG: [[VAR_2_:%.+]] = arith.extsi [[PARAM_4_]] : i32 to i64 +// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[PARAM_3_]] : i32 to index +// CHECK-DAG: [[VAR_4_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index +// CHECK-DAG: [[VAR_5_:%.+]] = arith.index_cast [[VAR_4_]] : index to i64 +// CHECK-DAG: [[VAR_6_:%.+]] = arith.muli [[PARAM_5_]], [[CST_6_i32]] : i32 +// CHECK: [[VAR_7_:%.+]] = arith.extsi [[PARAM_5_]] : i32 to i64 +// CHECK-DAG: [[VAR_8_:%.+]] = arith.index_cast [[VAR_3_]] : index to i64 +// CHECK-DAG: [[VAR_9_:%.+]] = arith.muli [[VAR_8_]], [[VAR_5_]] : i64 +// CHECK: [[VAR_10_:%.+]] = arith.divui [[VAR_1_]], [[PARAM_4_]] : i32 +// CHECK: [[VAR_11_:%.+]] = arith.divui [[VAR_6_]], [[PARAM_5_]] : i32 +// CHECK: [[VAR_12_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[VAR_9_]]], {{\[}}[[VAR_2_]], [[VAR_7_]]], {{\[}}[[VAR_10_]], [[VAR_11_]]] {{.*}} : > +// CHECK-DAG: [[VAR_13_:%.+]] = tt.expand_dims [[VAR_0_]] {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> +// CHECK-DAG: [[VAR_14_:%.+]] = arith.extsi [[PARAM_6_]] : i32 to i64 +// CHECK-DAG: [[VAR_15_:%.+]] = arith.extsi [[PARAM_7_]] : i32 to i64 +// CHECK: [[VAR_16_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_14_]], [[VAR_15_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : > +// CHECK: [[VAR_21_:%.+]] = arith.cmpi slt, [[VAR_13_]], {{.*}} : tensor<4x1xi32> +// CHECK: [[VAR_22_:%.+]] = tt.broadcast [[VAR_21_]] : tensor<4x1xi1> -> tensor<4x4xi1> +// CHECK-DAG: [[VAR_23_:%.+]] = arith.muli [[PARAM_4_]], [[CST_4_i32]] : i32 +// CHECK-DAG: [[VAR_24_:%.+]] = arith.muli [[PARAM_5_]], [[CST_4_i32]] : i32 +// CHECK-DAG: [[VAR_25_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = {{.*}} iter_args([[VAR_arg9_:%.+]] = [[VAR_12_]], [[VAR_arg10_:%.+]] = [[VAR_16_]]) -> (!tt.ptr>, !tt.ptr>) : i32 { +// CHECK: [[VAR_26_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_22_]], [[CST_]] : !tt.ptr> +// CHECK: tt.store [[VAR_arg10_]], [[VAR_26_]] : !tt.ptr> +// CHECK-DAG: [[VAR_27_:%.+]] = tt.advance [[VAR_arg9_]], {{\[}}[[CST_0_i32]], [[VAR_23_]]] : > +// CHECK-DAG: [[VAR_28_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[CST_0_i32]], [[VAR_24_]]] : > +// CHECK: scf.yield [[VAR_27_]], [[VAR_28_]] : !tt.ptr>, !tt.ptr> // CHECK: } // CHECK: tt.return // CHECK: } diff --git a/test/Triton/Intel/RaiseToBlockPointers/wraparound_stacked.mlir b/test/Triton/Intel/RaiseToBlockPointers/wraparound_stacked.mlir index d5ec41d3da..5c210f3ac6 100644 --- a/test/Triton/Intel/RaiseToBlockPointers/wraparound_stacked.mlir +++ b/test/Triton/Intel/RaiseToBlockPointers/wraparound_stacked.mlir @@ -65,35 +65,27 @@ module { // CHECK: [[VAR_2_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index // CHECK-DAG: [[VAR_3_:%.+]] = arith.muli [[PARAM_4_]], {{.*}} : i32 // CHECK-DAG: [[VAR_4_:%.+]] = arith.index_cast [[VAR_2_]] : index to i64 -// CHECK-DAG: [[VAR_5_:%.+]] = arith.index_cast [[VAR_1_]] : index to i64 -// CHECK-DAG: [[VAR_6_:%.+]] = arith.muli [[VAR_5_]], [[VAR_4_]] : i64 -// CHECK-DAG: [[VAR_7_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK-DAG: [[VAR_8_:%.+]] = arith.index_cast [[VAR_7_]] : index to i64 -// CHECK-DAG: [[VAR_9_:%.+]] = arith.muli [[PARAM_5_]], {{.*}} : i32 -// CHECK: [[VAR_10_:%.+]] = arith.trunci [[VAR_4_]] : i64 to i32 -// CHECK: [[VAR_11_:%.+]] = arith.divui [[VAR_3_]], [[VAR_10_]] : i32 -// CHECK: [[VAR_12_:%.+]] = arith.trunci [[VAR_8_]] : i64 to i32 -// CHECK: [[VAR_13_:%.+]] = arith.divui [[VAR_9_]], [[VAR_12_]] : i32 -// CHECK: [[VAR_14_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[VAR_6_]], [[CST_0_i64]]], {{\[}}[[VAR_4_]], [[VAR_8_]]], {{\[}}[[VAR_11_]], [[VAR_13_]]] {{.*}} : > -// CHECK-DAG: [[VAR_15_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index -// CHECK-DAG: [[VAR_16_:%.+]] = arith.index_cast [[VAR_15_]] : index to i64 -// CHECK-DAG: [[VAR_17_:%.+]] = tt.expand_dims [[VAR_0_]] {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> -// CHECK-DAG: [[VAR_18_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index -// CHECK-DAG: [[VAR_19_:%.+]] = arith.index_cast [[VAR_18_]] : index to i64 -// CHECK-DAG: [[VAR_20_:%.+]] = arith.trunci [[VAR_16_]] : i64 to i32 -// CHECK-DAG: [[VAR_21_:%.+]] = arith.divui [[CST_0_i32]], [[VAR_20_]] : i32 -// CHECK-DAG: [[VAR_22_:%.+]] = arith.trunci [[VAR_19_]] : i64 to i32 -// CHECK-DAG: [[VAR_23_:%.+]] = arith.divui [[CST_0_i32]], [[VAR_22_]] : i32 -// CHECK: [[VAR_24_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_16_]], [[VAR_19_]]], {{\[}}[[VAR_21_]], [[VAR_23_]]] {{.*}} : > -// CHECK: [[VAR_25_:%.+]] = arith.cmpi slt, [[VAR_17_]], {{.*}} : tensor<1x4xi32> -// CHECK: [[VAR_26_:%.+]] = tt.broadcast [[VAR_25_]] : tensor<1x4xi1> -> tensor<4x4xi1> -// CHECK: [[VAR_27_:%.+]] = arith.muli [[PARAM_5_]], [[CST_4_i32]] : i32 -// CHECK: [[VAR_28_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = {{.*}} iter_args([[VAR_arg9_:%.+]] = [[VAR_14_]], [[VAR_arg10_:%.+]] = [[VAR_24_]]) -> (!tt.ptr>, !tt.ptr>) : i32 { -// CHECK: [[VAR_29_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_26_]], [[CST_]] : !tt.ptr> -// CHECK: tt.store [[VAR_arg10_]], [[VAR_29_]] : !tt.ptr> -// CHECK-DAG: [[VAR_30_:%.+]] = tt.advance [[VAR_arg9_]], {{\[}}[[CST_0_i32]], [[VAR_27_]]] : > -// CHECK-DAG: [[VAR_31_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[CST_0_i32]], [[VAR_27_]]] : > -// CHECK: scf.yield [[VAR_30_]], [[VAR_31_]] : !tt.ptr>, !tt.ptr> +// CHECK-DAG: [[VAR_5_:%.+]] = arith.extsi [[PARAM_4_]] : i32 to i64 +// CHECK-DAG: [[VAR_6_:%.+]] = arith.index_cast [[VAR_1_]] : index to i64 +// CHECK-DAG: [[VAR_7_:%.+]] = arith.muli [[VAR_6_]], [[VAR_4_]] : i64 +// CHECK-DAG: [[VAR_8_:%.+]] = arith.muli [[PARAM_5_]], [[CST_3_i32]] : i32 +// CHECK-DAG: [[VAR_9_:%.+]] = arith.extsi [[PARAM_5_]] : i32 to i64 +// CHECK: [[VAR_10_:%.+]] = arith.divui [[VAR_3_]], [[PARAM_4_]] : i32 +// CHECK: [[VAR_11_:%.+]] = arith.divui [[VAR_8_]], [[PARAM_5_]] : i32 +// CHECK: [[VAR_12_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[VAR_7_]], [[CST_0_i64]]], {{\[}}[[VAR_5_]], [[VAR_9_]]], {{\[}}[[VAR_10_]], [[VAR_11_]]] {{.*}} : > +// CHECK-DAG: [[VAR_13_:%.+]] = arith.extsi [[PARAM_6_]] : i32 to i64 +// CHECK-DAG: [[VAR_14_:%.+]] = tt.expand_dims [[VAR_0_]] {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> +// CHECK-DAG: [[VAR_15_:%.+]] = arith.extsi [[PARAM_7_]] : i32 to i64 +// CHECK: [[VAR_16_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_13_]], [[VAR_15_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : > +// CHECK: [[VAR_17_:%.+]] = arith.cmpi slt, [[VAR_14_]], {{.*}} : tensor<1x4xi32> +// CHECK: [[VAR_18_:%.+]] = tt.broadcast [[VAR_17_]] : tensor<1x4xi1> -> tensor<4x4xi1> +// CHECK: [[VAR_19_:%.+]] = arith.muli [[PARAM_5_]], [[CST_4_i32]] : i32 +// CHECK: [[VAR_20_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = {{.*}} iter_args([[VAR_arg9_:%.+]] = [[VAR_12_]], [[VAR_arg10_:%.+]] = [[VAR_16_]]) -> (!tt.ptr>, !tt.ptr>) : i32 { +// CHECK: [[VAR_21_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_18_]], [[CST_]] : !tt.ptr> +// CHECK: tt.store [[VAR_arg10_]], [[VAR_21_]] : !tt.ptr> +// CHECK-DAG: [[VAR_22_:%.+]] = tt.advance [[VAR_arg9_]], {{\[}}[[CST_0_i32]], [[VAR_19_]]] : > +// CHECK-DAG: [[VAR_23_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[CST_0_i32]], [[VAR_19_]]] : > +// CHECK: scf.yield [[VAR_22_]], [[VAR_23_]] : !tt.ptr>, !tt.ptr> // CHECK: } // CHECK: tt.return // CHECK: } diff --git a/test/Triton/Intel/RaiseToBlockPointers/wraparound_unsupported_add_offset.mlir b/test/Triton/Intel/RaiseToBlockPointers/wraparound_unsupported_add_offset.mlir index 5fc68bd6cf..805bbb9ccb 100644 --- a/test/Triton/Intel/RaiseToBlockPointers/wraparound_unsupported_add_offset.mlir +++ b/test/Triton/Intel/RaiseToBlockPointers/wraparound_unsupported_add_offset.mlir @@ -89,15 +89,9 @@ module { // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_15_:%.+]] = tt.addptr [[VAR_14_]], [[VAR_13_]] : tensor<4x4x!tt.ptr>, tensor<4x4xi32> // CHECK-DAG: [[VAR_16_:%.+]] = tt.expand_dims [[VAR_0_]] {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> -// CHECK-DAG: [[VAR_17_:%.+]] = arith.index_cast [[arg6_]] : i32 to index -// CHECK-DAG: [[VAR_18_:%.+]] = arith.index_cast [[VAR_17_]] : index to i64 -// CHECK-DAG: [[VAR_19_:%.+]] = arith.index_cast [[arg7_]] : i32 to index -// CHECK-DAG: [[VAR_20_:%.+]] = arith.index_cast [[VAR_19_]] : index to i64 -// CHECK-DAG: [[VAR_21_:%.+]] = arith.trunci [[VAR_18_]] : i64 to i32 -// CHECK-DAG: [[VAR_22_:%.+]] = arith.divui [[CST_0_i32]], [[VAR_21_]] : i32 -// CHECK-DAG: [[VAR_23_:%.+]] = arith.trunci [[VAR_20_]] : i64 to i32 -// CHECK-DAG: [[VAR_24_:%.+]] = arith.divui [[CST_0_i32]], [[VAR_23_]] : i32 -// CHECK: [[VAR_25_:%.+]] = tt.make_tensor_ptr [[arg1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_18_]], [[VAR_20_]]], {{\[}}[[VAR_22_]], [[VAR_24_]]] {{.*}} : > +// CHECK-DAG: [[VAR_17_:%.+]] = arith.extsi [[arg6_]] : i32 to i64 +// CHECK-DAG: [[VAR_18_:%.+]] = arith.extsi [[arg7_]] : i32 to i64 +// CHECK: [[VAR_25_:%.+]] = tt.make_tensor_ptr [[arg1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_17_]], [[VAR_18_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : > // CHECK: [[VAR_26_:%.+]] = arith.cmpi slt, [[VAR_16_]], [[VAR_cst_0_]] : tensor<4x1xi32> // CHECK-DAG: [[VAR_27_:%.+]] = tt.broadcast [[VAR_26_]] : tensor<4x1xi1> -> tensor<4x4xi1> // CHECK-DAG: [[VAR_28_:%.+]] = arith.muli [[arg4_]], [[CST_4_i32]] : i32 diff --git a/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp b/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp index 0b4c4ff5c5..fcf5e14c84 100644 --- a/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp +++ b/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp @@ -105,48 +105,49 @@ Value findOrCreateMakeTensorPtr(Location loc, Value source, ValueRange shape, loc, source, shape, strides, offsets, sizes, order); } -Value addOrFold(Value lhs, Value rhs, ArithBuilder &abuilder) { - return ttgi::isConstant(lhs, 0) - ? rhs - : (ttgi::isConstant(rhs, 0) ? lhs : abuilder.add(lhs, rhs)); -} +Value getFinalValue(Value value) { + Operation *defOp = value.getDefiningOp(); + if (!defOp) { + // look init values outside the loop + BlockArgument blockArg = dyn_cast(value); + Operation *parentOp = blockArg.getOwner()->getParentOp(); + if (scf::ForOp forOp = dyn_cast(parentOp)) + return getFinalValue(forOp.getInitArgs()[blockArg.getArgNumber() - 1]); -Value mulOrFold(Value lhs, Value rhs, ArithBuilder &abuilder) { - if (ttgi::isConstant(lhs, 0) || ttgi::isConstant(rhs, 1)) - return lhs; - if (ttgi::isConstant(rhs, 0) || ttgi::isConstant(lhs, 1)) - return rhs; - return abuilder.mul(lhs, rhs); -} + return value; + } -Value divOrFold(Location loc, Type type, Value num, Value den, - OpBuilder &builder) { - // If the denominator has value one, return the numerator. - if (Operation *defOp = den.getDefiningOp()) { - if (auto truncOp = dyn_cast(defOp)) { - if (ttgi::isConstant(truncOp.getOperand(), 1)) - return num; - } - if (auto truncOp = dyn_cast(defOp)) { - if (ttgi::isConstant(truncOp.getOperand(), 1.0)) - return num; - } + if (isa( + defOp)) + return getFinalValue(defOp->getOperand(0)); + + if (auto addOp = dyn_cast(defOp)) { + if (ttgi::isConstant(addOp.getLhs(), 0)) + return getFinalValue(addOp.getRhs()); + if (ttgi::isConstant(addOp.getRhs(), 0)) + return getFinalValue(addOp.getLhs()); + return addOp.getResult(); } - // If the numerator has value zero, return it. - if (Operation *defOp = num.getDefiningOp()) { - if (auto truncOp = dyn_cast(defOp)) { - if (ttgi::isConstant(truncOp.getOperand(), 0)) - return num; - } - if (auto truncOp = dyn_cast(defOp)) { - if (ttgi::isConstant(truncOp.getOperand(), 0.0)) - return num; - } + if (auto mulOp = dyn_cast(defOp)) { + if (ttgi::isConstant(mulOp.getLhs(), 1) || + ttgi::isConstant(mulOp.getRhs(), 0)) + return getFinalValue(mulOp.getRhs()); + if (ttgi::isConstant(mulOp.getRhs(), 1) || + ttgi::isConstant(mulOp.getLhs(), 0)) + return getFinalValue(mulOp.getLhs()); + return mulOp.getResult(); } - return builder.createOrFold(loc, type, num, den); -}; + if (auto divOp = dyn_cast(defOp)) { + if (ttgi::isConstant(divOp.getRhs(), 1) || + ttgi::isConstant(divOp.getLhs(), 0)) + return getFinalValue(divOp.getLhs()); + return divOp.getResult(); + } + + return value; +} // Data structure used to decode pointer arithmetics. Offsets, sizes, and // strides are in unit of elements in a linearly laid-out memory, which is the @@ -218,19 +219,27 @@ struct PtrState { Location loc = op->getLoc(); ArithBuilder abuilder(builder, loc); - if (lhsState.scalar && rhsState.scalar) - scalar = addOrFold(lhsState.scalar, rhsState.scalar, abuilder); - else if (lhsState.getRank() == 0) + if (lhsState.scalar && rhsState.scalar) { + scalar = + builder.create(loc, lhsState.scalar, rhsState.scalar); + scalar = findOrCreateCast(loc, getFinalValue(scalar), + lhsState.scalar.getType(), builder); + + } else if (lhsState.getRank() == 0) scalar = lhsState.scalar ? lhsState.scalar : rhsState.scalar; for (unsigned i = 0; i < lhsState.getRank(); ++i) { - Value newOffset = - addOrFold(lhsState.offsets[i], rhsState.offsets[i], abuilder); - offsets.push_back(newOffset); - - Value newStride = - addOrFold(lhsState.strides[i], rhsState.strides[i], abuilder); - strides.push_back(newStride); + Value newOffset = builder.create(loc, lhsState.offsets[i], + rhsState.offsets[i]); + offsets.push_back(findOrCreateCast(loc, getFinalValue(newOffset), + lhsState.offsets[i].getType(), + builder)); + + Value newStride = builder.create(loc, lhsState.strides[i], + rhsState.strides[i]); + strides.push_back(findOrCreateCast(loc, getFinalValue(newStride), + lhsState.strides[i].getType(), + builder)); sizes.push_back(lhsState.sizes[i]); } @@ -304,28 +313,40 @@ struct PtrState { for (const auto &[offset, stride, dim, size] : llvm::zip(lhs->offsets, lhs->strides, lhs->shape, lhs->sizes)) { - Value newOffset = mulOrFold( + Value newOffset = builder.create( + loc, findOrCreateCast(loc, offset, builder.getIntegerType(offsetBitwidth), builder), findOrCreateCast(loc, rhs->scalar, - builder.getIntegerType(offsetBitwidth), builder), - abuilder); - Value newStride = mulOrFold( + builder.getIntegerType(offsetBitwidth), builder)); + newOffset = + findOrCreateCast(loc, getFinalValue(newOffset), + builder.getIntegerType(offsetBitwidth), builder); + + Value newStride = builder.create( + loc, findOrCreateCast(loc, stride, builder.getIntegerType(shapeAndStridesBitwidth), builder), findOrCreateCast(loc, rhs->scalar, builder.getIntegerType(shapeAndStridesBitwidth), - builder), - abuilder); - Value newDim = mulOrFold( + builder)); + newStride = findOrCreateCast( + loc, getFinalValue(newStride), + builder.getIntegerType(shapeAndStridesBitwidth), builder); + + Value newDim = builder.create( + loc, findOrCreateCast(loc, dim, builder.getIntegerType(shapeAndStridesBitwidth), builder), findOrCreateCast(loc, rhs->scalar, builder.getIntegerType(shapeAndStridesBitwidth), - builder), - abuilder); + builder)); + newDim = findOrCreateCast(loc, getFinalValue(newDim), + builder.getIntegerType(shapeAndStridesBitwidth), + builder); + offsets.push_back(newOffset); strides.push_back(newStride); shape.push_back(newDim); @@ -341,17 +362,18 @@ struct PtrState { for (const auto &[offset, stride, dim] : llvm::zip(offsets, strides, shape)) { if (ttgi::isConstant(stride, 0)) { - newOffsets.push_back( - findOrCreateCast(loc, offset, builder.getI32Type(), builder)); + newOffsets.push_back(findOrCreateCast( + loc, offset, builder.getIntegerType(offsetBitwidth), builder)); } else { - auto divOffset = divOrFold( - loc, builder.getI32Type(), + Value divOffset = builder.create( + loc, builder.getIntegerType(offsetBitwidth), findOrCreateCast(loc, offset, builder.getIntegerType(offsetBitwidth), builder), findOrCreateCast(loc, stride, - builder.getIntegerType(offsetBitwidth), builder), - builder); - newOffsets.push_back(divOffset); + builder.getIntegerType(offsetBitwidth), builder)); + newOffsets.push_back( + findOrCreateCast(loc, getFinalValue(divOffset), + builder.getIntegerType(offsetBitwidth), builder)); } newStrides.push_back(findOrCreateCast( loc, stride, builder.getIntegerType(shapeAndStridesBitwidth), @@ -559,37 +581,6 @@ struct TritonRaiseBlockPointer return success(); } - Value getFinalValue(Value value) const { - Operation *defOp = value.getDefiningOp(); - if (!defOp) { - // look init values outside the loop - BlockArgument blockArg = dyn_cast(value); - Operation *parentOp = blockArg.getOwner()->getParentOp(); - scf::ForOp forOp = dyn_cast(parentOp); - return forOp ? getFinalValue( - forOp.getInitArgs()[blockArg.getArgNumber() - 1]) - : value; - } - - if (isa(defOp) || isa(defOp) || - isa(defOp) || isa(defOp)) - return getFinalValue(defOp->getOperand(0)); - if (auto addOp = dyn_cast(defOp)) { - if (ttgi::isConstant(addOp.getLhs(), 0)) - return getFinalValue(addOp.getRhs()); - if (ttgi::isConstant(addOp.getRhs(), 0)) - return getFinalValue(addOp.getLhs()); - return addOp.getResult(); - } else if (auto mulOp = dyn_cast(defOp)) { - if (ttgi::isConstant(mulOp.getLhs(), 1)) - return getFinalValue(mulOp.getRhs()); - if (ttgi::isConstant(mulOp.getRhs(), 1)) - return getFinalValue(mulOp.getLhs()); - return mulOp.getResult(); - } - return value; - } - bool lookForMultiplyingValueInDefiningPath(Value &val, Value &ref) const { if (Operation *defOp = getFinalValue(val).getDefiningOp()) { if (auto mulOp = dyn_cast(defOp)) { @@ -711,7 +702,6 @@ struct TritonRaiseBlockPointer // rewrite the AddPtrOp using and AdvanceOp. if (Value mappedV = ptrMap.lookupOrNull(ptr)) { if (auto makeTPtrOp = mappedV.getDefiningOp()) { - Value finalVal = getFinalValue(op.getOffset()); auto offsetType = cast(op.getOffset().getType()); unsigned rank = offsetType.getRank(); @@ -743,7 +733,7 @@ struct TritonRaiseBlockPointer ptrMap.map(op.getResult(), advanceOp); LLVM_DEBUG(llvm::dbgs() - << "Rewrote:\n\t" << op << "to:\n\t" << advanceOp << "\n"); + << "Rewrote:\n\t" << op << "\nto:\n\t" << advanceOp << "\n"); return success(); } else { llvm_unreachable("Did not find tt::MakeTensorPtrOp"); @@ -810,7 +800,8 @@ struct TritonRaiseBlockPointer loc, builder.getIndexType(), makeTPtrOp.getStrides()[i]); auto offsetCst = builder.createOrFold( loc, builder.getIndexType(), makeTPtrOp.getOffsets()[i]); - auto scaledOffset = mulOrFold(offsetCst, strideCst, abuilder); + auto scaledOffset = + builder.createOrFold(loc, offsetCst, strideCst); state.offsets.push_back(findOrCreateCast( loc, scaledOffset, builder.getIntegerType(offsetBitwidth), builder)); } From 05e794c2fb106a59c82bc32a536943572f3d22a9 Mon Sep 17 00:00:00 2001 From: "Tiotto, Ettore" Date: Mon, 27 Jan 2025 22:22:02 +0000 Subject: [PATCH 4/4] Fix offset computation for operation Signed-off-by: Tiotto, Ettore --- .../RaiseToBlockPointers/addptr_dim1.mlir | 2 +- .../addptr_for_accumulation.mlir | 22 +++--- .../kernel-03-matrix-multiplication.mlir | 18 ++++- .../raise-block-pointer.mlir | 2 +- .../wraparound_side_by_side.mlir | 22 +++--- .../wraparound_stacked.mlir | 8 +- .../wraparound_unsupported_add_offset.mlir | 25 +++--- .../TritonRaiseBlockPointer.cpp | 76 +++++++++---------- 8 files changed, 92 insertions(+), 83 deletions(-) diff --git a/test/Triton/Intel/RaiseToBlockPointers/addptr_dim1.mlir b/test/Triton/Intel/RaiseToBlockPointers/addptr_dim1.mlir index a9549c583f..6258fcb741 100644 --- a/test/Triton/Intel/RaiseToBlockPointers/addptr_dim1.mlir +++ b/test/Triton/Intel/RaiseToBlockPointers/addptr_dim1.mlir @@ -85,7 +85,7 @@ module { // CHECK-DAG: [[VAR_3_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_0_i64]], [[CST_1_i64]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : > // CHECK-DAG: [[VAR_4_:%.+]] = tt.addptr [[VAR_2_]], [[VAR_1_]] : tensor<1x256x!tt.ptr>, tensor<1x256xi32> // CHECK-DAG: [[VAR_5_:%.+]] = tt.load [[VAR_3_]] : !tt.ptr> -// CHECK-DAG: [[VAR_6_:%.+]] = tt.advance [[VAR_3_]], {{\[}}[[CST_0_i32]], [[PARAM_1_]]] : > +// CHECK-DAG: [[VAR_6_:%.+]] = tt.advance [[VAR_3_]], {{\[}}[[PARAM_1_]], [[CST_0_i32]]] : > // CHECK: tt.store [[VAR_6_]], [[VAR_5_]] : !tt.ptr> // CHECK: [[VAR_7_:%.+]]:2 = scf.for [[VAR_arg2_:%.+]] = {{.*}} iter_args([[VAR_arg3_:%.+]] = [[CST_0_]], [[VAR_arg4_:%.+]] = [[VAR_4_]]) -> (tensor<4x256xbf16>, tensor<1x256x!tt.ptr>) { // CHECK: [[VAR_9_:%.+]] = tt.broadcast [[VAR_arg4_]] : tensor<1x256x!tt.ptr> -> tensor<4x256x!tt.ptr> diff --git a/test/Triton/Intel/RaiseToBlockPointers/addptr_for_accumulation.mlir b/test/Triton/Intel/RaiseToBlockPointers/addptr_for_accumulation.mlir index 92b893129f..3322a787e4 100644 --- a/test/Triton/Intel/RaiseToBlockPointers/addptr_for_accumulation.mlir +++ b/test/Triton/Intel/RaiseToBlockPointers/addptr_for_accumulation.mlir @@ -64,20 +64,18 @@ module { // CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : i32 // CHECK-DAG: [[CST_5_:%.+]] = arith.constant 5 : i64 // CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i64 -// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i64 // CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32 -// CHECK: [[VAR_1_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_5_]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {{.*}} : > -// CHECK: [[VAR_2_:%.+]] = tt.load [[VAR_1_]] : !tt.ptr> -// CHECK: [[VAR_3_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_5_]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {{.*}} : > -// CHECK: [[VAR_4_:%.+]]:2 = scf.for [[VAR_arg5_:%.+]] = {{.*}} iter_args([[VAR_arg6_:%.+]] = [[VAR_2_]], [[VAR_arg7_:%.+]] = [[VAR_3_]]) -> (tensor<4x256xbf16>, !tt.ptr>) { -// CHECK: [[VAR_8_:%.+]] = tt.load [[VAR_arg7_]] : !tt.ptr> -// CHECK-DAG: [[VAR_9_:%.+]] = arith.addf [[VAR_arg6_]], [[VAR_8_]] : tensor<4x256xbf16> -// CHECK-DAG: [[VAR_10_:%.+]] = tt.advance [[VAR_arg7_]], {{\[}}[[CST_0_i32]], [[CST_3_]]] : > -// CHECK: scf.yield [[VAR_9_]], [[VAR_10_]] : tensor<4x256xbf16>, !tt.ptr> +// CHECK: [[VAR_0_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_5_]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {{.*}} : > +// CHECK: [[VAR_1_:%.+]] = tt.load [[VAR_0_]] : !tt.ptr> +// CHECK: [[VAR_2_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_5_]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {{.*}} : > +// CHECK: [[VAR_3_:%.+]]:2 = scf.for [[VAR_arg5_:%.+]] = {{.*}} iter_args([[VAR_arg6_:%.+]] = [[VAR_1_]], [[VAR_arg7_:%.+]] = [[VAR_2_]]) -> (tensor<4x256xbf16>, !tt.ptr>) { +// CHECK: [[VAR_5_:%.+]] = tt.load [[VAR_arg7_]] : !tt.ptr> +// CHECK-DAG: [[VAR_6_:%.+]] = arith.addf [[VAR_arg6_]], [[VAR_5_]] : tensor<4x256xbf16> +// CHECK-DAG: [[VAR_7_:%.+]] = tt.advance [[VAR_arg7_]], {{\[}}[[CST_3_]], [[CST_0_i32]]] : > +// CHECK: scf.yield [[VAR_6_]], [[VAR_7_]] : tensor<4x256xbf16>, !tt.ptr> // CHECK: } -// COM: to sizes: [4, 256], strides: [1, [[CST_5_]]{{.}}, offsets: {{.}}[[VAR_5_]], 0], shape: [0, 0], order: [] : to tensor<4x256x!tt.ptr> -// CHECK: [[VAR_6_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_5_]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {{.*}} : > -// CHECK: tt.store [[VAR_6_]], [[VAR_4_]]#0 : !tt.ptr> +// CHECK: [[VAR_4_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_5_]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {{.*}} : > +// CHECK: tt.store [[VAR_4_]], [[VAR_3_]]#0 : !tt.ptr> // CHECK: tt.return // CHECK: } diff --git a/test/Triton/Intel/RaiseToBlockPointers/kernel-03-matrix-multiplication.mlir b/test/Triton/Intel/RaiseToBlockPointers/kernel-03-matrix-multiplication.mlir index c83eaf9fbc..0b0c52f3fb 100644 --- a/test/Triton/Intel/RaiseToBlockPointers/kernel-03-matrix-multiplication.mlir +++ b/test/Triton/Intel/RaiseToBlockPointers/kernel-03-matrix-multiplication.mlir @@ -106,6 +106,8 @@ module { // CHECK-DAG: [[CST_64_i32:%.+]] = arith.constant 64 : i32 // CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32 // CHECK-DAG: [[CST_1_i32:%.+]] = arith.constant 1 : i32 +// CHECK: [[VAR_17_:%.+]] = arith.muli {{.*}}, [[CST_128_i32]] : i32 +// CHECK: [[VAR_18_:%.+]] = arith.muli {{.*}}, [[CST_256_i32]] : i32 // CHECK: [[VAR_20_:%.+]] = arith.extsi [[PARAM_6_]] : i32 to i64 // CHECK: [[VAR_21_:%.+]] = arith.extsi [[PARAM_7_]] : i32 to i64 // CHECK: [[VAR_22_:%.+]] = arith.divui {{.*}}, [[PARAM_6_]] : i32 @@ -122,12 +124,20 @@ module { // CHECK-DAG: [[VAR_41_:%.+]] = tt.load [[VAR_arg15_]] : !tt.ptr> // CHECK: [[VAR_42_:%.+]] = tt.dot [[VAR_40_]], [[VAR_41_]], [[VAR_cst_]], inputPrecision = tf32 : tensor<128x64xbf16> * tensor<64x256xbf16> -> tensor<128x256xf32> // CHECK-DAG: [[VAR_43_:%.+]] = arith.addf [[VAR_arg13_]], [[VAR_42_]] : tensor<128x256xf32> -// CHECK-DAG: [[VAR_44_:%.+]] = tt.advance [[VAR_arg14_]], {{\[}}[[CST_0_i32]], [[VAR_29_]]] : > -// CHECK-DAG: [[VAR_45_:%.+]] = tt.advance [[VAR_arg15_]], {{\[}}[[CST_0_i32]], [[VAR_30_]]] : > -// CHECK: scf.yield [[VAR_43_]], [[VAR_44_]], [[VAR_45_]] : tensor<128x256xf32>, !tt.ptr>, !tt.ptr> +// CHECK-DAG: [[VAR_44_:%.+]] = arith.divui [[VAR_29_]], [[PARAM_6_]] : i32 +// CHECK-DAG: [[VAR_45_:%.+]] = tt.advance [[VAR_arg14_]], {{\[}}[[VAR_44_]], [[CST_0_i32]]] : > +// CHECK-DAG: [[VAR_46_:%.+]] = arith.divui [[VAR_30_]], [[PARAM_8_]] : i32 +// CHECK-DAG: [[VAR_47_:%.+]] = tt.advance [[VAR_arg15_]], {{\[}}[[VAR_46_]], [[CST_0_i32]]] : > +// CHECK: scf.yield [[VAR_43_]], [[VAR_45_]], [[VAR_47_]] : tensor<128x256xf32>, !tt.ptr>, !tt.ptr> // CHECK: } // CHECK-DAG: [[VAR_32_:%.+]] = arith.truncf [[VAR_31_]]#0 : tensor<128x256xf32> to tensor<128x256xbf16> -// CHECK: [[VAR_39_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}{{.*}}], {{\[}}{{.*}}] {{.*}} : > +// CHECK-DAG: [[VAR_33_:%.+]] = arith.muli [[VAR_17_]], [[PARAM_10_]] : i32 +// CHECK-DAG: [[VAR_34_:%.+]] = arith.extsi [[PARAM_10_]] : i32 to i64 +// CHECK-DAG: [[VAR_35_:%.+]] = arith.muli [[VAR_18_]], [[PARAM_11_]] : i32 +// CHECK-DAG: [[VAR_36_:%.+]] = arith.extsi [[PARAM_11_]] : i32 to i64 +// CHECK-DAG: [[VAR_37_:%.+]] = arith.divui [[VAR_33_]], [[PARAM_10_]] : i32 +// CHECK-DAG: [[VAR_38_:%.+]] = arith.divui [[VAR_35_]], [[PARAM_11_]] : i32 +// CHECK: [[VAR_39_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_34_]], [[VAR_36_]]], {{\[}}[[VAR_37_]], [[VAR_38_]]] {{.*}} : > // CHECK: tt.store [[VAR_39_]], [[VAR_32_]] : !tt.ptr> // CHECK: tt.return // CHECK: } diff --git a/test/Triton/Intel/RaiseToBlockPointers/raise-block-pointer.mlir b/test/Triton/Intel/RaiseToBlockPointers/raise-block-pointer.mlir index 3f6a7ae0fb..f26c2df9c8 100644 --- a/test/Triton/Intel/RaiseToBlockPointers/raise-block-pointer.mlir +++ b/test/Triton/Intel/RaiseToBlockPointers/raise-block-pointer.mlir @@ -342,7 +342,7 @@ tt.func public @wrap_side_by_side_masked(%arg0: !tt.ptr, %arg1: !tt.ptr (tensor<4x256xbf16>, !tt.ptr>) { // CHECK: [[VAR_5_:%.+]] = tt.load [[VAR_arg7_]] : !tt.ptr> // CHECK: [[VAR_6_:%.+]] = arith.addf [[VAR_arg6_]], [[VAR_5_]] : tensor<4x256xbf16> -// CHECK: [[VAR_7_:%.+]] = tt.advance [[VAR_arg7_]], {{\[}}[[CST_0_i32]], [[CST_3_i32]]{{\]}} : > +// CHECK: [[VAR_7_:%.+]] = tt.advance [[VAR_arg7_]], {{\[}}[[CST_3_i32]], [[CST_0_i32]]{{\]}} : > // CHECK: scf.yield [[VAR_6_]], [[VAR_7_]] : tensor<4x256xbf16>, !tt.ptr> // CHECK: } // CHECK: [[VAR_5_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_5_i64]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {order = array} : > diff --git a/test/Triton/Intel/RaiseToBlockPointers/wraparound_side_by_side.mlir b/test/Triton/Intel/RaiseToBlockPointers/wraparound_side_by_side.mlir index 43d5a3964e..d08a0981e2 100644 --- a/test/Triton/Intel/RaiseToBlockPointers/wraparound_side_by_side.mlir +++ b/test/Triton/Intel/RaiseToBlockPointers/wraparound_side_by_side.mlir @@ -79,16 +79,18 @@ module { // CHECK-DAG: [[VAR_14_:%.+]] = arith.extsi [[PARAM_6_]] : i32 to i64 // CHECK-DAG: [[VAR_15_:%.+]] = arith.extsi [[PARAM_7_]] : i32 to i64 // CHECK: [[VAR_16_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_14_]], [[VAR_15_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : > -// CHECK: [[VAR_21_:%.+]] = arith.cmpi slt, [[VAR_13_]], {{.*}} : tensor<4x1xi32> -// CHECK: [[VAR_22_:%.+]] = tt.broadcast [[VAR_21_]] : tensor<4x1xi1> -> tensor<4x4xi1> -// CHECK-DAG: [[VAR_23_:%.+]] = arith.muli [[PARAM_4_]], [[CST_4_i32]] : i32 -// CHECK-DAG: [[VAR_24_:%.+]] = arith.muli [[PARAM_5_]], [[CST_4_i32]] : i32 -// CHECK-DAG: [[VAR_25_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = {{.*}} iter_args([[VAR_arg9_:%.+]] = [[VAR_12_]], [[VAR_arg10_:%.+]] = [[VAR_16_]]) -> (!tt.ptr>, !tt.ptr>) : i32 { -// CHECK: [[VAR_26_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_22_]], [[CST_]] : !tt.ptr> -// CHECK: tt.store [[VAR_arg10_]], [[VAR_26_]] : !tt.ptr> -// CHECK-DAG: [[VAR_27_:%.+]] = tt.advance [[VAR_arg9_]], {{\[}}[[CST_0_i32]], [[VAR_23_]]] : > -// CHECK-DAG: [[VAR_28_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[CST_0_i32]], [[VAR_24_]]] : > -// CHECK: scf.yield [[VAR_27_]], [[VAR_28_]] : !tt.ptr>, !tt.ptr> +// CHECK: [[VAR_17_:%.+]] = arith.cmpi slt, [[VAR_13_]], {{.*}} : tensor<4x1xi32> +// CHECK: [[VAR_18_:%.+]] = tt.broadcast [[VAR_17_]] : tensor<4x1xi1> -> tensor<4x4xi1> +// CHECK-DAG: [[VAR_19_:%.+]] = arith.muli [[PARAM_4_]], [[CST_4_i32]] : i32 +// CHECK-DAG: [[VAR_20_:%.+]] = arith.muli [[PARAM_5_]], [[CST_4_i32]] : i32 +// CHECK-DAG: [[VAR_21_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = {{.*}} iter_args([[VAR_arg9_:%.+]] = [[VAR_12_]], [[VAR_arg10_:%.+]] = [[VAR_16_]]) -> (!tt.ptr>, !tt.ptr>) : i32 { +// CHECK: [[VAR_22_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_18_]], [[CST_]] : !tt.ptr> +// CHECK: tt.store [[VAR_arg10_]], [[VAR_22_]] : !tt.ptr> +// CHECK: [[VAR_23_:%.+]] = arith.divui [[VAR_19_]], [[PARAM_4_]] : i32 +// CHECK: [[VAR_24_:%.+]] = tt.advance [[VAR_arg9_]], {{\[}}[[VAR_23_]], [[CST_0_i32]]] : > +// CHECK: [[VAR_25_:%.+]] = arith.divui [[VAR_20_]], [[PARAM_6_]] : i32 +// CHECK: [[VAR_26_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[VAR_25_]], [[CST_0_i32]]] : > +// CHECK: scf.yield [[VAR_24_]], [[VAR_26_]] : !tt.ptr>, !tt.ptr> // CHECK: } // CHECK: tt.return // CHECK: } diff --git a/test/Triton/Intel/RaiseToBlockPointers/wraparound_stacked.mlir b/test/Triton/Intel/RaiseToBlockPointers/wraparound_stacked.mlir index 5c210f3ac6..b8a5374fd7 100644 --- a/test/Triton/Intel/RaiseToBlockPointers/wraparound_stacked.mlir +++ b/test/Triton/Intel/RaiseToBlockPointers/wraparound_stacked.mlir @@ -83,9 +83,11 @@ module { // CHECK: [[VAR_20_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = {{.*}} iter_args([[VAR_arg9_:%.+]] = [[VAR_12_]], [[VAR_arg10_:%.+]] = [[VAR_16_]]) -> (!tt.ptr>, !tt.ptr>) : i32 { // CHECK: [[VAR_21_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_18_]], [[CST_]] : !tt.ptr> // CHECK: tt.store [[VAR_arg10_]], [[VAR_21_]] : !tt.ptr> -// CHECK-DAG: [[VAR_22_:%.+]] = tt.advance [[VAR_arg9_]], {{\[}}[[CST_0_i32]], [[VAR_19_]]] : > -// CHECK-DAG: [[VAR_23_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[CST_0_i32]], [[VAR_19_]]] : > -// CHECK: scf.yield [[VAR_22_]], [[VAR_23_]] : !tt.ptr>, !tt.ptr> +// CHECK: [[VAR_22_:%.+]] = arith.divui [[VAR_19_]], [[PARAM_4_]] : i32 +// CHECK: [[VAR_23_:%.+]] = tt.advance [[VAR_arg9_]], {{\[}}[[VAR_22_]], [[CST_0_i32]]] : > +// CHECK: [[VAR_24_:%.+]] = arith.divui [[VAR_19_]], [[PARAM_6_]] : i32 +// CHECK-DAG: [[VAR_25_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[VAR_24_]], [[CST_0_i32]]] : > +// CHECK: scf.yield [[VAR_23_]], [[VAR_25_]] : !tt.ptr>, !tt.ptr> // CHECK: } // CHECK: tt.return // CHECK: } diff --git a/test/Triton/Intel/RaiseToBlockPointers/wraparound_unsupported_add_offset.mlir b/test/Triton/Intel/RaiseToBlockPointers/wraparound_unsupported_add_offset.mlir index 805bbb9ccb..9ce098f09a 100644 --- a/test/Triton/Intel/RaiseToBlockPointers/wraparound_unsupported_add_offset.mlir +++ b/test/Triton/Intel/RaiseToBlockPointers/wraparound_unsupported_add_offset.mlir @@ -91,20 +91,21 @@ module { // CHECK-DAG: [[VAR_16_:%.+]] = tt.expand_dims [[VAR_0_]] {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32> // CHECK-DAG: [[VAR_17_:%.+]] = arith.extsi [[arg6_]] : i32 to i64 // CHECK-DAG: [[VAR_18_:%.+]] = arith.extsi [[arg7_]] : i32 to i64 -// CHECK: [[VAR_25_:%.+]] = tt.make_tensor_ptr [[arg1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_17_]], [[VAR_18_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : > -// CHECK: [[VAR_26_:%.+]] = arith.cmpi slt, [[VAR_16_]], [[VAR_cst_0_]] : tensor<4x1xi32> -// CHECK-DAG: [[VAR_27_:%.+]] = tt.broadcast [[VAR_26_]] : tensor<4x1xi1> -> tensor<4x4xi1> -// CHECK-DAG: [[VAR_28_:%.+]] = arith.muli [[arg4_]], [[CST_4_i32]] : i32 +// CHECK: [[VAR_19_:%.+]] = tt.make_tensor_ptr [[arg1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_17_]], [[VAR_18_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {{.*}} : > +// CHECK: [[VAR_20_:%.+]] = arith.cmpi slt, [[VAR_16_]], [[VAR_cst_0_]] : tensor<4x1xi32> +// CHECK-DAG: [[VAR_21_:%.+]] = tt.broadcast [[VAR_20_]] : tensor<4x1xi1> -> tensor<4x4xi1> +// CHECK-DAG: [[VAR_22_:%.+]] = arith.muli [[arg4_]], [[CST_4_i32]] : i32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_29_:%.+]] = tt.splat [[VAR_28_]] : i32 -> tensor<4x4xi32> -// CHECK-DAG: [[VAR_30_:%.+]] = arith.muli [[arg5_]], [[CST_4_i32]] : i32 +// CHECK-DAG: [[VAR_23_:%.+]] = tt.splat [[VAR_22_]] : i32 -> tensor<4x4xi32> +// CHECK-DAG: [[VAR_24_:%.+]] = arith.muli [[arg5_]], [[CST_4_i32]] : i32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_31_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = [[CST_0_i32]] to [[CST_2_i32]] step [[CST_1_i32]] iter_args([[VAR_arg9_:%.+]] = [[VAR_15_]], [[VAR_arg10_:%.+]] = [[VAR_25_]]) -> (tensor<4x4x!tt.ptr>, !tt.ptr>) -// CHECK: [[VAR_32_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_27_]], [[VAR_cst_]] : tensor<4x4x!tt.ptr> -// CHECK: tt.store [[VAR_arg10_]], [[VAR_32_]] : !tt.ptr> -// CHECK-DAG: [[VAR_33_:%.+]] = tt.addptr [[VAR_arg9_]], [[VAR_29_]] : tensor<4x4x!tt.ptr>, tensor<4x4xi32> -// CHECK-DAG: [[VAR_34_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[CST_0_i32]], [[VAR_30_]]] : > -// CHECK: scf.yield [[VAR_33_]], [[VAR_34_]] : tensor<4x4x!tt.ptr>, !tt.ptr> +// CHECK-DAG: [[VAR_25_:%.+]]:2 = scf.for [[VAR_arg8_:%.+]] = [[CST_0_i32]] to [[CST_2_i32]] step [[CST_1_i32]] iter_args([[VAR_arg9_:%.+]] = [[VAR_15_]], [[VAR_arg10_:%.+]] = [[VAR_19_]]) -> (tensor<4x4x!tt.ptr>, !tt.ptr>) +// CHECK: [[VAR_26_:%.+]] = tt.load [[VAR_arg9_]], [[VAR_21_]], [[VAR_cst_]] : tensor<4x4x!tt.ptr> +// CHECK: tt.store [[VAR_arg10_]], [[VAR_26_]] : !tt.ptr> +// CHECK-DAG: [[VAR_27_:%.+]] = tt.addptr [[VAR_arg9_]], [[VAR_23_]] : tensor<4x4x!tt.ptr>, tensor<4x4xi32> +// CHECK-DAG: [[VAR_28_:%.+]] = arith.divui [[VAR_24_]], [[arg6_]] : i32 +// CHECK-DAG: [[VAR_29_:%.+]] = tt.advance [[VAR_arg10_]], {{\[}}[[VAR_28_]], [[CST_0_i32]]] : > +// CHECK: scf.yield [[VAR_27_]], [[VAR_29_]] : tensor<4x4x!tt.ptr>, !tt.ptr> // CHECK: } // CHECK: tt.return // CHECK: } diff --git a/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp b/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp index fcf5e14c84..ea618c6659 100644 --- a/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp +++ b/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp @@ -361,20 +361,7 @@ struct PtrState { for (const auto &[offset, stride, dim] : llvm::zip(offsets, strides, shape)) { - if (ttgi::isConstant(stride, 0)) { - newOffsets.push_back(findOrCreateCast( - loc, offset, builder.getIntegerType(offsetBitwidth), builder)); - } else { - Value divOffset = builder.create( - loc, builder.getIntegerType(offsetBitwidth), - findOrCreateCast(loc, offset, - builder.getIntegerType(offsetBitwidth), builder), - findOrCreateCast(loc, stride, - builder.getIntegerType(offsetBitwidth), builder)); - newOffsets.push_back( - findOrCreateCast(loc, getFinalValue(divOffset), - builder.getIntegerType(offsetBitwidth), builder)); - } + newOffsets.push_back(computeOffset(offset, stride, builder, loc)); newStrides.push_back(findOrCreateCast( loc, stride, builder.getIntegerType(shapeAndStridesBitwidth), builder)); @@ -385,6 +372,34 @@ struct PtrState { return findOrCreateMakeTensorPtr(loc, source, newShape, newStrides, newOffsets, order, sizes, builder); } + + Value createTTAdvanceOp(Value ptr, tt::MakeTensorPtrOp makeTPtrOp, + OpBuilder &builder, Location loc) const { + SmallVector newOffsets; + for (const auto &[offset, stride] : + llvm::zip(offsets, makeTPtrOp.getStrides())) + newOffsets.push_back(computeOffset(offset, stride, builder, loc)); + + return builder.createOrFold(loc, ptr.getType(), ptr, + newOffsets); + } + +private: + Value computeOffset(Value offset, Value stride, OpBuilder &builder, + Location loc) const { + if (ttgi::isConstant(stride, 0)) + return findOrCreateCast(loc, offset, + builder.getIntegerType(offsetBitwidth), builder); + + Value divOffset = builder.create( + loc, builder.getIntegerType(offsetBitwidth), + findOrCreateCast(loc, offset, builder.getIntegerType(offsetBitwidth), + builder), + findOrCreateCast(loc, stride, builder.getIntegerType(offsetBitwidth), + builder)); + return findOrCreateCast(loc, getFinalValue(divOffset), + builder.getIntegerType(offsetBitwidth), builder); + } }; #ifndef NDEBUG @@ -698,36 +713,17 @@ struct TritonRaiseBlockPointer return val; }; - // If the ptr has already been mapped (i.e. rewritten into a block pointer), - // rewrite the AddPtrOp using and AdvanceOp. + // If the ptr has already been mapped (i.e. rewritten into a block + // pointer), rewrite the AddPtrOp using and AdvanceOp. if (Value mappedV = ptrMap.lookupOrNull(ptr)) { if (auto makeTPtrOp = mappedV.getDefiningOp()) { - auto offsetType = cast(op.getOffset().getType()); - unsigned rank = offsetType.getRank(); - - SmallVector offsets; - TypeSwitch(op.getOffset().getDefiningOp()) - .Case([&](tt::SplatOp splatOp) { - fillOffsets(splatOp.getSrc(), rank, offsets); - }) - .Case([&](arith::ConstantOp cstOp) { - APInt val = getConstantValue(cstOp); - - fillOffsets(findOrCreateConstant(loc, val.getZExtValue(), - offsetBitwidth, builder), - rank, offsets); - }) - .Default([](Operation *op) { - llvm::errs() << "Operation: " << *op << "\n"; - llvm_unreachable("Unhandled operation"); - }); - - assert(!offsets.empty() && offsets.size() == rank && - "unexpected number of offsets"); + PtrState state; + if (failed(visitOperand(op.getOffset(), state, loc, builder))) + return failure(); Value basePtr = tt::isTensorPointerType(ptr.getType()) ? ptr : mappedV; - auto advanceOp = builder.createOrFold( - loc, basePtr.getType(), basePtr, offsets); + auto advanceOp = + state.createTTAdvanceOp(basePtr, makeTPtrOp, builder, loc); cleanUp.insert(op); ptrMap.map(op.getResult(), advanceOp);