Skip to content

[mlir][spirv] Enable dot operation for bfloat16 #145409

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -445,16 +445,19 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
}];

let arguments = (ins
SPIRV_VectorOf<SPIRV_Float>:$vector1,
SPIRV_VectorOf<SPIRV_Float>:$vector2
SPIRV_VectorOf<SPIRV_AnyFloat>:$vector1,
SPIRV_VectorOf<SPIRV_AnyFloat>:$vector2
);

let results = (outs
SPIRV_Float:$result
SPIRV_AnyFloat:$result
);

let assemblyFormat = "operands attr-dict `:` type($vector1) `->` type($result)";

// Require dynamic availability specification based on operand/result type.
bit autogenAvailability = 0;

let hasVerifier = 0;
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ add_mlir_dialect_library(MLIRSPIRVDialect
CastOps.cpp
ControlFlowOps.cpp
CooperativeMatrixOps.cpp
DotProductOps.cpp
GroupOps.cpp
ImageOps.cpp
IntegerDotProductOps.cpp
MemoryOps.cpp
MeshOps.cpp
SPIRVAttributes.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
//===- IntegerDotProductOps.cpp - MLIR SPIR-V Integer Dot Product Ops ----===//
//===- DotProductOps.cpp - MLIR SPIR-V Dot Product Ops ----===//
//
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure this header is padded to the 80 column line limit

// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines the Integer Dot Product operations in the SPIR-V dialect.
// Defines the Dot Product operations in the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//

Expand All @@ -21,6 +21,44 @@ using namespace mlir::spirv::AttrNames;

namespace mlir::spirv {

//===----------------------------------------------------------------------===//
// Dot Product ops
//===----------------------------------------------------------------------===//

static std::optional<spirv::Version> getDotProductMinVersion() {
return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
}

static std::optional<spirv::Version> getDotProductMaxVersion() {
return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
}

Comment on lines +32 to +34
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not like hardcoding the maximum version here - is there currently a way to retrieve the "default" maximum spirv version? From my understanding the only place this is defined is inside the availability field inside the tablegen definition for SPIRV_Op, and there is no way to retrieve that here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not aware of anything better. We could probably have a define for spirv::Version::Latest but in tests we'd still have to test for the exact version.

SmallVector<ArrayRef<spirv::Extension>, 1> DotOp::getExtensions() {
if (getResult().getType().isBF16()) {
static const auto extension = spirv::Extension::SPV_KHR_bfloat16;
Copy link
Member

@kuhar kuhar Jul 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: prefer isa<BFloat16Type>(getType())

return {extension};
}

return {};
}

SmallVector<ArrayRef<spirv::Capability>, 1> DotOp::getCapabilities() {
if (getResult().getType().isBF16()) {
static const auto capability = spirv::Capability::BFloat16DotProductKHR;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also here

return {capability};
}

return {};
}

std::optional<spirv::Version> DotOp::getMinVersion() {
return getDotProductMinVersion();
}

std::optional<spirv::Version> DotOp::getMaxVersion() {
return getDotProductMaxVersion();
}

//===----------------------------------------------------------------------===//
// Integer Dot Product ops
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -71,14 +109,6 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
return success();
}

static std::optional<spirv::Version> getIntegerDotProductMinVersion() {
return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
}

static std::optional<spirv::Version> getIntegerDotProductMaxVersion() {
return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
}

static SmallVector<ArrayRef<spirv::Extension>, 1>
getIntegerDotProductExtensions() {
// Requires the SPV_KHR_integer_dot_product extension, specified either
Expand Down Expand Up @@ -136,10 +166,10 @@ getIntegerDotProductCapabilities(Operation *op) {
return getIntegerDotProductCapabilities<OpName>(*this); \
} \
std::optional<spirv::Version> OpName::getMinVersion() { \
return getIntegerDotProductMinVersion(); \
return getDotProductMinVersion(); \
} \
std::optional<spirv::Version> OpName::getMaxVersion() { \
return getIntegerDotProductMaxVersion(); \
return getDotProductMaxVersion(); \
}

SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SDotOp)
Expand Down
11 changes: 10 additions & 1 deletion mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,15 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {

// -----

// CHECK-LABEL: @dot_bf16
func.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) -> bf16 {
// CHECK: spirv.Dot %{{.+}}, %{{.+}} : vector<4xbf16> -> bf16
%0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
return %0 : bf16
}

// -----

// expected-note @+1 {{prior use here}}
func.func @dot(%arg0: vector<4xf32>, %arg1: vector<3xf32>) -> f32 {
// expected-error @+1 {{use of value '%arg1' expects different type than prior uses}}
Expand All @@ -339,7 +348,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
// -----

func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
// expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4/8/16}}
// expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16}}
%0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32
return %0 : i32
}
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Dialect/SPIRV/IR/availability.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,20 @@ func.func @udot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 {
return %r: i64
}

//===----------------------------------------------------------------------===//
// Dot Product op with bfloat16
//===----------------------------------------------------------------------===//

// CHECK-LABEL: dot_vector_4xbf16_bf16
func.func @dot_vector_4xbf16_bf16(%a: vector<4xbf16>, %b: vector<4xbf16>) -> bf16 {
// CHECK: min version: v1.0
// CHECK: max version: v1.6
// CHECK: extensions: [ [SPV_KHR_bfloat16] ]
// CHECK: capabilities: [ [BFloat16DotProductKHR] ]
%r = spirv.Dot %a, %a: vector<4xbf16> -> bf16
return %r: bf16
}

//===----------------------------------------------------------------------===//
// Primitive ops
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions mlir/test/Target/SPIRV/arithmetic-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,9 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%0 = spirv.VectorTimesScalar %arg0, %arg1 : (vector<4xf32>, f32) -> vector<4xf32>
spirv.Return
}
spirv.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) "None" {
// CHECK: spirv.Dot %{{.+}}, %{{.+}} : vector<4xbf16> -> bf16
%0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
spirv.Return
}
}
Loading