-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][spirv] Enable dot operation for bfloat16 #145409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Darren Wihandi (fairywreath) ChangesAllows dot operations to use vectors of bfloat16 type. Full diff: https://github.com/llvm/llvm-project/pull/145409.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 309079e549846..33af979a45bc5 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -445,12 +445,12 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
}];
let arguments = (ins
- SPIRV_VectorOf<SPIRV_Float>:$vector1,
- SPIRV_VectorOf<SPIRV_Float>:$vector2
+ SPIRV_VectorOf<SPIRV_AnyFloat>:$vector1,
+ SPIRV_VectorOf<SPIRV_AnyFloat>:$vector2
);
let results = (outs
- SPIRV_Float:$result
+ SPIRV_AnyFloat:$result
);
let assemblyFormat = "operands attr-dict `:` type($vector1) `->` type($result)";
diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
index d58c27598f2b8..3adafc15c79f6 100644
--- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
@@ -321,6 +321,15 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
// -----
+// CHECK-LABEL: @dot_bf16
+func.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) -> bf16 {
+ // CHECK: spirv.Dot %{{.+}}, %{{.+}} : vector<4xbf16> -> bf16
+ %0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
+ return %0 : bf16
+}
+
+// -----
+
// expected-note @+1 {{prior use here}}
func.func @dot(%arg0: vector<4xf32>, %arg1: vector<3xf32>) -> f32 {
// expected-error @+1 {{use of value '%arg1' expects different type than prior uses}}
@@ -339,7 +348,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
// -----
func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
- // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4/8/16}}
+ // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16}}
%0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32
return %0 : i32
}
diff --git a/mlir/test/Target/SPIRV/arithmetic-ops.mlir b/mlir/test/Target/SPIRV/arithmetic-ops.mlir
index b1ea13c6854fd..84d301c608d7d 100644
--- a/mlir/test/Target/SPIRV/arithmetic-ops.mlir
+++ b/mlir/test/Target/SPIRV/arithmetic-ops.mlir
@@ -86,4 +86,9 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%0 = spirv.VectorTimesScalar %arg0, %arg1 : (vector<4xf32>, f32) -> vector<4xf32>
spirv.Return
}
+ spirv.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) "None" {
+ // CHECK: spirv.Dot %{{.+}}, %{{.+}} : vector<4xbf16> -> f16
+ %0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
+ spirv.Return
+ }
}
|
Is there a way to require extension/capability/availability for ops that use a specific type? In this case for example using dot with bf16 requires the |
ccd4281
to
71755ac
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to require extension/capability/availability for ops that use a specific type? In this case for example using dot with bf16 requires the
BFloat16DotProductKHR
capability.
Yes, I think we need to update:
getIntegerDotProductCapabilities(Operation *op) { |
// Integer Dot Product ops |
static std::optional<spirv::Version> getDotProductMaxVersion() { | ||
return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6. | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not like hardcoding the maximum version here - is there currently a way to retrieve the "default" maximum spirv version? From my understanding the only place this is defined is inside the availability field inside the tablegen definition for SPIRV_Op
, and there is no way to retrieve that here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not aware of anything better. We could probably have a define for spirv::Version::Latest
but in tests we'd still have to test for the exact version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also add a test for VectorReductionToFPDotProd
with BF16?
@@ -1,12 +1,12 @@ | |||
//===- IntegerDotProductOps.cpp - MLIR SPIR-V Integer Dot Product Ops ----===// | |||
//===- DotProductOps.cpp - MLIR SPIR-V Dot Product Ops ----===// |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure this header is padded to the 80 column line limit
} | ||
|
||
SmallVector<ArrayRef<spirv::Extension>, 1> DotOp::getExtensions() { | ||
if (getResult().getType().isBF16()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: prefer isa<BFloat16Type>(getType())
} | ||
|
||
SmallVector<ArrayRef<spirv::Capability>, 1> DotOp::getCapabilities() { | ||
if (getResult().getType().isBF16()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also here
Allows dot operations to use vectors of bfloat16 type.