From a86442f0f35c135c8ed8d7af760b1bd6a832ec07 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Thu, 7 Sep 2023 08:24:43 -0700 Subject: [PATCH] [Gen] Use flash_attn_with_kvcache in generation --- flash_attn/layers/rotary.py | 3 +- flash_attn/modules/mha.py | 112 ++++++++++++++++++++++++++++----- flash_attn/utils/generation.py | 33 +++++++--- tests/models/test_baichuan.py | 16 ++--- tests/models/test_gpt.py | 6 +- 5 files changed, 134 insertions(+), 36 deletions(-) diff --git a/flash_attn/layers/rotary.py b/flash_attn/layers/rotary.py index e081770a2..71259d020 100644 --- a/flash_attn/layers/rotary.py +++ b/flash_attn/layers/rotary.py @@ -146,7 +146,8 @@ def forward( # Call 1 kernel instead of 2 kernels # We need qkv to be contiguous so that when we reshape to combine (3, nheads) # dimensions, we get the same tensor - qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d") + # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d") + qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim) apply_rotary( qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True ) diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 4cb6eaf85..3d6b70790 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -15,10 +15,12 @@ flash_attn_qkvpacked_func, flash_attn_varlen_kvpacked_func, flash_attn_varlen_qkvpacked_func, + flash_attn_with_kvcache, ) except ImportError: flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None + flash_attn_with_kvcache = None try: from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear @@ -556,6 +558,35 @@ def _apply_rotary_single_query_attention(self, qkv, inference_params, kv=None): else False, ) + def _update_kvcache_attention(self, q, kv, inference_params): + """Write kv to inference_params, then do attention """ + if ( + inference_params.sequence_len_offset == 0 + or flash_attn_with_kvcache is None + or not self.use_flash_attn + ): + # TODO: this only uses sequence_len_offset and not lengths_per_sample. + kv = self._update_kv_cache(kv, inference_params) + return self.inner_cross_attn(q, kv) + else: + batch = q.shape[0] + kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] + cache_seqlens = ( + inference_params.lengths_per_sample[:batch] + if inference_params.lengths_per_sample is not None + else inference_params.sequence_len_offset + ) + return flash_attn_with_kvcache( + q, + kv_cache[:, :, 0], + kv_cache[:, :, 1], + kv[:, :, 0], + kv[:, :, 1], + cache_seqlens=cache_seqlens, + softmax_scale=self.inner_cross_attn.softmax_scale, + causal=self.inner_cross_attn.causal, + ) + def forward( self, x, @@ -605,10 +636,19 @@ def forward( if self.use_flash_attn else {"key_padding_mask": key_padding_mask, **kwargs} ) - seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset + seqlen_offset = ( + 0 + if inference_params is None + else ( + inference_params.lengths_per_sample + if inference_params.lengths_per_sample is not None + else inference_params.sequence_len_offset + ) + ) rotary_max_seqlen = ( inference_params.max_sequence_len if inference_params is not None else None ) + batch, seqlen = x.shape[:2] if not self.cross_attn and self.num_heads_kv == self.num_heads: assert x_kv is None and mixer_subset is None if not self.return_residual: @@ -619,7 +659,8 @@ def forward( qkv = rearrange( self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d" ).contiguous() - qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim) + # qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim) + qkv = qkv.reshape(batch, seqlen, 3, self.num_heads, self.head_dim) if ( inference_params is None or inference_params.sequence_len_offset == 0 @@ -635,9 +676,9 @@ def forward( else: context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs) else: - q = qkv[:, :, 0] - kv = self._update_kv_cache(qkv[:, :, 1:], inference_params) - context = self.inner_cross_attn(q, kv) + context = self._update_kvcache_attention( + qkv[:, :, 0], qkv[:, :, 1:], inference_params + ) else: context = self._apply_rotary_single_query_attention(qkv, inference_params) else: @@ -659,8 +700,10 @@ def forward( qkv, x = self.Wqkv(x) q = qkv[..., : self.num_heads * self.head_dim] kv = qkv[..., self.num_heads * self.head_dim :] - q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) - kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim) + # q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) + q = q.reshape(batch, seqlen, -1, self.head_dim) + # kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim) + kv = kv.reshape(batch, seqlen, 2, -1, self.head_dim) if self.dwconv: q = rearrange( self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d" @@ -685,11 +728,11 @@ def forward( self.inner_cross_attn, q, kv, **kwargs ) else: - kv = self._update_kv_cache(kv, inference_params) - context = self.inner_cross_attn(q, kv) + context = self._update_kvcache_attention(q, kv, inference_params) else: context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv) - out = self.out_proj(rearrange(context, "... h d -> ... (h d)")) + # out = self.out_proj(rearrange(context, "... h d -> ... (h d)")) + out = self.out_proj(context.reshape(batch, seqlen, -1)) return out if not self.return_residual else (out, x) @@ -846,6 +889,36 @@ def _apply_rotary_single_query_attention(self, qkv, inference_params, kv=None): else False, ) + def _update_kvcache_attention(self, q, kv, inference_params): + """Write kv to inference_params, then do attention """ + if ( + inference_params.sequence_len_offset == 0 + or flash_attn_with_kvcache is None + or not self.use_flash_attn + ): + # TODO: this only uses sequence_len_offset and not lengths_per_sample. + kv = self._update_kv_cache(kv, inference_params) + return self.inner_cross_attn(q, kv) + else: + batch = q.shape[0] + kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] + cache_seqlens = ( + inference_params.lengths_per_sample[:batch] + if inference_params.lengths_per_sample is not None + else inference_params.sequence_len_offset + ) + context = flash_attn_with_kvcache( + q, + kv_cache[:, :, 0], + kv_cache[:, :, 1], + kv[:, :, 0], + kv[:, :, 1], + cache_seqlens=cache_seqlens, + softmax_scale=self.inner_cross_attn.softmax_scale, + causal=self.inner_cross_attn.causal, + ) + return context + def forward(self, x, seqlen=None, inference_params=None, **kwargs): """ Arguments: @@ -857,7 +930,15 @@ def forward(self, x, seqlen=None, inference_params=None, **kwargs): qkv = self.Wqkv(x) if seqlen is not None: qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen) - seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset + seqlen_offset = ( + 0 + if inference_params is None + else ( + inference_params.lengths_per_sample + if inference_params.lengths_per_sample is not None + else inference_params.sequence_len_offset + ) + ) rotary_max_seqlen = ( inference_params.max_sequence_len if inference_params is not None else None ) @@ -878,9 +959,9 @@ def forward(self, x, seqlen=None, inference_params=None, **kwargs): else: context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs) else: - q = qkv[:, :, 0] - kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx) - context = self.inner_cross_attn(q, kv) + context = self._update_kvcache_attention( + qkv[:, :, 0], qkv[:, :, 1:], inference_params + ) else: context = self._apply_rotary_single_query_attention(qkv, inference_params) else: @@ -912,8 +993,7 @@ def forward(self, x, seqlen=None, inference_params=None, **kwargs): self.inner_cross_attn, q, kv, **kwargs ) else: - kv = self._update_kv_cache(kv, inference_params) - context = self.inner_cross_attn(q, kv) + context = self._update_kvcache_attention(q, kv, inference_params) else: context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv) context = rearrange(context, "b s h d -> b s (h d)") diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index a89c4146c..0cba9abf6 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -118,7 +118,6 @@ def decode( batch_size, seqlen_og = input_ids.shape teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0 if cg: - assert fused_ft_kernel if not hasattr(model, "_decoding_cache"): model._decoding_cache = None model._decoding_cache = update_graph_cache( @@ -128,11 +127,13 @@ def decode( seqlen_og, max_length, tensor_parallel=tensor_parallel, + fused_ft_kernel=fused_ft_kernel, ) inference_params = model._decoding_cache.inference_params inference_params.max_sequence_len = max_length inference_params.max_batch_size = batch_size inference_params.sequence_len_offset = 0 + inference_params.lengths_per_sample.zero_() else: inference_params = InferenceParams( max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=fused_ft_kernel @@ -167,7 +168,8 @@ def sample_tokens(logits, inference_params): token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) else: token = teacher_outputs[:, inference_params.sequence_len_offset] - return rearrange(token, "b -> b 1") + # return rearrange(token, "b -> b 1") + return token.unsqueeze(1) def should_stop(current_token, inference_params): if inference_params.sequence_len_offset == 0: @@ -197,9 +199,7 @@ def should_stop(current_token, inference_params): torch.cuda.synchronize() print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms") output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput - return output_cls( - sequences=torch.cat(sequences, dim=1), scores=tuple(scores) - ) + return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores)) def sample_speculative(logits, logits_draft, tokens_draft, top_k=1, top_p=0.0, temperature=1.0): @@ -298,7 +298,6 @@ def decode_speculative( assert batch_size == 1, "Speculative decoding implementation only supports batch_size=1" assert eos_token_id is None, "Speculative decoding implementation doesn't support eos_token_id" if cg: - assert fused_ft_kernel if not hasattr(model_draft, "_decoding_cache"): model_draft._decoding_cache = None model_draft._decoding_cache = update_graph_cache( @@ -308,6 +307,7 @@ def decode_speculative( seqlen_og, max_length, tensor_parallel=tensor_parallel, + fused_ft_kernel=fused_ft_kernel, ) inference_params_draft = model_draft._decoding_cache.inference_params inference_params_draft.max_sequence_len = max_length @@ -606,12 +606,14 @@ def allocate_inference_cache( layers: Union[int, Sequence], device, dtype=torch.float16, + fused_ft_kernel=False, ): assert dtype in [torch.float16, torch.bfloat16, torch.float32] packsize = 4 if dtype == torch.float32 else 8 assert headdim % packsize == 0 k_cache_shape = (max_batch_size, nheads, headdim // packsize, max_seqlen, packsize) v_cache_shape = (max_batch_size, nheads, max_seqlen, headdim) + kv_cache_shape = (max_batch_size, max_seqlen, 2, nheads, headdim) if isinstance(layers, int): layers = range(layers) return { @@ -619,6 +621,8 @@ def allocate_inference_cache( torch.empty(k_cache_shape, device=device, dtype=dtype), torch.empty(v_cache_shape, device=device, dtype=dtype), ) + if fused_ft_kernel + else torch.empty(kv_cache_sahpe, device=device, dtype=dtype) for i in layers } @@ -651,7 +655,15 @@ class DecodingCGCache: @torch.inference_mode() def update_graph_cache( - model, cache, batch_size, seqlen_og, max_seqlen, tensor_parallel=1, dtype=None, n_warmups=2 + model, + cache, + batch_size, + seqlen_og, + max_seqlen, + tensor_parallel=1, + dtype=None, + n_warmups=2, + fused_ft_kernel=False, ): if cache is None: cache = DecodingCGCache() @@ -671,7 +683,9 @@ def update_graph_cache( cache.device, cache.dtype = device, dtype cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen if hasattr(model, "allocate_inference_cache"): - inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype) + inf_cache = model.allocate_inference_cache( + batch_size, max_seqlen, dtype, fused_ft_kernel=fused_ft_kernel + ) else: headdim = getattr( model.config, @@ -686,6 +700,7 @@ def update_graph_cache( model.config.num_hidden_layers, device, dtype, + fused_ft_kernel=fused_ft_kernel, ) lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device) cache.inference_params = InferenceParams( @@ -693,7 +708,7 @@ def update_graph_cache( max_batch_size=batch_size, sequence_len_offset=seqlen_og, key_value_memory_dict=inf_cache, - fused_ft_kernel=True, + fused_ft_kernel=fused_ft_kernel, lengths_per_sample=lengths_per_sample, ) cache.mempool = torch.cuda.graphs.graph_pool_handle() diff --git a/tests/models/test_baichuan.py b/tests/models/test_baichuan.py index 6658c1c7b..3818f30bf 100644 --- a/tests/models/test_baichuan.py +++ b/tests/models/test_baichuan.py @@ -217,8 +217,9 @@ def test_baichuan_parallel_forward(model_name, world_size): ).abs().max().item() +@pytest.mark.parametrize("fused_ft_kernel", [False, True]) @pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"]) -def test_baichuan_generation(model_name): +def test_baichuan_generation(model_name, fused_ft_kernel): dtype = torch.float16 device = "cuda" config = baichuan_config_to_gpt2_config( @@ -276,6 +277,7 @@ def test_baichuan_generation(model_name): model.load_state_dict(pretrained_state_dict) model.eval() + model(input_ids) # Warm up print("Without CUDA graph") torch.cuda.synchronize() start = time.time() @@ -283,7 +285,7 @@ def test_baichuan_generation(model_name): input_ids=input_ids, max_length=max_length, eos_token_id=eos_token_id, - fused_ft_kernel=True, + fused_ft_kernel=fused_ft_kernel, return_dict_in_generate=True, output_scores=True, enable_timing=True, @@ -295,7 +297,7 @@ def test_baichuan_generation(model_name): # Capture graph outside the timing loop batch_size, seqlen_og = input_ids.shape model._decoding_cache = update_graph_cache( - model, None, batch_size, seqlen_og, max_length + model, None, batch_size, seqlen_og, max_length, fused_ft_kernel=fused_ft_kernel ) print("With CUDA graph") torch.cuda.synchronize() @@ -303,7 +305,7 @@ def test_baichuan_generation(model_name): out_cg = model.generate( input_ids=input_ids, max_length=max_length, - fused_ft_kernel=True, + fused_ft_kernel=fused_ft_kernel, cg=True, return_dict_in_generate=True, output_scores=True, @@ -346,7 +348,7 @@ def test_baichuan_parallel_generation(model_name, world_size): config = baichuan_config_to_gpt2_config( AutoConfig.from_pretrained(model_name, trust_remote_code=True) ) - config.use_flash_attn = False + config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = False # We don't have fused GatedMLP yet config.fused_dropout_add_ln = False @@ -393,7 +395,6 @@ def test_baichuan_parallel_generation(model_name, world_size): max_length=max_length, tensor_parallel=world_size, vocab_size=config.vocab_size, - fused_ft_kernel=True, # teacher_outputs=out_hf.sequences, return_dict_in_generate=True, output_scores=True, @@ -411,7 +412,6 @@ def test_baichuan_parallel_generation(model_name, world_size): max_length=max_length, tensor_parallel=world_size, vocab_size=config.vocab_size, - fused_ft_kernel=True, cg=True, # teacher_outputs=out_hf.sequences, return_dict_in_generate=True, @@ -458,6 +458,6 @@ def test_baichuan_parallel_generation(model_name, world_size): hf_error = (logits_hf - logits_ref).abs().max().item() print(f"HF fp16 logits max diff: {hf_error}") print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }") - assert (logits - logits_ref).abs().max().item() < 2 * hf_error print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }") + assert (logits - logits_ref).abs().max().item() < 2 * hf_error assert torch.equal(logits_cg, logits) diff --git a/tests/models/test_gpt.py b/tests/models/test_gpt.py index a9a827df8..e72a9b93b 100644 --- a/tests/models/test_gpt.py +++ b/tests/models/test_gpt.py @@ -135,7 +135,7 @@ def test_gpt2_optimized(model_name): @pytest.mark.parametrize("fused_ft_kernel", [False, True]) -# @pytest.mark.parametrize('fused_ft_kernel', [False]) +# @pytest.mark.parametrize('fused_ft_kernel', [True]) @pytest.mark.parametrize("optimized", [False, True]) # @pytest.mark.parametrize('optimized', [True]) @pytest.mark.parametrize("rotary", [False, True]) @@ -209,7 +209,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel): ) print(out.sequences) print(tokenizer.batch_decode(out.sequences.tolist())) - if fused_ft_kernel or config.use_flash_attn: + if fused_ft_kernel or getattr(config, "use_flash_attn", False): out_cg = model.generate( input_ids=input_ids, max_length=max_length, @@ -220,6 +220,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel): enable_timing=True, ) print(out_cg.sequences) + assert torch.equal(torch.stack(out.scores, dim=1), torch.stack(out_cg.scores, dim=1)) if not rotary: out_hf = model_hf.generate( @@ -282,6 +283,7 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs): @pytest.mark.parametrize("rotary", [None, "interleaved", "block"]) # @pytest.mark.parametrize('rotary', [None]) @pytest.mark.parametrize("fused_ft_kernel", [False, True]) +# @pytest.mark.parametrize("fused_ft_kernel", [False]) @pytest.mark.parametrize("model_name", ["gpt2"]) def test_gpt2_generation_cg(model_name, fused_ft_kernel, rotary, seqlen, maxlen): """Check that decoding with CUDA graph is the same as decoding without CUDA graph."""