@@ -123,4 +123,36 @@ module {
123123// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1]>>
124124// CHECK: affine.store %[[VAL_12]], %[[VAL_13]][0] : memref<1xi32, strided<[1]>>
125125// CHECK: return
126- // CHECK: }
126+ // CHECK: }
127+
128+ // -----
129+
130+ module {
131+ tt.func public @nan_aware_max (%arg0: tensor <1024 xf32 >, %arg_out: !tt.ptr <f32 >) {
132+ %res = " tt.reduce" (%arg0 ) <{axis = 0 : i32 }> ({
133+ ^bb0 (%lhs: f32 , %rhs: f32 ):
134+ %cmp_gt = arith.cmpf ogt , %lhs , %rhs : f32
135+ %lhs_nan = arith.cmpf une , %lhs , %lhs : f32
136+ %pred = arith.ori %cmp_gt , %lhs_nan : i1
137+ %sel = arith.select %pred , %lhs , %rhs : f32
138+ tt.reduce.return %sel : f32
139+ }) : (tensor <1024 xf32 >) -> f32
140+ tt.store %arg_out , %res : !tt.ptr <f32 >
141+ tt.return
142+ }
143+ }
144+
145+ // CHECK-LABEL: func.func @nan_aware_max
146+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1024xf32>, [[PARAM_1_:%.+]]: !tt.ptr<f32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) {
147+ // CHECK-DAG: [[CST_nan_:%.+]] = arith.constant 0xFF800000 : f32
148+ // CHECK-DAG: [[VAR_0_:%.+]] = bufferization.alloc_tensor() : tensor<f32>
149+ // CHECK: [[VAR_inserted_:%.+]] = tensor.insert [[CST_nan_]] into [[VAR_0_]][] : tensor<f32>
150+ // CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[PARAM_0_]] : tensor<1024xf32>) outs([[VAR_inserted_]] : tensor<f32>) dimensions = [0]
151+ // CHECK: ([[in_:%.+]]: f32, [[in_]]it: f32) {
152+ // CHECK: [[CMP_gt_:%.+]] = arith.maximumf [[in_]], [[in_]]it : f32
153+ // CHECK: linalg.yield [[CMP_gt_]] : f32
154+ // CHECK: }
155+ // CHECK: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_reduced_]][] : tensor<f32>
156+ // CHECK: tt.store [[PARAM_1_]], [[VAR_extracted_]] : !tt.ptr<f32>
157+ // CHECK: return
158+ // CHECK: }
0 commit comments