|
23 | 23 | #include <ops/all_ops.h>
|
24 | 24 | #include <scheduler/tools/abstract_tensor.h>
|
25 | 25 | #include <scheduler/tools/inlining.h>
|
| 26 | +#include <scheduler/tools/loop_domain_scheduler.h> |
26 | 27 | #include <scheduler/tools/resize_utils.h>
|
27 | 28 | #include <scheduler/utils.h>
|
28 | 29 |
|
@@ -5497,6 +5498,293 @@ TEST_F(PredicateIndexingTest, VectorizedResizeRotation) {
|
5497 | 5498 | testValidate(&fusion, outputs, inputs, __LINE__, __FILE__);
|
5498 | 5499 | }
|
5499 | 5500 |
|
| 5501 | +// Check if resize input IDs are predicated. Repro of issue |
| 5502 | +// https://github.com/NVIDIA/Fuser/issues/3710. |
| 5503 | +TEST_F(PredicateIndexingTest, SplitThenPad) { |
| 5504 | + Fusion fusion; |
| 5505 | + FusionGuard fg(&fusion); |
| 5506 | + |
| 5507 | + const int64_t i0 = 4; |
| 5508 | + const int64_t i1 = 32; |
| 5509 | + |
| 5510 | + auto zero = fusion.zeroVal(); |
| 5511 | + |
| 5512 | + auto tv0 = makeContigConcreteTensor({i0 * i1}); |
| 5513 | + fusion.addInput(tv0); |
| 5514 | + |
| 5515 | + auto tv1 = set(tv0); |
| 5516 | + auto tv2 = |
| 5517 | + reshape(tv1, {IrBuilder::create<Val>(i0), IrBuilder::create<Val>(i1)}); |
| 5518 | + auto tv3 = pad(tv2, {zero, IrBuilder::create<Val>(i1)}); |
| 5519 | + auto tv4 = set(tv3); |
| 5520 | + fusion.addOutput(tv4); |
| 5521 | + |
| 5522 | + scheduler_tools::propagateResizeToInputs(tv3->definition()); |
| 5523 | + |
| 5524 | + inlineMost(); |
| 5525 | + |
| 5526 | + // tv1 should be scheduled as: |
| 5527 | + // |
| 5528 | + // T1_l_float[iS11{4}, iS13{64}] |
| 5529 | + // logical domain : (iS1{128}) |
| 5530 | + // contiguity: t |
| 5531 | + // Outer split: iS1{128} by factor 4 -> iS11{4}, iS12{32} |
| 5532 | + // Resize: iS12{32} by 0 and 32 -> iS13{64} |
| 5533 | + // loop domain : (iS11{4}, iS13{64}) |
| 5534 | + // |
| 5535 | + // In addition to its logical ID, the resize input ID should be |
| 5536 | + // predicated. |
| 5537 | + |
| 5538 | + struct GetReference : AbstractGetReference { |
| 5539 | + GetReference(const TensorIndexer& indexer, const IdModel& id_model) |
| 5540 | + : AbstractGetReference(indexer, id_model) {} |
| 5541 | + |
| 5542 | + Val* getInlinePredicate(TensorView* tv) const override { |
| 5543 | + if (tv->name() != 1) { |
| 5544 | + return nullptr; |
| 5545 | + } |
| 5546 | + |
| 5547 | + // Without index hoist and expr simplification, the predicate |
| 5548 | + // should look like: |
| 5549 | + // |
| 5550 | + // (((((((i0 * 32LL) + i1) >= 0LL) && |
| 5551 | + // (((i0 * 32LL) + i1) < 128LL)) && |
| 5552 | + // (i1 >= 0LL)) && |
| 5553 | + // (i1 < 32LL))) |
| 5554 | + |
| 5555 | + std::vector<Val*> loop_indices = getLoopIndices(tv, indexer_, for_loops_); |
| 5556 | + |
| 5557 | + Val* zero = tv->fusion()->zeroVal(); |
| 5558 | + |
| 5559 | + auto resize = dynamic_cast<Resize*>(tv->axis(1)->definition()); |
| 5560 | + NVF_ERROR(resize != nullptr); |
| 5561 | + |
| 5562 | + auto logical_idx = addExpr( |
| 5563 | + mulExpr(loop_indices.at(0), createInt(i1)), loop_indices.at(1)); |
| 5564 | + |
| 5565 | + auto resize_idx = loop_indices.at(1); |
| 5566 | + |
| 5567 | + return andExpr( |
| 5568 | + andExpr( |
| 5569 | + andExpr( |
| 5570 | + geExpr(logical_idx, zero), |
| 5571 | + ltExpr(logical_idx, createInt(i0 * i1))), |
| 5572 | + geExpr(resize_idx, zero)), |
| 5573 | + ltExpr(resize_idx, createInt(i1))); |
| 5574 | + } |
| 5575 | + }; |
| 5576 | + |
| 5577 | + PredicateIndexValidator<GetReference>::validate(&fusion, false); |
| 5578 | + |
| 5579 | + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| 5580 | + auto t0 = at::randn({i0 * i1}, options); |
| 5581 | + std::vector<c10::IValue> inputs{t0}; |
| 5582 | + |
| 5583 | + KernelExecutor ke; |
| 5584 | + ke.compile(&fusion, inputs); |
| 5585 | + auto outputs = ke.run(inputs); |
| 5586 | + |
| 5587 | + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); |
| 5588 | +} |
| 5589 | + |
| 5590 | +TEST_F(PredicateIndexingTest, SplitThenPadTwice) { |
| 5591 | + Fusion fusion; |
| 5592 | + FusionGuard fg(&fusion); |
| 5593 | + |
| 5594 | + const int64_t i0 = 4; |
| 5595 | + const int64_t i1 = 32; |
| 5596 | + |
| 5597 | + auto zero = fusion.zeroVal(); |
| 5598 | + |
| 5599 | + auto tv0 = makeContigConcreteTensor({i0 * i1}); |
| 5600 | + fusion.addInput(tv0); |
| 5601 | + |
| 5602 | + auto tv1 = set(tv0); |
| 5603 | + auto tv2 = |
| 5604 | + reshape(tv1, {IrBuilder::create<Val>(i0), IrBuilder::create<Val>(i1)}); |
| 5605 | + auto tv3 = pad(tv2, {zero, IrBuilder::create<Val>(1L)}); |
| 5606 | + auto tv4 = pad(tv3, {IrBuilder::create<Val>(1L), zero}); |
| 5607 | + auto tv5 = set(tv4); |
| 5608 | + fusion.addOutput(tv5); |
| 5609 | + |
| 5610 | + scheduler_tools::propagateResizeToInputs(tv3->definition()); |
| 5611 | + scheduler_tools::propagateResizeToInputs(tv4->definition()); |
| 5612 | + |
| 5613 | + inlineMost(); |
| 5614 | + |
| 5615 | + // tv1 should be scheduled as: |
| 5616 | + // |
| 5617 | + // T1_l_float[iS14{4}, iS18{34}] ca_pos( 2 ) |
| 5618 | + // logical domain : (iS1{128}) |
| 5619 | + // contiguity: t |
| 5620 | + // Outer split: iS1{128} by factor 4 -> iS14{4}, iS15{32} |
| 5621 | + // Resize: iS15{32} by 0 and 1 -> iS16{33} |
| 5622 | + // Resize: iS16{33} by 1 and 0 -> iS18{34} |
| 5623 | + // loop domain : (iS14{4}, iS18{34}) |
| 5624 | + // |
| 5625 | + // In addition to its logical ID, the two resize input IDs should be |
| 5626 | + // predicated. |
| 5627 | + |
| 5628 | + struct GetReference : AbstractGetReference { |
| 5629 | + GetReference(const TensorIndexer& indexer, const IdModel& id_model) |
| 5630 | + : AbstractGetReference(indexer, id_model) {} |
| 5631 | + |
| 5632 | + Val* getInlinePredicate(TensorView* tv) const override { |
| 5633 | + if (tv->name() != 1) { |
| 5634 | + return nullptr; |
| 5635 | + } |
| 5636 | + |
| 5637 | + // Without index hoist and expr simplification, the predicate |
| 5638 | + // should look like: |
| 5639 | + // |
| 5640 | + // (((((((((i0 * 32LL) + (i1 - 1LL)) >= 0LL) && |
| 5641 | + // (((i0 * 32LL) + (i1 - 1LL)) < 128LL)) && |
| 5642 | + // ((i1 - 1LL) >= 0LL)) && |
| 5643 | + // ((i1 - 1LL) < 33LL)) && |
| 5644 | + // ((i1 - 1LL) >= 0LL)) && |
| 5645 | + // ((i1 - 1LL) < 32LL))) |
| 5646 | + |
| 5647 | + std::vector<Val*> loop_indices = getLoopIndices(tv, indexer_, for_loops_); |
| 5648 | + |
| 5649 | + Val* zero = tv->fusion()->zeroVal(); |
| 5650 | + Val* one = tv->fusion()->oneVal(); |
| 5651 | + |
| 5652 | + auto resize = dynamic_cast<Resize*>(tv->axis(1)->definition()); |
| 5653 | + NVF_ERROR(resize != nullptr); |
| 5654 | + |
| 5655 | + auto logical_idx = addExpr( |
| 5656 | + mulExpr(loop_indices.at(0), createInt(i1)), |
| 5657 | + subExpr(loop_indices.at(1), one)); |
| 5658 | + |
| 5659 | + auto resize_idx = subExpr(loop_indices.at(1), one); |
| 5660 | + |
| 5661 | + return andExpr( |
| 5662 | + andExpr( |
| 5663 | + andExpr( |
| 5664 | + andExpr( |
| 5665 | + andExpr( |
| 5666 | + geExpr(logical_idx, zero), |
| 5667 | + ltExpr(logical_idx, createInt(i0 * i1))), |
| 5668 | + geExpr(resize_idx, zero)), |
| 5669 | + ltExpr(resize_idx, createInt(i1 + 1))), |
| 5670 | + geExpr(resize_idx, zero)), |
| 5671 | + ltExpr(resize_idx, createInt(i1))); |
| 5672 | + } |
| 5673 | + }; |
| 5674 | + |
| 5675 | + PredicateIndexValidator<GetReference>::validate(&fusion, false); |
| 5676 | + |
| 5677 | + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| 5678 | + auto t0 = at::randn({i0 * i1}, options); |
| 5679 | + std::vector<c10::IValue> inputs{t0}; |
| 5680 | + |
| 5681 | + KernelExecutor ke; |
| 5682 | + ke.compile(&fusion, inputs); |
| 5683 | + auto outputs = ke.run(inputs); |
| 5684 | + |
| 5685 | + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); |
| 5686 | +} |
| 5687 | + |
| 5688 | +// Testing a split reshape followed by slice and pad, which is a |
| 5689 | +// common pattern in RoPE. |
| 5690 | +TEST_F(PredicateIndexingTest, SplitThenSliceAndPad) { |
| 5691 | + Fusion fusion; |
| 5692 | + FusionGuard fg(&fusion); |
| 5693 | + |
| 5694 | + const int64_t i0 = 4; |
| 5695 | + const int64_t i1 = 32; |
| 5696 | + |
| 5697 | + auto zero = fusion.zeroVal(); |
| 5698 | + |
| 5699 | + auto tv0 = makeContigConcreteTensor({i0 * i1}); |
| 5700 | + fusion.addInput(tv0); |
| 5701 | + |
| 5702 | + auto tv1 = set(tv0); |
| 5703 | + auto tv2 = |
| 5704 | + reshape(tv1, {IrBuilder::create<Val>(i0), IrBuilder::create<Val>(i1)}); |
| 5705 | + auto tv3 = slice( |
| 5706 | + tv2, |
| 5707 | + {{zero, IrBuilder::create<Val>(i0)}, |
| 5708 | + {IrBuilder::create<Val>(i1 / 2), IrBuilder::create<Val>(i1)}}); |
| 5709 | + auto tv4 = pad(tv3, {zero, IrBuilder::create<Val>(i1 / 2)}); |
| 5710 | + auto tv5 = set(tv4); |
| 5711 | + fusion.addOutput(tv5); |
| 5712 | + |
| 5713 | + scheduler_tools::propagateResizeToInputs(tv3->definition()); |
| 5714 | + scheduler_tools::propagateResizeToInputs(tv4->definition()); |
| 5715 | + |
| 5716 | + inlineMost(); |
| 5717 | + |
| 5718 | + // tv1 should be scheduled as: |
| 5719 | + // |
| 5720 | + // T1_l_float[iS14{4}, iS18{32}] ca_pos( 2 ) |
| 5721 | + // logical domain : (iS1{128}) |
| 5722 | + // contiguity: t |
| 5723 | + // Outer split: iS1{128} by factor 4 -> iS14{4}, iS15{32} |
| 5724 | + // Resize: iS15{32} by -16 and 0 -> iS16{16} |
| 5725 | + // Resize: iS16{16} by 0 and 16 -> iS18{32} |
| 5726 | + // loop domain : (iS14{4}, iS18{32}) |
| 5727 | + // |
| 5728 | + // In addition to its logical ID, the input of the second resize |
| 5729 | + // should be predicated. The first resize should not be predicated |
| 5730 | + // as its input can be known to cover the output since the expansion |
| 5731 | + // factors are static, so as long as the index of the |
| 5732 | + // output is within the boundary, its index should never need to be |
| 5733 | + // predicated. |
| 5734 | + |
| 5735 | + struct GetReference : AbstractGetReference { |
| 5736 | + GetReference(const TensorIndexer& indexer, const IdModel& id_model) |
| 5737 | + : AbstractGetReference(indexer, id_model) {} |
| 5738 | + |
| 5739 | + Val* getInlinePredicate(TensorView* tv) const override { |
| 5740 | + if (tv->name() != 1) { |
| 5741 | + return nullptr; |
| 5742 | + } |
| 5743 | + |
| 5744 | + // Without index hoist and expr simplification, the predicate |
| 5745 | + // should look like: |
| 5746 | + // |
| 5747 | + // (((((((i0 * 32LL) + (i1 + 16LL)) >= 0LL) && |
| 5748 | + // (((i0 * 32LL) + (i1 + 16LL)) < 128LL)) && |
| 5749 | + // (i1 >= 0LL)) && |
| 5750 | + // (i1 < 16LL))) |
| 5751 | + |
| 5752 | + std::vector<Val*> loop_indices = getLoopIndices(tv, indexer_, for_loops_); |
| 5753 | + |
| 5754 | + Val* zero = tv->fusion()->zeroVal(); |
| 5755 | + |
| 5756 | + auto resize = dynamic_cast<Resize*>(tv->axis(1)->definition()); |
| 5757 | + NVF_ERROR(resize != nullptr); |
| 5758 | + |
| 5759 | + auto logical_idx = addExpr( |
| 5760 | + mulExpr(loop_indices.at(0), createInt(i1)), |
| 5761 | + addExpr(loop_indices.at(1), createInt(i1 / 2))); |
| 5762 | + |
| 5763 | + auto resize_idx = loop_indices.at(1); |
| 5764 | + |
| 5765 | + return andExpr( |
| 5766 | + andExpr( |
| 5767 | + andExpr( |
| 5768 | + geExpr(logical_idx, zero), |
| 5769 | + ltExpr(logical_idx, createInt(i0 * i1))), |
| 5770 | + geExpr(resize_idx, zero)), |
| 5771 | + ltExpr(resize_idx, createInt(i1 / 2))); |
| 5772 | + } |
| 5773 | + }; |
| 5774 | + |
| 5775 | + PredicateIndexValidator<GetReference>::validate(&fusion, false); |
| 5776 | + |
| 5777 | + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); |
| 5778 | + auto t0 = at::randn({i0 * i1}, options); |
| 5779 | + std::vector<c10::IValue> inputs{t0}; |
| 5780 | + |
| 5781 | + KernelExecutor ke; |
| 5782 | + ke.compile(&fusion, inputs); |
| 5783 | + auto outputs = ke.run(inputs); |
| 5784 | + |
| 5785 | + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); |
| 5786 | +} |
| 5787 | + |
5500 | 5788 | // Repro of issue #3505. The indexing WAR for resize triggered an
|
5501 | 5789 | // assertion due to loop promotion.
|
5502 | 5790 | TEST_F(IndexingTest, Issue3505Repro1) {
|
|
0 commit comments