Skip to content

Commit 24cad3a

Browse files
committed
[ADT] Fix llvm::concat_iterator for ValueT == common_base_class *
Fix llvm::concat_iterator for the case of `ValueT` being a pointer to a common base class to which the result of dereferencing any iterator in `ItersT` can be casted to.
1 parent dfb5cad commit 24cad3a

File tree

2 files changed

+50
-8
lines changed

2 files changed

+50
-8
lines changed

llvm/include/llvm/ADT/STLExtras.h

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,13 +1032,22 @@ class concat_iterator
10321032

10331033
static constexpr bool ReturnsByValue =
10341034
!(std::is_reference_v<decltype(*std::declval<IterTs>())> && ...);
1035+
static constexpr bool ReturnsConvertiblePointer =
1036+
std::is_pointer_v<ValueT> &&
1037+
(std::is_convertible_v<decltype(*std::declval<IterTs>()), ValueT> && ...);
10351038

10361039
using reference_type =
1037-
typename std::conditional_t<ReturnsByValue, ValueT, ValueT &>;
1038-
1039-
using handle_type =
1040-
typename std::conditional_t<ReturnsByValue, std::optional<ValueT>,
1041-
ValueT *>;
1040+
typename std::conditional_t<ReturnsByValue || ReturnsConvertiblePointer,
1041+
ValueT, ValueT &>;
1042+
1043+
using optional_value_type =
1044+
std::conditional_t<ReturnsByValue, std::optional<ValueT>, ValueT *>;
1045+
// handle_type is used to return an optional value from `getHelper()`. If
1046+
// the type resulting from dereferencing all IterTs is a pointer that can be
1047+
// converted to `ValueT`, use that pointer type instead to avoid implicit
1048+
// conversion issues.
1049+
using handle_type = typename std::conditional_t<ReturnsConvertiblePointer,
1050+
ValueT, optional_value_type>;
10421051

10431052
/// We store both the current and end iterators for each concatenated
10441053
/// sequence in a tuple of pairs.
@@ -1088,7 +1097,7 @@ class concat_iterator
10881097
if (Begin == End)
10891098
return {};
10901099

1091-
if constexpr (ReturnsByValue)
1100+
if constexpr (ReturnsByValue || ReturnsConvertiblePointer)
10921101
return *Begin;
10931102
else
10941103
return &*Begin;
@@ -1105,8 +1114,12 @@ class concat_iterator
11051114

11061115
// Loop over them, and return the first result we find.
11071116
for (auto &GetHelperFn : GetHelperFns)
1108-
if (auto P = (this->*GetHelperFn)())
1109-
return *P;
1117+
if (auto P = (this->*GetHelperFn)()) {
1118+
if constexpr (ReturnsConvertiblePointer)
1119+
return P;
1120+
else
1121+
return *P;
1122+
}
11101123

11111124
llvm_unreachable("Attempted to get a pointer from an end concat iterator!");
11121125
}

llvm/unittests/ADT/STLExtrasTest.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,8 @@ struct some_struct {
398398
std::string swap_val;
399399
};
400400

401+
struct derives_from_some_struct : some_struct {};
402+
401403
std::vector<int>::const_iterator begin(const some_struct &s) {
402404
return s.data.begin();
403405
}
@@ -532,6 +534,33 @@ TEST(STLExtrasTest, ConcatRangeADL) {
532534
EXPECT_THAT(concat<const int>(S0, S1), ElementsAre(1, 2, 3, 4));
533535
}
534536

537+
TEST(STLExtrasTest, ConcatRangeRef) {
538+
SmallVector<some_namespace::some_struct> V12{{{1, 2}, "V12[0]"}};
539+
SmallVector<some_namespace::some_struct> V3456{{{3, 4}, "V3456[0]"},
540+
{{5, 6}, "V3456[1]"}};
541+
542+
// Use concat with `iterator type = some_namespace::some_struct *` and value
543+
// being a reference type.
544+
std::vector<some_namespace::some_struct *> Expected = {&V12[0], &V3456[0],
545+
&V3456[1]};
546+
std::vector<some_namespace::some_struct *> Test;
547+
for (auto &i : concat<some_namespace::some_struct>(V12, V3456))
548+
Test.push_back(&i);
549+
EXPECT_EQ(Expected, Test);
550+
}
551+
552+
TEST(STLExtrasTest, ConcatRangePtrToDerivedClass) {
553+
some_namespace::some_struct S0{};
554+
some_namespace::derives_from_some_struct S1{};
555+
SmallVector<some_namespace::some_struct *> V0{&S0};
556+
SmallVector<some_namespace::derives_from_some_struct *> V1{&S1, &S1};
557+
558+
// Use concat over ranges of pointers to different (but related) types.
559+
EXPECT_THAT(concat<some_namespace::some_struct *>(V0, V1),
560+
ElementsAre(&S0, static_cast<some_namespace::some_struct *>(&S1),
561+
static_cast<some_namespace::some_struct *>(&S1)));
562+
}
563+
535564
TEST(STLExtrasTest, MakeFirstSecondRangeADL) {
536565
// Make sure that we use the `begin`/`end` functions from `some_namespace`,
537566
// using ADL.

0 commit comments

Comments
 (0)