Skip to content

[mlir][spirv] Enable dot operation for bfloat16 #145409

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

fairywreath
Copy link
Contributor

Allows dot operations to use vectors of bfloat16 type.

@llvmbot
Copy link
Member

llvmbot commented Jun 23, 2025

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Darren Wihandi (fairywreath)

Changes

Allows dot operations to use vectors of bfloat16 type.


Full diff: https://github.com/llvm/llvm-project/pull/145409.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td (+3-3)
  • (modified) mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir (+10-1)
  • (modified) mlir/test/Target/SPIRV/arithmetic-ops.mlir (+5)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 309079e549846..33af979a45bc5 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -445,12 +445,12 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
   }];
 
   let arguments = (ins
-    SPIRV_VectorOf<SPIRV_Float>:$vector1,
-    SPIRV_VectorOf<SPIRV_Float>:$vector2
+    SPIRV_VectorOf<SPIRV_AnyFloat>:$vector1,
+    SPIRV_VectorOf<SPIRV_AnyFloat>:$vector2
   );
 
   let results = (outs
-    SPIRV_Float:$result
+    SPIRV_AnyFloat:$result
   );
 
   let assemblyFormat = "operands attr-dict `:` type($vector1) `->` type($result)";
diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
index d58c27598f2b8..3adafc15c79f6 100644
--- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
@@ -321,6 +321,15 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
 
 // -----
 
+// CHECK-LABEL: @dot_bf16
+func.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) -> bf16 {
+  // CHECK: spirv.Dot %{{.+}}, %{{.+}} : vector<4xbf16> -> bf16
+  %0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
+  return %0 : bf16
+}
+
+// -----
+
 // expected-note @+1 {{prior use here}}
 func.func @dot(%arg0: vector<4xf32>, %arg1: vector<3xf32>) -> f32 {
   // expected-error @+1 {{use of value '%arg1' expects different type than prior uses}}
@@ -339,7 +348,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
 // -----
 
 func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
-  // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4/8/16}}
+  // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16}}
   %0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32
   return %0 : i32
 }
diff --git a/mlir/test/Target/SPIRV/arithmetic-ops.mlir b/mlir/test/Target/SPIRV/arithmetic-ops.mlir
index b1ea13c6854fd..84d301c608d7d 100644
--- a/mlir/test/Target/SPIRV/arithmetic-ops.mlir
+++ b/mlir/test/Target/SPIRV/arithmetic-ops.mlir
@@ -86,4 +86,9 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     %0 = spirv.VectorTimesScalar %arg0, %arg1 : (vector<4xf32>, f32) -> vector<4xf32>
     spirv.Return
   }
+  spirv.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) "None" {
+    // CHECK: spirv.Dot %{{.+}}, %{{.+}} : vector<4xbf16> -> f16
+    %0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
+    spirv.Return
+  }
 }

@fairywreath
Copy link
Contributor Author

Is there a way to require extension/capability/availability for ops that use a specific type? In this case for example using dot with bf16 requires the BFloat16DotProductKHR capability.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to require extension/capability/availability for ops that use a specific type? In this case for example using dot with bf16 requires the BFloat16DotProductKHR capability.

Yes, I think we need to update:

getIntegerDotProductCapabilities(Operation *op) {
. The tests are in

Comment on lines +32 to +34
static std::optional<spirv::Version> getDotProductMaxVersion() {
return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not like hardcoding the maximum version here - is there currently a way to retrieve the "default" maximum spirv version? From my understanding the only place this is defined is inside the availability field inside the tablegen definition for SPIRV_Op, and there is no way to retrieve that here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not aware of anything better. We could probably have a define for spirv::Version::Latest but in tests we'd still have to test for the exact version.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add a test for VectorReductionToFPDotProd with BF16?

@@ -1,12 +1,12 @@
//===- IntegerDotProductOps.cpp - MLIR SPIR-V Integer Dot Product Ops ----===//
//===- DotProductOps.cpp - MLIR SPIR-V Dot Product Ops ----===//
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure this header is padded to the 80 column line limit

}

SmallVector<ArrayRef<spirv::Extension>, 1> DotOp::getExtensions() {
if (getResult().getType().isBF16()) {
Copy link
Member

@kuhar kuhar Jul 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: prefer isa<BFloat16Type>(getType())

}

SmallVector<ArrayRef<spirv::Capability>, 1> DotOp::getCapabilities() {
if (getResult().getType().isBF16()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants