Skip to content

InsertReshardingsPass decomposes matmul/linear+ReduceScatter. #4239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 16, 2025
Merged

Conversation

wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Apr 11, 2025

As a follow-up to #4209, handle ReduceScatter as well.

Fixes #4133

Copy link

github-actions bot commented Apr 11, 2025

Review updated until commit 0036c5e

Description

  • Handle ReduceScatter in InsertReshardingsPass

  • Add test for linear with ReduceScatter

  • Reset FusionCache before each multi-device test

  • Update setup_process_group to use Communicator


Changes walkthrough 📝

Relevant files
Enhancement
utils.cpp
Skip reduction IterDomains in device parallel type mapping

csrc/multidevice/utils.cpp

  • Skip reduction IterDomains in mapDeviceParallelTypeToId
  • Remove redundant reduction check in haveDifferentShardings
  • +7/-10   
    insert_reshardings.cpp
    Enhance rFactorLoopSplits to handle reduced parallel types

    csrc/preseg_passes/insert_reshardings.cpp

  • Add handling for reduced parallel types in rFactorLoopSplits
  • Unparallelize first iDIDx{d} after rFactor
  • +37/-2   
    Tests
    fixtures.py
    Reset FusionCache before each multi-device test                   

    tests/python/multidevice/fixtures.py

    • Reset FusionCache in multidevice_test fixture
    +14/-1   
    test_dtensor.py
    Update setup_process_group and test function signatures   

    tests/python/multidevice/test_dtensor.py

  • Update setup_process_group to use Communicator
  • Pass multidevice_test to test functions
  • +9/-5     
    test_matmul.py
    Add test for linear with ReduceScatter                                     

    tests/python/multidevice/test_matmul.py

    • Add test for linear with ReduceScatter
    +46/-0   
    test_overlap.py
    Remove redundant FusionCache reset                                             

    tests/python/multidevice/test_overlap.py

    • Remove redundant FusionCache reset
    +0/-3     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The change in mapDeviceParallelTypeToId might skip important reduction domains that need to be considered for resharding.

    // rDIDx{i0}, usually a product of an Allreduce or a ReduceScatter, is
    // treated as replicated. This way `iDIDx{i0} => rDIDx{i0}` is considered
    // resharding.
    if (id->isReduction()) {
      continue;
    }
    Code Clarity

    The logic for handling reduction parallel types in rFactorLoopSplits could be more clearly documented or refactored for better readability.

    for (auto&& [i, loop_id] : enumerate(tv->getLoopDomain())) {
      if (!loop_id->isReduction()) {
        // rFactor only applies to reduction dimensions.
        continue;
      }
    
      if (std::count(
              tv->getLogicalDomain().begin(),
              tv->getLogicalDomain().end(),
              loop_id) > 0) {
        // No need to rFactor if loop_id is in the logical domain.
        continue;
      }
    
      const ParallelType parallel_type = loop_id->getParallelType();
      if (parallel_type == ParallelType::Serial) {
        // rFactor non-parallelized IDs so they get reduced locally.
        rfactor_axes.push_back(i);
      } else {
        reduced_parallel_types.insert(parallel_type);
      }
    }
    
    if (!rfactor_axes.empty()) {
      TensorView* local = tv->rFactor(rfactor_axes);
      // Before rFactor:
      //
      // [i{m}         i{n}         r{k}]
      //               /  \         /   \.
      //         iDIDx{d} i{n/d} rDIDx{d}  r{k/d}
      //
      // After rFactor:
      //
      //                            r{k}
      //                            /  \.
      // [i{m}         i{n}    iDIDx{d}  r{k/d}]
      //               /  \.
      //         iDIDx{d} i{n/d}
      //
      //                 |
      //                 | reduce
      //                 v
      //
      // [i{m}         i{n}    rDIDx{d}]
      //               /  \.
      //         iDIDx{d} i{n/d}
      //
      // The TensorView returned by rFactor has two iDIDx, which is disallowed.
      // The following code unparallelizes the first iDIDx{d}.
      for (IterDomain* loop_id : local->getLoopDomain()) {
        if (!loop_id->isRFactorProduct() &&
            reduced_parallel_types.count(loop_id->getParallelType())) {
          loop_id->parallelize(ParallelType::Serial);
        }
      }
    Test Coverage

    The new test test_linear_reduce_scatter should be evaluated for its coverage and performance impact compared to existing tests.

    def test_linear_reduce_scatter(multidevice_test):
        d = multidevice_test.size
        mesh = nvfuser.DeviceMesh(range(d))
        e = 768
    
        class Model(FusionDefinition):
            def definition(self):
                self.inp = self.define_tensor([-1, -1, d * e])
                self.weight = self.define_tensor([e, d * e])
                self.out = self.ops.linear(self.inp, self.weight, None)
                self.add_output(self.out)
    
            def multidevice_schedule(self):
                for t in [self.inp, self.weight, self.out]:
                    self.sched._set_device_mesh(t, mesh)
                    self.sched.split(t, -1, d, False)
                    self.sched.parallelize(t, -2, nvfuser.ParallelType.mesh_x)
                    self.sched.set_allocation_as_loop(t)
    
                # Scatter
                self.sched.split(self.out, 1, d, False)
                self.sched.parallelize(self.out, 1, nvfuser.ParallelType.mesh_x)
    
        torch.cuda.set_device(multidevice_test.local_rank)
    
        b, s = 2, 1024
        unsharded_inp = torch.randn(b, s, d * e)
        unsharded_weight = torch.randn(e, d * e)
    
        inp = multidevice_test.shard_tensor(unsharded_inp, -1, mesh)
        weight = multidevice_test.shard_tensor(unsharded_weight, -1, mesh)
    
        fd = Model()
        (out,), _ = fd.execute([inp, weight])
    
        unsharded_out = torch.nn.functional.linear(unsharded_inp, unsharded_weight, None)
        # rtol is the same as the default for fp32. atol is slightly increased.
        torch.testing.assert_close(
            out,
            multidevice_test.shard_tensor(unsharded_out, 1, mesh),
            rtol=1.3e-6,
            atol=1e-3,
        )
    

    // rDIDx{i0}, usually a product of an Allreduce or a ReduceScatter, is
    // treated as replicated. This way `iDIDx{i0} => rDIDx{i0}` is considered
    // resharding.
    if (id->isReduction()) {
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This is moved from Line 529 so getShardedLogicalAxis also enjoys the fix.

    @wujingyue wujingyue requested a review from Priya2698 April 11, 2025 06:06
    @wujingyue
    Copy link
    Collaborator Author

    !test

    @wujingyue
    Copy link
    Collaborator Author

    The failing test is due to a bug in FusionCache, which I'll sort out in a separate PR. test_row_parallel_linear and test_linear_reduce_scatter have the same FusionDefinition.definition but different FusionDefinition.multidevice_schedule. They are incorrectly mapped to the same FusionSchedule and therefore the same FusionExecutorCache.

    I believe this is the same error as https://github.com/NVIDIA/Fuser/blob/main/tests/python/multidevice/test_overlap.py#L91

    cc @rdspring1

    Priya2698
    Priya2698 approved these changes Apr 15, 2025
    @wujingyue
    Copy link
    Collaborator Author

    !test

    @wujingyue
    Copy link
    Collaborator Author

    !test

    The former is per module and the latter is per function.
    @wujingyue
    Copy link
    Collaborator Author

    !test

    @wujingyue wujingyue merged commit 85a9463 into main Apr 16, 2025
    46 of 47 checks passed
    @wujingyue wujingyue deleted the wjy/rs branch April 16, 2025 20:04
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    InsertReshardingsPass decomposes matmul/linear+allreduce.
    2 participants