Skip to content

Pm/preseg reorder sharded #4256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft

Pm/preseg reorder sharded #4256

wants to merge 19 commits into from

Conversation

Priya2698
Copy link
Collaborator

No description provided.

@Priya2698
Copy link
Collaborator Author

!test

Copy link

Description

  • Added support for resharding in ReorderShardedAxisPass

  • Introduced getReshardingIdPair to identify resharding between producer and consumer

  • Updated axisIndex to work with any domain vector

  • Added tests for AllgatherLoopSplit and ReduceScatterLoopSplit


Changes walkthrough 📝

Relevant files
Enhancement
utils.cpp
Add resharding support and utility functions                         

csrc/multidevice/utils.cpp

  • Added axisIndex function to find axis index in a domain vector
  • Introduced getReshardingIdPair to identify resharding between producer
    and consumer
  • Updated isInnerResharding to use axisIndex
  • +57/-13 
    reorder_sharded_axis.cpp
    Implement resharding logic in pass                                             

    csrc/preseg_passes/reorder_sharded_axis.cpp

  • Added IdModel and ValGraph for exact graph building
  • Implemented resharding logic in ReorderShardedAxisPass
  • Added propagateTransform function to propagate transformations
  • +49/-3   
    utils.h
    Add function declarations for resharding                                 

    csrc/multidevice/utils.h

  • Added axisIndex declaration
  • Added getReshardingIdPair declaration
  • Added getInputsInTargetDomain declaration
  • +14/-0   
    Tests
    test_multidevice_communications.cpp
    Add loop split tests for communications                                   

    tests/cpp/test_multidevice_communications.cpp

  • Added AllgatherLoopSplit test case
  • Added ReduceScatterLoopSplit test case
  • Added AllreduceLoopSplit test case
  • +128/-0 
    test_multidevice_lower_communication.cpp
    Update communication lowering tests                                           

    tests/cpp/test_multidevice_lower_communication.cpp

  • Disabled AllgatherLoopSplit_Noncontig test
  • Added ReorderShardedAxisPass to test setup
  • +24/-17 

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The function axisIndex does not correctly handle the case where find_id is not found in the domain. It should return -1 in such cases, but the current implementation returns the index of the last non-trivial dimension instead.

    int64_t axisIndex(std::vector<IterDomain*> domain, IterDomain* find_id) {
      int64_t index = 0;
      for (auto* id : domain) {
        if (id == find_id) {
          return index;
        }
        if (!id->isDeviceDim() && !id->isReduction() &&
            !id->isBroadcast()) {
          index++;
        }
      }
    Redundant Code

    The function getReshardingIdPair checks if both producer and consumer IDs are not reduction or broadcast axes. This check is redundant because it is already ensured in the previous loop.

    std::optional<std::pair<IterDomain*, IterDomain*>> getReshardingIdPair(TensorView* producer, TensorView* consumer, ValGraph& graph) {
      auto p_loop_domain = producer->getLoopDomain();
      auto c_loop_domain = consumer->getLoopDomain();
      auto p2c_map = graph.buildMapBetween(
                p_loop_domain, c_loop_domain);
    
      std::vector<std::pair<IterDomain*, IterDomain*>> resharding_id_pairs;
    
      bool has_sharding_changes = false;
    
      IterDomain* resharded_p_id = nullptr;
      IterDomain* resharded_c_id = nullptr;
    
      for (auto [p_val, c_vals] : p2c_map) {
        auto p_id = p_val->as<IterDomain>();
        auto c_id = c_vals.front()->as<IterDomain>();
    
        if (!p_id->isDeviceDim() && !c_id->isDeviceDim()) {
          continue;
        }
    
        // No reordering for reduction and broadcast axes.
        if (p_id->isReduction() || p_id->isBroadcast()) {
          continue;
        }
        if (c_id->isReduction() || c_id->isBroadcast()) {
          continue;
        }
    Disabled Test

    The test AllgatherLoopSplit_Noncontig is disabled. It should be enabled and verified to ensure that the new functionality works as expected.

    TEST_P(LowerCollectiveTest, DISABLED_AllgatherLoopSplit_Noncontig) {
      auto fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      // ProcessGroupNCCL requires the gathered axis to be outermost.
      // We change the allocation of tensorviews to reflect this.
      // We do not modify the logical shape of the tensorview.
      // This would still require one copy on each device if the input tensor is in
      // a different layout.
      const auto d = communicator_->size();
      auto mesh = DeviceMesh::createForNumDevices(d);
    
      TensorView* tv0 = makeConcreteTensor({5, d * 3});
      tv0->outer_split(1, d);
      tv0->axis(1)->parallelize(ParallelType::DIDx);
      // tv0->reorder({{1, 0}, {2, 1}, {0, 2}});
      // tv0: Logical = [5, d*3], Loop/Allocation = [DIDx(d), 3, 5]
    
      TensorView* tv1 = set(tv0);
      tv1->outer_split(1, d);
      tv1->axis(1)->parallelize(ParallelType::Serial);
      // tv1->reorder({{1, 0}, {2, 1}, {0, 2}});
      // tv1: Logical = [5, d*3], Loop/Allocation = [Serial(d), 3, 5]
    
      for (auto tv : {tv0, tv1}) {
        tv->setDeviceMesh(mesh);
        // tv->setAllocationDomain(tv->getLoopDomain(), true);
      }
    
      fusion->addInput(tv0);
      fusion->addOutput(tv1);
    
      preseg_passes::OptimizationPass<preseg_passes::ReorderShardedAxisPass>::runPass(fusion.get());
      for (auto tv : fusion->allTvs()) {
        debug() << tv->toString() << std::endl;
        debug() << tv->getMaybeAllocationDomain() << std::endl;
      }
    
      // at::Tensor unsharded_in_tensor = at::randn({5, d * 3}, tensor_options);
      // at::Tensor in_tensor = shardTensor(unsharded_in_tensor, 1, mesh);
    
      // FusionExecutorCache executor_cache(std::move(fusion));
      // at::Tensor out_tensor =
      //     executor_cache.runFusionWithInputs({in_tensor})[0].as<at::Tensor>();
    
      // testValidate(
      //     executor_cache.fusion(),
      //     {out_tensor},
      //     {in_tensor},
      //     {unsharded_in_tensor.transpose(0, 1)},
      //     __LINE__,
      //     __FILE__);
    }

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    1 participant