Skip to content

Commit b57887e

Browse files
jberchtold-nvidiapggPL
authored andcommitted
[JAX] Fix mesh resource requirement when no mesh (NVIDIA#2307)
* Fix mesh resource requirement when no mesh Signed-off-by: Jeremy Berchtold <[email protected]> * do not require meshresource if all axes are manual axes Signed-off-by: Jeremy Berchtold <[email protected]> * remove abstract_mesh is None check Signed-off-by: Jeremy Berchtold <[email protected]> --------- Signed-off-by: Jeremy Berchtold <[email protected]>
1 parent b3cef3a commit b57887e

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

transformer_engine/jax/sharding.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,16 @@ def get_sharding_map_logic_axis_to_mesh_axis():
7575
"""
7676
Generate a dict to map logical axes to mesh axes.
7777
"""
78+
mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh
79+
if mesh is None or mesh.empty:
80+
# If no mesh is defined, return an empty dict and do not require a MeshResource context to be present
81+
return {}
82+
83+
abstract_mesh = get_abstract_mesh()
84+
if sorted(abstract_mesh.manual_axes) == sorted(mesh.axis_names):
85+
# If all mesh axes are manual axes, return an empty dict and do not require a MeshResource context to be present
86+
return {}
87+
7888
gsr = global_mesh_resource()
7989

8090
is_tpsp_enabled = gsr.tpsp_resource is not None and get_mesh_axis_size(gsr.tpsp_resource) > 1

0 commit comments

Comments
 (0)