Skip to content

[LLVM][Clang] Enable strict mode for getTrailingObjects #144930

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jurahul
Copy link
Contributor

@jurahul jurahul commented Jun 19, 2025

Disallow calls to templated getTrailingObjects if there is a single trailing type (strict mode). Add getTrailingObjectsNonStrict for cases when it's not possible to know statically if there will be a single or multiple trailing types (like in OpenMPClause.h) to bypass the struct checks.

This will ensure that future users of TrailingObjects class do not accidently use the templated getTrailingObjects when they have a single trailing type.

@jurahul jurahul force-pushed the trailingObjects_strict branch 2 times, most recently from f5216d4 to e985849 Compare June 20, 2025 14:29
@jurahul jurahul marked this pull request as ready for review June 20, 2025 15:56
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" llvm:support clang:openmp OpenMP related changes to Clang labels Jun 20, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 20, 2025

@llvm/pr-subscribers-clang

@llvm/pr-subscribers-llvm-support

Author: Rahul Joshi (jurahul)

Changes

Under strict mode, the templated getTrailingObjects can be called only when there is > 1 trailing types. The strict mode can be disabled on a per-call basis when its not possible to know statically if there will be a single or multiple trailing types (like in OpenMPClause.h).


Full diff: https://github.com/llvm/llvm-project/pull/144930.diff

4 Files Affected:

  • (modified) clang/include/clang/AST/OpenMPClause.h (+29-17)
  • (modified) clang/lib/AST/Expr.cpp (+2-1)
  • (modified) llvm/include/llvm/Support/TrailingObjects.h (+23-15)
  • (modified) llvm/unittests/Support/TrailingObjectsTest.cpp (+3-2)
diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h
index 2fa8fa529741e..b62ebd614e4c7 100644
--- a/clang/include/clang/AST/OpenMPClause.h
+++ b/clang/include/clang/AST/OpenMPClause.h
@@ -295,7 +295,8 @@ template <class T> class OMPVarListClause : public OMPClause {
 
   /// Fetches list of variables associated with this clause.
   MutableArrayRef<Expr *> getVarRefs() {
-    return static_cast<T *>(this)->template getTrailingObjects<Expr *>(NumVars);
+    return static_cast<T *>(this)
+        ->template getTrailingObjects<Expr *, /*Strict=*/false>(NumVars);
   }
 
   /// Sets the list of variables for this clause.
@@ -334,8 +335,8 @@ template <class T> class OMPVarListClause : public OMPClause {
 
   /// Fetches list of all variables in the clause.
   ArrayRef<const Expr *> getVarRefs() const {
-    return static_cast<const T *>(this)->template getTrailingObjects<Expr *>(
-        NumVars);
+    return static_cast<const T *>(this)
+        ->template getTrailingObjects<Expr *, /*Strict=*/false>(NumVars);
   }
 };
 
@@ -380,7 +381,8 @@ template <class T> class OMPDirectiveListClause : public OMPClause {
 
   MutableArrayRef<OpenMPDirectiveKind> getDirectiveKinds() {
     return static_cast<T *>(this)
-        ->template getTrailingObjects<OpenMPDirectiveKind>(NumKinds);
+        ->template getTrailingObjects<OpenMPDirectiveKind, /*Strict=*/false>(
+            NumKinds);
   }
 
   void setDirectiveKinds(ArrayRef<OpenMPDirectiveKind> DK) {
@@ -5901,15 +5903,17 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
   /// Get the unique declarations that are in the trailing objects of the
   /// class.
   MutableArrayRef<ValueDecl *> getUniqueDeclsRef() {
-    return static_cast<T *>(this)->template getTrailingObjects<ValueDecl *>(
-        NumUniqueDeclarations);
+    return static_cast<T *>(this)
+        ->template getTrailingObjects<ValueDecl *, /*Strict=*/false>(
+            NumUniqueDeclarations);
   }
 
   /// Get the unique declarations that are in the trailing objects of the
   /// class.
   ArrayRef<ValueDecl *> getUniqueDeclsRef() const {
     return static_cast<const T *>(this)
-        ->template getTrailingObjects<ValueDecl *>(NumUniqueDeclarations);
+        ->template getTrailingObjects<ValueDecl *, /*Strict=*/false>(
+            NumUniqueDeclarations);
   }
 
   /// Set the unique declarations that are in the trailing objects of the
@@ -5923,15 +5927,17 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
   /// Get the number of lists per declaration that are in the trailing
   /// objects of the class.
   MutableArrayRef<unsigned> getDeclNumListsRef() {
-    return static_cast<T *>(this)->template getTrailingObjects<unsigned>(
-        NumUniqueDeclarations);
+    return static_cast<T *>(this)
+        ->template getTrailingObjects<unsigned, /*Strict=*/false>(
+            NumUniqueDeclarations);
   }
 
   /// Get the number of lists per declaration that are in the trailing
   /// objects of the class.
   ArrayRef<unsigned> getDeclNumListsRef() const {
-    return static_cast<const T *>(this)->template getTrailingObjects<unsigned>(
-        NumUniqueDeclarations);
+    return static_cast<const T *>(this)
+        ->template getTrailingObjects<unsigned, /*Strict=*/false>(
+            NumUniqueDeclarations);
   }
 
   /// Set the number of lists per declaration that are in the trailing
@@ -5946,7 +5952,8 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
   /// objects of the class. They are appended after the number of lists.
   MutableArrayRef<unsigned> getComponentListSizesRef() {
     return MutableArrayRef<unsigned>(
-        static_cast<T *>(this)->template getTrailingObjects<unsigned>() +
+        static_cast<T *>(this)
+                ->template getTrailingObjects<unsigned, /*Strict=*/false>() +
             NumUniqueDeclarations,
         NumComponentLists);
   }
@@ -5955,7 +5962,8 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
   /// objects of the class. They are appended after the number of lists.
   ArrayRef<unsigned> getComponentListSizesRef() const {
     return ArrayRef<unsigned>(
-        static_cast<const T *>(this)->template getTrailingObjects<unsigned>() +
+        static_cast<const T *>(this)
+                ->template getTrailingObjects<unsigned, /*Strict=*/false>() +
             NumUniqueDeclarations,
         NumComponentLists);
   }
@@ -5971,13 +5979,15 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
   /// Get the components that are in the trailing objects of the class.
   MutableArrayRef<MappableComponent> getComponentsRef() {
     return static_cast<T *>(this)
-        ->template getTrailingObjects<MappableComponent>(NumComponents);
+        ->template getTrailingObjects<MappableComponent, /*Strict=*/false>(
+            NumComponents);
   }
 
   /// Get the components that are in the trailing objects of the class.
   ArrayRef<MappableComponent> getComponentsRef() const {
     return static_cast<const T *>(this)
-        ->template getTrailingObjects<MappableComponent>(NumComponents);
+        ->template getTrailingObjects<MappableComponent, /*Strict=*/false>(
+            NumComponents);
   }
 
   /// Set the components that are in the trailing objects of the class.
@@ -6084,7 +6094,8 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
     assert(SupportsMapper &&
            "Must be a clause that is possible to have user-defined mappers");
     return llvm::MutableArrayRef<Expr *>(
-        static_cast<T *>(this)->template getTrailingObjects<Expr *>() +
+        static_cast<T *>(this)
+                ->template getTrailingObjects<Expr *, /*Strict=*/false>() +
             OMPVarListClause<T>::varlist_size(),
         OMPVarListClause<T>::varlist_size());
   }
@@ -6095,7 +6106,8 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
     assert(SupportsMapper &&
            "Must be a clause that is possible to have user-defined mappers");
     return llvm::ArrayRef<Expr *>(
-        static_cast<const T *>(this)->template getTrailingObjects<Expr *>() +
+        static_cast<const T *>(this)
+                ->template getTrailingObjects<Expr *, /*Strict=*/false>() +
             OMPVarListClause<T>::varlist_size(),
         OMPVarListClause<T>::varlist_size());
   }
diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index c3722c65abf6e..b93a31ca4ed36 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -2024,7 +2024,8 @@ CXXBaseSpecifier **CastExpr::path_buffer() {
 #define ABSTRACT_STMT(x)
 #define CASTEXPR(Type, Base)                                                   \
   case Stmt::Type##Class:                                                      \
-    return static_cast<Type *>(this)->getTrailingObjects<CXXBaseSpecifier *>();
+    return static_cast<Type *>(this)                                           \
+        ->getTrailingObjects<CXXBaseSpecifier *, /*Strict=*/false>();
 #define STMT(Type, Base)
 #include "clang/AST/StmtNodes.inc"
   default:
diff --git a/llvm/include/llvm/Support/TrailingObjects.h b/llvm/include/llvm/Support/TrailingObjects.h
index f25f2311a81a4..f6f8f9ebaa38c 100644
--- a/llvm/include/llvm/Support/TrailingObjects.h
+++ b/llvm/include/llvm/Support/TrailingObjects.h
@@ -228,12 +228,17 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl<
 
   using ParentType::getTrailingObjectsImpl;
 
-  // This function contains only a static_assert BaseTy is final. The
-  // static_assert must be in a function, and not at class-level
-  // because BaseTy isn't complete at class instantiation time, but
-  // will be by the time this function is instantiated.
-  static void verifyTrailingObjectsAssertions() {
+  template <bool Strict> static void verifyTrailingObjectsAssertions() {
+    // The static_assert for BaseTy must be in a function, and not at
+    // class-level  because BaseTy isn't complete at class instantiation time,
+    // but will be by the time this function is instantiated.
     static_assert(std::is_final<BaseTy>(), "BaseTy must be final.");
+
+    // Verify that templated getTrailingObjects() is used only with multiple
+    // trailing types, unless strict mode is disabled.
+    static_assert(!Strict || sizeof...(TrailingTys) > 1,
+                  "Use templated getTrailingObjects() only when there are "
+                  "multiple trailing types");
   }
 
   // These two methods are the base of the recursion for this method.
@@ -282,8 +287,9 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl<
   /// Returns a pointer to the trailing object array of the given type
   /// (which must be one of those specified in the class template). The
   /// array may have zero or more elements in it.
-  template <typename T> const T *getTrailingObjects() const {
-    verifyTrailingObjectsAssertions();
+  template <typename T, bool Strict = true>
+  const T *getTrailingObjects() const {
+    verifyTrailingObjectsAssertions<Strict>();
     // Forwards to an impl function with overloads, since member
     // function templates can't be specialized.
     return this->getTrailingObjectsImpl(
@@ -294,8 +300,8 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl<
   /// Returns a pointer to the trailing object array of the given type
   /// (which must be one of those specified in the class template). The
   /// array may have zero or more elements in it.
-  template <typename T> T *getTrailingObjects() {
-    verifyTrailingObjectsAssertions();
+  template <typename T, bool Strict = true> T *getTrailingObjects() {
+    verifyTrailingObjectsAssertions<Strict>();
     // Forwards to an impl function with overloads, since member
     // function templates can't be specialized.
     return this->getTrailingObjectsImpl(
@@ -310,23 +316,25 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl<
     static_assert(sizeof...(TrailingTys) == 1,
                   "Can use non-templated getTrailingObjects() only when there "
                   "is a single trailing type");
-    return getTrailingObjects<FirstTrailingType>();
+    return getTrailingObjects<FirstTrailingType, /*Strict=*/false>();
   }
 
   FirstTrailingType *getTrailingObjects() {
     static_assert(sizeof...(TrailingTys) == 1,
                   "Can use non-templated getTrailingObjects() only when there "
                   "is a single trailing type");
-    return getTrailingObjects<FirstTrailingType>();
+    return getTrailingObjects<FirstTrailingType, /*Strict=*/false>();
   }
 
   // Functions that return the trailing objects as ArrayRefs.
-  template <typename T> MutableArrayRef<T> getTrailingObjects(size_t N) {
-    return MutableArrayRef(getTrailingObjects<T>(), N);
+  template <typename T, bool Strict = true>
+  MutableArrayRef<T> getTrailingObjects(size_t N) {
+    return MutableArrayRef(getTrailingObjects<T, Strict>(), N);
   }
 
-  template <typename T> ArrayRef<T> getTrailingObjects(size_t N) const {
-    return ArrayRef(getTrailingObjects<T>(), N);
+  template <typename T, bool Strict = true>
+  ArrayRef<T> getTrailingObjects(size_t N) const {
+    return ArrayRef(getTrailingObjects<T, Strict>(), N);
   }
 
   MutableArrayRef<FirstTrailingType> getTrailingObjects(size_t N) {
diff --git a/llvm/unittests/Support/TrailingObjectsTest.cpp b/llvm/unittests/Support/TrailingObjectsTest.cpp
index 2590f375b6598..83afb22f837aa 100644
--- a/llvm/unittests/Support/TrailingObjectsTest.cpp
+++ b/llvm/unittests/Support/TrailingObjectsTest.cpp
@@ -123,11 +123,12 @@ TEST(TrailingObjects, OneArg) {
   EXPECT_EQ(Class1::totalSizeToAlloc<short>(3),
             sizeof(Class1) + sizeof(short) * 3);
 
-  EXPECT_EQ(C->getTrailingObjects<short>(), reinterpret_cast<short *>(C + 1));
+  EXPECT_EQ(C->getTrailingObjects(), reinterpret_cast<short *>(C + 1));
   EXPECT_EQ(C->get(0), 1);
   EXPECT_EQ(C->get(2), 3);
 
-  EXPECT_EQ(C->getTrailingObjects(), C->getTrailingObjects<short>());
+  EXPECT_EQ(C->getTrailingObjects(),
+            (C->getTrailingObjects<short, /*Strict=*/false>()));
 
   delete C;
 }

@jurahul
Copy link
Contributor Author

jurahul commented Jun 20, 2025

This is the(?) final step in the getTrailingObjects change I started a month ago. All users are now moved to the simplified (non-templated) form in the upstream code. This change will ensure that any new users also continue to conform to this (i.e., are forced to use the non-templated form when possible).

Copy link
Collaborator

@erichkeane erichkeane left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you better explain what 'strict' is supposed to mean here? it isn't clear to me the purpose of it. I don't really understand when it is necessary/what it is doing/etc.

@jurahul
Copy link
Contributor Author

jurahul commented Jun 24, 2025

Sure: The first line of the description attempts to define it:

Under strict mode, the templated getTrailingObjects can be called only when there is > 1 trailing types.

I essentially used this strict mode to find all templated calls to getTrailingObjects<X> that could be converted to the simplified non-templated form getTrailingObjects in the series of PRs I did over last month. Essentially, what this will do is that any future user of TrailingObjects with a single trailing type will be forced to use the non-templated getTrailingObjects.

@jurahul jurahul requested a review from erichkeane June 24, 2025 18:00
@erichkeane
Copy link
Collaborator

Sure: The first line of the description attempts to define it:

Under strict mode, the templated getTrailingObjects can be called only when there is > 1 trailing types.

I essentially used this strict mode to find all templated calls to getTrailingObjects<X> that could be converted to the simplified non-templated form getTrailingObjects in the series of PRs I did over last month. Essentially, what this will do is that any future user of TrailingObjects with a single trailing type will be forced to use the non-templated getTrailingObjects.

Oh! I see... hmm... So this is for cases where we don't know how many items in the `TrailingObjects there are...

I think I might prefer instead of a strict template argument, a different name for getTrailingObjects here, like getNonCheckedTrailingObjects (though that has other implications... getNonStrictTrailingObjects<...> with a comment?

@jurahul
Copy link
Contributor Author

jurahul commented Jun 24, 2025

Just to be precise, it's not the number of items but the number of trailing types. I half expected the comment to make it a different function. I'll do that

@erichkeane
Copy link
Collaborator

Just to be precise, it's not the number of items but the number of trailing types. I half expected the comment to make it a different function. I'll do that

Ah right, I was using 'items' to mean 'types' too, so we were clear there :) Sorry for my lack of precision.

Under strict mode, the templated `getTrailingObjects` can be called
only when there is > 1 trailing types. The strict mode can be disabled
on a per-call basis when its not possible to know statically if there
will be a single or multiple trailing types (like in OpenMPClause.h).
@jurahul jurahul force-pushed the trailingObjects_strict branch from e985849 to ee982b8 Compare June 27, 2025 21:21
@jurahul
Copy link
Contributor Author

jurahul commented Jun 27, 2025

That was fast! Thanks

@jurahul jurahul changed the title [LLVM][Clang] Add and enable strict mode for getTrailingObjects [LLVM][Clang] Enable strict mode for getTrailingObjects Jun 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:openmp OpenMP related changes to Clang clang Clang issues not falling into any other category llvm:support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants