Skip to content

Commit 320d282

Browse files
committed
change sharding based on cross/self attention.
1 parent 8202837 commit 320d282

File tree

2 files changed

+28
-5
lines changed

2 files changed

+28
-5
lines changed

src/maxdiffusion/models/attention_flax.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def _tpu_flash_attention(
174174
flash_block_sizes: BlockSizes,
175175
dtype: jnp.dtype = jnp.float32,
176176
attention_kernel: str = "flash",
177+
is_self_attention: Optional[bool] = None,
177178
) -> jax.Array:
178179
"""TPU Flash Attention"""
179180

@@ -202,8 +203,22 @@ def _tpu_flash_attention(
202203
query = _reshape_data_for_flash(query, heads)
203204
key = _reshape_data_for_flash(key, heads)
204205
value = _reshape_data_for_flash(value, heads)
205-
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
206-
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
206+
207+
# Use different sharding strategy for self-attn vs cross-attn
208+
if is_self_attention is not None:
209+
if is_self_attention:
210+
# Self-attention: Context Parallelism (sharding along num_heads)
211+
q_axis_names = PartitionSpec("data", ("fsdp", "tensor"), None, None)
212+
kv_axis_names = PartitionSpec("data", ("fsdp", "tensor"), None, None)
213+
else:
214+
# Cross-attention: Sequence Parallelism for Q
215+
# Q's sequence is sharded; K/V are replicated
216+
q_axis_names = PartitionSpec("data", None, ("fsdp", "tensor"), None)
217+
kv_axis_names = PartitionSpec("data", None, None, None)
218+
else:
219+
# Fallback to original maxdiffusion behavior if the flag isn't provided
220+
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
221+
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
207222

208223
@functools.partial(
209224
shard_map.shard_map,
@@ -420,6 +435,7 @@ def _apply_attention(
420435
axis_names_kv: AxisNames,
421436
flash_block_sizes: BlockSizes,
422437
dpa_layer: Callable,
438+
is_self_attention: bool = True,
423439
):
424440
"""Routes to different attention kernels."""
425441
_check_attention_inputs(query, key, value)
@@ -440,7 +456,7 @@ def _apply_attention(
440456
)
441457
elif attention_kernel == "flash":
442458
return _tpu_flash_attention(
443-
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype
459+
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel, is_self_attention,
444460
)
445461
elif attention_kernel == "ring":
446462
return _tpu_flash_attention(
@@ -575,6 +591,7 @@ def __init__(
575591
flash_block_sizes: BlockSizes = None,
576592
dtype: DType = jnp.float32,
577593
quant: Quant = None,
594+
is_self_attention: bool = True,
578595
):
579596
self.dpa_layer = None
580597
if attention_kernel == "cudnn_flash_te":
@@ -594,6 +611,7 @@ def __init__(
594611
self.flash_block_sizes = flash_block_sizes
595612
self.dtype = dtype
596613
self.quant = quant
614+
self.is_self_attention = is_self_attention
597615

598616
def apply_attention(self, query: Array, key: Array, value: Array):
599617
return _apply_attention(
@@ -614,6 +632,7 @@ def apply_attention(self, query: Array, key: Array, value: Array):
614632
axis_names_kv=self.axis_names_kv,
615633
flash_block_sizes=self.flash_block_sizes,
616634
dpa_layer=self.dpa_layer,
635+
is_self_attention=self.is_self_attention,
617636
)
618637

619638

@@ -702,6 +721,7 @@ def __init__(
702721
precision: jax.lax.Precision = None,
703722
qkv_bias: bool = False,
704723
quant: Quant = None,
724+
is_self_attention: bool = True,
705725
):
706726
if attention_kernel == "cudnn_flash_te":
707727
raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}")
@@ -731,6 +751,7 @@ def __init__(
731751
flash_block_sizes=flash_block_sizes,
732752
dtype=dtype,
733753
quant=quant,
754+
is_self_attention=is_self_attention,
734755
)
735756
# None axes corresponds to the stacked weights across all blocks
736757
# because of the use of nnx.vmap and nnx.scan.
@@ -1524,4 +1545,4 @@ def setup(self):
15241545
def __call__(self, hidden_states, deterministic=True):
15251546
hidden_states = self.proj(hidden_states)
15261547
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
1527-
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
1548+
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ def __init__(
282282
precision=precision,
283283
attention_kernel=attention,
284284
dropout=dropout,
285+
is_self_attention=True,
285286
)
286287

287288
# 1. Cross-attention
@@ -300,6 +301,7 @@ def __init__(
300301
precision=precision,
301302
attention_kernel=attention,
302303
dropout=dropout,
304+
is_self_attention=False,
303305
)
304306
assert cross_attn_norm is True
305307
self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True)
@@ -351,7 +353,7 @@ def __call__(
351353
# 2. Cross-attention
352354
norm_hidden_states = self.norm2(hidden_states)
353355
attn_output = self.attn2(
354-
hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs
356+
hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs,
355357
)
356358
hidden_states = hidden_states + attn_output
357359

0 commit comments

Comments
 (0)