diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp index 6f79665c2bb60..6e12d3604a262 100644 --- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp @@ -65,6 +65,8 @@ static Value getSupportedReduction(AffineForOp forOp, unsigned pos, [](arith::MinimumFOp) { return arith::AtomicRMWKind::minimumf; }) .Case( [](arith::MaximumFOp) { return arith::AtomicRMWKind::maximumf; }) + .Case([](arith::MinNumFOp) { return arith::AtomicRMWKind::minnumf; }) + .Case([](arith::MaxNumFOp) { return arith::AtomicRMWKind::maxnumf; }) .Case([](arith::MinSIOp) { return arith::AtomicRMWKind::mins; }) .Case([](arith::MaxSIOp) { return arith::AtomicRMWKind::maxs; }) .Case([](arith::MinUIOp) { return arith::AtomicRMWKind::minu; }) diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 65f85444e70db..0c9dc8ad4e6bf 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -3949,8 +3949,9 @@ static bool isResultTypeMatchAtomicRMWKind(Type resultType, case arith::AtomicRMWKind::muli: return isa(resultType); case arith::AtomicRMWKind::maximumf: - return isa(resultType); case arith::AtomicRMWKind::minimumf: + case arith::AtomicRMWKind::maxnumf: + case arith::AtomicRMWKind::minnumf: return isa(resultType); case arith::AtomicRMWKind::maxs: { auto intType = llvm::dyn_cast(resultType); @@ -3972,9 +3973,8 @@ static bool isResultTypeMatchAtomicRMWKind(Type resultType, return isa(resultType); case arith::AtomicRMWKind::andi: return isa(resultType); - default: - return false; } + llvm_unreachable("exhaustive switch"); } LogicalResult AffineParallelOp::verify() { diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index f9c7fb7799eb0..74f6ae94c5cf4 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -690,12 +690,18 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op, case arith::AtomicRMWKind::ori: return builder.create(vector.getLoc(), CombiningKind::OR, vector); - // TODO: Add remaining reduction operations. - default: - (void)emitOptionalError(loc, "Reduction operation type not supported"); - break; + case arith::AtomicRMWKind::maxnumf: + return builder.create(vector.getLoc(), + CombiningKind::MAXNUMF, vector); + case arith::AtomicRMWKind::minnumf: + return builder.create(vector.getLoc(), + CombiningKind::MINNUMF, vector); + case arith::AtomicRMWKind::assign: + (void)emitOptionalError(loc, + "Reduction operation type not supported (assign)"); + return nullptr; } - return nullptr; + llvm_unreachable("exhaustive switch"); } std::optional> ReductionOp::getShapeForUnroll() { diff --git a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir index b616632a6fe24..00323d2853997 100644 --- a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir +++ b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir @@ -83,6 +83,60 @@ func.func @vecdim_reduction_maxf(%in: memref<256x512xf32>, %out: memref<256xf32> // ----- +func.func @vecdim_reduction_minnumf(%in: memref<256x512xf32>, %out: memref<256xf32>) { + %cst = arith.constant 0x7FC00000 : f32 + affine.for %i = 0 to 256 { + %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) { + %ld = affine.load %in[%i, %j] : memref<256x512xf32> + %min = arith.minnumf %red_iter, %ld : f32 + affine.yield %min : f32 + } + affine.store %final_red, %out[%i] : memref<256xf32> + } + return +} + +// CHECK-LABEL: @vecdim_reduction_minnumf +// CHECK: affine.for %{{.*}} = 0 to 256 { +// CHECK: %[[vmax:.*]] = arith.constant dense<0x7FC00000> : vector<128xf32> +// CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vmax]]) -> (vector<128xf32>) { +// CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xf32>, vector<128xf32> +// CHECK: %[[min:.*]] = arith.minnumf %[[red_iter]], %[[ld]] : vector<128xf32> +// CHECK: affine.yield %[[min]] : vector<128xf32> +// CHECK: } +// CHECK: %[[final_min:.*]] = vector.reduction , %[[vred:.*]] : vector<128xf32> into f32 +// CHECK: affine.store %[[final_min]], %{{.*}} : memref<256xf32> +// CHECK: } + +// ----- + +func.func @vecdim_reduction_maxnumf(%in: memref<256x512xf32>, %out: memref<256xf32>) { + %cst = arith.constant 0xFFC00000 : f32 + affine.for %i = 0 to 256 { + %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) { + %ld = affine.load %in[%i, %j] : memref<256x512xf32> + %max = arith.maxnumf %red_iter, %ld : f32 + affine.yield %max : f32 + } + affine.store %final_red, %out[%i] : memref<256xf32> + } + return +} + +// CHECK-LABEL: @vecdim_reduction_maxnumf +// CHECK: affine.for %{{.*}} = 0 to 256 { +// CHECK: %[[vmin:.*]] = arith.constant dense<0xFFC00000> : vector<128xf32> +// CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vmin]]) -> (vector<128xf32>) { +// CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xf32>, vector<128xf32> +// CHECK: %[[max:.*]] = arith.maxnumf %[[red_iter]], %[[ld]] : vector<128xf32> +// CHECK: affine.yield %[[max]] : vector<128xf32> +// CHECK: } +// CHECK: %[[final_max:.*]] = vector.reduction , %[[vred:.*]] : vector<128xf32> into f32 +// CHECK: affine.store %[[final_max]], %{{.*}} : memref<256xf32> +// CHECK: } + +// ----- + func.func @vecdim_reduction_minsi(%in: memref<256x512xi32>, %out: memref<256xi32>) { %cst = arith.constant 2147483647 : i32 affine.for %i = 0 to 256 {