-
Notifications
You must be signed in to change notification settings - Fork 543
[JAX] Fix mesh resource requirement when no mesh #2307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JAX] Fix mesh resource requirement when no mesh #2307
Conversation
Signed-off-by: Jeremy Berchtold <[email protected]>
|
/te-ci L1 jax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR removes an unnecessary requirement for a global MeshResource context when no JAX mesh is active in the distributed training setup. The fix modifies get_sharding_map_logic_axis_to_mesh_axis() in transformer_engine/jax/sharding.py to check if a mesh exists before attempting to retrieve the global mesh resource. When no mesh is defined (common in single-GPU scenarios), the function now returns an empty dictionary instead of raising an error. This change simplifies the onboarding experience for new users experimenting with TransformerEngine on a single GPU, as they no longer need to configure distributed training infrastructure that they don't intend to use. The modification maintains backward compatibility - when a mesh is present, the original behavior is preserved.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/jax/sharding.py | 5/5 | Added early return check for mesh existence/emptiness to avoid requiring MeshResource context in single-GPU scenarios |
Confidence score: 5/5
- This PR is safe to merge with minimal risk as it only adds a protective check for edge cases without modifying existing behavior
- The change is well-isolated, adds defensive logic for the no-mesh scenario, and preserves all existing functionality when a mesh is present
- No files require special attention; the single modified file contains a straightforward, well-commented fix that addresses the exact issue described
Sequence Diagram
sequenceDiagram
participant User
participant TE as TransformerEngine
participant ShardingModule as sharding.py
participant JAX
User->>TE: Initialize TE without mesh setup
TE->>ShardingModule: get_sharding_map_logic_axis_to_mesh_axis()
ShardingModule->>JAX: Get physical_mesh from thread_resources
JAX-->>ShardingModule: Returns mesh (None or empty)
alt Mesh is None or empty (FIX)
ShardingModule-->>TE: Return empty dict {}
Note over ShardingModule: No MeshResource context required
else Mesh exists
ShardingModule->>ShardingModule: global_mesh_resource()
ShardingModule->>ShardingModule: Build logical-to-mesh axis mapping
ShardingModule-->>TE: Return axis mapping dict
end
TE-->>User: Continue execution without error
User->>TE: Call with_sharding_constraint_by_logical_axes()
TE->>ShardingModule: Check logical_axis_names
alt No logical axes provided
ShardingModule-->>TE: Return original tensor
else Logical axes provided
ShardingModule->>ShardingModule: Try Flax logical axis rules
alt Flax rules available
ShardingModule-->>TE: Use Flax constraint
else Fallback to TE rules (deprecated)
ShardingModule->>ShardingModule: _generate_pspec()
ShardingModule->>ShardingModule: get_sharding_map_logic_axis_to_mesh_axis()
Note over ShardingModule: Returns empty dict if no mesh
ShardingModule->>ShardingModule: with_sharding_constraint()
ShardingModule-->>TE: Return constrained tensor
end
end
TE-->>User: Return result
1 file reviewed, no comments
Signed-off-by: Jeremy Berchtold <[email protected]>
|
/te-ci L1 jax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This review covers only the changes made in this single-file PR, which modifies transformer_engine/jax/sharding.py to eliminate the requirement for a MeshResource context when no JAX mesh is active. The function get_sharding_map_logic_axis_to_mesh_axis() now returns an empty dictionary in two new cases: (1) when no physical mesh exists or it's empty, and (2) when all mesh axes are configured as manual axes. This change integrates into TE's JAX sharding infrastructure by short-circuiting the logical-axis-to-mesh-axis mapping logic that would previously fail without a global MeshResource context, making single-GPU scenarios simpler and more user-friendly.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/jax/sharding.py | 3/5 | Adds early returns to skip MeshResource requirement when no mesh or all-manual-axes detected |
Confidence score: 3/5
- This PR improves usability for single-GPU cases but contains potential logic errors in the mesh-axis comparison that could cause incorrect behavior in edge cases
- Score lowered due to: (1) line 84 condition may incorrectly match when only partial overlap exists between abstract_mesh.manual_axes and mesh.axis_names rather than exact equality, and (2) the sorted() comparison doesn't account for duplicate axes which could lead to false positives
- Pay close attention to the conditional logic on lines 84-86; verify that sorted list comparison correctly handles all mesh configurations, especially when manual_axes is a subset or superset of mesh.axis_names
1 file reviewed, 2 comments
| abstract_mesh = get_abstract_mesh() | ||
| if abstract_mesh is None or sorted(abstract_mesh.manual_axes) == sorted(mesh.axis_names): | ||
| # If all mesh axes are manual axes, return an empty dict and do not require a MeshResource context to be present | ||
| return {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: check will fail when abstract_mesh is defined with some axes but not matching mesh axes - this would incorrectly skip the MeshResource check. Should this condition also verify that all mesh axes are covered by manual axes when abstract_mesh exists, or is the current comparison sufficient for your use case?
transformer_engine/jax/sharding.py
Outdated
| return {} | ||
|
|
||
| abstract_mesh = get_abstract_mesh() | ||
| if abstract_mesh is None or sorted(abstract_mesh.manual_axes) == sorted(mesh.axis_names): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: comparison does not account for duplicate axes in either list - sorted() removes ordering but not duplicates
transformer_engine/jax/sharding.py
Outdated
| return {} | ||
|
|
||
| abstract_mesh = get_abstract_mesh() | ||
| if abstract_mesh is None or sorted(abstract_mesh.manual_axes) == sorted(mesh.axis_names): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, why do we need to check for if abstract_mesh is None here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't exactly sure how the abstract mesh relates to the PXLA physical mesh above. If they both are always set via with jax.sharding.Mesh():, then this abstract_mesh is None check is redundant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like the get_abstract_mesh() won't return None.
https://github.com/jax-ml/jax/blob/eace2086ab710c2b95b40e926a44212d73cfd7ae/jax/_src/mesh.py#L625
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @phu0ngng! I've removed the abstract_mesh is None check in the latest version
Signed-off-by: Jeremy Berchtold <[email protected]>
|
/te-ci L1 jax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Adds early-return logic to avoid requiring MeshResource context when no JAX mesh is active or when all mesh axes are manually managed. This improves single-GPU usability by removing unnecessary complexity.
Key changes:
- Lines 78-81: Return empty dict when no mesh exists or mesh is empty
- Lines 83-86: Return empty dict when all mesh axes are in
abstract_mesh.manual_axes - Both paths bypass the
global_mesh_resource()call that would otherwise raise an error
Issues found:
- Line 84: The
sorted()comparison doesn't deduplicate axes, which could cause unexpected behavior if duplicate axis names exist (though JAX meshes should have unique names)
Confidence Score: 3/5
- This PR is moderately safe to merge with one edge case to verify
- The fix correctly addresses the stated problem of unnecessary MeshResource requirements for single-GPU use cases. However, the axis comparison logic on line 84 has a minor issue with duplicate handling that should be verified or addressed
- transformer_engine/jax/sharding.py - verify the duplicate axis handling in the comparison on line 84
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/jax/sharding.py | 3/5 | Added early returns to handle no-mesh and all-manual-axes cases, avoiding unnecessary MeshResource requirement. One potential issue with duplicate axis handling. |
Sequence Diagram
sequenceDiagram
participant User
participant get_sharding_map_logic_axis_to_mesh_axis
participant PXLA_THREAD_RESOURCES
participant get_abstract_mesh
participant global_mesh_resource
User->>get_sharding_map_logic_axis_to_mesh_axis: Call to get sharding map
get_sharding_map_logic_axis_to_mesh_axis->>PXLA_THREAD_RESOURCES: Get physical_mesh
alt No mesh or empty mesh
PXLA_THREAD_RESOURCES-->>get_sharding_map_logic_axis_to_mesh_axis: mesh is None or empty
get_sharding_map_logic_axis_to_mesh_axis-->>User: Return empty dict {}
else Mesh exists
PXLA_THREAD_RESOURCES-->>get_sharding_map_logic_axis_to_mesh_axis: mesh with axes
get_sharding_map_logic_axis_to_mesh_axis->>get_abstract_mesh: Get abstract mesh
get_abstract_mesh-->>get_sharding_map_logic_axis_to_mesh_axis: abstract_mesh with manual_axes
alt All mesh axes are manual
get_sharding_map_logic_axis_to_mesh_axis->>get_sharding_map_logic_axis_to_mesh_axis: Compare manual_axes with mesh.axis_names
get_sharding_map_logic_axis_to_mesh_axis-->>User: Return empty dict {}
else Need MeshResource mapping
get_sharding_map_logic_axis_to_mesh_axis->>global_mesh_resource: Get global MeshResource
global_mesh_resource-->>get_sharding_map_logic_axis_to_mesh_axis: MeshResource config
get_sharding_map_logic_axis_to_mesh_axis->>get_sharding_map_logic_axis_to_mesh_axis: Build logical-to-mesh mapping
get_sharding_map_logic_axis_to_mesh_axis-->>User: Return mapping dict
end
end
1 file reviewed, 1 comment
| return {} | ||
|
|
||
| abstract_mesh = get_abstract_mesh() | ||
| if sorted(abstract_mesh.manual_axes) == sorted(mesh.axis_names): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: comparison doesn't handle duplicate axes correctly - sorted(['a', 'a', 'b']) == sorted(['a', 'b']) is False, but sorted(set(['a', 'a', 'b'])) == sorted(set(['a', 'b'])) would be True. Consider whether duplicates in axis names are valid in your use case
| if sorted(abstract_mesh.manual_axes) == sorted(mesh.axis_names): | |
| if sorted(set(abstract_mesh.manual_axes)) == sorted(set(mesh.axis_names)): |
* Fix mesh resource requirement when no mesh Signed-off-by: Jeremy Berchtold <[email protected]> * do not require meshresource if all axes are manual axes Signed-off-by: Jeremy Berchtold <[email protected]> * remove abstract_mesh is None check Signed-off-by: Jeremy Berchtold <[email protected]> --------- Signed-off-by: Jeremy Berchtold <[email protected]>
* Fix mesh resource requirement when no mesh Signed-off-by: Jeremy Berchtold <[email protected]> * do not require meshresource if all axes are manual axes Signed-off-by: Jeremy Berchtold <[email protected]> * remove abstract_mesh is None check Signed-off-by: Jeremy Berchtold <[email protected]> --------- Signed-off-by: Jeremy Berchtold <[email protected]>
Description
When not using Flax logical axes, a global MeshResource context was required by TE even when no JAX mesh was active. This PR fixes this case and avoids raising an unnecessary error.
Type of change
Changes
Checklist: