Skip to content

[mlir][nvvm] Expand sitofp/uitofp to faster ops #107001

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 98 additions & 1 deletion mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
Expand Down Expand Up @@ -39,6 +40,17 @@ struct ExpandDivF16 : public OpRewritePattern<LLVM::FDivOp> {
PatternRewriter &rewriter) const override;
};

// Replaces sitofp or uitofp on src types no wider than the dst type mantissa
// with a faster combination of bit ops and add/sub.
template <typename OpTy> // OpTy should be LLVM::SIToFPOp or LLVM::UIToFPOp.
struct ExpandIToFP : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;

private:
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override;
};

struct NVVMOptimizeForTarget
: public NVVM::impl::NVVMOptimizeForTargetBase<NVVMOptimizeForTarget> {
void runOnOperation() override;
Expand Down Expand Up @@ -92,10 +104,95 @@ LogicalResult ExpandDivF16::matchAndRewrite(LLVM::FDivOp op,
return success();
}

template <typename OpTy>
LogicalResult
ExpandIToFP<OpTy>::matchAndRewrite(OpTy op, PatternRewriter &rewriter) const {
Type srcType = op.getOperand().getType();
auto intType = dyn_cast<IntegerType>(getElementTypeOrSelf(srcType));
if (!intType)
return rewriter.notifyMatchFailure(op, "src type is not integer");
Type dstType = op.getType();
auto floatType = dyn_cast<FloatType>(getElementTypeOrSelf(dstType));
if (!floatType)
return rewriter.notifyMatchFailure(op, "dst type is not float");

// Mantissa width includes the integer bit, e.g. 24 for fp32.
auto mantissaWidth = floatType.getFPMantissaWidth();
if (mantissaWidth < 2)
return rewriter.notifyMatchFailure(op, "mantissa is less than 2 bits");
auto intWidth = intType.getWidth();
if (intWidth > mantissaWidth)
return rewriter.notifyMatchFailure(op, "src is wider than dst mantissa");

Type extType = IntegerType::get(rewriter.getContext(), floatType.getWidth(),
intType.getSignedness());
if (ShapedType shapedType = dyn_cast<ShapedType>(srcType))
extType = shapedType.clone(extType);
auto getAttr = [&](APInt value) -> TypedAttr {
if (ShapedType shapedType = dyn_cast<ShapedType>(extType))
return DenseElementsAttr::get(shapedType, value);
return IntegerAttr::get(extType, value);
};
ImplicitLocOpBuilder builder(op.getLoc(), rewriter);

if (intWidth == mantissaWidth) {
if (std::is_same_v<OpTy, LLVM::UIToFPOp>) {
return rewriter.notifyMatchFailure(
op, "unsigned src is as wide as dst mantissa");
}
// Create a float bit-pattern with zero biased-exponent and zero mantissa.
APFloat::integerPart intPart = 1ull << (mantissaWidth - 1);
APFloat floatBits(floatType.getFloatSemantics(), intPart);
if (floatBits.bitcastToAPInt()[mantissaWidth - 1])
return rewriter.notifyMatchFailure(op, "bias exponent lsb bit is set");
TypedAttr intAttr = getAttr(floatBits.bitcastToAPInt());

// Combine zero-extended src and float bit-pattern. The msb of src becomes
// the lsb of the exponent.
Value zext = builder.create<LLVM::ZExtOp>(extType, op.getOperand());
Value intConst = builder.create<LLVM::ConstantOp>(intAttr);
Value pattern = builder.create<LLVM::OrOp>(zext, intConst);

// Mask the exponent-lsb and the mantissa to get two separate values.
auto mask = APInt::getBitsSetFrom(floatType.getWidth(), mantissaWidth - 1);
Value exponentMask = builder.create<LLVM::ConstantOp>(getAttr(mask));
Value mantissaMask = builder.create<LLVM::ConstantOp>(getAttr(mask - 1));
Value exponentAnd = builder.create<LLVM::AndOp>(pattern, exponentMask);
Value mantissaAnd = builder.create<LLVM::AndOp>(pattern, mantissaMask);

// Bitcast these values to float and subtract or add them.
Value exponentCast = builder.create<LLVM::BitcastOp>(dstType, exponentAnd);
Value mantissaCast = builder.create<LLVM::BitcastOp>(dstType, mantissaAnd);
rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, mantissaCast, exponentCast);
return success();
}

// Create a float with zero biased-exponent and msb-set mantissa.
APFloat::integerPart intPart = 3ull << (mantissaWidth - 2);
APFloat floatBits(floatType.getFloatSemantics(), intPart);
TypedAttr intAttr = getAttr(floatBits.bitcastToAPInt());
TypedAttr floatAttr = FloatAttr::get(floatType, floatBits);
if (ShapedType shapedType = dyn_cast<ShapedType>(dstType))
floatAttr = DenseElementsAttr::get(shapedType, floatAttr);

// Add extended src and bit-pattern of float, then subtract float.
using ExtOp = std::conditional_t<std::is_same_v<OpTy, LLVM::SIToFPOp>,
LLVM::SExtOp, LLVM::ZExtOp>;
Value ext = builder.create<ExtOp>(extType, op.getOperand());
Value intConst = builder.create<LLVM::ConstantOp>(intAttr);
Value add = builder.create<LLVM::AddOp>(ext, intConst);
Value bitcast = builder.create<LLVM::BitcastOp>(dstType, add);
Value floatConst = builder.create<LLVM::ConstantOp>(floatAttr);
rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, bitcast, floatConst);
return success();
}

void NVVMOptimizeForTarget::runOnOperation() {
MLIRContext *ctx = getOperation()->getContext();
RewritePatternSet patterns(ctx);
patterns.add<ExpandDivF16>(ctx);
patterns.add<ExpandDivF16, ExpandIToFP<LLVM::SIToFPOp>,
ExpandIToFP<LLVM::UIToFPOp>>(ctx);

if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
Expand Down
178 changes: 178 additions & 0 deletions mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,181 @@ llvm.func @fdiv_fp16(%arg0 : f16, %arg1 : f16) -> f16 {
// CHECK: llvm.return %[[result]] : f16
llvm.return %result : f16
}

// CHECK-LABEL: llvm.func @ui16_to_f32
llvm.func @ui16_to_f32(%arg0 : i16) -> f32 {
// CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : i16 to i32
// CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(1262485504 : i32) : i32
// CHECK-DAG: %[[add:.*]] = llvm.add %[[zext]], %[[bias]] : i32
// CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : i32 to f32
// CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(0x4B400000 : f32) : f32
// CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]] : f32
%result = llvm.uitofp %arg0 : i16 to f32
// CHECK: llvm.return %[[result]] : f32
llvm.return %result : f32
}

// Checks that expansion only applies to integer width up to mantissa width.
// CHECK-LABEL: llvm.func @si32_to_float
llvm.func @si32_to_float_no_rewrite(%arg0 : i32) -> f32 {
// CHECK: %[[result:.*]] = llvm.sitofp %arg0 : i32 to f32
%result = llvm.sitofp %arg0 : i32 to f32
// CHECK: llvm.return %[[result]] : f32
llvm.return %result : f32
}

// CHECK-LABEL: llvm.func @si8_to_f16
llvm.func @si8_to_f16(%arg0 : i8) -> f16 {
// CHECK-DAG: %[[sext:.*]] = llvm.sext %arg0 : i8 to i16
// CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(26112 : i16) : i16
// CHECK-DAG: %[[add:.*]] = llvm.add %[[sext]], %[[bias]] : i16
// CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : i16 to f16
// CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(1.536000e+03 : f16) : f16
// CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]] : f16
%result = llvm.sitofp %arg0 : i8 to f16
// CHECK: llvm.return %[[result]] : f16
llvm.return %result : f16
}

// CHECK-LABEL: llvm.func @vec_ui4_to_bf16
llvm.func @vec_ui4_to_bf16(%arg0 : vector<4xi4>) -> vector<4xbf16> {
// CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : vector<4xi4> to vector<4xi16>
// CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(dense<17216> : vector<4xi16>) : vector<4xi16>
// CHECK-DAG: %[[add:.*]] = llvm.add %[[zext]], %[[bias]] : vector<4xi16>
// CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : vector<4xi16> to vector<4xbf16>
// CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(dense<1.920000e+02> : vector<4xbf16>) : vector<4xbf16>
// CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]] : vector<4xbf16>
%result = llvm.uitofp %arg0 : vector<4xi4> to vector<4xbf16>
// CHECK: llvm.return %[[result]] : vector<4xbf16>
llvm.return %result : vector<4xbf16>
}

// Checks code path when integer width is equal to mantissa width.
// CHECK-LABEL: llvm.func @vec_si8_to_bf16
llvm.func @vec_si8_to_bf16(%arg0 : vector<4xi8>) -> vector<4xbf16> {
// CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : vector<4xi8> to vector<4xi16>
// CHECK-DAG: %[[const:.*]] = llvm.mlir.constant(dense<17152> : vector<4xi16>) : vector<4xi16>
// CHECK-DAG: %[[or:.*]] = llvm.or %[[zext]], %[[const]] : vector<4xi16>
// CHECK-DAG: %[[exp_mask:.*]] = llvm.mlir.constant(dense<-128> : vector<4xi16>) : vector<4xi16>
// CHECK-DAG: %[[man_mask:.*]] = llvm.mlir.constant(dense<-129> : vector<4xi16>) : vector<4xi16>
// CHECK-DAG: %[[exp_and:.*]] = llvm.and %[[or]], %[[exp_mask]] : vector<4xi16>
// CHECK-DAG: %[[man_and:.*]] = llvm.and %[[or]], %[[man_mask]] : vector<4xi16>
// CHECK-DAG: %[[exp_cast:.*]] = llvm.bitcast %[[exp_and]] : vector<4xi16> to vector<4xbf16>
// CHECK-DAG: %[[man_cast:.*]] = llvm.bitcast %[[man_and]] : vector<4xi16> to vector<4xbf16>
// CHECK-DAG: %[[result:.*]] = llvm.fsub %[[man_cast]], %[[exp_cast]] : vector<4xbf16>
%result = llvm.sitofp %arg0 : vector<4xi8> to vector<4xbf16>
// CHECK: llvm.return %[[result]] : vector<4xbf16>
llvm.return %result : vector<4xbf16>
}

// Checks code path when integer width is equal to mantissa width.
// CHECK-LABEL: llvm.func @ui8_to_bf16
llvm.func @ui8_to_bf16(%arg0 : i8) -> bf16 {
// CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : i8 to i16
// CHECK-DAG: %[[const:.*]] = llvm.mlir.constant(17152 : i16) : i16
// CHECK-DAG: %[[or:.*]] = llvm.or %[[zext]], %[[const]] : i16
// CHECK-DAG: %[[exp_mask:.*]] = llvm.mlir.constant(-128 : i16) : i16
// CHECK-DAG: %[[man_mask:.*]] = llvm.mlir.constant(-129 : i16) : i16
// CHECK-DAG: %[[exp_and:.*]] = llvm.and %[[or]], %[[exp_mask]] : i16
// CHECK-DAG: %[[man_and:.*]] = llvm.and %[[or]], %[[man_mask]] : i16
// CHECK-DAG: %[[exp_cast:.*]] = llvm.bitcast %[[exp_and]] : i16 to bf16
// CHECK-DAG: %[[man_cast:.*]] = llvm.bitcast %[[man_and]] : i16 to bf16
// CHECK-DAG: %[[result:.*]] = llvm.fadd %[[man_cast]], %[[exp_cast]] : bf16
%result = llvm.uitofp %arg0 : i8 to bf16
// CHECK: llvm.return %[[result]] : bf16
llvm.return %result : bf16
}

// Checks that expansion does not apply when exponent bias lsb is set.
// CHECK-LABEL: llvm.func @ui11_to_f16
llvm.func @ui11_to_f16(%arg0 : i11) -> f16 {
// CHECK: %[[result:.*]] = llvm.uitofp %arg0 : i11 to f16
%result = llvm.uitofp %arg0 : i11 to f16
// CHECK: llvm.return %[[result]] : f16
llvm.return %result : f16
}

// CHECK-LABEL: llvm.func @ui16_to_f32
llvm.func @ui16_to_f32(%arg0 : i16) -> f32 {
// CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : i16 to i32
// CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(1262485504 : i32) : i32
// CHECK-DAG: %[[add:.*]] = llvm.add %[[zext]], %[[bias]] : i32
// CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : i32 to f32
// CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(0x4B400000 : f32) : f32
// CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]] : f32
%result = llvm.uitofp %arg0 : i16 to f32
// CHECK: llvm.return %[[result]] : f32
llvm.return %result : f32
}

// Checks that expansion only applies to integer width up to mantissa width.
// CHECK-LABEL: llvm.func @si32_to_float
llvm.func @si32_to_float_no_rewrite(%arg0 : i32) -> f32 {
// CHECK: %[[result:.*]] = llvm.sitofp %arg0 : i32 to f32
%result = llvm.sitofp %arg0 : i32 to f32
// CHECK: llvm.return %[[result]] : f32
llvm.return %result : f32
}

// CHECK-LABEL: llvm.func @si8_to_f16
llvm.func @si8_to_f16(%arg0 : i8) -> f16 {
// CHECK-DAG: %[[sext:.*]] = llvm.sext %arg0 : i8 to i16
// CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(26112 : i16) : i16
// CHECK-DAG: %[[add:.*]] = llvm.add %[[sext]], %[[bias]] : i16
// CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : i16 to f16
// CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(1.536000e+03 : f16) : f16
// CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]] : f16
%result = llvm.sitofp %arg0 : i8 to f16
// CHECK: llvm.return %[[result]] : f16
llvm.return %result : f16
}

// CHECK-LABEL: llvm.func @vec_ui4_to_bf16
llvm.func @vec_ui4_to_bf16(%arg0 : vector<4xi4>) -> vector<4xbf16> {
// CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : vector<4xi4> to vector<4xi16>
// CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(dense<17216> : vector<4xi16>) : vector<4xi16>
// CHECK-DAG: %[[add:.*]] = llvm.add %[[zext]], %[[bias]] : vector<4xi16>
// CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[add]] : vector<4xi16> to vector<4xbf16>
// CHECK-DAG: %[[bias:.*]] = llvm.mlir.constant(dense<1.920000e+02> : vector<4xbf16>) : vector<4xbf16>
// CHECK-DAG: %[[result:.*]] = llvm.fsub %[[cast]], %[[bias]] : vector<4xbf16>
%result = llvm.uitofp %arg0 : vector<4xi4> to vector<4xbf16>
// CHECK: llvm.return %[[result]] : vector<4xbf16>
llvm.return %result : vector<4xbf16>
}

// Checks code path when integer width is equal to mantissa width.
// CHECK-LABEL: llvm.func @vec_si8_to_bf16
llvm.func @vec_si8_to_bf16(%arg0 : vector<4xi8>) -> vector<4xbf16> {
// CHECK-DAG: %[[zext:.*]] = llvm.zext %arg0 : vector<4xi8> to vector<4xi16>
// CHECK-DAG: %[[const:.*]] = llvm.mlir.constant(dense<17152> : vector<4xi16>) : vector<4xi16>
// CHECK-DAG: %[[or:.*]] = llvm.or %[[zext]], %[[const]] : vector<4xi16>
// CHECK-DAG: %[[exp_mask:.*]] = llvm.mlir.constant(dense<-128> : vector<4xi16>) : vector<4xi16>
// CHECK-DAG: %[[man_mask:.*]] = llvm.mlir.constant(dense<-129> : vector<4xi16>) : vector<4xi16>
// CHECK-DAG: %[[exp_and:.*]] = llvm.and %[[or]], %[[exp_mask]] : vector<4xi16>
// CHECK-DAG: %[[man_and:.*]] = llvm.and %[[or]], %[[man_mask]] : vector<4xi16>
// CHECK-DAG: %[[exp_cast:.*]] = llvm.bitcast %[[exp_and]] : vector<4xi16> to vector<4xbf16>
// CHECK-DAG: %[[man_cast:.*]] = llvm.bitcast %[[man_and]] : vector<4xi16> to vector<4xbf16>
// CHECK-DAG: %[[result:.*]] = llvm.fsub %[[man_cast]], %[[exp_cast]] : vector<4xbf16>
%result = llvm.sitofp %arg0 : vector<4xi8> to vector<4xbf16>
// CHECK: llvm.return %[[result]] : vector<4xbf16>
llvm.return %result : vector<4xbf16>
}

// Checks that expansion does not apply when unsigned integer width is equal to
// mantissa width.
// CHECK-LABEL: llvm.func @ui8_to_bf16
llvm.func @ui8_to_bf16(%arg0 : i8) -> bf16 {
// CHECK: %[[result:.*]] = llvm.uitofp %arg0 : i8 to bf16
%result = llvm.uitofp %arg0 : i8 to bf16
// CHECK: llvm.return %[[result]] : bf16
llvm.return %result : bf16
}

// Checks that expansion does not apply when exponent bias lsb is set.
// CHECK-LABEL: llvm.func @ui11_to_f16
llvm.func @ui11_to_f16(%arg0 : i11) -> f16 {
// CHECK: %[[result:.*]] = llvm.uitofp %arg0 : i11 to f16
%result = llvm.uitofp %arg0 : i11 to f16
// CHECK: llvm.return %[[result]] : f16
llvm.return %result : f16
}
Loading