4545EMBED = common_types .EMBED
4646Quant = quantizations .AqtQuantization
4747
48+ SELF_ATTN_HEAD = common_types .SELF_ATTN_HEAD
49+ SELF_ATTN_Q_LENGTH = common_types .SELF_ATTN_Q_LENGTH
50+ SELF_ATTN_KV_LENGTH = common_types .SELF_ATTN_KV_LENGTH
51+ CROSS_ATTN_HEAD = common_types .CROSS_ATTN_HEAD
52+ CROSS_ATTN_Q_LENGTH = common_types .CROSS_ATTN_Q_LENGTH
53+ CROSS_ATTN_KV_LENGTH = common_types .CROSS_ATTN_KV_LENGTH
54+
55+
4856
4957def _maybe_aqt_einsum (quant : Quant ):
5058 return jnp .einsum if quant is None else quant .einsum ()
@@ -174,7 +182,6 @@ def _tpu_flash_attention(
174182 flash_block_sizes : BlockSizes ,
175183 dtype : jnp .dtype = jnp .float32 ,
176184 attention_kernel : str = "flash" ,
177- is_self_attention : Optional [bool ] = None ,
178185) -> jax .Array :
179186 """TPU Flash Attention"""
180187
@@ -203,22 +210,8 @@ def _tpu_flash_attention(
203210 query = _reshape_data_for_flash (query , heads )
204211 key = _reshape_data_for_flash (key , heads )
205212 value = _reshape_data_for_flash (value , heads )
206-
207- # Use different sharding strategy for self-attn vs cross-attn
208- if is_self_attention is not None :
209- if is_self_attention :
210- # Self-attention: Context Parallelism (sharding along num_heads)
211- q_axis_names = PartitionSpec ("data" , ("fsdp" , "tensor" ), None , None )
212- kv_axis_names = PartitionSpec ("data" , ("fsdp" , "tensor" ), None , None )
213- else :
214- # Cross-attention: Sequence Parallelism for Q
215- # Q's sequence is sharded; K/V are replicated
216- q_axis_names = PartitionSpec ("data" , None , ("fsdp" , "tensor" ), None )
217- kv_axis_names = PartitionSpec ("data" , None , None , None )
218- else :
219- # Fallback to original maxdiffusion behavior if the flag isn't provided
220- q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
221- kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
213+ q_axis_names = nn .logical_to_mesh_axes (axis_names_q )
214+ kv_axis_names = nn .logical_to_mesh_axes (axis_names_kv )
222215
223216 @functools .partial (
224217 shard_map .shard_map ,
@@ -435,7 +428,6 @@ def _apply_attention(
435428 axis_names_kv : AxisNames ,
436429 flash_block_sizes : BlockSizes ,
437430 dpa_layer : Callable ,
438- is_self_attention : bool = True ,
439431):
440432 """Routes to different attention kernels."""
441433 _check_attention_inputs (query , key , value )
@@ -456,7 +448,7 @@ def _apply_attention(
456448 )
457449 elif attention_kernel == "flash" :
458450 return _tpu_flash_attention (
459- query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype , attention_kernel , is_self_attention ,
451+ query , key * scale , value , heads , mesh , axis_names_q , axis_names_kv , flash_block_sizes , dtype , attention_kernel ,
460452 )
461453 elif attention_kernel == "ring" :
462454 return _tpu_flash_attention (
@@ -591,7 +583,6 @@ def __init__(
591583 flash_block_sizes : BlockSizes = None ,
592584 dtype : DType = jnp .float32 ,
593585 quant : Quant = None ,
594- is_self_attention : bool = True ,
595586 ):
596587 self .dpa_layer = None
597588 if attention_kernel == "cudnn_flash_te" :
@@ -611,7 +602,6 @@ def __init__(
611602 self .flash_block_sizes = flash_block_sizes
612603 self .dtype = dtype
613604 self .quant = quant
614- self .is_self_attention = is_self_attention
615605
616606 def apply_attention (self , query : Array , key : Array , value : Array ):
617607 return _apply_attention (
@@ -632,7 +622,6 @@ def apply_attention(self, query: Array, key: Array, value: Array):
632622 axis_names_kv = self .axis_names_kv ,
633623 flash_block_sizes = self .flash_block_sizes ,
634624 dpa_layer = self .dpa_layer ,
635- is_self_attention = self .is_self_attention ,
636625 )
637626
638627
@@ -738,6 +727,13 @@ def __init__(
738727 self .value_axis_names = value_axis_names
739728 self .out_axis_names = out_axis_names
740729
730+ if is_self_attention :
731+ axis_names_q = (BATCH , SELF_ATTN_HEAD , SELF_ATTN_Q_LENGTH , D_KV )
732+ axis_names_kv = (BATCH , SELF_ATTN_HEAD , SELF_ATTN_KV_LENGTH , D_KV )
733+ else :
734+ axis_names_q = (BATCH , CROSS_ATTN_HEAD , CROSS_ATTN_Q_LENGTH , D_KV )
735+ axis_names_kv = (BATCH , CROSS_ATTN_HEAD , CROSS_ATTN_KV_LENGTH , D_KV )
736+
741737 self .attention_op = NNXAttentionOp (
742738 mesh = mesh ,
743739 attention_kernel = attention_kernel ,
@@ -747,11 +743,12 @@ def __init__(
747743 use_memory_efficient_attention = use_memory_efficient_attention ,
748744 split_head_dim = split_head_dim ,
749745 float32_qk_product = False ,
746+ axis_names_q = axis_names_q ,
747+ axis_names_kv = axis_names_kv ,
750748 flash_min_seq_length = flash_min_seq_length ,
751749 flash_block_sizes = flash_block_sizes ,
752750 dtype = dtype ,
753751 quant = quant ,
754- is_self_attention = is_self_attention ,
755752 )
756753 # None axes corresponds to the stacked weights across all blocks
757754 # because of the use of nnx.vmap and nnx.scan.
0 commit comments