Skip to content

Conversation

@jberchtold-nvidia
Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia commented Oct 27, 2025

Description

When not using Flax logical axes, a global MeshResource context was required by TE even when no JAX mesh was active. This PR fixes this case and avoids raising an unnecessary error.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Update TE's logical axis system to return an empty mapping and avoid raising an error when no JAX mesh is active. Use case, new user to TE trying it out on single-GPU. Want to avoid unnecessary complexity in cases where it isn't required

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L1 jax

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR removes an unnecessary requirement for a global MeshResource context when no JAX mesh is active in the distributed training setup. The fix modifies get_sharding_map_logic_axis_to_mesh_axis() in transformer_engine/jax/sharding.py to check if a mesh exists before attempting to retrieve the global mesh resource. When no mesh is defined (common in single-GPU scenarios), the function now returns an empty dictionary instead of raising an error. This change simplifies the onboarding experience for new users experimenting with TransformerEngine on a single GPU, as they no longer need to configure distributed training infrastructure that they don't intend to use. The modification maintains backward compatibility - when a mesh is present, the original behavior is preserved.

Important Files Changed

Filename Score Overview
transformer_engine/jax/sharding.py 5/5 Added early return check for mesh existence/emptiness to avoid requiring MeshResource context in single-GPU scenarios

Confidence score: 5/5

  • This PR is safe to merge with minimal risk as it only adds a protective check for edge cases without modifying existing behavior
  • The change is well-isolated, adds defensive logic for the no-mesh scenario, and preserves all existing functionality when a mesh is present
  • No files require special attention; the single modified file contains a straightforward, well-commented fix that addresses the exact issue described

Sequence Diagram

sequenceDiagram
    participant User
    participant TE as TransformerEngine
    participant ShardingModule as sharding.py
    participant JAX
    
    User->>TE: Initialize TE without mesh setup
    TE->>ShardingModule: get_sharding_map_logic_axis_to_mesh_axis()
    ShardingModule->>JAX: Get physical_mesh from thread_resources
    JAX-->>ShardingModule: Returns mesh (None or empty)
    
    alt Mesh is None or empty (FIX)
        ShardingModule-->>TE: Return empty dict {}
        Note over ShardingModule: No MeshResource context required
    else Mesh exists
        ShardingModule->>ShardingModule: global_mesh_resource()
        ShardingModule->>ShardingModule: Build logical-to-mesh axis mapping
        ShardingModule-->>TE: Return axis mapping dict
    end
    
    TE-->>User: Continue execution without error
    
    User->>TE: Call with_sharding_constraint_by_logical_axes()
    TE->>ShardingModule: Check logical_axis_names
    
    alt No logical axes provided
        ShardingModule-->>TE: Return original tensor
    else Logical axes provided
        ShardingModule->>ShardingModule: Try Flax logical axis rules
        alt Flax rules available
            ShardingModule-->>TE: Use Flax constraint
        else Fallback to TE rules (deprecated)
            ShardingModule->>ShardingModule: _generate_pspec()
            ShardingModule->>ShardingModule: get_sharding_map_logic_axis_to_mesh_axis()
            Note over ShardingModule: Returns empty dict if no mesh
            ShardingModule->>ShardingModule: with_sharding_constraint()
            ShardingModule-->>TE: Return constrained tensor
        end
    end
    
    TE-->>User: Return result
Loading

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L1 jax

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made in this single-file PR, which modifies transformer_engine/jax/sharding.py to eliminate the requirement for a MeshResource context when no JAX mesh is active. The function get_sharding_map_logic_axis_to_mesh_axis() now returns an empty dictionary in two new cases: (1) when no physical mesh exists or it's empty, and (2) when all mesh axes are configured as manual axes. This change integrates into TE's JAX sharding infrastructure by short-circuiting the logical-axis-to-mesh-axis mapping logic that would previously fail without a global MeshResource context, making single-GPU scenarios simpler and more user-friendly.

Important Files Changed

Filename Score Overview
transformer_engine/jax/sharding.py 3/5 Adds early returns to skip MeshResource requirement when no mesh or all-manual-axes detected

Confidence score: 3/5

  • This PR improves usability for single-GPU cases but contains potential logic errors in the mesh-axis comparison that could cause incorrect behavior in edge cases
  • Score lowered due to: (1) line 84 condition may incorrectly match when only partial overlap exists between abstract_mesh.manual_axes and mesh.axis_names rather than exact equality, and (2) the sorted() comparison doesn't account for duplicate axes which could lead to false positives
  • Pay close attention to the conditional logic on lines 84-86; verify that sorted list comparison correctly handles all mesh configurations, especially when manual_axes is a subset or superset of mesh.axis_names

1 file reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 83 to 86
abstract_mesh = get_abstract_mesh()
if abstract_mesh is None or sorted(abstract_mesh.manual_axes) == sorted(mesh.axis_names):
# If all mesh axes are manual axes, return an empty dict and do not require a MeshResource context to be present
return {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: check will fail when abstract_mesh is defined with some axes but not matching mesh axes - this would incorrectly skip the MeshResource check. Should this condition also verify that all mesh axes are covered by manual axes when abstract_mesh exists, or is the current comparison sufficient for your use case?

return {}

abstract_mesh = get_abstract_mesh()
if abstract_mesh is None or sorted(abstract_mesh.manual_axes) == sorted(mesh.axis_names):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: comparison does not account for duplicate axes in either list - sorted() removes ordering but not duplicates

return {}

abstract_mesh = get_abstract_mesh()
if abstract_mesh is None or sorted(abstract_mesh.manual_axes) == sorted(mesh.axis_names):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, why do we need to check for if abstract_mesh is None here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't exactly sure how the abstract mesh relates to the PXLA physical mesh above. If they both are always set via with jax.sharding.Mesh():, then this abstract_mesh is None check is redundant.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @phu0ngng! I've removed the abstract_mesh is None check in the latest version

@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L1 jax

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Adds early-return logic to avoid requiring MeshResource context when no JAX mesh is active or when all mesh axes are manually managed. This improves single-GPU usability by removing unnecessary complexity.

Key changes:

  • Lines 78-81: Return empty dict when no mesh exists or mesh is empty
  • Lines 83-86: Return empty dict when all mesh axes are in abstract_mesh.manual_axes
  • Both paths bypass the global_mesh_resource() call that would otherwise raise an error

Issues found:

  • Line 84: The sorted() comparison doesn't deduplicate axes, which could cause unexpected behavior if duplicate axis names exist (though JAX meshes should have unique names)

Confidence Score: 3/5

  • This PR is moderately safe to merge with one edge case to verify
  • The fix correctly addresses the stated problem of unnecessary MeshResource requirements for single-GPU use cases. However, the axis comparison logic on line 84 has a minor issue with duplicate handling that should be verified or addressed
  • transformer_engine/jax/sharding.py - verify the duplicate axis handling in the comparison on line 84

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/sharding.py 3/5 Added early returns to handle no-mesh and all-manual-axes cases, avoiding unnecessary MeshResource requirement. One potential issue with duplicate axis handling.

Sequence Diagram

sequenceDiagram
    participant User
    participant get_sharding_map_logic_axis_to_mesh_axis
    participant PXLA_THREAD_RESOURCES
    participant get_abstract_mesh
    participant global_mesh_resource
    
    User->>get_sharding_map_logic_axis_to_mesh_axis: Call to get sharding map
    get_sharding_map_logic_axis_to_mesh_axis->>PXLA_THREAD_RESOURCES: Get physical_mesh
    
    alt No mesh or empty mesh
        PXLA_THREAD_RESOURCES-->>get_sharding_map_logic_axis_to_mesh_axis: mesh is None or empty
        get_sharding_map_logic_axis_to_mesh_axis-->>User: Return empty dict {}
    else Mesh exists
        PXLA_THREAD_RESOURCES-->>get_sharding_map_logic_axis_to_mesh_axis: mesh with axes
        get_sharding_map_logic_axis_to_mesh_axis->>get_abstract_mesh: Get abstract mesh
        get_abstract_mesh-->>get_sharding_map_logic_axis_to_mesh_axis: abstract_mesh with manual_axes
        
        alt All mesh axes are manual
            get_sharding_map_logic_axis_to_mesh_axis->>get_sharding_map_logic_axis_to_mesh_axis: Compare manual_axes with mesh.axis_names
            get_sharding_map_logic_axis_to_mesh_axis-->>User: Return empty dict {}
        else Need MeshResource mapping
            get_sharding_map_logic_axis_to_mesh_axis->>global_mesh_resource: Get global MeshResource
            global_mesh_resource-->>get_sharding_map_logic_axis_to_mesh_axis: MeshResource config
            get_sharding_map_logic_axis_to_mesh_axis->>get_sharding_map_logic_axis_to_mesh_axis: Build logical-to-mesh mapping
            get_sharding_map_logic_axis_to_mesh_axis-->>User: Return mapping dict
        end
    end
Loading

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

return {}

abstract_mesh = get_abstract_mesh()
if sorted(abstract_mesh.manual_axes) == sorted(mesh.axis_names):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: comparison doesn't handle duplicate axes correctly - sorted(['a', 'a', 'b']) == sorted(['a', 'b']) is False, but sorted(set(['a', 'a', 'b'])) == sorted(set(['a', 'b'])) would be True. Consider whether duplicates in axis names are valid in your use case

Suggested change
if sorted(abstract_mesh.manual_axes) == sorted(mesh.axis_names):
if sorted(set(abstract_mesh.manual_axes)) == sorted(set(mesh.axis_names)):

@jberchtold-nvidia jberchtold-nvidia merged commit 006670d into NVIDIA:main Oct 31, 2025
24 of 25 checks passed
@jberchtold-nvidia jberchtold-nvidia deleted the jberchtold/fix-meshresource-requirement-when-no-mesh branch October 31, 2025 15:36
pggPL pushed a commit to pggPL/TransformerEngine that referenced this pull request Nov 4, 2025
* Fix mesh resource requirement when no mesh

Signed-off-by: Jeremy Berchtold <[email protected]>

* do not require meshresource if all axes are manual axes

Signed-off-by: Jeremy Berchtold <[email protected]>

* remove abstract_mesh is None check

Signed-off-by: Jeremy Berchtold <[email protected]>

---------

Signed-off-by: Jeremy Berchtold <[email protected]>
pggPL pushed a commit to pggPL/TransformerEngine that referenced this pull request Nov 6, 2025
* Fix mesh resource requirement when no mesh

Signed-off-by: Jeremy Berchtold <[email protected]>

* do not require meshresource if all axes are manual axes

Signed-off-by: Jeremy Berchtold <[email protected]>

* remove abstract_mesh is None check

Signed-off-by: Jeremy Berchtold <[email protected]>

---------

Signed-off-by: Jeremy Berchtold <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants