Skip to content

Commit 2e23ad7

Browse files
authored
[JAX] Add Shardy warning in GEMM custom call (#2101)
* added shardy warning Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]>
1 parent 78e097f commit 2e23ad7

File tree

1 file changed

+7
-0
lines changed
  • transformer_engine/jax/cpp_extensions

1 file changed

+7
-0
lines changed

transformer_engine/jax/cpp_extensions/gemm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from collections.abc import Iterable
99
from typing import Tuple, Sequence, Union
1010
from functools import partial, reduce
11+
import warnings
1112

1213
import jax
1314
import jax.numpy as jnp
@@ -658,6 +659,12 @@ def shardy_sharding_rule(
658659

659660
prefix = "GemmPrimitive_"
660661

662+
warnings.warn(
663+
"Known issues with TE GemmPrimitives when Shardy propagation is enabled. For now,"
664+
" please turn off Shardy by exporting the environment variable"
665+
" 'JAX_USE_SHARDY_PARTITIONER=0' if you experience any problems."
666+
)
667+
661668
def _generate_operand_rules(name, ndim, cdims):
662669
specs = []
663670
ldims = tuple(i for i in range(ndim) if i not in cdims)

0 commit comments

Comments
 (0)