Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] IdModel: Fix invalid promotion selection #3877

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 192 additions & 58 deletions csrc/id_model/loop_promotion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,189 @@

namespace nvfuser {

std::unordered_map<ValGroup, std::shared_ptr<CoveredGroups>>
computeCoveredGroups(const ValGraph& exact_graph) {
// Map from an exact iter domain group, to all the exact iter domain groups it
// covers
std::unordered_map<ValGroup, std::shared_ptr<CoveredGroups>> covered_ids;

for (const ValGroup& id_group :
exact_graph.disjointValSets().disjointSets()) {
// Initialize inputs
const ExprGroups& id_group_defs = exact_graph.getDefinitions(id_group);
if (id_group_defs.empty()) {
auto init_groups = std::make_shared<CoveredGroups>();
init_groups->insert(CoveredGroup(id_group));
NVF_ERROR(covered_ids.emplace(id_group, init_groups).second);
}

// Initialize broadcast groups to empty since broadcast domains
// don't matter for indexing
if (std::any_of(id_group->begin(), id_group->end(), [&](Val* id) {
return id->as<IterDomain>()->isBroadcast();
})) {
covered_ids[id_group] = std::make_shared<CoveredGroups>();
}
}

ValGraphStmtSort exact_stmt_sort(exact_graph);

for (const ExprGroup& exact_expr : exact_stmt_sort.exprs()) {
// Initialize to empty group if not yet initialized
for (const ValGroup& output_group : exact_graph.outputGroups(exact_expr)) {
covered_ids.emplace(output_group, std::make_shared<CoveredGroups>());
}

const std::vector<ValGroup> input_groups =
exact_graph.inputGroups(exact_expr);
const std::vector<ValGroup> output_groups =
exact_graph.outputGroups(exact_expr);

// If this expr is a split, don't propagate the input coverage as
// is but set the covered group of each output group by itself.
// The input coverage info is propagated as the split input.
if (exact_expr->front()->isA<Split>()) {
NVF_ERROR(input_groups.size() == 1);
const std::shared_ptr<CoveredGroups>& covered_groups =
covered_ids.at(input_groups.at(0));

for (const ValGroup& output_group : output_groups) {
bool is_inner =
output_group->has(exact_expr->front()->as<Split>()->inner());
covered_ids[output_group]->insert(
CoveredGroup(output_group, covered_groups, is_inner));
}
continue;
}

for (const ValGroup& output_group : output_groups) {
// Note that an exact group may have multiple
// exact expr groups and may have different coverage groups depending on
// the expr groups. For example, this can happen with reshape or resize.
// See test LoopPromotionCoverage for a concrete example.
for (const ValGroup& inp_group : input_groups) {
const std::shared_ptr<CoveredGroups>& inp_covered_groups =
covered_ids.at(inp_group);
covered_ids[output_group]->insert(
inp_covered_groups->begin(), inp_covered_groups->end());
}
}
}

return covered_ids;
}

bool CoveredGroup::isEqualToOrSuperSetOf(const CoveredGroup& other) const {
if (*this == other) {
return true;
}

// When both are derived from split
if (split_in_.get() && other.split_in_.get()) {
// If they correspond to differnt outputs, they are obviously different.
if (is_inner_ != other.is_inner_) {
return false;
}

const CoveredGroups& split_in = *split_in_;
const CoveredGroups& other_split_in = *other.split_in_;

// When both have the same split input (and both correspond to
// either inner or outer), they should cover the same exact
// groups. This should only happen when broadcast is merged. For
// example, suppose there are two tensors and they are scheduled as
// follows;
//
// t0: [i0]
// t1: [i1, b2]
//
// t1->merge(0, 1)->split(0, 4);
// t0->split(0, 4)
//
// t0->inlineAt(t1, 1)
//
// In this case, t0->axis(0) and t1->axis(0) have the same
// split input group, {i0, i1}. Note that b2 is not included as
// it's a broadcast. Also note that both are the outer
// output. Here, group_ of t0->axis(0) is the exact group of
// t0->axis(0), while that of tv1->axis(0) is the exact group of
// the t1->merge(0, 1) output. In this case, however, this merge
// is just a merge of i1 and the b2 broadcast ID, so in terms of
// covered exact groups, it's effectively the same as that of
// t0->axis(0). All in all, as long as both correspond to either
// inner or outer of the same split input, they should be
// considered the same.
if (split_in == other_split_in) {
return true;
}

// Both are derived from a split but have differnt split input
// groups. If the input groups of this split is a superset of the
// input groups of the split of the other CoveredGroup, this
// CoveredGroup is a superset
if (std::all_of(
other_split_in.begin(),
other_split_in.end(),
[&](const CoveredGroup& other_split_in_group) {
return std::any_of(
split_in.begin(),
split_in.end(),
[&](const CoveredGroup& split_in_group) {
return split_in_group.isEqualToOrSuperSetOf(
other_split_in_group);
});
})) {
return true;
}
}

return false;
}

std::string CoveredGroup::toString() const {
std::stringstream ss;

ss << "{" << nvfuser::toString(group_);
if (split_in_.get()) {
ss << " (" << (is_inner_ ? "inner" : "outer") << " split from ";
bool is_first = true;
for (const auto& cg : *split_in_) {
if (!is_first) {
ss << ", ";
}
ss << cg.toString();
is_first = false;
}
ss << ")";
}
ss << "}";
return ss.str();
}

namespace {

// Returns true if covered_groups_x is equal to or a superset of
// covered_groups_y, that is, for all of CoveredGroup of
// covered_groups_y, if there's a CoveredGroup in covered_groups_x
// that is equal or a superset.
bool isEqualToOrSuperSetOf(
const CoveredGroups& covered_groups_x,
const CoveredGroups& covered_groups_y) {
return std::all_of(
covered_groups_y.begin(),
covered_groups_y.end(),
[&](const CoveredGroup& cover_group_y) {
return std::any_of(
covered_groups_x.begin(),
covered_groups_x.end(),
[&](const CoveredGroup& cover_group_x) {
return cover_group_x.isEqualToOrSuperSetOf(cover_group_y);
});
});
}

} // namespace

LoopPromotionMapBuilder::LoopPromotionMapBuilder(
IdModel& id_model,
const StatefulInliningInfo& inlining_info,
Expand Down Expand Up @@ -713,66 +896,14 @@ void LoopPromotionMapBuilder::propagatePromotionsInIELGraph(
iel_graph, iel_promotion_map, idGraph(IdMappingMode::LOOP), {});
}

namespace {

// Returns for each ValGroup in provided IdGraph what the input ValGroups are
// traversing on definitions. Ignoring broadcast ValGroups and resetting inputs
// at RFactor ValGroups.
std::unordered_map<ValGroup, ValGroups> computeCoveredGroups(
const ValGraph& graph) {
// Map from an exact iter domain group, to all the exact iter domain groups it
// covers
std::unordered_map<ValGroup, ValGroups> covered_ids;

for (const ValGroup& id_group : graph.disjointValSets().disjointSets()) {
// Initialize inputs
const ExprGroups& id_group_defs = graph.getDefinitions(id_group);
if (id_group_defs.empty()) {
covered_ids[id_group] = {id_group};
}

// Initialize broadcast groups to empty since broadcast domains
// don't matter for indexing
if (std::any_of(id_group->begin(), id_group->end(), [&](Val* id) {
return id->as<IterDomain>()->isBroadcast();
})) {
covered_ids[id_group] = {};
}
}

ValGraphStmtSort exact_stmt_sort(graph);

for (const ExprGroup& exact_expr : exact_stmt_sort.exprs()) {
std::vector<ValGroup> input_groups = graph.inputGroups(exact_expr);

ValGroups covered;
for (const ValGroup& inp_group : input_groups) {
covered.pushBack(covered_ids.at(inp_group));
}

for (const ValGroup& output_group : graph.outputGroups(exact_expr)) {
// Note that pushBack must be used instead of just
// `covered_ids[outputGroups] = covered`. An exact group may have multiple
// exact expr groups and may have different coverage groups depending on
// the expr groups. For example, this can happen with reshape or resize.
// See test LoopPromotionCoverage for a concrete example.
covered_ids[output_group].pushBack(covered);
}
}

return covered_ids;
}

}; // namespace

std::unordered_map<ValGroup, IterDomain*> LoopPromotionMapBuilder::
projectIELPromotionToLoopGraph(
const ValGraph& iel_graph,
const std::unordered_map<ValGroup, IterDomain*>& iel_promotion_map,
const ValGraph& loop_graph,
const StatefulInliningInfo& inlining_info) const {
const std::unordered_map<ValGroup, ValGroups> exact_covered_ids =
computeCoveredGroups(idGraph(IdMappingMode::EXACT));
const std::unordered_map<ValGroup, std::shared_ptr<CoveredGroups>>
exact_covered_ids = computeCoveredGroups(idGraph(IdMappingMode::EXACT));

// Grab terminal iter domain in the loop groups.
const VectorOfUniqueEntries<IterDomain*> terminal_loop_ids =
Expand Down Expand Up @@ -800,7 +931,8 @@ IterDomain* LoopPromotionMapBuilder::findPromotionOfLoopGroup(
const ValGroup& loop_group,
const ValGraph& iel_graph,
const std::unordered_map<ValGroup, IterDomain*>& iel_promotion_map,
const std::unordered_map<ValGroup, ValGroups>& exact_covered_ids,
const std::unordered_map<ValGroup, std::shared_ptr<CoveredGroups>>&
exact_covered_ids,
const VectorOfUniqueEntries<IterDomain*>& terminal_loop_ids) const {
const ValGraph& exact_graph = idGraph(IdMappingMode::EXACT);

Expand Down Expand Up @@ -854,11 +986,12 @@ IterDomain* LoopPromotionMapBuilder::findPromotionOfLoopGroup(
ValGroups exact_groups = exact_graph.toGroups(*loop_group);

// All exact groups covered by all iter domains in this loop group
ValGroups loop_group_covered_ids;
CoveredGroups loop_group_covered_ids;
for (const ValGroup& exact_group : exact_groups) {
auto covered_it = exact_covered_ids.find(exact_group);
NVF_ERROR(covered_it != exact_covered_ids.end());
loop_group_covered_ids.pushBack(covered_it->second);
loop_group_covered_ids.insert(
covered_it->second->begin(), covered_it->second->end());
}

// Check if any of the candidate Iter Domains we collected cover all the
Expand All @@ -869,7 +1002,8 @@ IterDomain* LoopPromotionMapBuilder::findPromotionOfLoopGroup(
IterDomain* terminal_id = entry.second;
auto covered_it = exact_covered_ids.find(terminal_id_group);
NVF_ERROR(covered_it != exact_covered_ids.end());
if (loop_group_covered_ids.computeSubtract(covered_it->second).empty()) {
const auto& covered_groups = covered_it->second;
if (isEqualToOrSuperSetOf(*covered_groups, loop_group_covered_ids)) {
return terminal_id;
}
}
Expand Down
79 changes: 78 additions & 1 deletion csrc/id_model/loop_promotion.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,82 @@ namespace nvfuser {
class IdModel;
struct StatefulInliningInfo;

struct CoveredGroup;

using CoveredGroups = std::unordered_set<CoveredGroup>;

// Represents an input (or split output) ID group that an exact group
// depends on (i.e., covers). If an input ID group is split, split_in_
// refers to the covered groups of the input ID group, and group_
// refers to either inner or outer output group.
struct CoveredGroup {
CoveredGroup() = default;
CoveredGroup(
ValGroup group,
const std::shared_ptr<CoveredGroups>& split_parent = nullptr,
bool is_inner = false)
: group_(std::move(group)),
split_in_(split_parent),
is_inner_(is_inner) {}

const ValGroup& group() const {
return group_;
}

const std::shared_ptr<CoveredGroups>& splitIn() const {
return split_in_;
}

bool isInner() const {
return is_inner_;
}

// Note that the equality of this information is only determined by
// group_ and that split_in_ does not matter.
bool operator==(const CoveredGroup& other) const {
return group_ == other.group_;
}

bool operator!=(const CoveredGroup& other) const {
return !(group_ == other.group_);
}

// Check if this CoveredGroup is equal to or covers a given other
// CoveredGroup
bool isEqualToOrSuperSetOf(const CoveredGroup& other) const;

std::string toString() const;

private:
// Covered group
ValGroup group_;
// If this group is an output of a split, keep track of the covered
// groups of the split input group.
std::shared_ptr<CoveredGroups> split_in_;
// Indicates if the split is inner or not. Not relevant if split_in_
// is nullptr.
bool is_inner_ = false;
};

} // namespace nvfuser

namespace std {
template <>
struct hash<nvfuser::CoveredGroup> {
size_t operator()(const nvfuser::CoveredGroup& x) const {
return std::hash<nvfuser::ValGroup>()(x.group());
}
};
} // namespace std

namespace nvfuser {

// Computes coverage info of each exact group. Coverage is
// represented as a set of CoveredGroup, which is either an exact
// group of input IDs or an output group of split.
std::unordered_map<ValGroup, std::shared_ptr<CoveredGroups>>
computeCoveredGroups(const ValGraph& exact_graph);

// Callback interface for LoopPromotionMapBuilder. Allow exposing the
// temporary maps for testing and debugging
class LoopPromotionMapBuilderCallback {
Expand Down Expand Up @@ -140,7 +216,8 @@ class LoopPromotionMapBuilder {
const ValGroup& loop_group,
const ValGraph& iel_graph,
const std::unordered_map<ValGroup, IterDomain*>& iel_promotion_map,
const std::unordered_map<ValGroup, ValGroups>& exact_covered_ids,
const std::unordered_map<ValGroup, std::shared_ptr<CoveredGroups>>&
exact_covered_ids,
const VectorOfUniqueEntries<IterDomain*>& terminal_loop_ids) const;

// Terminal loop ids are iteration domains in each loop group that:
Expand Down
Loading