Skip to content

Commit ccd4281

Browse files
committed
[mlir][spirv] Enable dot operation for bfloat16
1 parent 89c6144 commit ccd4281

File tree

3 files changed

+18
-4
lines changed

3 files changed

+18
-4
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -445,12 +445,12 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
445445
}];
446446

447447
let arguments = (ins
448-
SPIRV_VectorOf<SPIRV_Float>:$vector1,
449-
SPIRV_VectorOf<SPIRV_Float>:$vector2
448+
SPIRV_VectorOf<SPIRV_AnyFloat>:$vector1,
449+
SPIRV_VectorOf<SPIRV_AnyFloat>:$vector2
450450
);
451451

452452
let results = (outs
453-
SPIRV_Float:$result
453+
SPIRV_AnyFloat:$result
454454
);
455455

456456
let assemblyFormat = "operands attr-dict `:` type($vector1) `->` type($result)";

mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,15 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {
321321

322322
// -----
323323

324+
// CHECK-LABEL: @dot_bf16
325+
func.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) -> bf16 {
326+
// CHECK: spirv.Dot %{{.+}}, %{{.+}} : vector<4xbf16> -> bf16
327+
%0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
328+
return %0 : bf16
329+
}
330+
331+
// -----
332+
324333
// expected-note @+1 {{prior use here}}
325334
func.func @dot(%arg0: vector<4xf32>, %arg1: vector<3xf32>) -> f32 {
326335
// expected-error @+1 {{use of value '%arg1' expects different type than prior uses}}
@@ -339,7 +348,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
339348
// -----
340349

341350
func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
342-
// expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4/8/16}}
351+
// expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16}}
343352
%0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32
344353
return %0 : i32
345354
}

mlir/test/Target/SPIRV/arithmetic-ops.mlir

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,9 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
8686
%0 = spirv.VectorTimesScalar %arg0, %arg1 : (vector<4xf32>, f32) -> vector<4xf32>
8787
spirv.Return
8888
}
89+
spirv.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) "None" {
90+
// CHECK: spirv.Dot %{{.+}}, %{{.+}} : vector<4xbf16> -> f16
91+
%0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
92+
spirv.Return
93+
}
8994
}

0 commit comments

Comments
 (0)