Skip to content

Commit ecd33d5

Browse files
committed
[mlir][nvvm] Expand sitofp/uitofp to faster ops
`sitofp` and `uitofp` are lowered to `cvt.rn` PTX instructions by the LLVM-NVPTX backend, which has lower throughput than int and float arithmetic ops. Doing this optimization in LLVM would only work for i16->fp32 because the NVPTX backend has no i8 registers and promotes them to i16.
1 parent 2f4232d commit ecd33d5

File tree

2 files changed

+274
-1
lines changed

2 files changed

+274
-1
lines changed

mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
1212
#include "mlir/IR/Builders.h"
13+
#include "mlir/IR/ImplicitLocOpBuilder.h"
1314
#include "mlir/IR/PatternMatch.h"
1415
#include "mlir/Pass/Pass.h"
1516
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -39,6 +40,17 @@ struct ExpandDivF16 : public OpRewritePattern<LLVM::FDivOp> {
3940
PatternRewriter &rewriter) const override;
4041
};
4142

43+
// Replaces sitofp or uitofp on src types no wider than the dst type mantissa
44+
// with a faster combination of bit ops and add/sub.
45+
template <typename OpTy> // OpTy should be LLVM::SIToFPOp or LLVM::UIToFPOp.
46+
struct ExpandIToFP : public OpRewritePattern<OpTy> {
47+
using OpRewritePattern<OpTy>::OpRewritePattern;
48+
49+
private:
50+
LogicalResult matchAndRewrite(OpTy op,
51+
PatternRewriter &rewriter) const override;
52+
};
53+
4254
struct NVVMOptimizeForTarget
4355
: public NVVM::impl::NVVMOptimizeForTargetBase<NVVMOptimizeForTarget> {
4456
void runOnOperation() override;
@@ -92,10 +104,93 @@ LogicalResult ExpandDivF16::matchAndRewrite(LLVM::FDivOp op,
92104
return success();
93105
}
94106

107+
template <typename OpTy>
108+
LogicalResult
109+
ExpandIToFP<OpTy>::matchAndRewrite(OpTy op, PatternRewriter &rewriter) const {
110+
Type srcType = op.getOperand().getType();
111+
auto intType = dyn_cast<IntegerType>(getElementTypeOrSelf(srcType));
112+
if (!intType)
113+
return rewriter.notifyMatchFailure(op, "src type is not integer");
114+
Type dstType = op.getType();
115+
auto floatType = dyn_cast<FloatType>(getElementTypeOrSelf(dstType));
116+
if (!floatType)
117+
return rewriter.notifyMatchFailure(op, "dst type is not float");
118+
119+
// Mantissa width includes the integer bit, e.g. 24 for fp32.
120+
auto mantissaWidth = floatType.getFPMantissaWidth();
121+
if (mantissaWidth < 2)
122+
return rewriter.notifyMatchFailure(op, "mantissa is less than 2 bits");
123+
auto intWidth = intType.getWidth();
124+
if (intWidth > mantissaWidth)
125+
return rewriter.notifyMatchFailure(op, "src is wider than dst mantissa");
126+
127+
Type extType = IntegerType::get(rewriter.getContext(), floatType.getWidth(),
128+
intType.getSignedness());
129+
if (ShapedType shapedType = dyn_cast<ShapedType>(srcType))
130+
extType = shapedType.clone(extType);
131+
auto getAttr = [&](APInt value) -> TypedAttr {
132+
if (ShapedType shapedType = dyn_cast<ShapedType>(extType))
133+
return DenseElementsAttr::get(shapedType, value);
134+
return IntegerAttr::get(extType, value);
135+
};
136+
ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
137+
138+
if (intWidth == mantissaWidth) {
139+
// Create a float bit-pattern with zero biased-exponent and zero mantissa.
140+
APFloat::integerPart intPart = 1ull << (mantissaWidth - 1);
141+
APFloat floatBits(floatType.getFloatSemantics(), intPart);
142+
if (floatBits.bitcastToAPInt()[mantissaWidth - 1])
143+
return rewriter.notifyMatchFailure(op, "bias exponent lsb bit is set");
144+
TypedAttr intAttr = getAttr(floatBits.bitcastToAPInt());
145+
146+
// Combine zero-extended src and float bit-pattern. The msb of src becomes
147+
// the lsb of the exponent.
148+
Value zext = builder.create<LLVM::ZExtOp>(extType, op.getOperand());
149+
Value intConst = builder.create<LLVM::ConstantOp>(intAttr);
150+
Value pattern = builder.create<LLVM::OrOp>(zext, intConst);
151+
152+
// Mask the exponent-lsb and the mantissa to get two separate values.
153+
auto mask = APInt::getBitsSetFrom(floatType.getWidth(), mantissaWidth - 1);
154+
Value exponentMask = builder.create<LLVM::ConstantOp>(getAttr(mask));
155+
Value mantissaMask = builder.create<LLVM::ConstantOp>(getAttr(mask - 1));
156+
Value exponentAnd = builder.create<LLVM::AndOp>(pattern, exponentMask);
157+
Value mantissaAnd = builder.create<LLVM::AndOp>(pattern, mantissaMask);
158+
159+
// Bitcast these values to float and subtract or add them.
160+
Value exponentCast = builder.create<LLVM::BitcastOp>(dstType, exponentAnd);
161+
Value mantissaCast = builder.create<LLVM::BitcastOp>(dstType, mantissaAnd);
162+
using SubOrAddOp = std::conditional_t<std::is_same_v<OpTy, LLVM::SIToFPOp>,
163+
LLVM::FSubOp, LLVM::FAddOp>;
164+
rewriter.replaceOpWithNewOp<SubOrAddOp>(op, mantissaCast, exponentCast);
165+
return success();
166+
}
167+
168+
// Create a float with zero biased-exponent and msb-set mantissa.
169+
APFloat::integerPart intPart = 3ull << (mantissaWidth - 2);
170+
APFloat floatBits(floatType.getFloatSemantics(), intPart);
171+
TypedAttr intAttr = getAttr(floatBits.bitcastToAPInt());
172+
TypedAttr floatAttr = FloatAttr::get(floatType, floatBits);
173+
if (ShapedType shapedType = dyn_cast<ShapedType>(dstType))
174+
floatAttr = DenseElementsAttr::get(shapedType, floatAttr);
175+
176+
// Add extended src and bit-pattern of float, then subtract float.
177+
using ExtOp = std::conditional_t<std::is_same_v<OpTy, LLVM::SIToFPOp>,
178+
LLVM::SExtOp, LLVM::ZExtOp>;
179+
Value ext = builder.create<ExtOp>(extType, op.getOperand());
180+
Value intConst = builder.create<LLVM::ConstantOp>(intAttr);
181+
Value add = builder.create<LLVM::AddOp>(ext, intConst);
182+
Value bitcast = builder.create<LLVM::BitcastOp>(dstType, add);
183+
Value floatConst = builder.create<LLVM::ConstantOp>(floatAttr);
184+
rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, bitcast, floatConst);
185+
return success();
186+
}
187+
95188
void NVVMOptimizeForTarget::runOnOperation() {
96189
MLIRContext *ctx = getOperation()->getContext();
97190
RewritePatternSet patterns(ctx);
98-
patterns.add<ExpandDivF16>(ctx);
191+
patterns.add<ExpandDivF16, ExpandIToFP<LLVM::SIToFPOp>,
192+
ExpandIToFP<LLVM::UIToFPOp>>(ctx);
193+
99194
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
100195
return signalPassFailure();
101196
}

mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,181 @@ llvm.func @fdiv_fp16(%arg0 : f16, %arg1 : f16) -> f16 {
2222
// CHECK: llvm.return %[[result]] : f16
2323
llvm.return %result : f16
2424
}
25+
26+
// CHECK-LABEL: llvm.func @ui16_to_f32
27+
llvm.func @ui16_to_f32(%arg0 : i16) -> f32 {
28+
// CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : i16 to i32
29+
// CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(1262485504 : i32) : i32
30+
// CHECK-DAG: %[[add:.*]] = llvm.add %[[zext]], %[[bias]] : i32
31+
// CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : i32 to f32
32+
// CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(0x4B400000 : f32) : f32
33+
// CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]] : f32
34+
%result = llvm.uitofp %arg0 : i16 to f32
35+
// CHECK: llvm.return %[[result]] : f32
36+
llvm.return %result : f32
37+
}
38+
39+
// Checks that expansion only applies to integer width up to mantissa width.
40+
// CHECK-LABEL: llvm.func @si32_to_float
41+
llvm.func @si32_to_float_no_rewrite(%arg0 : i32) -> f32 {
42+
// CHECK: %[[result:.*]] = llvm.sitofp %arg0 : i32 to f32
43+
%result = llvm.sitofp %arg0 : i32 to f32
44+
// CHECK: llvm.return %[[result]] : f32
45+
llvm.return %result : f32
46+
}
47+
48+
// CHECK-LABEL: llvm.func @si8_to_f16
49+
llvm.func @si8_to_f16(%arg0 : i8) -> f16 {
50+
// CHECK-DAG: %[[sext:.*]] = llvm.sext %arg0 : i8 to i16
51+
// CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(26112 : i16) : i16
52+
// CHECK-DAG: %[[add:.*]] = llvm.add %[[sext]], %[[bias]] : i16
53+
// CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : i16 to f16
54+
// CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(1.536000e+03 : f16) : f16
55+
// CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]] : f16
56+
%result = llvm.sitofp %arg0 : i8 to f16
57+
// CHECK: llvm.return %[[result]] : f16
58+
llvm.return %result : f16
59+
}
60+
61+
// CHECK-LABEL: llvm.func @vec_ui4_to_bf16
62+
llvm.func @vec_ui4_to_bf16(%arg0 : vector<4xi4>) -> vector<4xbf16> {
63+
// CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : vector<4xi4> to vector<4xi16>
64+
// CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(dense<17216> : vector<4xi16>) : vector<4xi16>
65+
// CHECK-DAG: %[[add:.*]] = llvm.add %[[zext]], %[[bias]] : vector<4xi16>
66+
// CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : vector<4xi16> to vector<4xbf16>
67+
// CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(dense<1.920000e+02> : vector<4xbf16>) : vector<4xbf16>
68+
// CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]] : vector<4xbf16>
69+
%result = llvm.uitofp %arg0 : vector<4xi4> to vector<4xbf16>
70+
// CHECK: llvm.return %[[result]] : vector<4xbf16>
71+
llvm.return %result : vector<4xbf16>
72+
}
73+
74+
// Checks code path when integer width is equal to mantissa width.
75+
// CHECK-LABEL: llvm.func @vec_si8_to_bf16
76+
llvm.func @vec_si8_to_bf16(%arg0 : vector<4xi8>) -> vector<4xbf16> {
77+
// CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : vector<4xi8> to vector<4xi16>
78+
// CHECK-DAG: %[[const:.*]] = llvm.mlir.constant(dense<17152> : vector<4xi16>) : vector<4xi16>
79+
// CHECK-DAG: %[[or:.*]] = llvm.or %[[zext]], %[[const]] : vector<4xi16>
80+
// CHECK-DAG: %[[exp_mask:.*]] = llvm.mlir.constant(dense<-128> : vector<4xi16>) : vector<4xi16>
81+
// CHECK-DAG: %[[man_mask:.*]] = llvm.mlir.constant(dense<-129> : vector<4xi16>) : vector<4xi16>
82+
// CHECK-DAG: %[[exp_and:.*]] = llvm.and %[[or]], %[[exp_mask]] : vector<4xi16>
83+
// CHECK-DAG: %[[man_and:.*]] = llvm.and %[[or]], %[[man_mask]] : vector<4xi16>
84+
// CHECK-DAG: %[[exp_cast:.*]] = llvm.bitcast %[[exp_and]] : vector<4xi16> to vector<4xbf16>
85+
// CHECK-DAG: %[[man_cast:.*]] = llvm.bitcast %[[man_and]] : vector<4xi16> to vector<4xbf16>
86+
// CHECK-DAG: %[[result:.*]] = llvm.fsub %[[man_cast]], %[[exp_cast]] : vector<4xbf16>
87+
%result = llvm.sitofp %arg0 : vector<4xi8> to vector<4xbf16>
88+
// CHECK: llvm.return %[[result]] : vector<4xbf16>
89+
llvm.return %result : vector<4xbf16>
90+
}
91+
92+
// Checks code path when integer width is equal to mantissa width.
93+
// CHECK-LABEL: llvm.func @ui8_to_bf16
94+
llvm.func @ui8_to_bf16(%arg0 : i8) -> bf16 {
95+
// CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : i8 to i16
96+
// CHECK-DAG: %[[const:.*]] = llvm.mlir.constant(17152 : i16) : i16
97+
// CHECK-DAG: %[[or:.*]] = llvm.or %[[zext]], %[[const]] : i16
98+
// CHECK-DAG: %[[exp_mask:.*]] = llvm.mlir.constant(-128 : i16) : i16
99+
// CHECK-DAG: %[[man_mask:.*]] = llvm.mlir.constant(-129 : i16) : i16
100+
// CHECK-DAG: %[[exp_and:.*]] = llvm.and %[[or]], %[[exp_mask]] : i16
101+
// CHECK-DAG: %[[man_and:.*]] = llvm.and %[[or]], %[[man_mask]] : i16
102+
// CHECK-DAG: %[[exp_cast:.*]] = llvm.bitcast %[[exp_and]] : i16 to bf16
103+
// CHECK-DAG: %[[man_cast:.*]] = llvm.bitcast %[[man_and]] : i16 to bf16
104+
// CHECK-DAG: %[[result:.*]] = llvm.fadd %[[man_cast]], %[[exp_cast]] : bf16
105+
%result = llvm.uitofp %arg0 : i8 to bf16
106+
// CHECK: llvm.return %[[result]] : bf16
107+
llvm.return %result : bf16
108+
}
109+
110+
// Checks that expansion does not apply when exponent bias lsb is set.
111+
// CHECK-LABEL: llvm.func @ui11_to_f16
112+
llvm.func @ui11_to_f16(%arg0 : i11) -> f16 {
113+
// CHECK: %[[result:.*]] = llvm.uitofp %arg0 : i11 to f16
114+
%result = llvm.uitofp %arg0 : i11 to f16
115+
// CHECK: llvm.return %[[result]] : f16
116+
llvm.return %result : f16
117+
}
118+
119+
// CHECK-LABEL: llvm.func @ui16_to_f32
120+
llvm.func @ui16_to_f32(%arg0 : i16) -> f32 {
121+
// CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : i16 to i32
122+
// CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(1262485504 : i32) : i32
123+
// CHECK-DAG: %[[add:.*]] = llvm.add %[[zext]], %[[bias]] : i32
124+
// CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : i32 to f32
125+
// CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(0x4B400000 : f32) : f32
126+
// CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]] : f32
127+
%result = llvm.uitofp %arg0 : i16 to f32
128+
// CHECK: llvm.return %[[result]] : f32
129+
llvm.return %result : f32
130+
}
131+
132+
// Checks that expansion only applies to integer width up to mantissa width.
133+
// CHECK-LABEL: llvm.func @si32_to_float
134+
llvm.func @si32_to_float_no_rewrite(%arg0 : i32) -> f32 {
135+
// CHECK: %[[result:.*]] = llvm.sitofp %arg0 : i32 to f32
136+
%result = llvm.sitofp %arg0 : i32 to f32
137+
// CHECK: llvm.return %[[result]] : f32
138+
llvm.return %result : f32
139+
}
140+
141+
// CHECK-LABEL: llvm.func @si8_to_f16
142+
llvm.func @si8_to_f16(%arg0 : i8) -> f16 {
143+
// CHECK-DAG: %[[sext:.*]] = llvm.sext %arg0 : i8 to i16
144+
// CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(26112 : i16) : i16
145+
// CHECK-DAG: %[[add:.*]] = llvm.add %[[sext]], %[[bias]] : i16
146+
// CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : i16 to f16
147+
// CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(1.536000e+03 : f16) : f16
148+
// CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]] : f16
149+
%result = llvm.sitofp %arg0 : i8 to f16
150+
// CHECK: llvm.return %[[result]] : f16
151+
llvm.return %result : f16
152+
}
153+
154+
// CHECK-LABEL: llvm.func @vec_ui4_to_bf16
155+
llvm.func @vec_ui4_to_bf16(%arg0 : vector<4xi4>) -> vector<4xbf16> {
156+
// CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : vector<4xi4> to vector<4xi16>
157+
// CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(dense<17216> : vector<4xi16>) : vector<4xi16>
158+
// CHECK-DAG: %[[add:.*]] = llvm.add %[[zext]], %[[bias]] : vector<4xi16>
159+
// CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : vector<4xi16> to vector<4xbf16>
160+
// CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(dense<1.920000e+02> : vector<4xbf16>) : vector<4xbf16>
161+
// CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]] : vector<4xbf16>
162+
%result = llvm.uitofp %arg0 : vector<4xi4> to vector<4xbf16>
163+
// CHECK: llvm.return %[[result]] : vector<4xbf16>
164+
llvm.return %result : vector<4xbf16>
165+
}
166+
167+
// Checks code path when integer width is equal to mantissa width.
168+
// CHECK-LABEL: llvm.func @vec_si8_to_bf16
169+
llvm.func @vec_si8_to_bf16(%arg0 : vector<4xi8>) -> vector<4xbf16> {
170+
// CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : vector<4xi8> to vector<4xi16>
171+
// CHECK-DAG: %[[const:.*]] = llvm.mlir.constant(dense<17152> : vector<4xi16>) : vector<4xi16>
172+
// CHECK-DAG: %[[or:.*]] = llvm.or %[[zext]], %[[const]] : vector<4xi16>
173+
// CHECK-DAG: %[[exp_mask:.*]] = llvm.mlir.constant(dense<-128> : vector<4xi16>) : vector<4xi16>
174+
// CHECK-DAG: %[[man_mask:.*]] = llvm.mlir.constant(dense<-129> : vector<4xi16>) : vector<4xi16>
175+
// CHECK-DAG: %[[exp_and:.*]] = llvm.and %[[or]], %[[exp_mask]] : vector<4xi16>
176+
// CHECK-DAG: %[[man_and:.*]] = llvm.and %[[or]], %[[man_mask]] : vector<4xi16>
177+
// CHECK-DAG: %[[exp_cast:.*]] = llvm.bitcast %[[exp_and]] : vector<4xi16> to vector<4xbf16>
178+
// CHECK-DAG: %[[man_cast:.*]] = llvm.bitcast %[[man_and]] : vector<4xi16> to vector<4xbf16>
179+
// CHECK-DAG: %[[result:.*]] = llvm.fsub %[[man_cast]], %[[exp_cast]] : vector<4xbf16>
180+
%result = llvm.sitofp %arg0 : vector<4xi8> to vector<4xbf16>
181+
// CHECK: llvm.return %[[result]] : vector<4xbf16>
182+
llvm.return %result : vector<4xbf16>
183+
}
184+
185+
// Checks that expansion does not apply when unsigned integer width is equal to
186+
// mantissa width.
187+
// CHECK-LABEL: llvm.func @ui8_to_bf16
188+
llvm.func @ui8_to_bf16(%arg0 : i8) -> bf16 {
189+
// CHECK: %[[result:.*]] = llvm.uitofp %arg0 : i8 to bf16
190+
%result = llvm.uitofp %arg0 : i8 to bf16
191+
// CHECK: llvm.return %[[result]] : bf16
192+
llvm.return %result : bf16
193+
}
194+
195+
// Checks that expansion does not apply when exponent bias lsb is set.
196+
// CHECK-LABEL: llvm.func @ui11_to_f16
197+
llvm.func @ui11_to_f16(%arg0 : i11) -> f16 {
198+
// CHECK: %[[result:.*]] = llvm.uitofp %arg0 : i11 to f16
199+
%result = llvm.uitofp %arg0 : i11 to f16
200+
// CHECK: llvm.return %[[result]] : f16
201+
llvm.return %result : f16
202+
}

0 commit comments

Comments
 (0)