|
17 | 17 |
|
18 | 18 | import math
|
19 | 19 | from typing import Any, Callable, Mapping, Sequence
|
20 |
| - |
| 20 | +from absl import logging |
21 | 21 | from flax import linen as nn
|
22 | 22 | import jax
|
23 | 23 | from jax import numpy as jnp
|
|
31 | 31 | from praxis.layers import base_ops
|
32 | 32 | from praxis.layers import embedding_softmax
|
33 | 33 | from praxis.layers import stochastics
|
34 |
| - |
| 34 | +from transformer_engine.jax.flax.transformer import DotProductAttention as TEDotProductAttention |
35 | 35 |
|
36 | 36 | WeightInit = base_layer.WeightInit
|
37 | 37 | WeightHParams = base_layer.WeightHParams
|
@@ -215,6 +215,7 @@ class MultiQueryDotProductAttention(base_layer.BaseLayer):
|
215 | 215 | scale_query_by_dim_per_head: bool = False
|
216 | 216 | chunked_attn_num_seq_split: int = 1
|
217 | 217 | local_window_size: tuple[int, int] | None = None
|
| 218 | + use_te_dpa: bool = False # Experimental way to use TE flash attention when can't use standard TE |
218 | 219 |
|
219 | 220 | # SPMD partition related params.
|
220 | 221 | #
|
@@ -353,6 +354,20 @@ def project_input_kv(input_dim, dim_per_head):
|
353 | 354 | self.create_child('post', post_proj_p)
|
354 | 355 | self.create_child('qk_einsum', self.qk_einsum_tpl.clone())
|
355 | 356 | self.create_child('pv_einsum', self.pv_einsum_tpl.clone())
|
| 357 | + self.dpa_layer = TEDotProductAttention( |
| 358 | + head_dim=dim_per_head, |
| 359 | + num_attention_heads=self.num_heads, |
| 360 | + num_gqa_groups=self.num_kv_heads, |
| 361 | + attn_mask_type='causal', # 'causal' or 'padding' |
| 362 | + attn_bias_type='no_bias', # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' |
| 363 | + attention_dropout=0., |
| 364 | + dropout_rng_name='aqt', |
| 365 | + dtype=jnp.bfloat16, |
| 366 | + float32_logits=False, |
| 367 | + qkv_layout='BSHD_BSHD_BSHD', # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' |
| 368 | + scale_factor=1.0/math.sqrt(self.num_heads), |
| 369 | + transpose_batch_sequence=False |
| 370 | + ) |
356 | 371 |
|
357 | 372 | def _shard_bnh(self, x: JTensor) -> JTensor:
|
358 | 373 | """Shards tensors of shape [b, n, h].
|
@@ -889,29 +904,36 @@ def _rep_d(x):
|
889 | 904 | else:
|
890 | 905 | key_proj = self._shard_blnh(key_proj)
|
891 | 906 | value_proj = self._shard_blnh(value_proj)
|
892 |
| - b, t, n, h = query_proj.shape |
893 |
| - _, s, nk, _ = key_proj.shape |
894 |
| - assert n % nk == 0 |
895 |
| - v_q = jnp.reshape(query_proj, (b, t, nk, n // nk, h)) |
896 |
| - if relative_bias is not None: |
897 |
| - v_rb = jnp.reshape(relative_bias, (b, nk, n // nk, t, s)) |
898 |
| - else: |
899 |
| - v_rb = None |
900 |
| - with self._context_for_kv_vmap(): |
901 |
| - encoded, atten_probs = jax.vmap( |
902 |
| - self._dot_atten, |
903 |
| - in_axes=(2, 2, 2, None, 1), |
904 |
| - out_axes=(2, 1), |
905 |
| - )( |
906 |
| - v_q, |
907 |
| - key_proj, |
908 |
| - value_proj, |
909 |
| - atten_mask, |
910 |
| - v_rb, |
| 907 | + if self.use_te_dpa: |
| 908 | + logging.warning( |
| 909 | + 'use_te_dpa is set to True, so TE dpa is used as an experimental way to use TE flash attention.' |
911 | 910 | )
|
912 |
| - encoded = self._shard_blnh(jnp.reshape(encoded, (b, t, n, h))) |
913 |
| - if atten_probs is not None: |
914 |
| - atten_probs = jnp.reshape(atten_probs, (b, t, n, s)) |
| 911 | + atten_probs = None |
| 912 | + encoded = self.dpa_layer(query_proj, key_proj, value_proj) |
| 913 | + else: |
| 914 | + b, t, n, h = query_proj.shape |
| 915 | + _, s, nk, _ = key_proj.shape |
| 916 | + assert n % nk == 0 |
| 917 | + v_q = jnp.reshape(query_proj, (b, t, nk, n // nk, h)) |
| 918 | + if relative_bias is not None: |
| 919 | + v_rb = jnp.reshape(relative_bias, (b, nk, n // nk, t, s)) |
| 920 | + else: |
| 921 | + v_rb = None |
| 922 | + with self._context_for_kv_vmap(): |
| 923 | + encoded, atten_probs = jax.vmap( |
| 924 | + self._dot_atten, |
| 925 | + in_axes=(2, 2, 2, None, 1), |
| 926 | + out_axes=(2, 1), |
| 927 | + )( |
| 928 | + v_q, |
| 929 | + key_proj, |
| 930 | + value_proj, |
| 931 | + atten_mask, |
| 932 | + v_rb, |
| 933 | + ) |
| 934 | + encoded = self._shard_blnh(jnp.reshape(encoded, (b, t, n, h))) |
| 935 | + if atten_probs is not None: |
| 936 | + atten_probs = jnp.reshape(atten_probs, (b, t, n, s)) |
915 | 937 |
|
916 | 938 | # Post projection
|
917 | 939 | encoded = self.post(encoded)
|
|
0 commit comments