We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 78e097f commit 2e23ad7Copy full SHA for 2e23ad7
transformer_engine/jax/cpp_extensions/gemm.py
@@ -8,6 +8,7 @@
8
from collections.abc import Iterable
9
from typing import Tuple, Sequence, Union
10
from functools import partial, reduce
11
+import warnings
12
13
import jax
14
import jax.numpy as jnp
@@ -658,6 +659,12 @@ def shardy_sharding_rule(
658
659
660
prefix = "GemmPrimitive_"
661
662
+ warnings.warn(
663
+ "Known issues with TE GemmPrimitives when Shardy propagation is enabled. For now,"
664
+ " please turn off Shardy by exporting the environment variable"
665
+ " 'JAX_USE_SHARDY_PARTITIONER=0' if you experience any problems."
666
+ )
667
+
668
def _generate_operand_rules(name, ndim, cdims):
669
specs = []
670
ldims = tuple(i for i in range(ndim) if i not in cdims)
0 commit comments