Skip to content

Commit

Permalink
[Gen] Use flash_attn_with_kvcache in generation
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Sep 7, 2023
1 parent a1576ad commit a86442f
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 36 deletions.
3 changes: 2 additions & 1 deletion flash_attn/layers/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ def forward(
# Call 1 kernel instead of 2 kernels
# We need qkv to be contiguous so that when we reshape to combine (3, nheads)
# dimensions, we get the same tensor
qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
# qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
apply_rotary(
qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
)
Expand Down
112 changes: 96 additions & 16 deletions flash_attn/modules/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
flash_attn_qkvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
flash_attn_with_kvcache,
)
except ImportError:
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
flash_attn_with_kvcache = None

try:
from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
Expand Down Expand Up @@ -556,6 +558,35 @@ def _apply_rotary_single_query_attention(self, qkv, inference_params, kv=None):
else False,
)

def _update_kvcache_attention(self, q, kv, inference_params):
"""Write kv to inference_params, then do attention """
if (
inference_params.sequence_len_offset == 0
or flash_attn_with_kvcache is None
or not self.use_flash_attn
):
# TODO: this only uses sequence_len_offset and not lengths_per_sample.
kv = self._update_kv_cache(kv, inference_params)
return self.inner_cross_attn(q, kv)
else:
batch = q.shape[0]
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
cache_seqlens = (
inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None
else inference_params.sequence_len_offset
)
return flash_attn_with_kvcache(
q,
kv_cache[:, :, 0],
kv_cache[:, :, 1],
kv[:, :, 0],
kv[:, :, 1],
cache_seqlens=cache_seqlens,
softmax_scale=self.inner_cross_attn.softmax_scale,
causal=self.inner_cross_attn.causal,
)

def forward(
self,
x,
Expand Down Expand Up @@ -605,10 +636,19 @@ def forward(
if self.use_flash_attn
else {"key_padding_mask": key_padding_mask, **kwargs}
)
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
seqlen_offset = (
0
if inference_params is None
else (
inference_params.lengths_per_sample
if inference_params.lengths_per_sample is not None
else inference_params.sequence_len_offset
)
)
rotary_max_seqlen = (
inference_params.max_sequence_len if inference_params is not None else None
)
batch, seqlen = x.shape[:2]
if not self.cross_attn and self.num_heads_kv == self.num_heads:
assert x_kv is None and mixer_subset is None
if not self.return_residual:
Expand All @@ -619,7 +659,8 @@ def forward(
qkv = rearrange(
self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
).contiguous()
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
# qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
qkv = qkv.reshape(batch, seqlen, 3, self.num_heads, self.head_dim)
if (
inference_params is None
or inference_params.sequence_len_offset == 0
Expand All @@ -635,9 +676,9 @@ def forward(
else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
else:
q = qkv[:, :, 0]
kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
context = self.inner_cross_attn(q, kv)
context = self._update_kvcache_attention(
qkv[:, :, 0], qkv[:, :, 1:], inference_params
)
else:
context = self._apply_rotary_single_query_attention(qkv, inference_params)
else:
Expand All @@ -659,8 +700,10 @@ def forward(
qkv, x = self.Wqkv(x)
q = qkv[..., : self.num_heads * self.head_dim]
kv = qkv[..., self.num_heads * self.head_dim :]
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
# q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
q = q.reshape(batch, seqlen, -1, self.head_dim)
# kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
kv = kv.reshape(batch, seqlen, 2, -1, self.head_dim)
if self.dwconv:
q = rearrange(
self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
Expand All @@ -685,11 +728,11 @@ def forward(
self.inner_cross_attn, q, kv, **kwargs
)
else:
kv = self._update_kv_cache(kv, inference_params)
context = self.inner_cross_attn(q, kv)
context = self._update_kvcache_attention(q, kv, inference_params)
else:
context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
# out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
out = self.out_proj(context.reshape(batch, seqlen, -1))
return out if not self.return_residual else (out, x)


Expand Down Expand Up @@ -846,6 +889,36 @@ def _apply_rotary_single_query_attention(self, qkv, inference_params, kv=None):
else False,
)

def _update_kvcache_attention(self, q, kv, inference_params):
"""Write kv to inference_params, then do attention """
if (
inference_params.sequence_len_offset == 0
or flash_attn_with_kvcache is None
or not self.use_flash_attn
):
# TODO: this only uses sequence_len_offset and not lengths_per_sample.
kv = self._update_kv_cache(kv, inference_params)
return self.inner_cross_attn(q, kv)
else:
batch = q.shape[0]
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
cache_seqlens = (
inference_params.lengths_per_sample[:batch]
if inference_params.lengths_per_sample is not None
else inference_params.sequence_len_offset
)
context = flash_attn_with_kvcache(
q,
kv_cache[:, :, 0],
kv_cache[:, :, 1],
kv[:, :, 0],
kv[:, :, 1],
cache_seqlens=cache_seqlens,
softmax_scale=self.inner_cross_attn.softmax_scale,
causal=self.inner_cross_attn.causal,
)
return context

def forward(self, x, seqlen=None, inference_params=None, **kwargs):
"""
Arguments:
Expand All @@ -857,7 +930,15 @@ def forward(self, x, seqlen=None, inference_params=None, **kwargs):
qkv = self.Wqkv(x)
if seqlen is not None:
qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
seqlen_offset = (
0
if inference_params is None
else (
inference_params.lengths_per_sample
if inference_params.lengths_per_sample is not None
else inference_params.sequence_len_offset
)
)
rotary_max_seqlen = (
inference_params.max_sequence_len if inference_params is not None else None
)
Expand All @@ -878,9 +959,9 @@ def forward(self, x, seqlen=None, inference_params=None, **kwargs):
else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
else:
q = qkv[:, :, 0]
kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
context = self.inner_cross_attn(q, kv)
context = self._update_kvcache_attention(
qkv[:, :, 0], qkv[:, :, 1:], inference_params
)
else:
context = self._apply_rotary_single_query_attention(qkv, inference_params)
else:
Expand Down Expand Up @@ -912,8 +993,7 @@ def forward(self, x, seqlen=None, inference_params=None, **kwargs):
self.inner_cross_attn, q, kv, **kwargs
)
else:
kv = self._update_kv_cache(kv, inference_params)
context = self.inner_cross_attn(q, kv)
context = self._update_kvcache_attention(q, kv, inference_params)
else:
context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
context = rearrange(context, "b s h d -> b s (h d)")
Expand Down
33 changes: 24 additions & 9 deletions flash_attn/utils/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def decode(
batch_size, seqlen_og = input_ids.shape
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
if cg:
assert fused_ft_kernel
if not hasattr(model, "_decoding_cache"):
model._decoding_cache = None
model._decoding_cache = update_graph_cache(
Expand All @@ -128,11 +127,13 @@ def decode(
seqlen_og,
max_length,
tensor_parallel=tensor_parallel,
fused_ft_kernel=fused_ft_kernel,
)
inference_params = model._decoding_cache.inference_params
inference_params.max_sequence_len = max_length
inference_params.max_batch_size = batch_size
inference_params.sequence_len_offset = 0
inference_params.lengths_per_sample.zero_()
else:
inference_params = InferenceParams(
max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=fused_ft_kernel
Expand Down Expand Up @@ -167,7 +168,8 @@ def sample_tokens(logits, inference_params):
token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
else:
token = teacher_outputs[:, inference_params.sequence_len_offset]
return rearrange(token, "b -> b 1")
# return rearrange(token, "b -> b 1")
return token.unsqueeze(1)

def should_stop(current_token, inference_params):
if inference_params.sequence_len_offset == 0:
Expand Down Expand Up @@ -197,9 +199,7 @@ def should_stop(current_token, inference_params):
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
return output_cls(
sequences=torch.cat(sequences, dim=1), scores=tuple(scores)
)
return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))


def sample_speculative(logits, logits_draft, tokens_draft, top_k=1, top_p=0.0, temperature=1.0):
Expand Down Expand Up @@ -298,7 +298,6 @@ def decode_speculative(
assert batch_size == 1, "Speculative decoding implementation only supports batch_size=1"
assert eos_token_id is None, "Speculative decoding implementation doesn't support eos_token_id"
if cg:
assert fused_ft_kernel
if not hasattr(model_draft, "_decoding_cache"):
model_draft._decoding_cache = None
model_draft._decoding_cache = update_graph_cache(
Expand All @@ -308,6 +307,7 @@ def decode_speculative(
seqlen_og,
max_length,
tensor_parallel=tensor_parallel,
fused_ft_kernel=fused_ft_kernel,
)
inference_params_draft = model_draft._decoding_cache.inference_params
inference_params_draft.max_sequence_len = max_length
Expand Down Expand Up @@ -606,19 +606,23 @@ def allocate_inference_cache(
layers: Union[int, Sequence],
device,
dtype=torch.float16,
fused_ft_kernel=False,
):
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if dtype == torch.float32 else 8
assert headdim % packsize == 0
k_cache_shape = (max_batch_size, nheads, headdim // packsize, max_seqlen, packsize)
v_cache_shape = (max_batch_size, nheads, max_seqlen, headdim)
kv_cache_shape = (max_batch_size, max_seqlen, 2, nheads, headdim)
if isinstance(layers, int):
layers = range(layers)
return {
i: (
torch.empty(k_cache_shape, device=device, dtype=dtype),
torch.empty(v_cache_shape, device=device, dtype=dtype),
)
if fused_ft_kernel
else torch.empty(kv_cache_sahpe, device=device, dtype=dtype)
for i in layers
}

Expand Down Expand Up @@ -651,7 +655,15 @@ class DecodingCGCache:

@torch.inference_mode()
def update_graph_cache(
model, cache, batch_size, seqlen_og, max_seqlen, tensor_parallel=1, dtype=None, n_warmups=2
model,
cache,
batch_size,
seqlen_og,
max_seqlen,
tensor_parallel=1,
dtype=None,
n_warmups=2,
fused_ft_kernel=False,
):
if cache is None:
cache = DecodingCGCache()
Expand All @@ -671,7 +683,9 @@ def update_graph_cache(
cache.device, cache.dtype = device, dtype
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
if hasattr(model, "allocate_inference_cache"):
inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
inf_cache = model.allocate_inference_cache(
batch_size, max_seqlen, dtype, fused_ft_kernel=fused_ft_kernel
)
else:
headdim = getattr(
model.config,
Expand All @@ -686,14 +700,15 @@ def update_graph_cache(
model.config.num_hidden_layers,
device,
dtype,
fused_ft_kernel=fused_ft_kernel,
)
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
cache.inference_params = InferenceParams(
max_sequence_len=max_seqlen,
max_batch_size=batch_size,
sequence_len_offset=seqlen_og,
key_value_memory_dict=inf_cache,
fused_ft_kernel=True,
fused_ft_kernel=fused_ft_kernel,
lengths_per_sample=lengths_per_sample,
)
cache.mempool = torch.cuda.graphs.graph_pool_handle()
Expand Down
Loading

0 comments on commit a86442f

Please sign in to comment.