-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][nvvm] Expand sitofp/uitofp to faster ops #107001
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir Author: Christian Sigg (chsigg) Changes
Doing this optimization in LLVM would only work for i16->fp32 because the NVPTX backend has no i8 registers and promotes them to i16. Full diff: https://github.com/llvm/llvm-project/pull/107001.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp
index 8c33148d1d2d78..de3295ead2c3cd 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -39,6 +40,17 @@ struct ExpandDivF16 : public OpRewritePattern<LLVM::FDivOp> {
PatternRewriter &rewriter) const override;
};
+// Replaces sitofp or uitofp on src types no wider than the dst type mantissa
+// with a faster combination of bit ops and add/sub.
+template <typename OpTy> // OpTy should be LLVM::SIToFPOp or LLVM::UIToFPOp.
+struct ExpandIToFP : public OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+private:
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const override;
+};
+
struct NVVMOptimizeForTarget
: public NVVM::impl::NVVMOptimizeForTargetBase<NVVMOptimizeForTarget> {
void runOnOperation() override;
@@ -92,10 +104,93 @@ LogicalResult ExpandDivF16::matchAndRewrite(LLVM::FDivOp op,
return success();
}
+template <typename OpTy>
+LogicalResult
+ExpandIToFP<OpTy>::matchAndRewrite(OpTy op, PatternRewriter &rewriter) const {
+ Type srcType = op.getOperand().getType();
+ auto intType = dyn_cast<IntegerType>(getElementTypeOrSelf(srcType));
+ if (!intType)
+ return rewriter.notifyMatchFailure(op, "src type is not integer");
+ Type dstType = op.getType();
+ auto floatType = dyn_cast<FloatType>(getElementTypeOrSelf(dstType));
+ if (!floatType)
+ return rewriter.notifyMatchFailure(op, "dst type is not float");
+
+ // Mantissa width includes the integer bit, e.g. 24 for fp32.
+ auto mantissaWidth = floatType.getFPMantissaWidth();
+ if (mantissaWidth < 2)
+ return rewriter.notifyMatchFailure(op, "mantissa is less than 2 bits");
+ auto intWidth = intType.getWidth();
+ if (intWidth > mantissaWidth)
+ return rewriter.notifyMatchFailure(op, "src is wider than dst mantissa");
+
+ Type extType = IntegerType::get(rewriter.getContext(), floatType.getWidth(),
+ intType.getSignedness());
+ if (ShapedType shapedType = dyn_cast<ShapedType>(srcType))
+ extType = shapedType.clone(extType);
+ auto getAttr = [&](APInt value) -> TypedAttr {
+ if (ShapedType shapedType = dyn_cast<ShapedType>(extType))
+ return DenseElementsAttr::get(shapedType, value);
+ return IntegerAttr::get(extType, value);
+ };
+ ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
+
+ if (intWidth == mantissaWidth) {
+ // Create a float bit-pattern with zero biased-exponent and zero mantissa.
+ APFloat::integerPart intPart = 1ull << (mantissaWidth - 1);
+ APFloat floatBits(floatType.getFloatSemantics(), intPart);
+ if (floatBits.bitcastToAPInt()[mantissaWidth - 1])
+ return rewriter.notifyMatchFailure(op, "bias exponent lsb bit is set");
+ TypedAttr intAttr = getAttr(floatBits.bitcastToAPInt());
+
+ // Combine zero-extended src and float bit-pattern. The msb of src becomes
+ // the lsb of the exponent.
+ Value zext = builder.create<LLVM::ZExtOp>(extType, op.getOperand());
+ Value intConst = builder.create<LLVM::ConstantOp>(intAttr);
+ Value pattern = builder.create<LLVM::OrOp>(zext, intConst);
+
+ // Mask the exponent-lsb and the mantissa to get two separate values.
+ auto mask = APInt::getBitsSetFrom(floatType.getWidth(), mantissaWidth - 1);
+ Value exponentMask = builder.create<LLVM::ConstantOp>(getAttr(mask));
+ Value mantissaMask = builder.create<LLVM::ConstantOp>(getAttr(mask - 1));
+ Value exponentAnd = builder.create<LLVM::AndOp>(pattern, exponentMask);
+ Value mantissaAnd = builder.create<LLVM::AndOp>(pattern, mantissaMask);
+
+ // Bitcast these values to float and subtract or add them.
+ Value exponentCast = builder.create<LLVM::BitcastOp>(dstType, exponentAnd);
+ Value mantissaCast = builder.create<LLVM::BitcastOp>(dstType, mantissaAnd);
+ using SubOrAddOp = std::conditional_t<std::is_same_v<OpTy, LLVM::SIToFPOp>,
+ LLVM::FSubOp, LLVM::FAddOp>;
+ rewriter.replaceOpWithNewOp<SubOrAddOp>(op, mantissaCast, exponentCast);
+ return success();
+ }
+
+ // Create a float with zero biased-exponent and msb-set mantissa.
+ APFloat::integerPart intPart = 3ull << (mantissaWidth - 2);
+ APFloat floatBits(floatType.getFloatSemantics(), intPart);
+ TypedAttr intAttr = getAttr(floatBits.bitcastToAPInt());
+ TypedAttr floatAttr = FloatAttr::get(floatType, floatBits);
+ if (ShapedType shapedType = dyn_cast<ShapedType>(dstType))
+ floatAttr = DenseElementsAttr::get(shapedType, floatAttr);
+
+ // Add extended src and bit-pattern of float, then subtract float.
+ using ExtOp = std::conditional_t<std::is_same_v<OpTy, LLVM::SIToFPOp>,
+ LLVM::SExtOp, LLVM::ZExtOp>;
+ Value ext = builder.create<ExtOp>(extType, op.getOperand());
+ Value intConst = builder.create<LLVM::ConstantOp>(intAttr);
+ Value add = builder.create<LLVM::AddOp>(ext, intConst);
+ Value bitcast = builder.create<LLVM::BitcastOp>(dstType, add);
+ Value floatConst = builder.create<LLVM::ConstantOp>(floatAttr);
+ rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, bitcast, floatConst);
+ return success();
+}
+
void NVVMOptimizeForTarget::runOnOperation() {
MLIRContext *ctx = getOperation()->getContext();
RewritePatternSet patterns(ctx);
- patterns.add<ExpandDivF16>(ctx);
+ patterns.add<ExpandDivF16, ExpandIToFP<LLVM::SIToFPOp>,
+ ExpandIToFP<LLVM::UIToFPOp>>(ctx);
+
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
diff --git a/mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir b/mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir
index b98d2e08b75486..813f25535d3295 100644
--- a/mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir
@@ -22,3 +22,96 @@ llvm.func @fdiv_fp16(%arg0 : f16, %arg1 : f16) -> f16 {
// CHECK: llvm.return %[[result]] : f16
llvm.return %result : f16
}
+
+// CHECK-LABEL: llvm.func @ui16_to_f32
+llvm.func @ui16_to_f32(%arg0 : i16) -> f32 {
+ // CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : i16 to i32
+ // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(1262485504 : i32) : i32
+ // CHECK-DAG: %[[add:.*]] = llvm.add %[[zext]], %[[bias]] : i32
+ // CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : i32 to f32
+ // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(0x4B400000 : f32) : f32
+ // CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]] : f32
+ %result = llvm.uitofp %arg0 : i16 to f32
+ // CHECK: llvm.return %[[result]] : f32
+ llvm.return %result : f32
+}
+
+// Checks that expansion only applies to integer width up to mantissa width.
+// CHECK-LABEL: llvm.func @si32_to_float
+llvm.func @si32_to_float_no_rewrite(%arg0 : i32) -> f32 {
+ // CHECK: %[[result:.*]] = llvm.sitofp %arg0 : i32 to f32
+ %result = llvm.sitofp %arg0 : i32 to f32
+ // CHECK: llvm.return %[[result]] : f32
+ llvm.return %result : f32
+}
+
+// CHECK-LABEL: llvm.func @si8_to_f16
+llvm.func @si8_to_f16(%arg0 : i8) -> f16 {
+ // CHECK-DAG: %[[sext:.*]] = llvm.sext %arg0 : i8 to i16
+ // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(26112 : i16) : i16
+ // CHECK-DAG: %[[add:.*]] = llvm.add %[[sext]], %[[bias]] : i16
+ // CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : i16 to f16
+ // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(1.536000e+03 : f16) : f16
+ // CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]] : f16
+ %result = llvm.sitofp %arg0 : i8 to f16
+ // CHECK: llvm.return %[[result]] : f16
+ llvm.return %result : f16
+}
+
+// CHECK-LABEL: llvm.func @vec_ui4_to_bf16
+llvm.func @vec_ui4_to_bf16(%arg0 : vector<4xi4>) -> vector<4xbf16> {
+ // CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : vector<4xi4> to vector<4xi16>
+ // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(dense<17216> : vector<4xi16>) : vector<4xi16>
+ // CHECK-DAG: %[[add:.*]] = llvm.add %[[zext]], %[[bias]] : vector<4xi16>
+ // CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : vector<4xi16> to vector<4xbf16>
+ // CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(dense<1.920000e+02> : vector<4xbf16>) : vector<4xbf16>
+ // CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]] : vector<4xbf16>
+ %result = llvm.uitofp %arg0 : vector<4xi4> to vector<4xbf16>
+ // CHECK: llvm.return %[[result]] : vector<4xbf16>
+ llvm.return %result : vector<4xbf16>
+}
+
+// Checks code path when integer width is equal to mantissa width.
+// CHECK-LABEL: llvm.func @vec_si8_to_bf16
+llvm.func @vec_si8_to_bf16(%arg0 : vector<4xi8>) -> vector<4xbf16> {
+ // CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : vector<4xi8> to vector<4xi16>
+ // CHECK-DAG: %[[const:.*]] = llvm.mlir.constant(dense<17152> : vector<4xi16>) : vector<4xi16>
+ // CHECK-DAG: %[[or:.*]] = llvm.or %[[zext]], %[[const]] : vector<4xi16>
+ // CHECK-DAG: %[[exp_mask:.*]] = llvm.mlir.constant(dense<-128> : vector<4xi16>) : vector<4xi16>
+ // CHECK-DAG: %[[man_mask:.*]] = llvm.mlir.constant(dense<-129> : vector<4xi16>) : vector<4xi16>
+ // CHECK-DAG: %[[exp_and:.*]] = llvm.and %[[or]], %[[exp_mask]] : vector<4xi16>
+ // CHECK-DAG: %[[man_and:.*]] = llvm.and %[[or]], %[[man_mask]] : vector<4xi16>
+ // CHECK-DAG: %[[exp_cast:.*]] = llvm.bitcast %[[exp_and]] : vector<4xi16> to vector<4xbf16>
+ // CHECK-DAG: %[[man_cast:.*]] = llvm.bitcast %[[man_and]] : vector<4xi16> to vector<4xbf16>
+ // CHECK-DAG: %[[result:.*]] = llvm.fsub %[[man_cast]], %[[exp_cast]] : vector<4xbf16>
+ %result = llvm.sitofp %arg0 : vector<4xi8> to vector<4xbf16>
+ // CHECK: llvm.return %[[result]] : vector<4xbf16>
+ llvm.return %result : vector<4xbf16>
+}
+
+// Checks code path when integer width is equal to mantissa width.
+// CHECK-LABEL: llvm.func @ui8_to_bf16
+llvm.func @ui8_to_bf16(%arg0 : i8) -> bf16 {
+ // CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : i8 to i16
+ // CHECK-DAG: %[[const:.*]] = llvm.mlir.constant(17152 : i16) : i16
+ // CHECK-DAG: %[[or:.*]] = llvm.or %[[zext]], %[[const]] : i16
+ // CHECK-DAG: %[[exp_mask:.*]] = llvm.mlir.constant(-128 : i16) : i16
+ // CHECK-DAG: %[[man_mask:.*]] = llvm.mlir.constant(-129 : i16) : i16
+ // CHECK-DAG: %[[exp_and:.*]] = llvm.and %[[or]], %[[exp_mask]] : i16
+ // CHECK-DAG: %[[man_and:.*]] = llvm.and %[[or]], %[[man_mask]] : i16
+ // CHECK-DAG: %[[exp_cast:.*]] = llvm.bitcast %[[exp_and]] : i16 to bf16
+ // CHECK-DAG: %[[man_cast:.*]] = llvm.bitcast %[[man_and]] : i16 to bf16
+ // CHECK-DAG: %[[result:.*]] = llvm.fadd %[[man_cast]], %[[exp_cast]] : bf16
+ %result = llvm.uitofp %arg0 : i8 to bf16
+ // CHECK: llvm.return %[[result]] : bf16
+ llvm.return %result : bf16
+}
+
+// Checks that expansion does not apply when exponent bias lsb is set.
+// CHECK-LABEL: llvm.func @ui11_to_f16
+llvm.func @ui11_to_f16(%arg0 : i11) -> f16 {
+ // CHECK: %[[result:.*]] = llvm.uitofp %arg0 : i11 to f16
+ %result = llvm.uitofp %arg0 : i11 to f16
+ // CHECK: llvm.return %[[result]] : f16
+ llvm.return %result : f16
+}
|
`sitofp` and `uitofp` are lowered to `cvt.rn` PTX instructions by the LLVM-NVPTX backend, which has lower throughput than int and float arithmetic ops. Doing this optimization in LLVM would only work for i16->fp32 because the NVPTX backend has no i8 registers and promotes them to i16.
97f4c2e
to
ecd33d5
Compare
I didn't follow this explanation: can you elaborate a bit more which one of the cases in the tests couldn't be done in LLVM IR? |
Conversion from less than 16 bit integers seems tricky because 8 bit registers have been removed from the NVPTX backend. I'm not very familiar with LLVM, so it's well possible that it could be done. I've been experimenting with https://godbolt.org/z/cvsGTrM5j. I was considering matching |
Let me ask NVVM folks to chime in! Are you saying you can't do this inside the NVPTX backend? When I questioned about doing this in LLVM, I meant adding something to the "codegen prepare" phase before the backend (LLVM IR -> LLVM IR transformation). |
Yes, it is possible to do this transformation at the LLVM IR level. It should also be possible to do this in However, I would lean towards doing this at the LLVM IR level as a "codegen prepare" because there is more infrastructure to verify the correctness of LLVM IR level transformation. |
Thank you for the explanation. I will look into implementing this in LLVM. I'm not familiar with that code base, but I will give it a try when I have some spare time. |
sitofp
anduitofp
are lowered tocvt.rn
PTX instructions by the LLVM-NVPTX backend, which has lower throughput than int and float arithmetic ops.Doing this optimization in LLVM would only work for i16->fp32 because the NVPTX backend has no i8 registers and promotes them to i16.