@@ -31,6 +31,7 @@ def generate(
3131 prompt : torch .Tensor ,
3232 max_returned_tokens : int ,
3333 * ,
34+ prompt_chunksize : int = 1 ,
3435 temperature : float = 1.0 ,
3536 top_k : Optional [int ] = None ,
3637 top_p : float = 1.0 ,
@@ -60,35 +61,60 @@ def generate(
6061 or https://huyenchip.com/2024/01/16/sampling.html#top_p
6162 stop_tokens: If specified, stop generating any more token once one of this list is generated.
6263 """
63- from litgpt .generate .base import generate_fn
64- return generate_fn (
65- include_prompt = False ,
66- include_eos = False ,
67- model = model ,
68- prompt = prompt ,
69- max_returned_tokens = max_returned_tokens ,
70- temperature = temperature ,
71- top_k = top_k ,
72- top_p = top_p ,
73- stop_tokens = stop_tokens
64+ from litgpt .generate .base import batched_generate_fn
65+
66+ return map (
67+ lambda lst : lst [0 ],
68+ batched_generate_fn (
69+ model = model ,
70+ prompts = [prompt ],
71+ max_returned_tokens = max_returned_tokens ,
72+ prompt_chunksize = prompt_chunksize ,
73+ sample_args = dict (
74+ temperature = temperature ,
75+ top_k = top_k ,
76+ top_p = top_p ,
77+ ),
78+ stop_tokens = stop_tokens ,
79+ include_prompt = False ,
80+ include_eos = False ,
81+ )
7482 )
7583
7684
77- def process_prompt (prompt , model , tokenizer , prompt_style , fabric , temperature , max_new_tokens , top_k , top_p , stop_tokens ):
85+ def process_prompt (
86+ prompt : str ,
87+ model : GPT ,
88+ tokenizer ,
89+ prompt_style ,
90+ fabric ,
91+ max_new_tokens : int ,
92+ prompt_chunksize : int ,
93+ temperature : float ,
94+ top_k : Optional [int ],
95+ top_p : float ,
96+ stop_tokens : Tuple [List [int ], ...],
97+ ):
7898 prompt = prompt_style .apply (prompt = prompt )
7999 encoded_prompt = tokenizer .encode (prompt , device = fabric .device )
80100
81101 if max_new_tokens is None :
82102 max_returned_tokens = model .max_seq_length
83103 else :
84- first_turn = model .mask_cache is None
85104 max_returned_tokens = encoded_prompt .size (0 ) + max_new_tokens
86- if first_turn or max_returned_tokens > model .max_seq_length :
105+ msl = model .max_seq_length
106+ if max_returned_tokens > msl or model .config .block_size == msl :
87107 model .max_seq_length = max_returned_tokens
88- model .set_kv_cache (batch_size = 1 , device = fabric .device )
89108
90109 y : Iterator [torch .Tensor ] = generate (
91- model , encoded_prompt , max_returned_tokens , temperature = temperature , top_k = top_k , top_p = top_p , stop_tokens = stop_tokens
110+ model = model ,
111+ prompt = encoded_prompt ,
112+ max_returned_tokens = max_returned_tokens ,
113+ prompt_chunksize = prompt_chunksize ,
114+ temperature = temperature ,
115+ top_k = top_k ,
116+ top_p = top_p ,
117+ stop_tokens = stop_tokens ,
92118 )
93119 token_generator : Iterator [str ] = tokenizer .decode_stream (y , device = fabric .device )
94120
@@ -103,8 +129,7 @@ def process_prompt(prompt, model, tokenizer, prompt_style, fabric, temperature,
103129
104130 t = time .perf_counter () - t0
105131
106- for block in model .transformer .h :
107- block .attn .kv_cache .reset_parameters ()
132+ model .clear_kv_cache ()
108133 fabric .print (
109134 f"\n Time for inference: { t :.02f} sec total, { tokens_generated / t :.02f} tokens/sec,"
110135 f" { tokens_generated } tokens" ,
@@ -113,7 +138,19 @@ def process_prompt(prompt, model, tokenizer, prompt_style, fabric, temperature,
113138 fabric .print ()
114139
115140
116- def interact (multiline , model , tokenizer , prompt_style , fabric , temperature , max_new_tokens , top_k , top_p , stop_tokens ):
141+ def interact (
142+ multiline : bool ,
143+ model : GPT ,
144+ tokenizer ,
145+ prompt_style ,
146+ fabric ,
147+ max_new_tokens : int ,
148+ prompt_chunksize : int ,
149+ temperature : float ,
150+ top_k : Optional [int ],
151+ top_p : float ,
152+ stop_tokens : Tuple [List [int ], ...],
153+ ):
117154 while True :
118155 try :
119156 if not multiline :
@@ -135,14 +172,27 @@ def interact(multiline, model, tokenizer, prompt_style, fabric, temperature, max
135172 if not prompt or prompt in ("!quit" , "!exit" ):
136173 break
137174
138- process_prompt (prompt , model , tokenizer , prompt_style , fabric , temperature , max_new_tokens , top_k , top_p , stop_tokens )
175+ process_prompt (
176+ prompt = prompt ,
177+ model = model ,
178+ tokenizer = tokenizer ,
179+ prompt_style = prompt_style ,
180+ fabric = fabric ,
181+ temperature = temperature ,
182+ max_new_tokens = max_new_tokens ,
183+ prompt_chunksize = prompt_chunksize ,
184+ top_k = top_k ,
185+ top_p = top_p ,
186+ stop_tokens = stop_tokens ,
187+ )
139188
140189
141190@torch .inference_mode ()
142191def main (
143192 checkpoint_dir : Path ,
144193 * ,
145194 max_new_tokens : int = 50 ,
195+ prompt_chunksize : int = 1 ,
146196 top_k : Optional [int ] = 50 ,
147197 top_p : float = 1.0 ,
148198 temperature : float = 0.8 ,
@@ -158,6 +208,11 @@ def main(
158208 checkpoint_dir: A local path to a directory containing the model weights or a valid model name.
159209 You can get a list of valid model names via the `litgpt download list` command line argument.
160210 max_new_tokens: The number of generation steps to take.
211+ prompt_chunksize: If even the shortest prompt is longer than the KV
212+ cache, prompts are processed in chunks of this size in the
213+ prefill phase. Once the shortest has been processed to the
214+ end, we proceed with chunk size 1.
215+ Defaults to 1, but larger values are recommended for long prompts.
161216 top_k: The number of top most probable tokens to consider in the sampling process.
162217 top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
163218 In top-p sampling, the next token is sampled from the highest probability tokens
@@ -252,8 +307,9 @@ def main(
252307 tokenizer = tokenizer ,
253308 prompt_style = prompt_style ,
254309 fabric = fabric ,
255- temperature = temperature ,
256310 max_new_tokens = (None if compile else max_new_tokens ),
311+ prompt_chunksize = prompt_chunksize ,
312+ temperature = temperature ,
257313 top_k = top_k ,
258314 top_p = top_p ,
259315 stop_tokens = stop_tokens
0 commit comments