Skip to content

Commit c8baf1e

Browse files
committed
Fixes #3710
1 parent 3d27e10 commit c8baf1e

File tree

2 files changed

+325
-2
lines changed

2 files changed

+325
-2
lines changed

csrc/id_model/indexing.cpp

+37-2
Original file line numberDiff line numberDiff line change
@@ -1063,8 +1063,7 @@ std::vector<PredicateInfo> TensorIndexer::getPredicates(
10631063
ForLoop* unswitched_loop) const {
10641064
const auto& zero_val = tv->fusion()->zeroVal();
10651065

1066-
const std::vector<IterDomain*>& predicate_domains =
1067-
getPredicateDomains(tv, expr);
1066+
std::vector<IterDomain*> predicate_domains = getPredicateDomains(tv, expr);
10681067

10691068
const IndexingInfo& index_info =
10701069
computeIndex(expr, predicate_domains, for_loops);
@@ -1093,6 +1092,42 @@ std::vector<PredicateInfo> TensorIndexer::getPredicates(
10931092
/*is_start_predicate=*/false,
10941093
/*unswitched_loop=*/unswitched_loop);
10951094

1095+
// When resize is involved, predicate its input ID as well to avoid
1096+
// redudancy. This is only necessary if a predicated resize is
1097+
// preceded by a split, however, for now it's always predicated
1098+
// with an exception of static resize. See
1099+
// PredicateIndexingTest.SplitThenPad for a concrete example.
1100+
for (const auto& [eg, direction] : index_info.traversal_path) {
1101+
auto resize = dynamic_cast<Resize*>(eg->front());
1102+
if (resize == nullptr) {
1103+
continue;
1104+
}
1105+
1106+
// If the input ID is guaranteed to cover the output ID, then
1107+
// the input index should never exceed its boundary.
1108+
if (resize->leftExpand()->isConstInt() &&
1109+
resize->rightExpand()->isConstInt()) {
1110+
auto left_int = resize->leftExpand()->evaluate().as<int64_t>();
1111+
auto right_int = resize->rightExpand()->evaluate().as<int64_t>();
1112+
// If the traversal direction is forward, the predicate is not
1113+
// necessary if both of the left and right factors are
1114+
// non-negative as the ouput ID is guaranteed to cover the
1115+
// input ID. Similarly, if it's backward, it is not necessary
1116+
// if they are non-positive.
1117+
if ((direction == Direction::Forward && left_int >= 0 &&
1118+
right_int >= 0) ||
1119+
(direction == Direction::Backward && left_int <= 0 &&
1120+
right_int <= 0)) {
1121+
continue;
1122+
}
1123+
}
1124+
1125+
IterDomain* id_to_predicate =
1126+
direction == Direction::Forward ? resize->out() : resize->in();
1127+
1128+
predicate_domains.push_back(id_to_predicate);
1129+
}
1130+
10961131
const std::unordered_map<IterDomain*, ValGroup> contig_domains =
10971132
isContigIndexingEnabled()
10981133
? getContigDomains(

tests/cpp/test_indexing.cpp

+288
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <ops/all_ops.h>
2424
#include <scheduler/tools/abstract_tensor.h>
2525
#include <scheduler/tools/inlining.h>
26+
#include <scheduler/tools/loop_domain_scheduler.h>
2627
#include <scheduler/tools/resize_utils.h>
2728
#include <scheduler/utils.h>
2829

@@ -5497,6 +5498,293 @@ TEST_F(PredicateIndexingTest, VectorizedResizeRotation) {
54975498
testValidate(&fusion, outputs, inputs, __LINE__, __FILE__);
54985499
}
54995500

5501+
// Check if resize input IDs are predicated. Repro of issue
5502+
// https://github.com/NVIDIA/Fuser/issues/3710.
5503+
TEST_F(PredicateIndexingTest, SplitThenPad) {
5504+
Fusion fusion;
5505+
FusionGuard fg(&fusion);
5506+
5507+
const int64_t i0 = 4;
5508+
const int64_t i1 = 32;
5509+
5510+
auto zero = fusion.zeroVal();
5511+
5512+
auto tv0 = makeContigConcreteTensor({i0 * i1});
5513+
fusion.addInput(tv0);
5514+
5515+
auto tv1 = set(tv0);
5516+
auto tv2 =
5517+
reshape(tv1, {IrBuilder::create<Val>(i0), IrBuilder::create<Val>(i1)});
5518+
auto tv3 = pad(tv2, {zero, IrBuilder::create<Val>(i1)});
5519+
auto tv4 = set(tv3);
5520+
fusion.addOutput(tv4);
5521+
5522+
scheduler_tools::propagateResizeToInputs(tv3->definition());
5523+
5524+
inlineMost();
5525+
5526+
// tv1 should be scheduled as:
5527+
//
5528+
// T1_l_float[iS11{4}, iS13{64}]
5529+
// logical domain : (iS1{128})
5530+
// contiguity: t
5531+
// Outer split: iS1{128} by factor 4 -> iS11{4}, iS12{32}
5532+
// Resize: iS12{32} by 0 and 32 -> iS13{64}
5533+
// loop domain : (iS11{4}, iS13{64})
5534+
//
5535+
// In addition to its logical ID, the resize input ID should be
5536+
// predicated.
5537+
5538+
struct GetReference : AbstractGetReference {
5539+
GetReference(const TensorIndexer& indexer, const IdModel& id_model)
5540+
: AbstractGetReference(indexer, id_model) {}
5541+
5542+
Val* getInlinePredicate(TensorView* tv) const override {
5543+
if (tv->name() != 1) {
5544+
return nullptr;
5545+
}
5546+
5547+
// Without index hoist and expr simplification, the predicate
5548+
// should look like:
5549+
//
5550+
// (((((((i0 * 32LL) + i1) >= 0LL) &&
5551+
// (((i0 * 32LL) + i1) < 128LL)) &&
5552+
// (i1 >= 0LL)) &&
5553+
// (i1 < 32LL)))
5554+
5555+
std::vector<Val*> loop_indices = getLoopIndices(tv, indexer_, for_loops_);
5556+
5557+
Val* zero = tv->fusion()->zeroVal();
5558+
5559+
auto resize = dynamic_cast<Resize*>(tv->axis(1)->definition());
5560+
NVF_ERROR(resize != nullptr);
5561+
5562+
auto logical_idx = addExpr(
5563+
mulExpr(loop_indices.at(0), createInt(i1)), loop_indices.at(1));
5564+
5565+
auto resize_idx = loop_indices.at(1);
5566+
5567+
return andExpr(
5568+
andExpr(
5569+
andExpr(
5570+
geExpr(logical_idx, zero),
5571+
ltExpr(logical_idx, createInt(i0 * i1))),
5572+
geExpr(resize_idx, zero)),
5573+
ltExpr(resize_idx, createInt(i1)));
5574+
}
5575+
};
5576+
5577+
PredicateIndexValidator<GetReference>::validate(&fusion, false);
5578+
5579+
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
5580+
auto t0 = at::randn({i0 * i1}, options);
5581+
std::vector<c10::IValue> inputs{t0};
5582+
5583+
KernelExecutor ke;
5584+
ke.compile(&fusion, inputs);
5585+
auto outputs = ke.run(inputs);
5586+
5587+
testValidate(&fusion, outputs, inputs, __LINE__, __FILE__);
5588+
}
5589+
5590+
TEST_F(PredicateIndexingTest, SplitThenPadTwice) {
5591+
Fusion fusion;
5592+
FusionGuard fg(&fusion);
5593+
5594+
const int64_t i0 = 4;
5595+
const int64_t i1 = 32;
5596+
5597+
auto zero = fusion.zeroVal();
5598+
5599+
auto tv0 = makeContigConcreteTensor({i0 * i1});
5600+
fusion.addInput(tv0);
5601+
5602+
auto tv1 = set(tv0);
5603+
auto tv2 =
5604+
reshape(tv1, {IrBuilder::create<Val>(i0), IrBuilder::create<Val>(i1)});
5605+
auto tv3 = pad(tv2, {zero, IrBuilder::create<Val>(1L)});
5606+
auto tv4 = pad(tv3, {IrBuilder::create<Val>(1L), zero});
5607+
auto tv5 = set(tv4);
5608+
fusion.addOutput(tv5);
5609+
5610+
scheduler_tools::propagateResizeToInputs(tv3->definition());
5611+
scheduler_tools::propagateResizeToInputs(tv4->definition());
5612+
5613+
inlineMost();
5614+
5615+
// tv1 should be scheduled as:
5616+
//
5617+
// T1_l_float[iS14{4}, iS18{34}] ca_pos( 2 )
5618+
// logical domain : (iS1{128})
5619+
// contiguity: t
5620+
// Outer split: iS1{128} by factor 4 -> iS14{4}, iS15{32}
5621+
// Resize: iS15{32} by 0 and 1 -> iS16{33}
5622+
// Resize: iS16{33} by 1 and 0 -> iS18{34}
5623+
// loop domain : (iS14{4}, iS18{34})
5624+
//
5625+
// In addition to its logical ID, the two resize input IDs should be
5626+
// predicated.
5627+
5628+
struct GetReference : AbstractGetReference {
5629+
GetReference(const TensorIndexer& indexer, const IdModel& id_model)
5630+
: AbstractGetReference(indexer, id_model) {}
5631+
5632+
Val* getInlinePredicate(TensorView* tv) const override {
5633+
if (tv->name() != 1) {
5634+
return nullptr;
5635+
}
5636+
5637+
// Without index hoist and expr simplification, the predicate
5638+
// should look like:
5639+
//
5640+
// (((((((((i0 * 32LL) + (i1 - 1LL)) >= 0LL) &&
5641+
// (((i0 * 32LL) + (i1 - 1LL)) < 128LL)) &&
5642+
// ((i1 - 1LL) >= 0LL)) &&
5643+
// ((i1 - 1LL) < 33LL)) &&
5644+
// ((i1 - 1LL) >= 0LL)) &&
5645+
// ((i1 - 1LL) < 32LL)))
5646+
5647+
std::vector<Val*> loop_indices = getLoopIndices(tv, indexer_, for_loops_);
5648+
5649+
Val* zero = tv->fusion()->zeroVal();
5650+
Val* one = tv->fusion()->oneVal();
5651+
5652+
auto resize = dynamic_cast<Resize*>(tv->axis(1)->definition());
5653+
NVF_ERROR(resize != nullptr);
5654+
5655+
auto logical_idx = addExpr(
5656+
mulExpr(loop_indices.at(0), createInt(i1)),
5657+
subExpr(loop_indices.at(1), one));
5658+
5659+
auto resize_idx = subExpr(loop_indices.at(1), one);
5660+
5661+
return andExpr(
5662+
andExpr(
5663+
andExpr(
5664+
andExpr(
5665+
andExpr(
5666+
geExpr(logical_idx, zero),
5667+
ltExpr(logical_idx, createInt(i0 * i1))),
5668+
geExpr(resize_idx, zero)),
5669+
ltExpr(resize_idx, createInt(i1 + 1))),
5670+
geExpr(resize_idx, zero)),
5671+
ltExpr(resize_idx, createInt(i1)));
5672+
}
5673+
};
5674+
5675+
PredicateIndexValidator<GetReference>::validate(&fusion, false);
5676+
5677+
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
5678+
auto t0 = at::randn({i0 * i1}, options);
5679+
std::vector<c10::IValue> inputs{t0};
5680+
5681+
KernelExecutor ke;
5682+
ke.compile(&fusion, inputs);
5683+
auto outputs = ke.run(inputs);
5684+
5685+
testValidate(&fusion, outputs, inputs, __LINE__, __FILE__);
5686+
}
5687+
5688+
// Testing a split reshape followed by slice and pad, which is a
5689+
// common pattern in RoPE.
5690+
TEST_F(PredicateIndexingTest, SplitThenSliceAndPad) {
5691+
Fusion fusion;
5692+
FusionGuard fg(&fusion);
5693+
5694+
const int64_t i0 = 4;
5695+
const int64_t i1 = 32;
5696+
5697+
auto zero = fusion.zeroVal();
5698+
5699+
auto tv0 = makeContigConcreteTensor({i0 * i1});
5700+
fusion.addInput(tv0);
5701+
5702+
auto tv1 = set(tv0);
5703+
auto tv2 =
5704+
reshape(tv1, {IrBuilder::create<Val>(i0), IrBuilder::create<Val>(i1)});
5705+
auto tv3 = slice(
5706+
tv2,
5707+
{{zero, IrBuilder::create<Val>(i0)},
5708+
{IrBuilder::create<Val>(i1 / 2), IrBuilder::create<Val>(i1)}});
5709+
auto tv4 = pad(tv3, {zero, IrBuilder::create<Val>(i1 / 2)});
5710+
auto tv5 = set(tv4);
5711+
fusion.addOutput(tv5);
5712+
5713+
scheduler_tools::propagateResizeToInputs(tv3->definition());
5714+
scheduler_tools::propagateResizeToInputs(tv4->definition());
5715+
5716+
inlineMost();
5717+
5718+
// tv1 should be scheduled as:
5719+
//
5720+
// T1_l_float[iS14{4}, iS18{32}] ca_pos( 2 )
5721+
// logical domain : (iS1{128})
5722+
// contiguity: t
5723+
// Outer split: iS1{128} by factor 4 -> iS14{4}, iS15{32}
5724+
// Resize: iS15{32} by -16 and 0 -> iS16{16}
5725+
// Resize: iS16{16} by 0 and 16 -> iS18{32}
5726+
// loop domain : (iS14{4}, iS18{32})
5727+
//
5728+
// In addition to its logical ID, the input of the second resize
5729+
// should be predicated. The first resize should not be predicated
5730+
// as its input can be known to cover the output since the expansion
5731+
// factors are static, so as long as the index of the
5732+
// output is within the boundary, its index should never need to be
5733+
// predicated.
5734+
5735+
struct GetReference : AbstractGetReference {
5736+
GetReference(const TensorIndexer& indexer, const IdModel& id_model)
5737+
: AbstractGetReference(indexer, id_model) {}
5738+
5739+
Val* getInlinePredicate(TensorView* tv) const override {
5740+
if (tv->name() != 1) {
5741+
return nullptr;
5742+
}
5743+
5744+
// Without index hoist and expr simplification, the predicate
5745+
// should look like:
5746+
//
5747+
// (((((((i0 * 32LL) + (i1 + 16LL)) >= 0LL) &&
5748+
// (((i0 * 32LL) + (i1 + 16LL)) < 128LL)) &&
5749+
// (i1 >= 0LL)) &&
5750+
// (i1 < 16LL)))
5751+
5752+
std::vector<Val*> loop_indices = getLoopIndices(tv, indexer_, for_loops_);
5753+
5754+
Val* zero = tv->fusion()->zeroVal();
5755+
5756+
auto resize = dynamic_cast<Resize*>(tv->axis(1)->definition());
5757+
NVF_ERROR(resize != nullptr);
5758+
5759+
auto logical_idx = addExpr(
5760+
mulExpr(loop_indices.at(0), createInt(i1)),
5761+
addExpr(loop_indices.at(1), createInt(i1 / 2)));
5762+
5763+
auto resize_idx = loop_indices.at(1);
5764+
5765+
return andExpr(
5766+
andExpr(
5767+
andExpr(
5768+
geExpr(logical_idx, zero),
5769+
ltExpr(logical_idx, createInt(i0 * i1))),
5770+
geExpr(resize_idx, zero)),
5771+
ltExpr(resize_idx, createInt(i1 / 2)));
5772+
}
5773+
};
5774+
5775+
PredicateIndexValidator<GetReference>::validate(&fusion, false);
5776+
5777+
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
5778+
auto t0 = at::randn({i0 * i1}, options);
5779+
std::vector<c10::IValue> inputs{t0};
5780+
5781+
KernelExecutor ke;
5782+
ke.compile(&fusion, inputs);
5783+
auto outputs = ke.run(inputs);
5784+
5785+
testValidate(&fusion, outputs, inputs, __LINE__, __FILE__);
5786+
}
5787+
55005788
// Repro of issue #3505. The indexing WAR for resize triggered an
55015789
// assertion due to loop promotion.
55025790
TEST_F(IndexingTest, Issue3505Repro1) {

0 commit comments

Comments
 (0)