99"""
1010
1111from dataclasses import dataclass
12- from typing import Any , Dict , Optional , Tuple
12+ from typing import Any , Dict , Optional , Tuple , List
1313
1414import torch
1515import torch .nn as nn
1919from litgpt .model import GPT as BaseModel
2020from litgpt .model import Block as BaseBlock
2121from litgpt .model import CausalSelfAttention as BaseCausalSelfAttention
22+ from litgpt .kvcache .base import KVCache , KeysAndValues , DefaultKeysAndValues
2223
2324
2425@dataclass
@@ -29,20 +30,33 @@ class Config(BaseConfig):
2930
3031class GPT (BaseModel ):
3132 # Copy & paste from :class:`model.GPT`. Note that :class:`Block` is new here.
32- def __init__ (self , config : Config ) -> None :
33+ def __init__ (
34+ self ,
35+ config : Config ,
36+ kv_cache : Optional [List [KVCache ]] = None
37+ ) -> None :
3338 nn .Module .__init__ (self )
3439 assert config .padded_vocab_size is not None
3540 self .config = config
3641
42+ if kv_cache is not None :
43+ if len (kv_cache ) != config .n_layer :
44+ raise ValueError (f"kv_cache length { len (kv_cache )} != { config .n_layer } = config.n_layer" )
45+ for kvc in kv_cache :
46+ self ._check_kv_cache (config , kvc )
47+ self ._default_kv_cache = False
48+ else :
49+ kv_cache = [None ] * config .n_layer
50+ self ._default_kv_cache = True
3751 self .lm_head = nn .Linear (
3852 config .n_embd , config .padded_vocab_size , bias = config .lm_head_bias
3953 )
4054 self .transformer = nn .ModuleDict (
4155 dict (
4256 wte = nn .Embedding (config .padded_vocab_size , config .n_embd ),
4357 h = nn .ModuleList (
44- Block (config , block_idx )
45- for block_idx in range ( config . n_layer )
58+ Block (config , block_idx , kv_cache = kvc )
59+ for block_idx , kvc in enumerate ( kv_cache )
4660 ),
4761 ln_f = config .norm_class (config .n_embd , eps = config .norm_eps ),
4862 )
@@ -62,17 +76,27 @@ def _init_weights(self, module: nn.Module) -> None:
6276
6377
6478class Block (BaseBlock ):
65- def __init__ (self , config : Config , block_idx : int ) -> None :
66- super ().__init__ (config , block_idx )
67- self .attn = CausalSelfAttention (config , block_idx )
79+ def __init__ (
80+ self ,
81+ config : Config ,
82+ block_idx : int ,
83+ kv_cache : Optional [KVCache ] = None ,
84+ ) -> None :
85+ super ().__init__ (config , block_idx , kv_cache )
86+ self .attn = CausalSelfAttention (config , block_idx , kv_cache = kv_cache )
6887
6988
7089class CausalSelfAttention (BaseCausalSelfAttention ):
7190 """A modification of `litgpt.model.CausalSelfAttention` that adds the attention
7291 over the adaption prompt."""
7392
74- def __init__ (self , config : Config , block_idx : int ) -> None :
75- super ().__init__ (config , block_idx )
93+ def __init__ (
94+ self ,
95+ config : Config ,
96+ block_idx : int ,
97+ kv_cache : Optional [KVCache ] = None ,
98+ ) -> None :
99+ super ().__init__ (config , block_idx , kv_cache )
76100 if block_idx >= config .adapter_start_layer :
77101 # adapter embedding layer
78102 self .adapter_wte = nn .Embedding (config .adapter_prompt_length , config .n_embd )
@@ -82,11 +106,16 @@ def __init__(self, config: Config, block_idx: int) -> None:
82106 self .adapter_kv_cache : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None
83107
84108 def scaled_dot_product_attention (
85- self , q : torch .Tensor , k : torch .Tensor , v : torch .Tensor , mask : Optional [torch .Tensor ] = None
86- ) -> torch .Tensor :
87- y = super ().scaled_dot_product_attention (q , k , v , mask )
109+ self ,
110+ q : torch .Tensor ,
111+ k_and_v : KeysAndValues ,
112+ mask : Optional [torch .Tensor ] = None ,
113+ is_causal : bool = True ,
114+ return_scores : bool = False ,
115+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ]]:
116+ y , scores = super ().scaled_dot_product_attention (q , k_and_v , mask , is_causal , return_scores )
88117 if self .block_idx < self .config .adapter_start_layer :
89- return y
118+ return y , scores
90119
91120 aT = self .config .adapter_prompt_length
92121 if self .adapter_kv_cache is not None :
@@ -110,8 +139,14 @@ def scaled_dot_product_attention(
110139
111140 T = q .size (2 )
112141 amask = torch .ones (T , aT , dtype = torch .bool , device = q .device )
113- ay = super ().scaled_dot_product_attention (q , ak , av , amask )
114- return y + self .gating_factor * ay
142+ a_k_and_v = DefaultKeysAndValues (keys = ak , values = av )
143+ ay , _ = super ().scaled_dot_product_attention (
144+ q = q ,
145+ k_and_v = a_k_and_v ,
146+ mask = amask ,
147+ is_causal = False ,
148+ )
149+ return y + self .gating_factor * ay , scores
115150
116151 def reset_parameters (self ) -> None :
117152 if hasattr (self , "gating_factor" ):
0 commit comments