-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[LLVM][Clang] Enable strict mode for getTrailingObjects
#144930
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
f5216d4
to
e985849
Compare
@llvm/pr-subscribers-clang @llvm/pr-subscribers-llvm-support Author: Rahul Joshi (jurahul) ChangesUnder strict mode, the templated Full diff: https://github.com/llvm/llvm-project/pull/144930.diff 4 Files Affected:
diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h
index 2fa8fa529741e..b62ebd614e4c7 100644
--- a/clang/include/clang/AST/OpenMPClause.h
+++ b/clang/include/clang/AST/OpenMPClause.h
@@ -295,7 +295,8 @@ template <class T> class OMPVarListClause : public OMPClause {
/// Fetches list of variables associated with this clause.
MutableArrayRef<Expr *> getVarRefs() {
- return static_cast<T *>(this)->template getTrailingObjects<Expr *>(NumVars);
+ return static_cast<T *>(this)
+ ->template getTrailingObjects<Expr *, /*Strict=*/false>(NumVars);
}
/// Sets the list of variables for this clause.
@@ -334,8 +335,8 @@ template <class T> class OMPVarListClause : public OMPClause {
/// Fetches list of all variables in the clause.
ArrayRef<const Expr *> getVarRefs() const {
- return static_cast<const T *>(this)->template getTrailingObjects<Expr *>(
- NumVars);
+ return static_cast<const T *>(this)
+ ->template getTrailingObjects<Expr *, /*Strict=*/false>(NumVars);
}
};
@@ -380,7 +381,8 @@ template <class T> class OMPDirectiveListClause : public OMPClause {
MutableArrayRef<OpenMPDirectiveKind> getDirectiveKinds() {
return static_cast<T *>(this)
- ->template getTrailingObjects<OpenMPDirectiveKind>(NumKinds);
+ ->template getTrailingObjects<OpenMPDirectiveKind, /*Strict=*/false>(
+ NumKinds);
}
void setDirectiveKinds(ArrayRef<OpenMPDirectiveKind> DK) {
@@ -5901,15 +5903,17 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
/// Get the unique declarations that are in the trailing objects of the
/// class.
MutableArrayRef<ValueDecl *> getUniqueDeclsRef() {
- return static_cast<T *>(this)->template getTrailingObjects<ValueDecl *>(
- NumUniqueDeclarations);
+ return static_cast<T *>(this)
+ ->template getTrailingObjects<ValueDecl *, /*Strict=*/false>(
+ NumUniqueDeclarations);
}
/// Get the unique declarations that are in the trailing objects of the
/// class.
ArrayRef<ValueDecl *> getUniqueDeclsRef() const {
return static_cast<const T *>(this)
- ->template getTrailingObjects<ValueDecl *>(NumUniqueDeclarations);
+ ->template getTrailingObjects<ValueDecl *, /*Strict=*/false>(
+ NumUniqueDeclarations);
}
/// Set the unique declarations that are in the trailing objects of the
@@ -5923,15 +5927,17 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
/// Get the number of lists per declaration that are in the trailing
/// objects of the class.
MutableArrayRef<unsigned> getDeclNumListsRef() {
- return static_cast<T *>(this)->template getTrailingObjects<unsigned>(
- NumUniqueDeclarations);
+ return static_cast<T *>(this)
+ ->template getTrailingObjects<unsigned, /*Strict=*/false>(
+ NumUniqueDeclarations);
}
/// Get the number of lists per declaration that are in the trailing
/// objects of the class.
ArrayRef<unsigned> getDeclNumListsRef() const {
- return static_cast<const T *>(this)->template getTrailingObjects<unsigned>(
- NumUniqueDeclarations);
+ return static_cast<const T *>(this)
+ ->template getTrailingObjects<unsigned, /*Strict=*/false>(
+ NumUniqueDeclarations);
}
/// Set the number of lists per declaration that are in the trailing
@@ -5946,7 +5952,8 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
/// objects of the class. They are appended after the number of lists.
MutableArrayRef<unsigned> getComponentListSizesRef() {
return MutableArrayRef<unsigned>(
- static_cast<T *>(this)->template getTrailingObjects<unsigned>() +
+ static_cast<T *>(this)
+ ->template getTrailingObjects<unsigned, /*Strict=*/false>() +
NumUniqueDeclarations,
NumComponentLists);
}
@@ -5955,7 +5962,8 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
/// objects of the class. They are appended after the number of lists.
ArrayRef<unsigned> getComponentListSizesRef() const {
return ArrayRef<unsigned>(
- static_cast<const T *>(this)->template getTrailingObjects<unsigned>() +
+ static_cast<const T *>(this)
+ ->template getTrailingObjects<unsigned, /*Strict=*/false>() +
NumUniqueDeclarations,
NumComponentLists);
}
@@ -5971,13 +5979,15 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
/// Get the components that are in the trailing objects of the class.
MutableArrayRef<MappableComponent> getComponentsRef() {
return static_cast<T *>(this)
- ->template getTrailingObjects<MappableComponent>(NumComponents);
+ ->template getTrailingObjects<MappableComponent, /*Strict=*/false>(
+ NumComponents);
}
/// Get the components that are in the trailing objects of the class.
ArrayRef<MappableComponent> getComponentsRef() const {
return static_cast<const T *>(this)
- ->template getTrailingObjects<MappableComponent>(NumComponents);
+ ->template getTrailingObjects<MappableComponent, /*Strict=*/false>(
+ NumComponents);
}
/// Set the components that are in the trailing objects of the class.
@@ -6084,7 +6094,8 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
assert(SupportsMapper &&
"Must be a clause that is possible to have user-defined mappers");
return llvm::MutableArrayRef<Expr *>(
- static_cast<T *>(this)->template getTrailingObjects<Expr *>() +
+ static_cast<T *>(this)
+ ->template getTrailingObjects<Expr *, /*Strict=*/false>() +
OMPVarListClause<T>::varlist_size(),
OMPVarListClause<T>::varlist_size());
}
@@ -6095,7 +6106,8 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
assert(SupportsMapper &&
"Must be a clause that is possible to have user-defined mappers");
return llvm::ArrayRef<Expr *>(
- static_cast<const T *>(this)->template getTrailingObjects<Expr *>() +
+ static_cast<const T *>(this)
+ ->template getTrailingObjects<Expr *, /*Strict=*/false>() +
OMPVarListClause<T>::varlist_size(),
OMPVarListClause<T>::varlist_size());
}
diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index c3722c65abf6e..b93a31ca4ed36 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -2024,7 +2024,8 @@ CXXBaseSpecifier **CastExpr::path_buffer() {
#define ABSTRACT_STMT(x)
#define CASTEXPR(Type, Base) \
case Stmt::Type##Class: \
- return static_cast<Type *>(this)->getTrailingObjects<CXXBaseSpecifier *>();
+ return static_cast<Type *>(this) \
+ ->getTrailingObjects<CXXBaseSpecifier *, /*Strict=*/false>();
#define STMT(Type, Base)
#include "clang/AST/StmtNodes.inc"
default:
diff --git a/llvm/include/llvm/Support/TrailingObjects.h b/llvm/include/llvm/Support/TrailingObjects.h
index f25f2311a81a4..f6f8f9ebaa38c 100644
--- a/llvm/include/llvm/Support/TrailingObjects.h
+++ b/llvm/include/llvm/Support/TrailingObjects.h
@@ -228,12 +228,17 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl<
using ParentType::getTrailingObjectsImpl;
- // This function contains only a static_assert BaseTy is final. The
- // static_assert must be in a function, and not at class-level
- // because BaseTy isn't complete at class instantiation time, but
- // will be by the time this function is instantiated.
- static void verifyTrailingObjectsAssertions() {
+ template <bool Strict> static void verifyTrailingObjectsAssertions() {
+ // The static_assert for BaseTy must be in a function, and not at
+ // class-level because BaseTy isn't complete at class instantiation time,
+ // but will be by the time this function is instantiated.
static_assert(std::is_final<BaseTy>(), "BaseTy must be final.");
+
+ // Verify that templated getTrailingObjects() is used only with multiple
+ // trailing types, unless strict mode is disabled.
+ static_assert(!Strict || sizeof...(TrailingTys) > 1,
+ "Use templated getTrailingObjects() only when there are "
+ "multiple trailing types");
}
// These two methods are the base of the recursion for this method.
@@ -282,8 +287,9 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl<
/// Returns a pointer to the trailing object array of the given type
/// (which must be one of those specified in the class template). The
/// array may have zero or more elements in it.
- template <typename T> const T *getTrailingObjects() const {
- verifyTrailingObjectsAssertions();
+ template <typename T, bool Strict = true>
+ const T *getTrailingObjects() const {
+ verifyTrailingObjectsAssertions<Strict>();
// Forwards to an impl function with overloads, since member
// function templates can't be specialized.
return this->getTrailingObjectsImpl(
@@ -294,8 +300,8 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl<
/// Returns a pointer to the trailing object array of the given type
/// (which must be one of those specified in the class template). The
/// array may have zero or more elements in it.
- template <typename T> T *getTrailingObjects() {
- verifyTrailingObjectsAssertions();
+ template <typename T, bool Strict = true> T *getTrailingObjects() {
+ verifyTrailingObjectsAssertions<Strict>();
// Forwards to an impl function with overloads, since member
// function templates can't be specialized.
return this->getTrailingObjectsImpl(
@@ -310,23 +316,25 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl<
static_assert(sizeof...(TrailingTys) == 1,
"Can use non-templated getTrailingObjects() only when there "
"is a single trailing type");
- return getTrailingObjects<FirstTrailingType>();
+ return getTrailingObjects<FirstTrailingType, /*Strict=*/false>();
}
FirstTrailingType *getTrailingObjects() {
static_assert(sizeof...(TrailingTys) == 1,
"Can use non-templated getTrailingObjects() only when there "
"is a single trailing type");
- return getTrailingObjects<FirstTrailingType>();
+ return getTrailingObjects<FirstTrailingType, /*Strict=*/false>();
}
// Functions that return the trailing objects as ArrayRefs.
- template <typename T> MutableArrayRef<T> getTrailingObjects(size_t N) {
- return MutableArrayRef(getTrailingObjects<T>(), N);
+ template <typename T, bool Strict = true>
+ MutableArrayRef<T> getTrailingObjects(size_t N) {
+ return MutableArrayRef(getTrailingObjects<T, Strict>(), N);
}
- template <typename T> ArrayRef<T> getTrailingObjects(size_t N) const {
- return ArrayRef(getTrailingObjects<T>(), N);
+ template <typename T, bool Strict = true>
+ ArrayRef<T> getTrailingObjects(size_t N) const {
+ return ArrayRef(getTrailingObjects<T, Strict>(), N);
}
MutableArrayRef<FirstTrailingType> getTrailingObjects(size_t N) {
diff --git a/llvm/unittests/Support/TrailingObjectsTest.cpp b/llvm/unittests/Support/TrailingObjectsTest.cpp
index 2590f375b6598..83afb22f837aa 100644
--- a/llvm/unittests/Support/TrailingObjectsTest.cpp
+++ b/llvm/unittests/Support/TrailingObjectsTest.cpp
@@ -123,11 +123,12 @@ TEST(TrailingObjects, OneArg) {
EXPECT_EQ(Class1::totalSizeToAlloc<short>(3),
sizeof(Class1) + sizeof(short) * 3);
- EXPECT_EQ(C->getTrailingObjects<short>(), reinterpret_cast<short *>(C + 1));
+ EXPECT_EQ(C->getTrailingObjects(), reinterpret_cast<short *>(C + 1));
EXPECT_EQ(C->get(0), 1);
EXPECT_EQ(C->get(2), 3);
- EXPECT_EQ(C->getTrailingObjects(), C->getTrailingObjects<short>());
+ EXPECT_EQ(C->getTrailingObjects(),
+ (C->getTrailingObjects<short, /*Strict=*/false>()));
delete C;
}
|
This is the(?) final step in the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you better explain what 'strict' is supposed to mean here? it isn't clear to me the purpose of it. I don't really understand when it is necessary/what it is doing/etc.
Sure: The first line of the description attempts to define it:
I essentially used this strict mode to find all templated calls to |
Oh! I see... hmm... So this is for cases where we don't know how many items in the `TrailingObjects there are... I think I might prefer instead of a |
Just to be precise, it's not the number of items but the number of trailing types. I half expected the comment to make it a different function. I'll do that |
Ah right, I was using 'items' to mean 'types' too, so we were clear there :) Sorry for my lack of precision. |
Under strict mode, the templated `getTrailingObjects` can be called only when there is > 1 trailing types. The strict mode can be disabled on a per-call basis when its not possible to know statically if there will be a single or multiple trailing types (like in OpenMPClause.h).
e985849
to
ee982b8
Compare
That was fast! Thanks |
getTrailingObjects
getTrailingObjects
Disallow calls to templated
getTrailingObjects
if there is a single trailing type (strict mode). AddgetTrailingObjectsNonStrict
for cases when it's not possible to know statically if there will be a single or multiple trailing types (like in OpenMPClause.h) to bypass the struct checks.This will ensure that future users of TrailingObjects class do not accidently use the templated
getTrailingObjects
when they have a single trailing type.