1515import torch .nn as nn
1616from typing_extensions import Self
1717
18+ from litgpt .attention import DefaultKeysAndValues , MultiHeadSelfAttention
1819from litgpt .config import Config as BaseConfig
20+ from litgpt .kvcache .base import KVCache
1921from litgpt .model import GPT as BaseModel
2022from litgpt .model import Block as BaseBlock
2123from litgpt .model import CausalSelfAttention as BaseCausalSelfAttention
@@ -29,21 +31,28 @@ class Config(BaseConfig):
2931
3032class GPT (BaseModel ):
3133 # Copy & paste from :class:`model.GPT`. Note that :class:`Block` is new here.
32- def __init__ (self , config : Config ) -> None :
34+ def __init__ (self , config : Config , ** mha_kwargs ) -> None :
3335 nn .Module .__init__ (self )
3436 assert config .padded_vocab_size is not None
3537 self .config = config
3638
37- self .lm_head = nn .Linear (config .n_embd , config .padded_vocab_size , bias = config .lm_head_bias )
39+ self .lm_head = nn .Linear (
40+ config .n_embd ,
41+ config .padded_vocab_size ,
42+ bias = config .lm_head_bias ,
43+ )
3844 self .transformer = nn .ModuleDict (
3945 dict (
4046 wte = nn .Embedding (config .padded_vocab_size , config .n_embd ),
4147 h = nn .ModuleList (Block (config , block_idx ) for block_idx in range (config .n_layer )),
4248 ln_f = config .norm_class (config .n_embd , eps = config .norm_eps ),
4349 )
4450 )
45- self .mask_cache : Optional [ torch . Tensor ] = None
51+ self .mha = MultiHeadSelfAttention ( config , ** mha_kwargs )
4652 self .max_seq_length = self .config .block_size
53+ self ._start_of_layer_hook = config .start_of_layer_hook
54+ # Have dense KV caches been created by `set_kv_cache`?
55+ self ._default_kv_cache = False
4756
4857 @classmethod
4958 def from_name (cls , name : str , ** kwargs : Any ) -> Self :
@@ -57,56 +66,79 @@ def _init_weights(self, module: nn.Module) -> None:
5766
5867
5968class Block (BaseBlock ):
60- def __init__ (self , config : Config , block_idx : int ) -> None :
61- super ().__init__ (config , block_idx )
62- self .attn = CausalSelfAttention (config , block_idx )
69+ def __init__ (
70+ self ,
71+ config : Config ,
72+ block_idx : int ,
73+ kv_cache : Optional [KVCache ] = None ,
74+ ) -> None :
75+ super ().__init__ (config , block_idx , kv_cache )
76+ self .attn = CausalSelfAttention (config , block_idx , kv_cache = kv_cache )
6377
6478
6579class CausalSelfAttention (BaseCausalSelfAttention ):
6680 """A modification of `litgpt.model.CausalSelfAttention` that adds the attention
6781 over the adaption prompt."""
6882
69- def __init__ (self , config : Config , block_idx : int ) -> None :
70- super ().__init__ (config , block_idx )
71- if block_idx >= config .adapter_start_layer :
83+ def __init__ (
84+ self ,
85+ config : Config ,
86+ block_idx : int ,
87+ kv_cache : Optional [KVCache ] = None ,
88+ ) -> None :
89+ super ().__init__ (
90+ config = config ,
91+ block_idx = block_idx ,
92+ kv_cache = kv_cache ,
93+ )
94+ self ._extend_forward = block_idx >= config .adapter_start_layer
95+ if self ._extend_forward :
7296 # adapter embedding layer
7397 self .adapter_wte = nn .Embedding (config .adapter_prompt_length , config .n_embd )
7498 # gate for adaption
7599 self .gating_factor = torch .nn .Parameter (torch .zeros (1 , 1 , config .n_head , 1 ))
76100 # kv cache for inference
77101 self .adapter_kv_cache : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None
78102
79- def scaled_dot_product_attention (
80- self , q : torch .Tensor , k : torch .Tensor , v : torch .Tensor , mask : Optional [torch .Tensor ] = None
103+ def _transform_output (
104+ self ,
105+ y : torch .Tensor ,
106+ query : torch .Tensor ,
107+ mha : MultiHeadSelfAttention ,
81108 ) -> torch .Tensor :
82- y = super ().scaled_dot_product_attention (q , k , v , mask )
83- if self .block_idx < self .config .adapter_start_layer :
84- return y
85-
86- aT = self .config .adapter_prompt_length
87- if self .adapter_kv_cache is not None :
88- # since this uses the wte weights as the prefix and the kv cache is only used during inference, ak and av
89- # are the same every call
90- ak , av = self .adapter_kv_cache
91- else :
92- prefix = self .adapter_wte .weight .reshape (1 , aT , self .config .n_embd )
93- aqkv = self .qkv (prefix )
94- q_per_kv = self .config .n_head // self .config .n_query_groups
95- aqkv = aqkv .view (1 , aT , self .config .n_query_groups , q_per_kv + 2 , self .config .head_size )
96- aqkv = aqkv .permute (0 , 2 , 3 , 1 , 4 )
97- _ , ak , av = aqkv .split ((q_per_kv , 1 , 1 ), dim = 2 )
98- if self .config .n_query_groups != 1 :
99- # for MHA this is a no-op
100- ak = ak .repeat_interleave (q_per_kv , dim = 2 )
101- av = av .repeat_interleave (q_per_kv , dim = 2 )
102- ak = ak .view (1 , - 1 , aT , self .config .head_size ) # (1, nh_ak, aT, hs)
103- av = av .view (1 , - 1 , aT , self .config .head_size ) # (1, nh_av, aT, hs)
104- self .adapter_kv_cache = (ak , av )
105-
106- T = q .size (2 )
107- amask = torch .ones (T , aT , dtype = torch .bool , device = q .device )
108- ay = super ().scaled_dot_product_attention (q , ak , av , amask )
109- return y + self .gating_factor * ay
109+ if self ._extend_forward :
110+ B , T , _ = y .shape
111+ y = y .view (B , T , self .config .n_head , self .config .head_size )
112+ aT = self .config .adapter_prompt_length
113+ if self .adapter_kv_cache is not None :
114+ # since this uses the wte weights as the prefix and the kv cache is only used during inference, ak and av
115+ # are the same every call
116+ ak , av = self .adapter_kv_cache
117+ else :
118+ prefix = self .adapter_wte .weight .reshape (1 , aT , self .config .n_embd )
119+ aqkv = self .qkv (prefix )
120+ q_per_kv = self .config .n_head // self .config .n_query_groups
121+ aqkv = aqkv .view (1 , aT , self .config .n_query_groups , q_per_kv + 2 , self .config .head_size )
122+ aqkv = aqkv .permute (0 , 2 , 3 , 1 , 4 )
123+ _ , ak , av = aqkv .split ((q_per_kv , 1 , 1 ), dim = 2 )
124+ if self .config .n_query_groups != 1 :
125+ # for MHA this is a no-op
126+ ak = ak .repeat_interleave (q_per_kv , dim = 2 )
127+ av = av .repeat_interleave (q_per_kv , dim = 2 )
128+ ak = ak .view (1 , - 1 , aT , self .config .head_size ) # (1, nh_ak, aT, hs)
129+ av = av .view (1 , - 1 , aT , self .config .head_size ) # (1, nh_av, aT, hs)
130+ self .adapter_kv_cache = (ak , av )
131+
132+ amask = torch .ones (T , aT , dtype = torch .bool , device = query .device )
133+ a_k_and_v = DefaultKeysAndValues (keys = ak , values = av )
134+ ay , _ = mha .scaled_dot_product_attention (
135+ query = query ,
136+ k_and_v = a_k_and_v ,
137+ mask = amask ,
138+ )
139+ y = (y + self .gating_factor * ay ).view (B , T , - 1 )
140+
141+ return y
110142
111143 def reset_parameters (self ) -> None :
112144 if hasattr (self , "gating_factor" ):
0 commit comments