Skip to content

Commit

Permalink
Fix dataclass init and small bugs in the unit test code.
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Goldfarb <[email protected]>
  • Loading branch information
mgoldfarb-nvidia committed Jan 13, 2025
1 parent f14924c commit ffc8268
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
8 changes: 4 additions & 4 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def generate_collectives_count_ref(
return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0)

@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]])
@pytest.mark.parametrize("data_shape", [[32, 512, 12, 64], [32, 1024, 16, 128]])
@pytest.mark.parametrize(
"attn_bias_type",
[AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS],
Expand All @@ -85,7 +85,7 @@ def test_self_attn(
dropout_prob = 0.0
is_training = True

batch, seqlen, seqlen, num_head, hidden = data_shape
batch, seqlen, num_head, hidden = data_shape

if not is_fused_attn_kernel_available(
dtype,
Expand Down Expand Up @@ -115,7 +115,7 @@ def test_self_attn(
dropout_prob,
dtype,
is_training,
qkv_layout,
QKVLayout.BS3HD,
None,
None,
number_of_devices=device_count,
Expand Down Expand Up @@ -147,7 +147,7 @@ def test_cross_attn(
dropout_prob = 0.0
is_training = True

_, seqlen, num_head, hidden = data_shape
batch, seqlen, num_head, hidden = data_shape

if not is_fused_attn_kernel_available(
dtype,
Expand Down
11 changes: 7 additions & 4 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# See LICENSE for license information.
"""Tests for fused attention"""
from enum import Enum
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import partial
from math import sqrt
from typing import Tuple, Optional
Expand Down Expand Up @@ -314,7 +314,7 @@ class FusedAttnRunner:
number_of_devices: int = 1
mesh_shape: tuple[int, ...] = (1, 1, 1)
mesh_axes: tuple[str, ...] = ("dp", "cp", "tp")
mesh_resource: MeshResource = MeshResource("dp", "cp", "tp")
mesh_resource: MeshResource = field(default_factory=partial(MeshResource, "dp", "cp", "tp"))

# Context parallel aux arguments
cp_strategy: CPStrategy = CPStrategy.DEFAULT
Expand Down Expand Up @@ -377,12 +377,15 @@ def _check_configs(self):
)

def _setup_inputs(self):
self._check_configs()

# Create a mesh for distributed tests
self.devices = np.asarray(jax.devices()[: self.number_of_devices]).reshape(*self.mesh_shape)
self.mesh = Mesh(self.devices, self.mesh_axes)
self.dp_size, self.cp_size, self.tp_size = self.mesh_shape
self.dp_size = self.mesh.shape.get(self.mesh_resource.dp_resource, 1)
self.cp_size = self.mesh.shape.get(self.mesh_resource.cp_resource, 1)
self.tp_size = self.mesh.shape.get(self.mesh_resource.tp_resource, 1)

self._check_configs()
key = jax.random.PRNGKey(0)
q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)

Expand Down

0 comments on commit ffc8268

Please sign in to comment.