diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp index 8c33148d1d2d7..fbaae515787d3 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -39,6 +40,17 @@ struct ExpandDivF16 : public OpRewritePattern { PatternRewriter &rewriter) const override; }; +// Replaces sitofp or uitofp on src types no wider than the dst type mantissa +// with a faster combination of bit ops and add/sub. +template // OpTy should be LLVM::SIToFPOp or LLVM::UIToFPOp. +struct ExpandIToFP : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + +private: + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override; +}; + struct NVVMOptimizeForTarget : public NVVM::impl::NVVMOptimizeForTargetBase { void runOnOperation() override; @@ -92,10 +104,95 @@ LogicalResult ExpandDivF16::matchAndRewrite(LLVM::FDivOp op, return success(); } +template +LogicalResult +ExpandIToFP::matchAndRewrite(OpTy op, PatternRewriter &rewriter) const { + Type srcType = op.getOperand().getType(); + auto intType = dyn_cast(getElementTypeOrSelf(srcType)); + if (!intType) + return rewriter.notifyMatchFailure(op, "src type is not integer"); + Type dstType = op.getType(); + auto floatType = dyn_cast(getElementTypeOrSelf(dstType)); + if (!floatType) + return rewriter.notifyMatchFailure(op, "dst type is not float"); + + // Mantissa width includes the integer bit, e.g. 24 for fp32. + auto mantissaWidth = floatType.getFPMantissaWidth(); + if (mantissaWidth < 2) + return rewriter.notifyMatchFailure(op, "mantissa is less than 2 bits"); + auto intWidth = intType.getWidth(); + if (intWidth > mantissaWidth) + return rewriter.notifyMatchFailure(op, "src is wider than dst mantissa"); + + Type extType = IntegerType::get(rewriter.getContext(), floatType.getWidth(), + intType.getSignedness()); + if (ShapedType shapedType = dyn_cast(srcType)) + extType = shapedType.clone(extType); + auto getAttr = [&](APInt value) -> TypedAttr { + if (ShapedType shapedType = dyn_cast(extType)) + return DenseElementsAttr::get(shapedType, value); + return IntegerAttr::get(extType, value); + }; + ImplicitLocOpBuilder builder(op.getLoc(), rewriter); + + if (intWidth == mantissaWidth) { + if (std::is_same_v) { + return rewriter.notifyMatchFailure( + op, "unsigned src is as wide as dst mantissa"); + } + // Create a float bit-pattern with zero biased-exponent and zero mantissa. + APFloat::integerPart intPart = 1ull << (mantissaWidth - 1); + APFloat floatBits(floatType.getFloatSemantics(), intPart); + if (floatBits.bitcastToAPInt()[mantissaWidth - 1]) + return rewriter.notifyMatchFailure(op, "bias exponent lsb bit is set"); + TypedAttr intAttr = getAttr(floatBits.bitcastToAPInt()); + + // Combine zero-extended src and float bit-pattern. The msb of src becomes + // the lsb of the exponent. + Value zext = builder.create(extType, op.getOperand()); + Value intConst = builder.create(intAttr); + Value pattern = builder.create(zext, intConst); + + // Mask the exponent-lsb and the mantissa to get two separate values. + auto mask = APInt::getBitsSetFrom(floatType.getWidth(), mantissaWidth - 1); + Value exponentMask = builder.create(getAttr(mask)); + Value mantissaMask = builder.create(getAttr(mask - 1)); + Value exponentAnd = builder.create(pattern, exponentMask); + Value mantissaAnd = builder.create(pattern, mantissaMask); + + // Bitcast these values to float and subtract or add them. + Value exponentCast = builder.create(dstType, exponentAnd); + Value mantissaCast = builder.create(dstType, mantissaAnd); + rewriter.replaceOpWithNewOp(op, mantissaCast, exponentCast); + return success(); + } + + // Create a float with zero biased-exponent and msb-set mantissa. + APFloat::integerPart intPart = 3ull << (mantissaWidth - 2); + APFloat floatBits(floatType.getFloatSemantics(), intPart); + TypedAttr intAttr = getAttr(floatBits.bitcastToAPInt()); + TypedAttr floatAttr = FloatAttr::get(floatType, floatBits); + if (ShapedType shapedType = dyn_cast(dstType)) + floatAttr = DenseElementsAttr::get(shapedType, floatAttr); + + // Add extended src and bit-pattern of float, then subtract float. + using ExtOp = std::conditional_t, + LLVM::SExtOp, LLVM::ZExtOp>; + Value ext = builder.create(extType, op.getOperand()); + Value intConst = builder.create(intAttr); + Value add = builder.create(ext, intConst); + Value bitcast = builder.create(dstType, add); + Value floatConst = builder.create(floatAttr); + rewriter.replaceOpWithNewOp(op, bitcast, floatConst); + return success(); +} + void NVVMOptimizeForTarget::runOnOperation() { MLIRContext *ctx = getOperation()->getContext(); RewritePatternSet patterns(ctx); - patterns.add(ctx); + patterns.add, + ExpandIToFP>(ctx); + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir b/mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir index b98d2e08b7548..a77d98a1b71a9 100644 --- a/mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir @@ -22,3 +22,181 @@ llvm.func @fdiv_fp16(%arg0 : f16, %arg1 : f16) -> f16 { // CHECK: llvm.return %[[result]] : f16 llvm.return %result : f16 } + +// CHECK-LABEL: llvm.func @ui16_to_f32 +llvm.func @ui16_to_f32(%arg0 : i16) -> f32 { + // CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : i16 to i32 + // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(1262485504 : i32) : i32 + // CHECK-DAG: %[[add:.*]] = llvm.add %[[zext]], %[[bias]] : i32 + // CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : i32 to f32 + // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(0x4B400000 : f32) : f32 + // CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]] : f32 + %result = llvm.uitofp %arg0 : i16 to f32 + // CHECK: llvm.return %[[result]] : f32 + llvm.return %result : f32 +} + +// Checks that expansion only applies to integer width up to mantissa width. +// CHECK-LABEL: llvm.func @si32_to_float +llvm.func @si32_to_float_no_rewrite(%arg0 : i32) -> f32 { + // CHECK: %[[result:.*]] = llvm.sitofp %arg0 : i32 to f32 + %result = llvm.sitofp %arg0 : i32 to f32 + // CHECK: llvm.return %[[result]] : f32 + llvm.return %result : f32 +} + +// CHECK-LABEL: llvm.func @si8_to_f16 +llvm.func @si8_to_f16(%arg0 : i8) -> f16 { + // CHECK-DAG: %[[sext:.*]] = llvm.sext %arg0 : i8 to i16 + // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(26112 : i16) : i16 + // CHECK-DAG: %[[add:.*]] = llvm.add %[[sext]], %[[bias]] : i16 + // CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : i16 to f16 + // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(1.536000e+03 : f16) : f16 + // CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]] : f16 + %result = llvm.sitofp %arg0 : i8 to f16 + // CHECK: llvm.return %[[result]] : f16 + llvm.return %result : f16 +} + +// CHECK-LABEL: llvm.func @vec_ui4_to_bf16 +llvm.func @vec_ui4_to_bf16(%arg0 : vector<4xi4>) -> vector<4xbf16> { + // CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : vector<4xi4> to vector<4xi16> + // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(dense<17216> : vector<4xi16>) : vector<4xi16> + // CHECK-DAG: %[[add:.*]] = llvm.add %[[zext]], %[[bias]] : vector<4xi16> + // CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : vector<4xi16> to vector<4xbf16> + // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(dense<1.920000e+02> : vector<4xbf16>) : vector<4xbf16> + // CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]] : vector<4xbf16> + %result = llvm.uitofp %arg0 : vector<4xi4> to vector<4xbf16> + // CHECK: llvm.return %[[result]] : vector<4xbf16> + llvm.return %result : vector<4xbf16> +} + +// Checks code path when integer width is equal to mantissa width. +// CHECK-LABEL: llvm.func @vec_si8_to_bf16 +llvm.func @vec_si8_to_bf16(%arg0 : vector<4xi8>) -> vector<4xbf16> { + // CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : vector<4xi8> to vector<4xi16> + // CHECK-DAG: %[[const:.*]] = llvm.mlir.constant(dense<17152> : vector<4xi16>) : vector<4xi16> + // CHECK-DAG: %[[or:.*]] = llvm.or %[[zext]], %[[const]] : vector<4xi16> + // CHECK-DAG: %[[exp_mask:.*]] = llvm.mlir.constant(dense<-128> : vector<4xi16>) : vector<4xi16> + // CHECK-DAG: %[[man_mask:.*]] = llvm.mlir.constant(dense<-129> : vector<4xi16>) : vector<4xi16> + // CHECK-DAG: %[[exp_and:.*]] = llvm.and %[[or]], %[[exp_mask]] : vector<4xi16> + // CHECK-DAG: %[[man_and:.*]] = llvm.and %[[or]], %[[man_mask]] : vector<4xi16> + // CHECK-DAG: %[[exp_cast:.*]] = llvm.bitcast %[[exp_and]] : vector<4xi16> to vector<4xbf16> + // CHECK-DAG: %[[man_cast:.*]] = llvm.bitcast %[[man_and]] : vector<4xi16> to vector<4xbf16> + // CHECK-DAG: %[[result:.*]] = llvm.fsub %[[man_cast]], %[[exp_cast]] : vector<4xbf16> + %result = llvm.sitofp %arg0 : vector<4xi8> to vector<4xbf16> + // CHECK: llvm.return %[[result]] : vector<4xbf16> + llvm.return %result : vector<4xbf16> +} + +// Checks code path when integer width is equal to mantissa width. +// CHECK-LABEL: llvm.func @ui8_to_bf16 +llvm.func @ui8_to_bf16(%arg0 : i8) -> bf16 { + // CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : i8 to i16 + // CHECK-DAG: %[[const:.*]] = llvm.mlir.constant(17152 : i16) : i16 + // CHECK-DAG: %[[or:.*]] = llvm.or %[[zext]], %[[const]] : i16 + // CHECK-DAG: %[[exp_mask:.*]] = llvm.mlir.constant(-128 : i16) : i16 + // CHECK-DAG: %[[man_mask:.*]] = llvm.mlir.constant(-129 : i16) : i16 + // CHECK-DAG: %[[exp_and:.*]] = llvm.and %[[or]], %[[exp_mask]] : i16 + // CHECK-DAG: %[[man_and:.*]] = llvm.and %[[or]], %[[man_mask]] : i16 + // CHECK-DAG: %[[exp_cast:.*]] = llvm.bitcast %[[exp_and]] : i16 to bf16 + // CHECK-DAG: %[[man_cast:.*]] = llvm.bitcast %[[man_and]] : i16 to bf16 + // CHECK-DAG: %[[result:.*]] = llvm.fadd %[[man_cast]], %[[exp_cast]] : bf16 + %result = llvm.uitofp %arg0 : i8 to bf16 + // CHECK: llvm.return %[[result]] : bf16 + llvm.return %result : bf16 +} + +// Checks that expansion does not apply when exponent bias lsb is set. +// CHECK-LABEL: llvm.func @ui11_to_f16 +llvm.func @ui11_to_f16(%arg0 : i11) -> f16 { + // CHECK: %[[result:.*]] = llvm.uitofp %arg0 : i11 to f16 + %result = llvm.uitofp %arg0 : i11 to f16 + // CHECK: llvm.return %[[result]] : f16 + llvm.return %result : f16 +} + +// CHECK-LABEL: llvm.func @ui16_to_f32 +llvm.func @ui16_to_f32(%arg0 : i16) -> f32 { + // CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : i16 to i32 + // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(1262485504 : i32) : i32 + // CHECK-DAG: %[[add:.*]] = llvm.add %[[zext]], %[[bias]] : i32 + // CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : i32 to f32 + // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(0x4B400000 : f32) : f32 + // CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]] : f32 + %result = llvm.uitofp %arg0 : i16 to f32 + // CHECK: llvm.return %[[result]] : f32 + llvm.return %result : f32 +} + +// Checks that expansion only applies to integer width up to mantissa width. +// CHECK-LABEL: llvm.func @si32_to_float +llvm.func @si32_to_float_no_rewrite(%arg0 : i32) -> f32 { + // CHECK: %[[result:.*]] = llvm.sitofp %arg0 : i32 to f32 + %result = llvm.sitofp %arg0 : i32 to f32 + // CHECK: llvm.return %[[result]] : f32 + llvm.return %result : f32 +} + +// CHECK-LABEL: llvm.func @si8_to_f16 +llvm.func @si8_to_f16(%arg0 : i8) -> f16 { + // CHECK-DAG: %[[sext:.*]] = llvm.sext %arg0 : i8 to i16 + // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(26112 : i16) : i16 + // CHECK-DAG: %[[add:.*]] = llvm.add %[[sext]], %[[bias]] : i16 + // CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : i16 to f16 + // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(1.536000e+03 : f16) : f16 + // CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]] : f16 + %result = llvm.sitofp %arg0 : i8 to f16 + // CHECK: llvm.return %[[result]] : f16 + llvm.return %result : f16 +} + +// CHECK-LABEL: llvm.func @vec_ui4_to_bf16 +llvm.func @vec_ui4_to_bf16(%arg0 : vector<4xi4>) -> vector<4xbf16> { + // CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : vector<4xi4> to vector<4xi16> + // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(dense<17216> : vector<4xi16>) : vector<4xi16> + // CHECK-DAG: %[[add:.*]] = llvm.add %[[zext]], %[[bias]] : vector<4xi16> + // CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : vector<4xi16> to vector<4xbf16> + // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(dense<1.920000e+02> : vector<4xbf16>) : vector<4xbf16> + // CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]] : vector<4xbf16> + %result = llvm.uitofp %arg0 : vector<4xi4> to vector<4xbf16> + // CHECK: llvm.return %[[result]] : vector<4xbf16> + llvm.return %result : vector<4xbf16> +} + +// Checks code path when integer width is equal to mantissa width. +// CHECK-LABEL: llvm.func @vec_si8_to_bf16 +llvm.func @vec_si8_to_bf16(%arg0 : vector<4xi8>) -> vector<4xbf16> { + // CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : vector<4xi8> to vector<4xi16> + // CHECK-DAG: %[[const:.*]] = llvm.mlir.constant(dense<17152> : vector<4xi16>) : vector<4xi16> + // CHECK-DAG: %[[or:.*]] = llvm.or %[[zext]], %[[const]] : vector<4xi16> + // CHECK-DAG: %[[exp_mask:.*]] = llvm.mlir.constant(dense<-128> : vector<4xi16>) : vector<4xi16> + // CHECK-DAG: %[[man_mask:.*]] = llvm.mlir.constant(dense<-129> : vector<4xi16>) : vector<4xi16> + // CHECK-DAG: %[[exp_and:.*]] = llvm.and %[[or]], %[[exp_mask]] : vector<4xi16> + // CHECK-DAG: %[[man_and:.*]] = llvm.and %[[or]], %[[man_mask]] : vector<4xi16> + // CHECK-DAG: %[[exp_cast:.*]] = llvm.bitcast %[[exp_and]] : vector<4xi16> to vector<4xbf16> + // CHECK-DAG: %[[man_cast:.*]] = llvm.bitcast %[[man_and]] : vector<4xi16> to vector<4xbf16> + // CHECK-DAG: %[[result:.*]] = llvm.fsub %[[man_cast]], %[[exp_cast]] : vector<4xbf16> + %result = llvm.sitofp %arg0 : vector<4xi8> to vector<4xbf16> + // CHECK: llvm.return %[[result]] : vector<4xbf16> + llvm.return %result : vector<4xbf16> +} + +// Checks that expansion does not apply when unsigned integer width is equal to +// mantissa width. +// CHECK-LABEL: llvm.func @ui8_to_bf16 +llvm.func @ui8_to_bf16(%arg0 : i8) -> bf16 { + // CHECK: %[[result:.*]] = llvm.uitofp %arg0 : i8 to bf16 + %result = llvm.uitofp %arg0 : i8 to bf16 + // CHECK: llvm.return %[[result]] : bf16 + llvm.return %result : bf16 +} + +// Checks that expansion does not apply when exponent bias lsb is set. +// CHECK-LABEL: llvm.func @ui11_to_f16 +llvm.func @ui11_to_f16(%arg0 : i11) -> f16 { + // CHECK: %[[result:.*]] = llvm.uitofp %arg0 : i11 to f16 + %result = llvm.uitofp %arg0 : i11 to f16 + // CHECK: llvm.return %[[result]] : f16 + llvm.return %result : f16 +}