diff --git a/lit_llama/adapter.py b/lit_llama/adapter.py index f57ee970..8725516b 100644 --- a/lit_llama/adapter.py +++ b/lit_llama/adapter.py @@ -87,8 +87,8 @@ def forward( self, x: torch.Tensor, rope: RoPECache, - mask: torch.Tensor, max_seq_length: int, + mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, kv_cache: Optional[KVCache] = None, adapter_kv_cache: Optional[KVCache] = None, @@ -142,7 +142,7 @@ def forward( # efficient attention using Flash Attention CUDA kernels # ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs) - y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) # (B, nh, T, hs) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=mask is None) # (B, nh, T, hs) # "Adapters are applied to the topmost layers to better tune the language # representations with higher-level semantics". @@ -203,14 +203,14 @@ def forward( self, x: torch.Tensor, rope: RoPECache, - mask: torch.Tensor, max_seq_length: int, + mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, kv_cache: Optional[KVCache] = None, adapter_kv_cache: Optional[KVCache] = None, ) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]: h, new_kv_cache, new_adapter_kv_cache = self.attn( - self.rms_1(x), rope, mask, max_seq_length, input_pos, kv_cache, adapter_kv_cache + self.rms_1(x), rope, max_seq_length, mask, input_pos, kv_cache, adapter_kv_cache ) x = x + h x = x + self.mlp(self.rms_2(x)) @@ -253,6 +253,7 @@ def forward( self, idx: torch.Tensor, max_seq_length: Optional[int] = None, input_pos: Optional[torch.Tensor] = None ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[KVCache]]]: B, T = idx.size() + use_kv_cache = input_pos is not None block_size = self.config.block_size if max_seq_length is None: @@ -263,23 +264,26 @@ def forward( if self.rope_cache is None: self.rope_cache = self.build_rope_cache(idx) # (block_size, head_size / 2, 2) - if self.mask_cache is None: + # passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask + # for the kv-cache support (only during inference), we only create it in that situation + # this will be resolved by https://github.com/pytorch/pytorch/issues/96099 + if use_kv_cache and self.mask_cache is None: self.mask_cache = self.build_mask_cache(idx) # (1, 1, block_size, block_size) - if input_pos is not None: + if use_kv_cache: rope = self.rope_cache.index_select(0, input_pos) mask = self.mask_cache.index_select(2, input_pos) mask = mask[:, :, :, :max_seq_length] else: rope = self.rope_cache[:T] - mask = self.mask_cache[:, :, :T, :T] + mask = None # forward the model itself x = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd) - if input_pos is None: # proxy for use_cache=False + if not use_kv_cache: for block in self.transformer.h: - x, *_ = block(x, rope, mask, max_seq_length) + x, *_ = block(x, rope, max_seq_length) else: if not self.kv_caches: head_size = self.config.n_embd // self.config.n_head @@ -292,7 +296,7 @@ def forward( self.adapter_kv_caches = [None for _ in range(self.config.n_layer)] for i, block in enumerate(self.transformer.h): x, self.kv_caches[i], self.adapter_kv_caches[i] = block( - x, rope, mask, max_seq_length, input_pos, self.kv_caches[i], self.adapter_kv_caches[i] + x, rope, max_seq_length, mask, input_pos, self.kv_caches[i], self.adapter_kv_caches[i] ) x = self.transformer.ln_f(x) # (B, T, n_embd) diff --git a/lit_llama/model.py b/lit_llama/model.py index 4d0637ec..0d2dcde4 100644 --- a/lit_llama/model.py +++ b/lit_llama/model.py @@ -75,6 +75,7 @@ def forward( self, idx: torch.Tensor, max_seq_length: Optional[int] = None, input_pos: Optional[torch.Tensor] = None ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[KVCache]]]: B, T = idx.size() + use_kv_cache = input_pos is not None block_size = self.config.block_size if max_seq_length is None: @@ -85,23 +86,26 @@ def forward( if self.rope_cache is None: self.rope_cache = self.build_rope_cache(idx) - if self.mask_cache is None: + # passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask + # for the kv-cache support (only during inference), we only create it in that situation + # this will be resolved by https://github.com/pytorch/pytorch/issues/96099 + if use_kv_cache and self.mask_cache is None: self.mask_cache = self.build_mask_cache(idx) - if input_pos is not None: + if use_kv_cache: rope = self.rope_cache.index_select(0, input_pos) mask = self.mask_cache.index_select(2, input_pos) mask = mask[:, :, :, :max_seq_length] else: rope = self.rope_cache[:T] - mask = self.mask_cache[:, :, :T, :T] + mask = None # forward the model itself x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) - if input_pos is None: # proxy for use_cache=False + if not use_kv_cache: for block in self.transformer.h: - x, _ = block(x, rope, mask, max_seq_length) + x, _ = block(x, rope, max_seq_length) else: if not self.kv_caches: head_size = self.config.n_embd // self.config.n_head @@ -111,7 +115,7 @@ def forward( for _ in range(self.config.n_layer) ] for i, block in enumerate(self.transformer.h): - x, self.kv_caches[i] = block(x, rope, mask, max_seq_length, input_pos, self.kv_caches[i]) + x, self.kv_caches[i] = block(x, rope, max_seq_length, mask, input_pos, self.kv_caches[i]) x = self.transformer.ln_f(x) @@ -155,12 +159,12 @@ def forward( self, x: torch.Tensor, rope: RoPECache, - mask: MaskCache, max_seq_length: int, + mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, kv_cache: Optional[KVCache] = None, ) -> Tuple[torch.Tensor, Optional[KVCache]]: - h, new_kv_cache = self.attn(self.rms_1(x), rope, mask, max_seq_length, input_pos, kv_cache) + h, new_kv_cache = self.attn(self.rms_1(x), rope, max_seq_length, mask, input_pos, kv_cache) x = x + h x = x + self.mlp(self.rms_2(x)) return x, new_kv_cache @@ -184,8 +188,8 @@ def forward( self, x: torch.Tensor, rope: RoPECache, - mask: MaskCache, max_seq_length: int, + mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, kv_cache: Optional[KVCache] = None, ) -> Tuple[torch.Tensor, Optional[KVCache]]: @@ -225,7 +229,7 @@ def forward( # y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) # efficient attention using Flash Attention CUDA kernels - y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=mask is None) y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side diff --git a/tests/test_model.py b/tests/test_model.py index 3abc4843..8508d54d 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -82,7 +82,7 @@ def test_to_orig_llama(lit_llama, orig_llama, kv_cache) -> None: kv_cache_shape = (batch_size, n_head, block_size, head_size) ours_kv_cache = torch.zeros(kv_cache_shape), torch.zeros(kv_cache_shape) (llama_block_out, ours_kv_cache) = llama_model.transformer.h[0]( - llama_embed, llama_rope, llama_mask, seq_len, torch.arange(block_size), ours_kv_cache + llama_embed, llama_rope, seq_len, llama_mask, torch.arange(block_size), ours_kv_cache ) ours_k_cache = ours_kv_cache[0].permute(0, 2, 1, 3) ours_v_cache = ours_kv_cache[1].permute(0, 2, 1, 3) @@ -92,7 +92,7 @@ def test_to_orig_llama(lit_llama, orig_llama, kv_cache) -> None: orig_llama_block_out = orig_llama_model.layers[0]( orig_llama_embed, 0, orig_llama_model.freqs_cis[:seq_len], orig_llama_mask ) - (llama_block_out, _) = llama_model.transformer.h[0](llama_embed, llama_rope, llama_mask, seq_len) + (llama_block_out, _) = llama_model.transformer.h[0](llama_embed, llama_rope, seq_len, llama_mask) assert torch.allclose(orig_llama_block_out, llama_block_out) expected = orig_llama_model(token_sample, 0)