Skip to content

Commit 014c094

Browse files
authored
repeat kv for GQA (#462)
1 parent 7fb78a5 commit 014c094

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

mamba_ssm/modules/mha.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ def _update_kvcache_attention(self, q, kv, inference_params):
173173
# TODO: this only uses seqlen_offset and not lengths_per_sample.
174174
kv = self._update_kv_cache(kv, inference_params)
175175
k, v = kv.unbind(dim=-3)
176+
k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
177+
v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
176178
return F.scaled_dot_product_attention(
177179
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
178180
).transpose(1, 2)
@@ -275,6 +277,8 @@ def forward(self, x, inference_params=None):
275277
)
276278
if inference_params is None:
277279
k, v = kv.unbind(dim=-3)
280+
k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
281+
v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
278282
context = F.scaled_dot_product_attention(
279283
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
280284
).transpose(1, 2)

0 commit comments

Comments
 (0)