Skip to content

Commit f1efef1

Browse files
committed
Add LIT test
1 parent 713c836 commit f1efef1

File tree

1 file changed

+33
-1
lines changed

1 file changed

+33
-1
lines changed

test/Conversion/TritonToLinalg/convert_minmax_reduce.mlir

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,4 +123,36 @@ module {
123123
// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1]>>
124124
// CHECK: affine.store %[[VAL_12]], %[[VAL_13]][0] : memref<1xi32, strided<[1]>>
125125
// CHECK: return
126-
// CHECK: }
126+
// CHECK: }
127+
128+
// -----
129+
130+
module {
131+
tt.func public @nan_aware_max(%arg0: tensor<1024xf32>, %arg_out: !tt.ptr<f32>) {
132+
%res = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({
133+
^bb0(%lhs: f32, %rhs: f32):
134+
%cmp_gt = arith.cmpf ogt, %lhs, %rhs : f32
135+
%lhs_nan = arith.cmpf une, %lhs, %lhs : f32
136+
%pred = arith.ori %cmp_gt, %lhs_nan : i1
137+
%sel = arith.select %pred, %lhs, %rhs : f32
138+
tt.reduce.return %sel : f32
139+
}) : (tensor<1024xf32>) -> f32
140+
tt.store %arg_out, %res : !tt.ptr<f32>
141+
tt.return
142+
}
143+
}
144+
145+
// CHECK-LABEL: func.func @nan_aware_max
146+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1024xf32>, [[PARAM_1_:%.+]]: !tt.ptr<f32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) {
147+
// CHECK-DAG: [[CST_nan_:%.+]] = arith.constant 0xFF800000 : f32
148+
// CHECK-DAG: [[VAR_0_:%.+]] = bufferization.alloc_tensor() : tensor<f32>
149+
// CHECK: [[VAR_inserted_:%.+]] = tensor.insert [[CST_nan_]] into [[VAR_0_]][] : tensor<f32>
150+
// CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[PARAM_0_]] : tensor<1024xf32>) outs([[VAR_inserted_]] : tensor<f32>) dimensions = [0]
151+
// CHECK: ([[in_:%.+]]: f32, [[in_]]it: f32) {
152+
// CHECK: [[CMP_gt_:%.+]] = arith.maximumf [[in_]], [[in_]]it : f32
153+
// CHECK: linalg.yield [[CMP_gt_]] : f32
154+
// CHECK: }
155+
// CHECK: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_reduced_]][] : tensor<f32>
156+
// CHECK: tt.store [[PARAM_1_]], [[VAR_extracted_]] : !tt.ptr<f32>
157+
// CHECK: return
158+
// CHECK: }

0 commit comments

Comments
 (0)