Skip to content

Commit ee982b8

Browse files
committed
[LLVM][Clang] Add and enable strict mode for getTrailingObjects
Under strict mode, the templated `getTrailingObjects` can be called only when there is > 1 trailing types. The strict mode can be disabled on a per-call basis when its not possible to know statically if there will be a single or multiple trailing types (like in OpenMPClause.h).
1 parent e9c9adc commit ee982b8

File tree

4 files changed

+74
-28
lines changed

4 files changed

+74
-28
lines changed

clang/include/clang/AST/OpenMPClause.h

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,8 @@ template <class T> class OMPVarListClause : public OMPClause {
295295

296296
/// Fetches list of variables associated with this clause.
297297
MutableArrayRef<Expr *> getVarRefs() {
298-
return static_cast<T *>(this)->template getTrailingObjects<Expr *>(NumVars);
298+
return static_cast<T *>(this)->template getTrailingObjectsNonStrict<Expr *>(
299+
NumVars);
299300
}
300301

301302
/// Sets the list of variables for this clause.
@@ -334,8 +335,8 @@ template <class T> class OMPVarListClause : public OMPClause {
334335

335336
/// Fetches list of all variables in the clause.
336337
ArrayRef<const Expr *> getVarRefs() const {
337-
return static_cast<const T *>(this)->template getTrailingObjects<Expr *>(
338-
NumVars);
338+
return static_cast<const T *>(this)
339+
->template getTrailingObjectsNonStrict<Expr *>(NumVars);
339340
}
340341
};
341342

@@ -380,7 +381,7 @@ template <class T> class OMPDirectiveListClause : public OMPClause {
380381

381382
MutableArrayRef<OpenMPDirectiveKind> getDirectiveKinds() {
382383
return static_cast<T *>(this)
383-
->template getTrailingObjects<OpenMPDirectiveKind>(NumKinds);
384+
->template getTrailingObjectsNonStrict<OpenMPDirectiveKind>(NumKinds);
384385
}
385386

386387
void setDirectiveKinds(ArrayRef<OpenMPDirectiveKind> DK) {
@@ -5921,15 +5922,17 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
59215922
/// Get the unique declarations that are in the trailing objects of the
59225923
/// class.
59235924
MutableArrayRef<ValueDecl *> getUniqueDeclsRef() {
5924-
return static_cast<T *>(this)->template getTrailingObjects<ValueDecl *>(
5925-
NumUniqueDeclarations);
5925+
return static_cast<T *>(this)
5926+
->template getTrailingObjectsNonStrict<ValueDecl *>(
5927+
NumUniqueDeclarations);
59265928
}
59275929

59285930
/// Get the unique declarations that are in the trailing objects of the
59295931
/// class.
59305932
ArrayRef<ValueDecl *> getUniqueDeclsRef() const {
59315933
return static_cast<const T *>(this)
5932-
->template getTrailingObjects<ValueDecl *>(NumUniqueDeclarations);
5934+
->template getTrailingObjectsNonStrict<ValueDecl *>(
5935+
NumUniqueDeclarations);
59335936
}
59345937

59355938
/// Set the unique declarations that are in the trailing objects of the
@@ -5943,15 +5946,15 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
59435946
/// Get the number of lists per declaration that are in the trailing
59445947
/// objects of the class.
59455948
MutableArrayRef<unsigned> getDeclNumListsRef() {
5946-
return static_cast<T *>(this)->template getTrailingObjects<unsigned>(
5947-
NumUniqueDeclarations);
5949+
return static_cast<T *>(this)
5950+
->template getTrailingObjectsNonStrict<unsigned>(NumUniqueDeclarations);
59485951
}
59495952

59505953
/// Get the number of lists per declaration that are in the trailing
59515954
/// objects of the class.
59525955
ArrayRef<unsigned> getDeclNumListsRef() const {
5953-
return static_cast<const T *>(this)->template getTrailingObjects<unsigned>(
5954-
NumUniqueDeclarations);
5956+
return static_cast<const T *>(this)
5957+
->template getTrailingObjectsNonStrict<unsigned>(NumUniqueDeclarations);
59555958
}
59565959

59575960
/// Set the number of lists per declaration that are in the trailing
@@ -5966,7 +5969,8 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
59665969
/// objects of the class. They are appended after the number of lists.
59675970
MutableArrayRef<unsigned> getComponentListSizesRef() {
59685971
return MutableArrayRef<unsigned>(
5969-
static_cast<T *>(this)->template getTrailingObjects<unsigned>() +
5972+
static_cast<T *>(this)
5973+
->template getTrailingObjectsNonStrict<unsigned>() +
59705974
NumUniqueDeclarations,
59715975
NumComponentLists);
59725976
}
@@ -5975,7 +5979,8 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
59755979
/// objects of the class. They are appended after the number of lists.
59765980
ArrayRef<unsigned> getComponentListSizesRef() const {
59775981
return ArrayRef<unsigned>(
5978-
static_cast<const T *>(this)->template getTrailingObjects<unsigned>() +
5982+
static_cast<const T *>(this)
5983+
->template getTrailingObjectsNonStrict<unsigned>() +
59795984
NumUniqueDeclarations,
59805985
NumComponentLists);
59815986
}
@@ -5991,13 +5996,15 @@ class OMPMappableExprListClause : public OMPVarListClause<T>,
59915996
/// Get the components that are in the trailing objects of the class.
59925997
MutableArrayRef<MappableComponent> getComponentsRef() {
59935998
return static_cast<T *>(this)
5994-
->template getTrailingObjects<MappableComponent>(NumComponents);
5999+
->template getTrailingObjectsNonStrict<MappableComponent>(
6000+
NumComponents);
59956001
}
59966002

59976003
/// Get the components that are in the trailing objects of the class.
59986004
ArrayRef<MappableComponent> getComponentsRef() const {
59996005
return static_cast<const T *>(this)
6000-
->template getTrailingObjects<MappableComponent>(NumComponents);
6006+
->template getTrailingObjectsNonStrict<MappableComponent>(
6007+
NumComponents);
60016008
}
60026009

60036010
/// Set the components that are in the trailing objects of the class.

clang/lib/AST/Expr.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2020,7 +2020,8 @@ CXXBaseSpecifier **CastExpr::path_buffer() {
20202020
#define ABSTRACT_STMT(x)
20212021
#define CASTEXPR(Type, Base) \
20222022
case Stmt::Type##Class: \
2023-
return static_cast<Type *>(this)->getTrailingObjects<CXXBaseSpecifier *>();
2023+
return static_cast<Type *>(this) \
2024+
->getTrailingObjectsNonStrict<CXXBaseSpecifier *>();
20242025
#define STMT(Type, Base)
20252026
#include "clang/AST/StmtNodes.inc"
20262027
default:

llvm/include/llvm/Support/TrailingObjects.h

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -228,12 +228,18 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl<
228228

229229
using ParentType::getTrailingObjectsImpl;
230230

231-
// This function contains only a static_assert BaseTy is final. The
232-
// static_assert must be in a function, and not at class-level
233-
// because BaseTy isn't complete at class instantiation time, but
234-
// will be by the time this function is instantiated.
235-
static void verifyTrailingObjectsAssertions() {
231+
template <bool Strict> static void verifyTrailingObjectsAssertions() {
232+
// The static_assert for BaseTy must be in a function, and not at
233+
// class-level because BaseTy isn't complete at class instantiation time,
234+
// but will be by the time this function is instantiated.
236235
static_assert(std::is_final<BaseTy>(), "BaseTy must be final.");
236+
237+
// Verify that templated getTrailingObjects() is used only with multiple
238+
// trailing types. Use getTrailingObjectsNonStrict() which does not check
239+
// this.
240+
static_assert(!Strict || sizeof...(TrailingTys) > 1,
241+
"Use templated getTrailingObjects() only when there are "
242+
"multiple trailing types");
237243
}
238244

239245
// These two methods are the base of the recursion for this method.
@@ -283,7 +289,7 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl<
283289
/// (which must be one of those specified in the class template). The
284290
/// array may have zero or more elements in it.
285291
template <typename T> const T *getTrailingObjects() const {
286-
verifyTrailingObjectsAssertions();
292+
verifyTrailingObjectsAssertions<true>();
287293
// Forwards to an impl function with overloads, since member
288294
// function templates can't be specialized.
289295
return this->getTrailingObjectsImpl(
@@ -295,7 +301,7 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl<
295301
/// (which must be one of those specified in the class template). The
296302
/// array may have zero or more elements in it.
297303
template <typename T> T *getTrailingObjects() {
298-
verifyTrailingObjectsAssertions();
304+
verifyTrailingObjectsAssertions<true>();
299305
// Forwards to an impl function with overloads, since member
300306
// function templates can't be specialized.
301307
return this->getTrailingObjectsImpl(
@@ -310,14 +316,20 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl<
310316
static_assert(sizeof...(TrailingTys) == 1,
311317
"Can use non-templated getTrailingObjects() only when there "
312318
"is a single trailing type");
313-
return getTrailingObjects<FirstTrailingType>();
319+
verifyTrailingObjectsAssertions<false>();
320+
return this->getTrailingObjectsImpl(
321+
static_cast<const BaseTy *>(this),
322+
TrailingObjectsBase::OverloadToken<FirstTrailingType>());
314323
}
315324

316325
FirstTrailingType *getTrailingObjects() {
317326
static_assert(sizeof...(TrailingTys) == 1,
318327
"Can use non-templated getTrailingObjects() only when there "
319328
"is a single trailing type");
320-
return getTrailingObjects<FirstTrailingType>();
329+
verifyTrailingObjectsAssertions<false>();
330+
return this->getTrailingObjectsImpl(
331+
static_cast<BaseTy *>(this),
332+
TrailingObjectsBase::OverloadToken<FirstTrailingType>());
321333
}
322334

323335
// Functions that return the trailing objects as ArrayRefs.
@@ -337,6 +349,31 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl<
337349
return ArrayRef(getTrailingObjects(), N);
338350
}
339351

352+
// Non-strict forms of templated `getTrailingObjects` that work with single
353+
// trailing type.
354+
template <typename T> const T *getTrailingObjectsNonStrict() const {
355+
verifyTrailingObjectsAssertions<false>();
356+
return this->getTrailingObjectsImpl(
357+
static_cast<const BaseTy *>(this),
358+
TrailingObjectsBase::OverloadToken<T>());
359+
}
360+
361+
template <typename T> T *getTrailingObjectsNonStrict() {
362+
verifyTrailingObjectsAssertions<false>();
363+
return this->getTrailingObjectsImpl(
364+
static_cast<BaseTy *>(this), TrailingObjectsBase::OverloadToken<T>());
365+
}
366+
367+
template <typename T>
368+
MutableArrayRef<T> getTrailingObjectsNonStrict(size_t N) {
369+
return MutableArrayRef(getTrailingObjectsNonStrict<T>(), N);
370+
}
371+
372+
template <typename T>
373+
ArrayRef<T> getTrailingObjectsNonStrict(size_t N) const {
374+
return ArrayRef(getTrailingObjectsNonStrict<T>(), N);
375+
}
376+
340377
/// Returns the size of the trailing data, if an object were
341378
/// allocated with the given counts (The counts are in the same order
342379
/// as the template arguments). This does not include the size of the

llvm/unittests/Support/TrailingObjectsTest.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,10 @@ class Class1 final : private TrailingObjects<Class1, short> {
4545
template <typename... Ty>
4646
using FixedSizeStorage = TrailingObjects::FixedSizeStorage<Ty...>;
4747

48-
using TrailingObjects::totalSizeToAlloc;
4948
using TrailingObjects::additionalSizeToAlloc;
5049
using TrailingObjects::getTrailingObjects;
50+
using TrailingObjects::getTrailingObjectsNonStrict;
51+
using TrailingObjects::totalSizeToAlloc;
5152
};
5253

5354
// Here, there are two singular optional object types appended. Note
@@ -123,11 +124,11 @@ TEST(TrailingObjects, OneArg) {
123124
EXPECT_EQ(Class1::totalSizeToAlloc<short>(3),
124125
sizeof(Class1) + sizeof(short) * 3);
125126

126-
EXPECT_EQ(C->getTrailingObjects<short>(), reinterpret_cast<short *>(C + 1));
127+
EXPECT_EQ(C->getTrailingObjects(), reinterpret_cast<short *>(C + 1));
127128
EXPECT_EQ(C->get(0), 1);
128129
EXPECT_EQ(C->get(2), 3);
129130

130-
EXPECT_EQ(C->getTrailingObjects(), C->getTrailingObjects<short>());
131+
EXPECT_EQ(C->getTrailingObjects(), C->getTrailingObjectsNonStrict<short>());
131132

132133
delete C;
133134
}

0 commit comments

Comments
 (0)