diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td index 309079e549846..2260ee85493c7 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -445,16 +445,19 @@ def SPIRV_DotOp : SPIRV_Op<"Dot", }]; let arguments = (ins - SPIRV_VectorOf:$vector1, - SPIRV_VectorOf:$vector2 + SPIRV_VectorOf:$vector1, + SPIRV_VectorOf:$vector2 ); let results = (outs - SPIRV_Float:$result + SPIRV_AnyFloat:$result ); let assemblyFormat = "operands attr-dict `:` type($vector1) `->` type($result)"; + // Require dynamic availability specification based on operand/result type. + bit autogenAvailability = 0; + let hasVerifier = 0; } diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt index 1a8f30dd39871..b9aa7b7491abf 100644 --- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt @@ -7,9 +7,9 @@ add_mlir_dialect_library(MLIRSPIRVDialect CastOps.cpp ControlFlowOps.cpp CooperativeMatrixOps.cpp + DotProductOps.cpp GroupOps.cpp ImageOps.cpp - IntegerDotProductOps.cpp MemoryOps.cpp MeshOps.cpp SPIRVAttributes.cpp diff --git a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp b/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp similarity index 83% rename from mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp rename to mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp index f5676f36a0f5f..01ef1bdc42515 100644 --- a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/DotProductOps.cpp @@ -1,4 +1,4 @@ -//===- IntegerDotProductOps.cpp - MLIR SPIR-V Integer Dot Product Ops ----===// +//===- DotProductOps.cpp - MLIR SPIR-V Dot Product Ops -------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// // -// Defines the Integer Dot Product operations in the SPIR-V dialect. +// Defines the Dot Product operations in the SPIR-V dialect. // //===----------------------------------------------------------------------===// @@ -21,6 +21,44 @@ using namespace mlir::spirv::AttrNames; namespace mlir::spirv { +//===----------------------------------------------------------------------===// +// Dot Product ops +//===----------------------------------------------------------------------===// + +static std::optional getDotProductMinVersion() { + return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0. +} + +static std::optional getDotProductMaxVersion() { + return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6. +} + +SmallVector, 1> DotOp::getExtensions() { + if (isa(getType())) { + static const auto extension = spirv::Extension::SPV_KHR_bfloat16; + return {extension}; + } + + return {}; +} + +SmallVector, 1> DotOp::getCapabilities() { + if (isa(getType())) { + static const auto capability = spirv::Capability::BFloat16DotProductKHR; + return {capability}; + } + + return {}; +} + +std::optional DotOp::getMinVersion() { + return getDotProductMinVersion(); +} + +std::optional DotOp::getMaxVersion() { + return getDotProductMaxVersion(); +} + //===----------------------------------------------------------------------===// // Integer Dot Product ops //===----------------------------------------------------------------------===// @@ -71,14 +109,6 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) { return success(); } -static std::optional getIntegerDotProductMinVersion() { - return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0. -} - -static std::optional getIntegerDotProductMaxVersion() { - return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6. -} - static SmallVector, 1> getIntegerDotProductExtensions() { // Requires the SPV_KHR_integer_dot_product extension, specified either @@ -136,10 +166,10 @@ getIntegerDotProductCapabilities(Operation *op) { return getIntegerDotProductCapabilities(*this); \ } \ std::optional OpName::getMinVersion() { \ - return getIntegerDotProductMinVersion(); \ + return getDotProductMinVersion(); \ } \ std::optional OpName::getMaxVersion() { \ - return getIntegerDotProductMaxVersion(); \ + return getDotProductMaxVersion(); \ } SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SDotOp) diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index 4701ac5d96009..db283b31b1550 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -932,6 +932,22 @@ func.func @reduction_minui(%v : vector<3xi32>, %s: i32) -> i32 { // ----- +module attributes { spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { + +// CHECK-LABEL: func @reduction_bf16_addf_mulf +// CHECK-SAME: (%[[ARG0:.+]]: vector<4xbf16>, %[[ARG1:.+]]: vector<4xbf16>) +// CHECK: %[[DOT:.+]] = spirv.Dot %[[ARG0]], %[[ARG1]] : vector<4xbf16> -> bf16 +// CHECK: return %[[DOT]] : bf16 +func.func @reduction_bf16_addf_mulf(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) -> bf16 { + %mul = arith.mulf %arg0, %arg1 : vector<4xbf16> + %red = vector.reduction , %mul : vector<4xbf16> into bf16 + return %red : bf16 +} + +} // end module + +// ----- + // CHECK-LABEL: @shape_cast_same_type // CHECK-SAME: (%[[ARG0:.*]]: vector<2xf32>) // CHECK: return %[[ARG0]] diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir index d58c27598f2b8..3adafc15c79f6 100644 --- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir @@ -321,6 +321,15 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 { // ----- +// CHECK-LABEL: @dot_bf16 +func.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) -> bf16 { + // CHECK: spirv.Dot %{{.+}}, %{{.+}} : vector<4xbf16> -> bf16 + %0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16 + return %0 : bf16 +} + +// ----- + // expected-note @+1 {{prior use here}} func.func @dot(%arg0: vector<4xf32>, %arg1: vector<3xf32>) -> f32 { // expected-error @+1 {{use of value '%arg1' expects different type than prior uses}} @@ -339,7 +348,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 { // ----- func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 { - // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4/8/16}} + // expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16}} %0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32 return %0 : i32 } diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir index 64ba8e3fc249e..9c8665b1e4bbe 100644 --- a/mlir/test/Dialect/SPIRV/IR/availability.mlir +++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir @@ -234,6 +234,20 @@ func.func @udot_acc_sat_vector_4xi16_i64(%a: vector<4xi16>, %acc: i64) -> i64 { return %r: i64 } +//===----------------------------------------------------------------------===// +// Dot Product op with bfloat16 +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: dot_vector_4xbf16_bf16 +func.func @dot_vector_4xbf16_bf16(%a: vector<4xbf16>, %b: vector<4xbf16>) -> bf16 { + // CHECK: min version: v1.0 + // CHECK: max version: v1.6 + // CHECK: extensions: [ [SPV_KHR_bfloat16] ] + // CHECK: capabilities: [ [BFloat16DotProductKHR] ] + %r = spirv.Dot %a, %a: vector<4xbf16> -> bf16 + return %r: bf16 +} + //===----------------------------------------------------------------------===// // Primitive ops //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/SPIRV/arithmetic-ops.mlir b/mlir/test/Target/SPIRV/arithmetic-ops.mlir index b1ea13c6854fd..b80e17f979daa 100644 --- a/mlir/test/Target/SPIRV/arithmetic-ops.mlir +++ b/mlir/test/Target/SPIRV/arithmetic-ops.mlir @@ -86,4 +86,9 @@ spirv.module Logical GLSL450 requires #spirv.vce { %0 = spirv.VectorTimesScalar %arg0, %arg1 : (vector<4xf32>, f32) -> vector<4xf32> spirv.Return } + spirv.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) "None" { + // CHECK: spirv.Dot %{{.+}}, %{{.+}} : vector<4xbf16> -> bf16 + %0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16 + spirv.Return + } }