Skip to content

[ADT] Fix llvm::concat_iterator for ValueT == common_base_class * #144744

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions llvm/include/llvm/ADT/STLExtras.h
Original file line number Diff line number Diff line change
Expand Up @@ -1032,13 +1032,22 @@ class concat_iterator

static constexpr bool ReturnsByValue =
!(std::is_reference_v<decltype(*std::declval<IterTs>())> && ...);
static constexpr bool ReturnsConvertiblePointer =
std::is_pointer_v<ValueT> &&
(std::is_convertible_v<decltype(*std::declval<IterTs>()), ValueT> && ...);

using reference_type =
typename std::conditional_t<ReturnsByValue, ValueT, ValueT &>;

using handle_type =
typename std::conditional_t<ReturnsByValue, std::optional<ValueT>,
ValueT *>;
typename std::conditional_t<ReturnsByValue || ReturnsConvertiblePointer,
ValueT, ValueT &>;

using optional_value_type =
std::conditional_t<ReturnsByValue, std::optional<ValueT>, ValueT *>;
// handle_type is used to return an optional value from `getHelper()`. If
// the type resulting from dereferencing all IterTs is a pointer that can be
// converted to `ValueT`, use that pointer type instead to avoid implicit
// conversion issues.
using handle_type = typename std::conditional_t<ReturnsConvertiblePointer,
ValueT, optional_value_type>;

/// We store both the current and end iterators for each concatenated
/// sequence in a tuple of pairs.
Expand Down Expand Up @@ -1088,7 +1097,7 @@ class concat_iterator
if (Begin == End)
return {};

if constexpr (ReturnsByValue)
if constexpr (ReturnsByValue || ReturnsConvertiblePointer)
return *Begin;
else
return &*Begin;
Expand All @@ -1105,8 +1114,12 @@ class concat_iterator

// Loop over them, and return the first result we find.
for (auto &GetHelperFn : GetHelperFns)
if (auto P = (this->*GetHelperFn)())
return *P;
if (auto P = (this->*GetHelperFn)()) {
if constexpr (ReturnsConvertiblePointer)
return P;
else
return *P;
}

llvm_unreachable("Attempted to get a pointer from an end concat iterator!");
}
Expand Down
29 changes: 29 additions & 0 deletions llvm/unittests/ADT/STLExtrasTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,8 @@ struct some_struct {
std::string swap_val;
};

struct derives_from_some_struct : some_struct {};

std::vector<int>::const_iterator begin(const some_struct &s) {
return s.data.begin();
}
Expand Down Expand Up @@ -532,6 +534,33 @@ TEST(STLExtrasTest, ConcatRangeADL) {
EXPECT_THAT(concat<const int>(S0, S1), ElementsAre(1, 2, 3, 4));
}

TEST(STLExtrasTest, ConcatRangeRef) {
SmallVector<some_namespace::some_struct> V12{{{1, 2}, "V12[0]"}};
SmallVector<some_namespace::some_struct> V3456{{{3, 4}, "V3456[0]"},
{{5, 6}, "V3456[1]"}};

// Use concat with `iterator type = some_namespace::some_struct *` and value
// being a reference type.
std::vector<some_namespace::some_struct *> Expected = {&V12[0], &V3456[0],
&V3456[1]};
std::vector<some_namespace::some_struct *> Test;
for (auto &i : concat<some_namespace::some_struct>(V12, V3456))
Test.push_back(&i);
EXPECT_EQ(Expected, Test);
}

TEST(STLExtrasTest, ConcatRangePtrToDerivedClass) {
some_namespace::some_struct S0{};
some_namespace::derives_from_some_struct S1{};
SmallVector<some_namespace::some_struct *> V0{&S0};
SmallVector<some_namespace::derives_from_some_struct *> V1{&S1, &S1};

// Use concat over ranges of pointers to different (but related) types.
EXPECT_THAT(concat<some_namespace::some_struct *>(V0, V1),
ElementsAre(&S0, static_cast<some_namespace::some_struct *>(&S1),
static_cast<some_namespace::some_struct *>(&S1)));
}

TEST(STLExtrasTest, MakeFirstSecondRangeADL) {
// Make sure that we use the `begin`/`end` functions from `some_namespace`,
// using ADL.
Expand Down