@@ -174,6 +174,7 @@ def _tpu_flash_attention(
174174 flash_block_sizes : BlockSizes ,
175175 dtype : jnp .dtype = jnp .float32 ,
176176 attention_kernel : str = "flash" ,
177+ is_self_attention : Optional [bool ] = None ,
177178) -> jax .Array :
178179 """TPU Flash Attention"""
179180
@@ -202,8 +203,22 @@ def _tpu_flash_attention(
202203 query = _reshape_data_for_flash (query , heads )
203204 key = _reshape_data_for_flash (key , heads )
204205 value = _reshape_data_for_flash (value , heads )
205- q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
206- kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
206+
207+ # Use different sharding strategy for self-attn vs cross-attn
208+ if is_self_attention is not None :
209+ if is_self_attention :
210+ # Self-attention: Context Parallelism (sharding along num_heads)
211+ q_axis_names = PartitionSpec ("data" , ("fsdp" , "tensor" ), None , None )
212+ kv_axis_names = PartitionSpec ("data" , ("fsdp" , "tensor" ), None , None )
213+ else :
214+ # Cross-attention: Sequence Parallelism for Q
215+ # Q's sequence is sharded; K/V are replicated
216+ q_axis_names = PartitionSpec ("data" , None , ("fsdp" , "tensor" ), None )
217+ kv_axis_names = PartitionSpec ("data" , None , None , None )
218+ else :
219+ # Fallback to original maxdiffusion behavior if the flag isn't provided
220+ q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
221+ kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
207222
208223 @functools .partial (
209224 shard_map .shard_map ,
@@ -420,6 +435,7 @@ def _apply_attention(
420435 axis_names_kv : AxisNames ,
421436 flash_block_sizes : BlockSizes ,
422437 dpa_layer : Callable ,
438+ is_self_attention : bool = True ,
423439):
424440 """Routes to different attention kernels."""
425441 _check_attention_inputs (query , key , value )
@@ -440,7 +456,7 @@ def _apply_attention(
440456 )
441457 elif attention_kernel == "flash" :
442458 return _tpu_flash_attention (
443- query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype
459+ query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype , attention_kernel , is_self_attention ,
444460 )
445461 elif attention_kernel == "ring" :
446462 return _tpu_flash_attention (
@@ -575,6 +591,7 @@ def __init__(
575591 flash_block_sizes : BlockSizes = None ,
576592 dtype : DType = jnp .float32 ,
577593 quant : Quant = None ,
594+ is_self_attention : bool = True ,
578595 ):
579596 self .dpa_layer = None
580597 if attention_kernel == "cudnn_flash_te" :
@@ -594,6 +611,7 @@ def __init__(
594611 self .flash_block_sizes = flash_block_sizes
595612 self .dtype = dtype
596613 self .quant = quant
614+ self .is_self_attention = is_self_attention
597615
598616 def apply_attention (self , query : Array , key : Array , value : Array ):
599617 return _apply_attention (
@@ -614,6 +632,7 @@ def apply_attention(self, query: Array, key: Array, value: Array):
614632 axis_names_kv = self .axis_names_kv ,
615633 flash_block_sizes = self .flash_block_sizes ,
616634 dpa_layer = self .dpa_layer ,
635+ is_self_attention = self .is_self_attention ,
617636 )
618637
619638
@@ -702,6 +721,7 @@ def __init__(
702721 precision : jax .lax .Precision = None ,
703722 qkv_bias : bool = False ,
704723 quant : Quant = None ,
724+ is_self_attention : bool = True ,
705725 ):
706726 if attention_kernel == "cudnn_flash_te" :
707727 raise NotImplementedError (f"Wan 2.1 has not been tested with { attention_kernel } " )
@@ -731,6 +751,7 @@ def __init__(
731751 flash_block_sizes = flash_block_sizes ,
732752 dtype = dtype ,
733753 quant = quant ,
754+ is_self_attention = is_self_attention ,
734755 )
735756 # None axes corresponds to the stacked weights across all blocks
736757 # because of the use of nnx.vmap and nnx.scan.
@@ -1524,4 +1545,4 @@ def setup(self):
15241545 def __call__ (self , hidden_states , deterministic = True ):
15251546 hidden_states = self .proj (hidden_states )
15261547 hidden_linear , hidden_gelu = jnp .split (hidden_states , 2 , axis = 2 )
1527- return self .dropout_layer (hidden_linear * nn .gelu (hidden_gelu ), deterministic = deterministic )
1548+ return self .dropout_layer (hidden_linear * nn .gelu (hidden_gelu ), deterministic = deterministic )
0 commit comments