Skip to content

Commit cf2ce07

Browse files
committed
Use TE dpa for grok mqa
1 parent 67637df commit cf2ce07

File tree

2 files changed

+50
-24
lines changed

2 files changed

+50
-24
lines changed

praxis/layers/grok.py

+4
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def GrokStackedTransformerHParams(
5959
combine_qkv=False,
6060
bidirectional=False,
6161
use_fp8=False,
62+
use_te_dpa=False,
6263
) -> pax_fiddle.Config[transformers.StackedTransformer]:
6364
"""Common setup for Grok-1 Transformer layers.
6465
@@ -168,6 +169,7 @@ def GrokStackedTransformerHParams(
168169
p.transformer_layer_params_tpl.tr_atten_tpl = pax_fiddle.Config(
169170
multi_query_attention.MultiQueryDotProductAttention,
170171
num_kv_heads=attention_num_groups,
172+
use_te_dpa=use_te_dpa,
171173
)
172174
tr_atten_tpl = p.transformer_layer_params_tpl.tr_atten_tpl
173175
tr_atten_tpl.combine_qkv = False
@@ -225,6 +227,7 @@ def GrokUniTransformerLmHParams(
225227
model_type=LanguageModelType.CAUSAL,
226228
checkpoint_policy=AutodiffCheckpointType.SAVE_NOTHING,
227229
use_fp8=False,
230+
use_te_dpa=False,
228231
) -> pax_fiddle.Config[transformer_models.TransformerLm]:
229232
"""Common setup for Grok-1 Decoder-only Transformer Model.
230233
@@ -328,6 +331,7 @@ def GrokUniTransformerLmHParams(
328331
bidirectional=bidirectional,
329332
moe_gating_embedding_level=moe_gating_embedding_level,
330333
use_fp8=use_fp8,
334+
use_te_dpa=use_te_dpa,
331335
)
332336
num_blocks = num_transformer_layers
333337

praxis/layers/multi_query_attention.py

+46-24
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import math
1919
from typing import Any, Callable, Mapping, Sequence
20-
20+
from absl import logging
2121
from flax import linen as nn
2222
import jax
2323
from jax import numpy as jnp
@@ -31,7 +31,7 @@
3131
from praxis.layers import base_ops
3232
from praxis.layers import embedding_softmax
3333
from praxis.layers import stochastics
34-
34+
from transformer_engine.jax.flax.transformer import DotProductAttention as TEDotProductAttention
3535

3636
WeightInit = base_layer.WeightInit
3737
WeightHParams = base_layer.WeightHParams
@@ -215,6 +215,7 @@ class MultiQueryDotProductAttention(base_layer.BaseLayer):
215215
scale_query_by_dim_per_head: bool = False
216216
chunked_attn_num_seq_split: int = 1
217217
local_window_size: tuple[int, int] | None = None
218+
use_te_dpa: bool = False # Experimental way to use TE flash attention when can't use standard TE
218219

219220
# SPMD partition related params.
220221
#
@@ -353,6 +354,20 @@ def project_input_kv(input_dim, dim_per_head):
353354
self.create_child('post', post_proj_p)
354355
self.create_child('qk_einsum', self.qk_einsum_tpl.clone())
355356
self.create_child('pv_einsum', self.pv_einsum_tpl.clone())
357+
self.dpa_layer = TEDotProductAttention(
358+
head_dim=dim_per_head,
359+
num_attention_heads=self.num_heads,
360+
num_gqa_groups=self.num_kv_heads,
361+
attn_mask_type='causal', # 'causal' or 'padding'
362+
attn_bias_type='no_bias', # 'no_bias', 'pre_scale_bias' or 'post_scale_bias'
363+
attention_dropout=0.,
364+
dropout_rng_name='aqt',
365+
dtype=jnp.bfloat16,
366+
float32_logits=False,
367+
qkv_layout='BSHD_BSHD_BSHD', # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD'
368+
scale_factor=1.0/math.sqrt(self.num_heads),
369+
transpose_batch_sequence=False
370+
)
356371

357372
def _shard_bnh(self, x: JTensor) -> JTensor:
358373
"""Shards tensors of shape [b, n, h].
@@ -889,29 +904,36 @@ def _rep_d(x):
889904
else:
890905
key_proj = self._shard_blnh(key_proj)
891906
value_proj = self._shard_blnh(value_proj)
892-
b, t, n, h = query_proj.shape
893-
_, s, nk, _ = key_proj.shape
894-
assert n % nk == 0
895-
v_q = jnp.reshape(query_proj, (b, t, nk, n // nk, h))
896-
if relative_bias is not None:
897-
v_rb = jnp.reshape(relative_bias, (b, nk, n // nk, t, s))
898-
else:
899-
v_rb = None
900-
with self._context_for_kv_vmap():
901-
encoded, atten_probs = jax.vmap(
902-
self._dot_atten,
903-
in_axes=(2, 2, 2, None, 1),
904-
out_axes=(2, 1),
905-
)(
906-
v_q,
907-
key_proj,
908-
value_proj,
909-
atten_mask,
910-
v_rb,
907+
if self.use_te_dpa:
908+
logging.warning(
909+
'use_te_dpa is set to True, so TE dpa is used as an experimental way to use TE flash attention.'
911910
)
912-
encoded = self._shard_blnh(jnp.reshape(encoded, (b, t, n, h)))
913-
if atten_probs is not None:
914-
atten_probs = jnp.reshape(atten_probs, (b, t, n, s))
911+
atten_probs = None
912+
encoded = self.dpa_layer(query_proj, key_proj, value_proj)
913+
else:
914+
b, t, n, h = query_proj.shape
915+
_, s, nk, _ = key_proj.shape
916+
assert n % nk == 0
917+
v_q = jnp.reshape(query_proj, (b, t, nk, n // nk, h))
918+
if relative_bias is not None:
919+
v_rb = jnp.reshape(relative_bias, (b, nk, n // nk, t, s))
920+
else:
921+
v_rb = None
922+
with self._context_for_kv_vmap():
923+
encoded, atten_probs = jax.vmap(
924+
self._dot_atten,
925+
in_axes=(2, 2, 2, None, 1),
926+
out_axes=(2, 1),
927+
)(
928+
v_q,
929+
key_proj,
930+
value_proj,
931+
atten_mask,
932+
v_rb,
933+
)
934+
encoded = self._shard_blnh(jnp.reshape(encoded, (b, t, n, h)))
935+
if atten_probs is not None:
936+
atten_probs = jnp.reshape(atten_probs, (b, t, n, s))
915937

916938
# Post projection
917939
encoded = self.post(encoded)

0 commit comments

Comments
 (0)