Skip to content

Commit

Permalink
[JAX] Fixes for CI failures with the latest JAX (#1469)
Browse files Browse the repository at this point in the history
* fixes L1 test

* fix test_multigpu_encoder

* fixes for other multi-encoder tests

* jax.extend.ffi to jax.ffi

* initialization with float32

* add init_dtype as an optional arg to all modules

* update use_scan query from xla flags

* relax threshold for test_encoder fp8

* relax the tols

---------

Signed-off-by: Phuong Nguyen <[email protected]>
  • Loading branch information
phu0ngng authored Feb 14, 2025
1 parent 24e4f95 commit e19b828
Show file tree
Hide file tree
Showing 14 changed files with 37 additions and 35 deletions.
6 changes: 3 additions & 3 deletions examples/jax/encoder/test_model_parallel_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def to_device_axis(logical_axis):
)
params_axes_sharding = flax.core.unfreeze(params_axes_sharding)
params_sharding = jax.tree_util.tree_map(
lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY]
lambda x: NamedSharding(mesh, PartitionSpec(None)), abs_var_collect[PARAMS_KEY]
)
params_sharding = {**params_sharding, **params_axes_sharding}
return params_sharding
Expand Down Expand Up @@ -447,7 +447,7 @@ def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
self.args.use_fp8 = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
assert actual[0] < 0.455 and actual[1] > 0.785

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_sp(self):
Expand All @@ -462,7 +462,7 @@ def test_te_fp8_sp(self):
self.args.enable_sp = True
self.args.use_fp8 = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
assert actual[0] < 0.455 and actual[1] > 0.785


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/jax/encoder/test_multigpu_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def to_device_axis(logical_axis):
)
params_axes_sharding = flax.core.unfreeze(params_axes_sharding)
params_sharding = jax.tree_util.tree_map(
lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY]
lambda x: NamedSharding(mesh, PartitionSpec(None)), abs_var_collect[PARAMS_KEY]
)
params_sharding = {**params_sharding, **params_axes_sharding}
return params_sharding
Expand Down
4 changes: 2 additions & 2 deletions examples/jax/encoder/test_multiprocessing_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def to_device_axis(logical_axis):
)
params_axes_sharding = flax.core.unfreeze(params_axes_sharding)
params_sharding = jax.tree_util.tree_map(
lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY]
lambda x: NamedSharding(mesh, PartitionSpec(None)), abs_var_collect[PARAMS_KEY]
)
params_sharding = {**params_sharding, **params_axes_sharding}
return params_sharding
Expand Down Expand Up @@ -587,7 +587,7 @@ def test_te_bf16(self):
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
result = self.exec(True)
assert result[0] < 0.45 and result[1] > 0.79
assert result[0] < 0.455 and result[1] > 0.79


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/jax/encoder/test_single_gpu_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
self.args.use_fp8 = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
assert actual[0] < 0.455 and actual[1] > 0.79


if __name__ == "__main__":
Expand Down
8 changes: 1 addition & 7 deletions qa/L1_jax_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,4 @@ set -xe

: ${TE_PATH:=/opt/transformerengine}

# Skip ring attention tests since they need fixed environment vars
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_* -k 'not test_context_parallel_ring_attn'

# Test ring attention with and without scan loop
NVTE_FUSED_RING_ATTENTION_USE_SCAN=0 pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_fused_attn.py -k test_context_parallel_ring_attn
NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 XLA_FLAGS="--xla_experimental_ignore_channel_id" \
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_fused_attn.py -k test_context_parallel_ring_attn
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_*
20 changes: 15 additions & 5 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# See LICENSE for license information.

import os
import pytest
import jax
import jax.numpy as jnp
import numpy as np
Expand All @@ -11,7 +13,7 @@
generate_context_parallel_configs,
generate_collectives_count,
)
from transformer_engine.jax import fp8_autocast
from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat
from transformer_engine.jax.attention import (
is_fused_attn_kernel_available,
AttnBiasType,
Expand All @@ -22,10 +24,7 @@
inverse_reorder_causal_load_balancing,
CPStrategy,
)
from transformer_engine.jax.sharding import MeshResource
import pytest

from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat

DTYPES = [jnp.bfloat16]

Expand Down Expand Up @@ -355,6 +354,10 @@ def test_context_parallel_allgather_attn(
CPStrategy.ALL_GATHER,
)

@pytest.mark.parametrize(
"use_scan",
[pytest.param(False, id="NO_SCAN"), pytest.param(True, id="USE_SCAN")],
)
def test_context_parallel_ring_attn(
self,
device_count,
Expand All @@ -367,8 +370,14 @@ def test_context_parallel_ring_attn(
dtype,
qkv_layout,
load_balanced,
use_scan,
):
return self.impl_test_context_parallel_attn(
if use_scan:
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "1"
else:
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0"

self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
Expand All @@ -381,6 +390,7 @@ def test_context_parallel_ring_attn(
load_balanced,
CPStrategy.RING,
)
del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"]


class TestReorderCausalLoadBalancing:
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/jax/cpp_extensions/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from jax import ffi

from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import NVTE_Activation_Type
Expand Down
6 changes: 2 additions & 4 deletions transformer_engine/jax/cpp_extensions/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from jax import ffi

from transformer_engine.jax.attention import CPStrategy, SequenceDescriptor

Expand Down Expand Up @@ -1602,9 +1602,7 @@ def use_scanloop():
def truthy(val):
return val.lower() in ["1", "true"]

x = use_scan and get_xla_flag(
"--xla_experimental_ignore_channel_id", default=False, cast=truthy
)
x = use_scan and get_xla_flag("--xla_ignore_channel_id", default=True, cast=truthy)
return x

def check_supported(self):
Expand Down
7 changes: 3 additions & 4 deletions transformer_engine/jax/cpp_extensions/custom_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
from dataclasses import dataclass
from enum import IntEnum

import jax
from jax.interpreters import mlir
import jax.extend as jex

from transformer_engine import transformer_engine_jax

from .misc import is_ffi_enabled
Expand All @@ -30,11 +29,11 @@ class CustomCallAPIVersion(IntEnum):
for _name, _value in transformer_engine_jax.registrations().items():
if _name.endswith("_ffi"):
if is_ffi_enabled():
jex.ffi.register_ffi_target(
jax.ffi.register_ffi_target(
_name, _value, platform="CUDA", api_version=CustomCallAPIVersion.FFI.value
)
else:
jex.ffi.register_ffi_target(
jax.ffi.register_ffi_target(
_name, _value, platform="CUDA", api_version=CustomCallAPIVersion.OPAQUE.value
)

Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/jax/cpp_extensions/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from jax import ffi

from transformer_engine import transformer_engine_jax

Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/jax/cpp_extensions/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from jax import ffi

from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import DType as TEDType
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/jax/cpp_extensions/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from jax import ffi

from transformer_engine import transformer_engine_jax

Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/jax/cpp_extensions/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from jax import ffi

from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import DType as TEDType
Expand Down
7 changes: 4 additions & 3 deletions transformer_engine/jax/flax/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ def __call__(
del self.scale_factor

if self.float32_logits:
query = query.astype(jnp.float32)
key = key.astype(jnp.float32)
query = query.astype(self.dtype)
key = key.astype(self.dtype)
h_q, h_kv = query.shape[-2], key.shape[-2]
# The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv.
# Therefore, we have to maintain two code paths.
Expand Down Expand Up @@ -989,6 +989,7 @@ def __post_init__(self):
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "normal", self.weight_dtype
)
self.kernel_init = _kernel_init.astype(self.dtype)
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
Expand Down Expand Up @@ -1281,7 +1282,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq):
f"expected query shape {expected_shape} instead got {query.shape}."
)

cur_index = cache_index.value
cur_index = cache_index.value.astype(jnp.int32)
one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
one_hot_indices = jnp.reshape(one_hot_indices, one_hot_indices_shape)
key = cached_key.value + key * one_hot_indices
Expand Down

0 comments on commit e19b828

Please sign in to comment.