From 4901e004523465826b07255f4f13278ee4092b6c Mon Sep 17 00:00:00 2001 From: Abhinav Goel Date: Fri, 27 Oct 2023 15:52:58 -0700 Subject: [PATCH] Added draft for grouped query attention --- praxis/layers/multi_query_attention.py | 527 +++++++++++++++++++++++++ 1 file changed, 527 insertions(+) diff --git a/praxis/layers/multi_query_attention.py b/praxis/layers/multi_query_attention.py index 409c2e58..cac29113 100644 --- a/praxis/layers/multi_query_attention.py +++ b/praxis/layers/multi_query_attention.py @@ -1057,6 +1057,533 @@ def lazy_broadcast_prefix(self, num_suffix_samples: int, 'MultiQueryDotProductAttentionLPB instead.') +class GroupedQueryDotProductAttention(base_layer.BaseLayer): + """Dot-product attention sharing groups of keys and values across heads. + + This implementation heavily uses einsum to be efficient on TPUs. We use the + following capital letters to denote certain JTensor parameters. + + B = batch size + G = number of groups + S = length of the key/value (source) + T = length of the query (target) + D = model dimension + N = number of query attention heads + H = dimensions of each attention head. + + The algorithm is sketched as follows. Each intermediate JTensor or weight + JTensor is annotated with its shape. E.g., Wq, the weight JTensor for query's + projection, its shape is [D, N, H]. + + Trainable weights: + Wq: [D, N, H] + Wk, Wv: [D, G, H] + Wout: [D, N, H] + + Note it also allows k, v and q to have different input dimension by setting + input_dim as a dict: {'key': key_dim, 'value': value_dim, 'query': query_dim}. + + Input q:[B, T, D]; k:[B, S, D]; v:[B, S, D] + q_proj:[B, T, N, H] = einsum('BTD,DNH->BTNH', x, Wq) + k_proj:[B, S, G, H] = einsum('BSD,DGH->BSGH', x, Wk) + v_proj:[B, S, G, H] = einsum('BSD,DGH->BSGH', x, Wv) + logits:[B, G, T, S] = einsum('BTNH,BSGH->BGTS', q_proj, k_proj) / sqrt(H) + probs:[B, G, T, S] = softmax(logits) + context:[B, T, G, H] = einsum('BGTS,BSGH->BTGH', probs, v_proj) + Output y:[B, T, D] = einsum('BTGH,DNH>BTD', context, Wout) + + Attributes: + input_dim: An integer or a dict of integer values as number of input nodes. + If input_dim is a dict, keys must be key, value and query. + hidden_dim: Number of hidden nodes. + num_heads: Number of query attention heads. + num_groups: Number of groups (number of kv attention heads). + num_kv_heads: Number of kv heads. num_heads % num_kv_heads = 0. + dim_per_head: Dimension of each attention head. If None then dim_per_head == + hidden_dim // num_heads. + dropout_tpl: Parameterization for the dropout layer. + atten_dropout_prob: Probability at which we apply dropout to the attention + weights. + proj_tpl: Parameterization for the query projection_tpl layer. + headless_proj_tpl: Parameterization for the key/value projection_tpl layer. + use_bias: Whether to use bias for projection_tpl layers. + output_proj_use_nhd_shape: Whether to use NHD variable shape in output + projection layer. + internal_enable_query_scale: Internal. Enable scaling of query vector. + atten_logit_cap: Cap the absolute values of logits by tanh. Enabled when a + positive value is specified. May not be supported by a subclass. + use_rotary_position_emb: Whether to add rotary position embedding to the + queries and keys before computing self attention scores. This was proposed + in https://arxiv.org/abs/2104.09864. + relative_bias_tpl: Optional parameterization of relative bias. + attention_extra_logit: Extra logit for attention softmax. + combine_qkv: Whether to combine qkv tensor for optimizing qkv input gradient + computation with SPMD. Only supports self-attention. + scale_query_by_dim_per_head: whether to scale the query by dim_per_head, + instead of default hidden_dim // num_heads. + Note: dconv_qkv and ngrammer are not supported. + """ + input_dim: Union[int, Dict[str, int]] = 0 + hidden_dim: int = 0 + num_heads: int = 1 + num_groups: int = 1 + num_kv_heads: int = num_groups + dim_per_head: Optional[int] = None + dropout_tpl: LayerTpl = template_field(stochastics.Dropout) + atten_dropout_prob: float = 0.0 + proj_tpl: LayerTpl = template_field(attentions.AttentionProjection) + headless_proj_tpl: LayerTpl = template_field(OneHeadedAttentionProjection) + internal_gshard_gaussian_init: bool = False + use_bias: bool = True + output_proj_use_nhd_shape: bool = False + internal_enable_query_scale: bool = True + atten_logit_cap: float = 0.0 + use_rotary_position_emb: bool = False + relative_bias_tpl: Optional[LayerTpl] = template_field(None) + attention_extra_logit: Optional[float] = None + dconv_qkv: bool = False + combine_qkv: bool = False + qk_einsum_tpl: LayerTpl = template_field(base_ops.EinsumOp) + pv_einsum_tpl: LayerTpl = template_field(base_ops.EinsumOp) + scale_query_by_dim_per_head: bool = False + + # SPMD partition related params. + # + # d - model_dim + # n - num_heads + # h - attention_dim_per_heads + # b - batch_size + # l - seq_len + # g - num_groups + + class WeightSharding(base_layer.BaseLayer.WeightSharding): + """Represents how layer's learned parameters are partitioned across a mesh. + + Attributes: + proj: How the projection weights should be sharded. All projection matrix + share the same sharding. + dconv: How the dconv weights should be sharded. All dconv weights share + the same sharding. + """ + proj: SplitDimsMapping = None + dconv: SplitDimsMapping = None + proj_headless: SplitDimsMapping = None + + class ActivationSharding(base_layer.BaseLayer.ActivationSharding): + """Represents how intermediate values should be partitioned across a mesh. + + Attributes: + blnh: Mesh split for query, and encoded tensors with the shape of + [batch_size, seq_len, num_heads/num_groups, dim_per_head]. + blh: Mesh split key, value, and encoded tensors with the shape of + [batch_size, seq_len, dim_per_head]. + bld: Mesh split for output after post projection with the shape of + [batch_size, seq_len, model_dim]. + """ + blnh: SplitDimsMapping = None + blh: SplitDimsMapping = None + bld: SplitDimsMapping = None + + def setup(self) -> None: + wp = self.weight_split_dims_mapping + assert self.input_dim, 'input_dim is {}'.format(self.input_dim) + assert self.hidden_dim, 'hidden_dim is {}'.format(self.hidden_dim) + + assert not self.dconv_qkv + assert not self.combine_qkv + + dim_per_head = self.dim_per_head + if dim_per_head is None: + dim_per_head = self.hidden_dim // self.num_heads + assert ( + dim_per_head * self.num_heads == self.hidden_dim + ), f'{dim_per_head} * {self.num_heads} != {self.hidden_dim}' + + if self.mesh_shape is not None: + assert self.weight_split_dims_mapping is not None + assert self.activation_split_dims_mapping is not None + + if isinstance(self.input_dim, dict): + key_input_dim = self.input_dim['key'] + value_input_dim = self.input_dim['value'] + query_input_dim = self.input_dim['query'] + assert key_input_dim, f'key_input_dim is {key_input_dim}' + assert query_input_dim, f'query_input_dim is {query_input_dim}' + else: + key_input_dim = self.input_dim + value_input_dim = self.input_dim + query_input_dim = self.input_dim + + def project_input(input_dim, dim_per_head, num_heads): + proj_p = self.proj_tpl.clone().set( + input_dim=input_dim, + num_heads=num_heads, + dim_per_head=dim_per_head, + use_bias=self.use_bias, + ) + proj_p.weight_split_dims_mapping.wt = wp.proj + return proj_p + + def project_input_kv(input_dim, dim_per_head): + if self.num_kv_heads == 1: + proj_p = self.headless_proj_tpl.clone().set( + input_dim=input_dim, output_dim=dim_per_head, use_bias=self.use_bias + ) + proj_p.weight_split_dims_mapping.wt = wp.proj_headless + return proj_p + else: + assert self.num_heads % self.num_kv_heads == 0 + return project_input(input_dim, dim_per_head, self.num_kv_heads) + + dim_per_head = self.dim_per_head + if dim_per_head is None: + dim_per_head = self.hidden_dim // self.num_heads + assert ( + dim_per_head * self.num_heads == self.hidden_dim + ), f'{dim_per_head} * {self.num_heads} != {self.hidden_dim}' + self.create_child('value', project_input_kv(value_input_dim, dim_per_head)) + + self.create_child('key', project_input_kv(key_input_dim, dim_per_head)) + self.create_child( + 'query', project_input(query_input_dim, dim_per_head, self.num_heads) + ) + + if self.use_rotary_position_emb: + pos_emb_p = pax_fiddle.Config(embedding_softmax.RotaryPositionalEmbedding) + pos_emb_p.embedding_dims = dim_per_head + self.create_child('rotary_position_emb', pos_emb_p) + + if self.relative_bias_tpl is not None: + relative_bias_p = self.relative_bias_tpl.clone() + relative_bias_p.num_heads = self.num_heads + self.create_child('relative_bias', relative_bias_p) + + self.create_child( + 'atten_dropout', + self.dropout_tpl.clone().set(keep_prob=1.0 - self.atten_dropout_prob), + ) + + # Setting is_output_projection=True to set the projection direction + # from hidden dim to input dim. Output projection follows query_input_dim. + post_proj_p = self.proj_tpl.clone().set( + input_dim=query_input_dim, + num_heads=self.num_heads, + dim_per_head=dim_per_head, + is_output_projection=True, + use_bias=self.use_bias, + use_nhd_shape=self.output_proj_use_nhd_shape, + ) + post_proj_p.weight_split_dims_mapping.wt = wp.proj + + self.create_child('post', post_proj_p) + self.create_child('qk_einsum', self.qk_einsum_tpl.clone()) + self.create_child('pv_einsum', self.pv_einsum_tpl.clone()) + + def _shard_bnh(self, x: JTensor) -> JTensor: + """Shards tensors of shape [b, n, h]. + + Single step decoder output are of shape [b, n, h]. + + Args: + x: A tensor of shape [b, n, h] + + Returns: + x with proper sharding annotations. + """ + ap = self.activation_split_dims_mapping + if self.mesh_axis_names is None: + return x + if ap.blnh is None: + return x + assert len(ap.blnh) == 4 + bnh = [ap.blnh[0], ap.blnh[2], ap.blnh[3]] + return base_layer.maybe_shard(x, bnh, self.mesh_axis_names) + + def _shard_blnh(self, x: JTensor) -> JTensor: + """Adds sharding annotations to tensors of shape [b, l, n, h].""" + ap = self.activation_split_dims_mapping + return base_layer.maybe_shard(x, ap.blnh, self.mesh_axis_names) + + def _shard_blh(self, x: JTensor) -> JTensor: + """Adds sharding annotations to tensors of shape [b, l, h].""" + ap = self.activation_split_dims_mapping + shard = None + if ap.blh is not None: + shard = ap.blh + elif ap.blnh is not None: + shard = [ap.blnh[0], ap.blnh[1], ap.blnh[3]] + return base_layer.maybe_shard(x, shard, self.mesh_axis_names) + + def _shard_bld(self, x: JTensor) -> JTensor: + """Adds sharding annotations to tensors of shape [b, l, d].""" + ap = self.activation_split_dims_mapping + return base_layer.maybe_shard(x, ap.bld, self.mesh_axis_names) + + def _shard_bd(self, x: JTensor) -> JTensor: + """Adds sharding annotations to tensors of shape [b, d].""" + ap = self.activation_split_dims_mapping + if self.mesh_axis_names is None: + return x + if ap.bld is None: + return x + assert len(ap.bld) == 3 + bd = [ap.bld[0], ap.bld[2]] + return base_layer.maybe_shard(x, bd, self.mesh_axis_names) + + def _scale_query(self, query: JTensor) -> JTensor: + """Scales the query vector.""" + if self.scale_query_by_dim_per_head and self.dim_per_head is not None: + dim_per_head = self.dim_per_head + else: + dim_per_head = self.hidden_dim // self.num_heads + + query *= dim_per_head**-0.5 + + return query + + def _cap_logits(self, logits: JTensor) -> JTensor: + """When enabled, caps the logits by p.atten_logit_cap with tanh.""" + if not self.atten_logit_cap or self.atten_logit_cap <= 0.0: + return logits + cap = jnp.array(self.atten_logit_cap, dtype=logits.dtype) + # Note that since this caps the negative side as well, caller + # must defer the pad-with-very-negative-logits logic to after + # this function returns. + logits = cap * jnp.tanh(logits / cap) + return logits + + def _log_softmax_with_extra_logit(self, logits: JTensor) -> JTensor: + """Computes log softmax with extra logit. + + self.attention_extra_logit is a user defined float value that + helps to stabilize logit values so that they don't drift too much from it. + + Args: + logits: input logit tensor + + Returns: + Log softmax with extra logit value. + """ + # Applies stop_gradient to max_logit instead of logits. + max_logit = jnp.max(jax.lax.stop_gradient(logits), axis=-1, keepdims=True) + extra_logit = self.attention_extra_logit + if extra_logit is not None: + extra_logit = jnp.asarray(extra_logit, dtype=max_logit.dtype) + max_logit = jnp.maximum(max_logit, extra_logit) + exp_x = jnp.exp(logits - max_logit) + sum_exp_x = jnp.sum(exp_x, axis=-1, keepdims=True) + if extra_logit is not None: + sum_exp_x += jnp.exp(extra_logit - max_logit) + return logits - jnp.log(sum_exp_x) - max_logit + + def _atten_logits(self, query: JTensor, key: JTensor) -> JTensor: + """Compute logits from query and key.""" + query = query.transpose(0, 2, 1, 3) + key = key.transpose(0, 2, 1, 3) + logits = self.qk_einsum('BNTH,BGSH->BGTS', query, key) + return logits + + def _dot_atten( + self, + query: JTensor, + key: JTensor, + value: JTensor, + atten_mask: JTensor, + relative_bias: Optional[JTensor] = None) -> Tuple[JTensor, JTensor]: + """Main attention function. + + Args: + query: JTensor of shape [B, T, N, H]. + key: JTensor of shape [B, S, G, H]. + value: JTensor of shape [B, S, G, H]. + atten_mask: JTensor of shape [1/B, 1, 1/T, S] which is a mask that is + applied to prevent attention between unwanted pairs. This has already + been converted into large negative logits. Note that the first and third + dimension allow size 1 if the mask is shared by every item in the batch + or every token in the target sequence. + relative_bias: Relative bias of shape [B, G, T, S]. + + Returns: + encoded: JTensor of shape [B, T, G, H]. + atten_probs: JTensor of shape [B, G, T, S]. + """ + b, t, n, h = query.shape + _, s, g, _ = key.shape + base_layer.assert_has_shape(key, [b, s, g, h]) + base_layer.assert_has_shape(value, [b, s, g, -1]) + # If only padding bias is supplied, then atten_mask can be [B, 1, 1, S] + # since each target token is prohibited from attending to the same set of + # source tokens. In this case tiling is inefficient and unnecessary. + # If there is no padding mask, and only causal mask then the shape can be + # [1, 1, T, S] + base_layer.assert_has_shape(atten_mask, [-1, 1, -1, s]) + asserts.in_set(atten_mask.shape[2], [1, t]) + asserts.in_set(atten_mask.shape[0], [1, b]) + query = self._scale_query(query) + logits = self._atten_logits(query, key) + if relative_bias is not None: + # The relative_bias has shape [1, n, t, s] or [b, n, t, s]. + base_layer.assert_has_shape(relative_bias, [-1, g, t, s]) + logits += relative_bias + logits = checkpoint_name(logits, 'logits') + logits = self._cap_logits(logits) + # Attention softmax is always carried out in fp32. + logits = logits.astype(jnp.float32) + # Apply attention masking + padded_logits = py_utils.apply_mask_to_logits(logits, atten_mask) + if self.attention_extra_logit is None: + probs = jax.nn.softmax(padded_logits, axis=-1).astype(key.dtype) + else: + probs = jnp.exp(self._log_softmax_with_extra_logit(padded_logits)).astype( + key.dtype) + # Apply attention dropout. + probs = self.atten_dropout(probs) + # Compute the attention context. + encoded = self.pv_einsum('BGTS,BSGH->BGTH', probs, value) + encoded = encoded.transpose(0, 2, 1, 3) + encoded = checkpoint_name(encoded, 'context') + encoded = self._shard_blnh(encoded) + return encoded, probs + + def _context_for_kv_vmap(self): + # Transpose the sharding on num_heads to None, so that the inner n // nk dim + # is not sharded. The out nk dim is sharded on the vmap dim, which is not + # visible by the layer code. + if self.activation_split_dims_mapping.blnh: + n_sharding = self.activation_split_dims_mapping.blnh[2] + else: + n_sharding = None + assert n_sharding is None or isinstance(n_sharding, str) + if base_layer.JaxContext.has_context(): + new_context_params = base_layer.cur_jax_context().hparams.clone() + else: + new_context_params = base_layer.JaxContext.HParams() + if not self.is_initializing() and n_sharding is not None: + new_context_params.mesh_axes_transpose = {n_sharding: None} + return base_layer.JaxContext.new_context(hparams=new_context_params) + + def __call__( + self, + query_vec: JTensor, + key_vec: JTensor, + value_vec: JTensor, + atten_mask: JTensor, + query_segment_pos: Optional[JTensor] = None, + key_segment_pos: Optional[JTensor] = None) -> Tuple[JTensor, JTensor]: + """Computes the value vector given the current query output. + + Args: + query_vec: JTensor of shape [B, T, D]. + key_vec: JTensor of shape [B, S, D]. + value_vec: JTensor of shape [B, S, D]. + atten_mask: JTensor of shape [1/B, 1, 1/T, S] which is a mask that is + applied to prevent attention between unwanted pairs. This has already + been converted into large negative logits. Note that the first and third + dimension allow size 1 if the mask is shared by every item in the batch + or every token in the target sequence. + query_segment_pos: JTensor of shape [B, T] + key_segment_pos: JTensor of shape [B, S] + + Returns: + encoded: JTensor of shape [B, T, D]. + atten_probs: JTensor of shape [B, G, T, S]. + """ + # Make sure the weight gradient matmul computes with D replicated, which + # will be a regular reduce-scatter pattern on the result. This helps + # scheduling MegaScale ops. + def _rep_d(x): + return base_layer.maybe_shard( + x, + [None] * x.ndim, + self.mesh_axis_names, + unconstrained_dims=range(x.ndim - 1), + ) + + query_vec, key_vec, value_vec = [ + _rep_d(x) for x in [query_vec, key_vec, value_vec] + ] + + # Project inputs to key, value and query, respectively has shape + # [B, S, N, H], [B, S, G, H], and [B, T, G, H]. + query_proj = self.query(query_vec) + key_proj = self.key(key_vec) + value_proj = self.value(value_vec) + + self._fprop_update_decode_state('key_state', key_proj) + self._fprop_update_decode_state('value_state', value_proj) + + # Apply rotary position embeddings. + # Paper: https://arxiv.org/abs/2104.09864. + if self.use_rotary_position_emb: + query_proj = self.rotary_position_emb(query_proj, query_segment_pos) + key_shape = key_proj.shape + # [B, S, H] -> [B, S, N(1), H] + if self.num_kv_heads == 1: + key_proj = jnp.expand_dims(key_proj, axis=-2) + key_proj = self.rotary_position_emb(key_proj, key_segment_pos) + if self.num_kv_heads == 1: + key_proj = jnp.reshape(key_proj, key_shape) + self._fprop_update_decode_state('key_post_rotary_pos_emb', key_proj) + + # Apply relative bias. + # Paper: https://aclanthology.org/N18-2074.pdf. + if self.relative_bias_tpl: + relative_bias = self.relative_bias(query_segment_pos, key_segment_pos) + else: + relative_bias = None + + query_proj = self._shard_blnh(query_proj) + if self.num_kv_heads == 1: + key_proj = self._shard_blh(key_proj) + value_proj = self._shard_blh(value_proj) + encoded, atten_probs = self._dot_atten( + query_proj, + key_proj, + value_proj, + atten_mask, + relative_bias, + ) + else: + key_proj = self._shard_blnh(key_proj) + value_proj = self._shard_blnh(value_proj) + b, t, n, h = query_proj.shape + _, s, g, _ = key_proj.shape + assert n % g == 0 + v_q = jnp.reshape(query_proj, (b, t, g, h)) + if relative_bias is not None: + v_rb = jnp.reshape(relative_bias, (b, g, t, s)) + else: + v_rb = None + encoded, atten_probs = self._dot_atten( + query_proj, + key_proj, + value_proj, + atten_mask, + relative_bias, + ) + encoded = self._shard_blnh(jnp.reshape(encoded, (b, t, g, h))) + if atten_probs is not None: + atten_probs = jnp.reshape(atten_probs, (b, t, g, s)) + + # Post projection + encoded = self.post(encoded) + encoded = self._shard_bld(encoded) + encoded = checkpoint_name(encoded, 'out_proj') + + return encoded, atten_probs + + def init_states(self, target_batch_size: int, target_max_length: int) -> None: + """Initializes cache for autoregressive cached decoding. + + Args: + target_batch_size: The batch size of the target to be decoded. + target_max_length: The sequence length of the target to be decoded. + Return: None. + """ + raise NotImplementedError(type(self)) + + + class MultiQueryDotProductAttentionLPB(MultiQueryDotProductAttention): # TODO(pax-dev): Implement a single base class for all LPB type models. """Multi-query dot-product attention with lazy prefix broadcast.