diff --git a/llama/generation.py b/llama/generation.py index 77c87ba17..78bbeda35 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -60,7 +60,8 @@ def generate( min_prompt_size = min([len(t) for t in prompt_tokens]) max_prompt_size = max([len(t) for t in prompt_tokens]) - assert min_prompt_size >= 1 and max_prompt_size < params.max_seq_len + assert min_prompt_size >= 1, f"Prompt size must be >= 1" + assert max_prompt_size < params.max_seq_len, f"Prompt size {max_prompt_size} exceeds max sequence length of {params.max_seq_len}" total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)