-
Notifications
You must be signed in to change notification settings - Fork 11
Allow user to specify length of prompt #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: stable
Are you sure you want to change the base?
Changes from all commits
01408fb
e2591a7
ef0452e
b721b83
8b27dfd
10d9c0b
984fb68
12a2c53
6a7c6f1
f58733f
c54802b
ee3a349
7b28736
4029650
74e120c
f61383e
94f19e9
d0cc999
644c88b
a50045c
a185dda
799ed7d
36f17d8
7220c02
ddb7a5e
8ab9f48
d2fb888
516351f
9222169
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,12 +81,14 @@ def main( | |
tokenizer_path: str, | ||
temperature: float = 0.8, | ||
top_p: float = 0.95, | ||
max_seq_len: int = 512, | ||
max_seq_len: int = 2048, | ||
max_batch_size: int = 32, | ||
ckpt_dir: str = '', | ||
dim: int = 4096, | ||
n_layers: int = 32, | ||
n_heads: int = 32, | ||
prompt_len: int = 6, | ||
max_gen_len: int = 256, | ||
): | ||
rank, world_size = setup_model_parallel() | ||
if rank > 0: | ||
|
@@ -96,9 +98,10 @@ def main( | |
ckpt_dir, tokenizer_path, rank, world_size, max_seq_len, max_batch_size, dim, n_layers, n_heads | ||
) | ||
|
||
prompts = [ | ||
prompts = [generator.tokenizer.decode([8]*prompt_len) for _ in range(max_batch_size)] | ||
# prompts = [ | ||
# For these prompts, the expected answer is the natural continuation of the prompt | ||
"I believe the meaning of life is", | ||
# "I believe the meaning of life is", | ||
# "Simply put, the theory of relativity states that ", | ||
# "Building a website can be done in 10 simple steps:\n", | ||
# Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api | ||
|
@@ -122,11 +125,11 @@ def main( | |
#plush girafe => girafe peluche | ||
# | ||
#cheese =>""", | ||
] | ||
# ] | ||
for _ in range(2): | ||
with torch.no_grad(): | ||
results = generator.generate( | ||
prompts, max_gen_len=256, temperature=temperature, top_p=top_p | ||
prompts, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p | ||
) | ||
|
||
for result in results: | ||
|
@@ -139,31 +142,35 @@ def _fn( | |
tokenizer_path: str, | ||
temperature: float = 0.8, | ||
top_p: float = 0.95, | ||
max_seq_len: int = 512, | ||
max_seq_len: int = 2048, | ||
max_batch_size: int = 32, | ||
ckpt_dir: str = '', | ||
dim: int = 4096, | ||
n_layers: int = 32, | ||
n_heads: int = 32, | ||
prompt_len: int = 6, | ||
max_gen_len: int = 256, | ||
): | ||
main(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads) | ||
main(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads, prompt_len, max_gen_len) | ||
|
||
def mp_main( | ||
mp: bool, | ||
tokenizer_path: str, | ||
temperature: float = 0.8, | ||
top_p: float = 0.95, | ||
max_seq_len: int = 512, | ||
max_seq_len: int = 2048, | ||
max_batch_size: int = 32, | ||
ckpt_dir: str = '', | ||
dim: int = 4096, | ||
n_layers: int = 32, | ||
n_heads: int = 32, | ||
prompt_len: int = 6, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we know why the default is 6? Also, I believe that prompt_len has something to do with max_batch_size: https://github.com/pytorch-tpu/llama/blob/stable/llama/generation.py#L57. So I'm not sure how this could work with bs=1... On the other hand, I'm not sure if this is even the right solution. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I set the default to 6 because I wanted the default to be the same number of input tokens as our previous prompt "I believe the meaning of life is", only I miscounted (that has 7 words), and there's not a 1-1 mapping between words and tokens necessarily, so it's still wrong regardless. We can change the default if the current one is not right - any suggestions on alternatives? If I'm reading the code right, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, now I get it. But from the discussion in gchat, do you still need this approach? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you mean our discussion about |
||
max_gen_len: int = 256, | ||
): | ||
if mp: | ||
xmp.spawn(_fn, args=(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads)) | ||
xmp.spawn(_fn, args=(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads, prompt_len, max_gen_len)) | ||
else: | ||
main(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads) | ||
main(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads, prompt_len, max_gen_len) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add a comment to define
max_seq_len
,prompt_len
,max_gen_len
to clarify for the user in plain English?