-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[mlir] ResourceAttrInterface
to abstract AsmResourceBlob from resource handle.
#101780
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write If you have received no comments on your PR for a week, you can request a review If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir-ods @llvm/pr-subscribers-mlir-llvm Author: Pavel Prokofyev (integralpro) ChangesFull diff: https://github.com/llvm/llvm-project/pull/101780.diff 5 Files Affected:
diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
index c4a42020d1389..bb96de2ef1f92 100644
--- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
@@ -19,6 +19,8 @@
namespace mlir {
+class AsmResourceBlob;
+
//===----------------------------------------------------------------------===//
// ElementsAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
index 954429c7d8eae..768964ac578ac 100644
--- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
@@ -495,4 +495,35 @@ def MemRefLayoutAttrInterface : AttrInterface<"MemRefLayoutAttrInterface"> {
];
}
+//===----------------------------------------------------------------------===//
+// ResourceAttrInterface
+//===----------------------------------------------------------------------===//
+
+def ResourceAttrInterface : AttrInterface<"ResourceAttr", [TypedAttrInterface]> {
+ let cppNamespace = "::mlir";
+
+ let description = [{
+ The interface abstracts the nature of underlying resource blob from its handle.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ "Get blob key associated with the resource.",
+ "::mlir::StringRef", "getBlobKey", (ins),
+ [{}],
+ [{
+ return $_attr.getRawHandle().getKey();
+ }]
+ >,
+ InterfaceMethod<
+ "Get blob associated with the resource.",
+ "::mlir::AsmResourceBlob *", "getBlob", (ins),
+ [{}],
+ [{
+ return $_attr.getRawHandle().getBlob();
+ }]
+ >
+ ];
+}
+
#endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index d9295936ee97b..aa0b29ffb6fd4 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -431,7 +431,7 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
//===----------------------------------------------------------------------===//
def Builtin_DenseResourceElementsAttr : Builtin_Attr<"DenseResourceElements",
- "dense_resource_elements", [ElementsAttrInterface]> {
+ "dense_resource_elements", [ElementsAttrInterface, ResourceAttrInterface]> {
let summary = "An Attribute containing a dense multi-dimensional array "
"backed by a resource";
let description = [{
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index b468228ea78b7..9549d4628fac0 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -458,10 +458,11 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
/// of the innermost dimension. Constants for other dimensions are still
/// constructed recursively. Returns nullptr on failure and emits errors at
/// `loc`.
-static llvm::Constant *convertDenseResourceElementsAttr(
- Location loc, DenseResourceElementsAttr denseResourceAttr,
- llvm::Type *llvmType, const ModuleTranslation &moduleTranslation) {
- assert(denseResourceAttr && "expected non-null attribute");
+static llvm::Constant *
+convertResourceAttr(Location loc, ResourceAttr resourceAttr,
+ llvm::Type *llvmType,
+ const ModuleTranslation &moduleTranslation) {
+ assert(resourceAttr && "expected non-null attribute");
llvm::Type *innermostLLVMType = getInnermostElementType(llvmType);
if (!llvm::ConstantDataSequential::isElementTypeCompatible(
@@ -470,10 +471,10 @@ static llvm::Constant *convertDenseResourceElementsAttr(
return nullptr;
}
- ShapedType type = denseResourceAttr.getType();
+ ShapedType type = mlir::cast<ShapedType>(resourceAttr.getType());
assert(type.getNumElements() > 0 && "Expected non-empty elements attribute");
- AsmResourceBlob *blob = denseResourceAttr.getRawHandle().getBlob();
+ AsmResourceBlob *blob = resourceAttr.getBlob();
if (!blob) {
emitError(loc, "resource does not exist");
return nullptr;
@@ -486,7 +487,7 @@ static llvm::Constant *convertDenseResourceElementsAttr(
// raw data.
// TODO: we may also need to consider endianness when cross-compiling to an
// architecture where it is different.
- int64_t numElements = denseResourceAttr.getType().getNumElements();
+ int64_t numElements = type.getNumElements();
int64_t elementByteSize = rawData.size() / numElements;
if (8 * elementByteSize != innermostLLVMType->getScalarSizeInBits()) {
emitError(loc, "raw data size does not match element type size");
@@ -497,9 +498,7 @@ static llvm::Constant *convertDenseResourceElementsAttr(
// innermost dimension may be that of the vector element type.
bool hasVectorElementType = isa<VectorType>(type.getElementType());
int64_t numAggregates =
- numElements / (hasVectorElementType
- ? 1
- : denseResourceAttr.getType().getShape().back());
+ numElements / (hasVectorElementType ? 1 : type.getShape().back());
ArrayRef<int64_t> outerShape = type.getShape();
if (!hasVectorElementType)
outerShape = outerShape.drop_back();
@@ -533,8 +532,8 @@ static llvm::Constant *convertDenseResourceElementsAttr(
// Create innermost constants and defer to the default constant creation
// mechanism for other dimensions.
SmallVector<llvm::Constant *> constants;
- int64_t aggregateSize = denseResourceAttr.getType().getShape().back() *
- (innermostLLVMType->getScalarSizeInBits() / 8);
+ int64_t aggregateSize =
+ type.getShape().back() * (innermostLLVMType->getScalarSizeInBits() / 8);
constants.reserve(numAggregates);
for (unsigned i = 0; i < numAggregates; ++i) {
StringRef data(rawData.data() + i * aggregateSize, aggregateSize);
@@ -679,9 +678,8 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
return result;
}
- if (auto denseResourceAttr = dyn_cast<DenseResourceElementsAttr>(attr)) {
- return convertDenseResourceElementsAttr(loc, denseResourceAttr, llvmType,
- moduleTranslation);
+ if (auto resourceAttr = dyn_cast<ResourceAttr>(attr)) {
+ return convertResourceAttr(loc, resourceAttr, llvmType, moduleTranslation);
}
// Fall back to element-by-element construction otherwise.
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index e72bfe9d82e7c..207373722dcbb 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -382,6 +382,23 @@ TEST(DenseResourceElementsAttrTest, CheckInvalidType) {
},
"invalid shape element type for provided type `T`");
}
+
+TEST(DenseResourceElementsAttrTest, CheckResourceInterface) {
+ MLIRContext context;
+ Builder builder(&context);
+
+ ArrayRef<double> data = {0, 1, 2};
+ auto elementType = builder.getF64Type();
+ auto type = RankedTensorType::get(data.size(), elementType);
+ auto attr = DenseF64ResourceElementsAttr::get(
+ type, "resource", UnmanagedAsmResourceBlob::allocateInferAlign(data));
+
+ EXPECT_TRUE(isa<DenseF64ResourceElementsAttr>(attr));
+ auto resourceAttr = dyn_cast<ResourceAttr>(attr);
+ EXPECT_TRUE(resourceAttr);
+ EXPECT_TRUE(resourceAttr.getBlobKey() == "resource");
+ EXPECT_TRUE(resourceAttr.getBlob());
+}
} // namespace
//===----------------------------------------------------------------------===//
|
Have you considered extending ElementsAttr to support the case of returning the full underlying buffer (if it can)? This PR doesn't seem like a great reason to have a new interface, but rather (to me) points out a lack of API in existing interfaces. Would be good to continue consolidating interfaces if we can. |
DenseElementsAttr has the following two-
Another way is to have them in ElementsAttr interface, then DenseResourceElementsAttr and similar may provide implementation specific to the One thing that I'm worrying that it will open up for more uses of structured attributes (like DenseIntOrFPElementsAttr) with opaque storage to be accessed by their What do you think? @River707 |
I think that probably makes sense, and I believe it's what River is suggesting.
Agreed - but it's already noted that this is a 'use-at-your-own-risk' API so most users shouldn't be using it. DenseIntOrFPElementsAttr could use an empty implementation, or return an error, or something else. |
nice! Thank you for the feedback! I'll make a change that follows |
There's a use-case when projects create attributes similar to
DenseResourceElementsAttr
but having custom dialect-specific resource blob handle.The added interface allows to have logic for processing resource attributes of different kinds independently from their concrete type by retrieving their AsmResourceBlob.