File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed
Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -232,7 +232,7 @@ def batched_generate_fn(
232232 """
233233 batch_size = len (prompts )
234234 assert batch_size > 0 , "No prompts are given"
235- assert prompt_chunksize > 0 , "prompt_chunksize must be positive"
235+ assert prompt_chunksize >= 1 , "prompt_chunksize must be positive"
236236 prompt_size = []
237237 device = prompts [0 ].device
238238 prompt_dtype = prompts [0 ].dtype
@@ -266,7 +266,7 @@ def batched_generate_fn(
266266 max_prefill_length = model .kv_cache_max_prefill_length ()
267267 if max_prefill_length is None :
268268 max_prefill_length = min_prompt_size
269- token_pos = min ([ min_prompt_size , max_prefill_length ] )
269+ token_pos = min (min_prompt_size , max_prefill_length )
270270 start = 0
271271 while True :
272272 inputs = torch .cat (
@@ -275,7 +275,7 @@ def batched_generate_fn(
275275 )
276276 # We may need the last time slice of `all_logits` below:
277277 all_logits = model (inputs , input_pos = start )
278- if token_pos = = min_prompt_size :
278+ if token_pos > = min_prompt_size :
279279 break
280280 start = token_pos
281281 # Note that `max_tokens_forward` can change during the course of
You can’t perform that action at this time.
0 commit comments