-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][spirv] Enable dot operation for bfloat16 #145409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,12 @@ | ||
//===- IntegerDotProductOps.cpp - MLIR SPIR-V Integer Dot Product Ops ----===// | ||
//===- DotProductOps.cpp - MLIR SPIR-V Dot Product Ops ----===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Defines the Integer Dot Product operations in the SPIR-V dialect. | ||
// Defines the Dot Product operations in the SPIR-V dialect. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
|
@@ -21,6 +21,44 @@ using namespace mlir::spirv::AttrNames; | |
|
||
namespace mlir::spirv { | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Dot Product ops | ||
//===----------------------------------------------------------------------===// | ||
|
||
static std::optional<spirv::Version> getDotProductMinVersion() { | ||
return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0. | ||
} | ||
|
||
static std::optional<spirv::Version> getDotProductMaxVersion() { | ||
return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6. | ||
} | ||
|
||
Comment on lines
+32
to
+34
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not like hardcoding the maximum version here - is there currently a way to retrieve the "default" maximum spirv version? From my understanding the only place this is defined is inside the availability field inside the tablegen definition for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not aware of anything better. We could probably have a define for |
||
SmallVector<ArrayRef<spirv::Extension>, 1> DotOp::getExtensions() { | ||
if (getResult().getType().isBF16()) { | ||
static const auto extension = spirv::Extension::SPV_KHR_bfloat16; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: prefer |
||
return {extension}; | ||
} | ||
|
||
return {}; | ||
} | ||
|
||
SmallVector<ArrayRef<spirv::Capability>, 1> DotOp::getCapabilities() { | ||
if (getResult().getType().isBF16()) { | ||
static const auto capability = spirv::Capability::BFloat16DotProductKHR; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also here |
||
return {capability}; | ||
} | ||
|
||
return {}; | ||
} | ||
|
||
std::optional<spirv::Version> DotOp::getMinVersion() { | ||
return getDotProductMinVersion(); | ||
} | ||
|
||
std::optional<spirv::Version> DotOp::getMaxVersion() { | ||
return getDotProductMaxVersion(); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Integer Dot Product ops | ||
//===----------------------------------------------------------------------===// | ||
|
@@ -71,14 +109,6 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) { | |
return success(); | ||
} | ||
|
||
static std::optional<spirv::Version> getIntegerDotProductMinVersion() { | ||
return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0. | ||
} | ||
|
||
static std::optional<spirv::Version> getIntegerDotProductMaxVersion() { | ||
return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6. | ||
} | ||
|
||
static SmallVector<ArrayRef<spirv::Extension>, 1> | ||
getIntegerDotProductExtensions() { | ||
// Requires the SPV_KHR_integer_dot_product extension, specified either | ||
|
@@ -136,10 +166,10 @@ getIntegerDotProductCapabilities(Operation *op) { | |
return getIntegerDotProductCapabilities<OpName>(*this); \ | ||
} \ | ||
std::optional<spirv::Version> OpName::getMinVersion() { \ | ||
return getIntegerDotProductMinVersion(); \ | ||
return getDotProductMinVersion(); \ | ||
} \ | ||
std::optional<spirv::Version> OpName::getMaxVersion() { \ | ||
return getIntegerDotProductMaxVersion(); \ | ||
return getDotProductMaxVersion(); \ | ||
} | ||
|
||
SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SDotOp) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure this header is padded to the 80 column line limit