Skip to content

Commit 7d7a6a6

Browse files
committed
[ADT] Fix llvm::concat_iterator for ValueT == common_base_class *
Fix llvm::concat_iterator for the case of `ValueT` being a pointer to a common base class to which the result of dereferencing any iterator in `ItersT` can be casted to.
1 parent dfb5cad commit 7d7a6a6

File tree

2 files changed

+45
-8
lines changed

2 files changed

+45
-8
lines changed

llvm/include/llvm/ADT/STLExtras.h

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,13 +1032,22 @@ class concat_iterator
10321032

10331033
static constexpr bool ReturnsByValue =
10341034
!(std::is_reference_v<decltype(*std::declval<IterTs>())> && ...);
1035+
static constexpr bool ReturnsConvertiblePointer =
1036+
std::is_pointer_v<ValueT> &&
1037+
(std::is_convertible_v<decltype(*std::declval<IterTs>()), ValueT> && ...);
10351038

10361039
using reference_type =
1037-
typename std::conditional_t<ReturnsByValue, ValueT, ValueT &>;
1038-
1039-
using handle_type =
1040-
typename std::conditional_t<ReturnsByValue, std::optional<ValueT>,
1041-
ValueT *>;
1040+
typename std::conditional_t<ReturnsByValue || ReturnsConvertiblePointer,
1041+
ValueT, ValueT &>;
1042+
1043+
using optional_value_type =
1044+
std::conditional_t<ReturnsByValue, std::optional<ValueT>, ValueT *>;
1045+
// handle_type is used to return an optional value from `getHelper()`. If
1046+
// the type resulting from dereferencing all IterTs is a pointer that can be
1047+
// converted to `ValueT`, use that pointer type instead to avoid implicit
1048+
// conversion issues.
1049+
using handle_type = typename std::conditional_t<ReturnsConvertiblePointer,
1050+
ValueT, optional_value_type>;
10421051

10431052
/// We store both the current and end iterators for each concatenated
10441053
/// sequence in a tuple of pairs.
@@ -1088,7 +1097,7 @@ class concat_iterator
10881097
if (Begin == End)
10891098
return {};
10901099

1091-
if constexpr (ReturnsByValue)
1100+
if constexpr (ReturnsByValue || ReturnsConvertiblePointer)
10921101
return *Begin;
10931102
else
10941103
return &*Begin;
@@ -1105,8 +1114,12 @@ class concat_iterator
11051114

11061115
// Loop over them, and return the first result we find.
11071116
for (auto &GetHelperFn : GetHelperFns)
1108-
if (auto P = (this->*GetHelperFn)())
1109-
return *P;
1117+
if (auto P = (this->*GetHelperFn)()) {
1118+
if constexpr (ReturnsConvertiblePointer)
1119+
return P;
1120+
else
1121+
return *P;
1122+
}
11101123

11111124
llvm_unreachable("Attempted to get a pointer from an end concat iterator!");
11121125
}

llvm/unittests/ADT/STLExtrasTest.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,8 @@ struct some_struct {
398398
std::string swap_val;
399399
};
400400

401+
struct derives_from_some_struct : some_struct {};
402+
401403
std::vector<int>::const_iterator begin(const some_struct &s) {
402404
return s.data.begin();
403405
}
@@ -532,6 +534,28 @@ TEST(STLExtrasTest, ConcatRangeADL) {
532534
EXPECT_THAT(concat<const int>(S0, S1), ElementsAre(1, 2, 3, 4));
533535
}
534536

537+
TEST(STLExtrasTest, ConcatRangePtr) {
538+
some_namespace::some_struct S0{};
539+
some_namespace::some_struct S1{};
540+
SmallVector<some_namespace::some_struct *> V0{&S0};
541+
SmallVector<some_namespace::some_struct *> V1{&S1, &S1};
542+
543+
EXPECT_THAT(concat<some_namespace::some_struct>(V0, V1),
544+
ElementsAre(&S0, &S1, &S1));
545+
}
546+
547+
TEST(STLExtrasTest, ConcatRangePtrToDerivedClass) {
548+
some_namespace::some_struct S0{};
549+
some_namespace::derives_from_some_struct S1{};
550+
SmallVector<some_namespace::some_struct *> V0{&S0};
551+
SmallVector<some_namespace::derives_from_some_struct *> V1{&S1, &S1};
552+
553+
// Use concat over ranges of pointers to different (but related) types.
554+
EXPECT_THAT(concat<some_namespace::some_struct *>(V0, V1),
555+
ElementsAre(&S0, static_cast<some_namespace::some_struct *>(&S1),
556+
static_cast<some_namespace::some_struct *>(&S1)));
557+
}
558+
535559
TEST(STLExtrasTest, MakeFirstSecondRangeADL) {
536560
// Make sure that we use the `begin`/`end` functions from `some_namespace`,
537561
// using ADL.

0 commit comments

Comments
 (0)