Skip to content

Allow user to specify length of prompt #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 29 commits into
base: stable
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
01408fb
Added argument for prompt generation of fixed token length
AlexWertheim May 8, 2023
e2591a7
Commented out old prompt
AlexWertheim May 8, 2023
ef0452e
Cast Tuple to List when creating prompt string
AlexWertheim May 8, 2023
b721b83
Fix list cast, correct size
AlexWertheim May 8, 2023
8b27dfd
Make `max_gen_len` an exposed parameter
AlexWertheim May 10, 2023
10d9c0b
Reintroduced max_prompt_size
AlexWertheim May 10, 2023
984fb68
Modified how prompts is generated
AlexWertheim May 10, 2023
12a2c53
bucketize_prompt_len
Liyang90 May 11, 2023
6a7c6f1
update
Liyang90 May 11, 2023
f58733f
update
Liyang90 May 11, 2023
c54802b
update
Liyang90 May 12, 2023
ee3a349
tmp test
Liyang90 May 12, 2023
7b28736
tmp test
Liyang90 May 12, 2023
4029650
clean up
Liyang90 May 12, 2023
74e120c
adjust scale factor
Liyang90 May 12, 2023
f61383e
Merge pull request #15 from pytorch-tpu/liyanglu/bucketized_prompt_len
Liyang90 May 12, 2023
94f19e9
turn temperature and top_p into tensors
Liyang90 May 19, 2023
d0cc999
tmp test
Liyang90 May 19, 2023
644c88b
tmp update
Liyang90 May 19, 2023
a50045c
update
Liyang90 May 19, 2023
a185dda
tmp experiment
Liyang90 May 19, 2023
799ed7d
update
Liyang90 May 19, 2023
36f17d8
tmp test
Liyang90 May 19, 2023
7220c02
update
Liyang90 May 19, 2023
ddb7a5e
recover tmp changes
Liyang90 May 19, 2023
8ab9f48
add comment
Liyang90 May 19, 2023
d2fb888
Merge pull request #22 from pytorch-tpu/liyanglu/tensorfy_temp_top_p
Liyang90 May 19, 2023
516351f
minor update
Liyang90 May 19, 2023
9222169
Merge pull request #23 from pytorch-tpu/liyanglu/tensorfy_temp_top_p
Liyang90 May 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions example_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,14 @@ def main(
tokenizer_path: str,
temperature: float = 0.8,
top_p: float = 0.95,
max_seq_len: int = 512,
max_seq_len: int = 2048,
max_batch_size: int = 32,
ckpt_dir: str = '',
dim: int = 4096,
n_layers: int = 32,
n_heads: int = 32,
prompt_len: int = 6,
max_gen_len: int = 256,
):
rank, world_size = setup_model_parallel()
if rank > 0:
Expand All @@ -96,9 +98,10 @@ def main(
ckpt_dir, tokenizer_path, rank, world_size, max_seq_len, max_batch_size, dim, n_layers, n_heads
)

prompts = [
prompts = [generator.tokenizer.decode([8]*prompt_len) for _ in range(max_batch_size)]
# prompts = [
# For these prompts, the expected answer is the natural continuation of the prompt
"I believe the meaning of life is",
# "I believe the meaning of life is",
# "Simply put, the theory of relativity states that ",
# "Building a website can be done in 10 simple steps:\n",
# Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api
Expand All @@ -122,11 +125,11 @@ def main(
#plush girafe => girafe peluche
#
#cheese =>""",
]
# ]
for _ in range(2):
with torch.no_grad():
results = generator.generate(
prompts, max_gen_len=256, temperature=temperature, top_p=top_p
prompts, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p
)

for result in results:
Expand All @@ -139,31 +142,35 @@ def _fn(
tokenizer_path: str,
temperature: float = 0.8,
top_p: float = 0.95,
max_seq_len: int = 512,
max_seq_len: int = 2048,
max_batch_size: int = 32,
ckpt_dir: str = '',
dim: int = 4096,
n_layers: int = 32,
n_heads: int = 32,
prompt_len: int = 6,
max_gen_len: int = 256,
):
main(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads)
main(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads, prompt_len, max_gen_len)

def mp_main(
mp: bool,
tokenizer_path: str,
temperature: float = 0.8,
top_p: float = 0.95,
max_seq_len: int = 512,
max_seq_len: int = 2048,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a comment to define max_seq_len, prompt_len, max_gen_len to clarify for the user in plain English?

max_batch_size: int = 32,
ckpt_dir: str = '',
dim: int = 4096,
n_layers: int = 32,
n_heads: int = 32,
prompt_len: int = 6,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know why the default is 6? Also, I believe that prompt_len has something to do with max_batch_size: https://github.com/pytorch-tpu/llama/blob/stable/llama/generation.py#L57.

So I'm not sure how this could work with bs=1... On the other hand, I'm not sure if this is even the right solution.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I set the default to 6 because I wanted the default to be the same number of input tokens as our previous prompt "I believe the meaning of life is", only I miscounted (that has 7 words), and there's not a 1-1 mapping between words and tokens necessarily, so it's still wrong regardless. We can change the default if the current one is not right - any suggestions on alternatives?

If I'm reading the code right, max_batch_size is related to the total number of prompts, not the number of tokens in each prompt. The goal of the prompt_len is to allow the user to specify a variable number of input tokens in a single prompt. (It is reasonable to point out that this does not currently support multiple prompts.) I'm not sure I understand your comments about bs = 1 or about whether this is the right solution. (Whether what is the right solution?) Could you please clarify?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, now I get it. But from the discussion in gchat, do you still need this approach?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you mean our discussion about max_seq_len and max_gen_len, this is something quite separate, right? This PR allows the user to modify the length of the input prompt. max_seq_len controls the size allocated for the output (and in our repo, the total number of tokens generated), and max_gen_len controls the number of tokens displayed. I think we still need this for the user to modify input - please let me know if you had some other discussion in mind.

max_gen_len: int = 256,
):
if mp:
xmp.spawn(_fn, args=(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads))
xmp.spawn(_fn, args=(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads, prompt_len, max_gen_len))
else:
main(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads)
main(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads, prompt_len, max_gen_len)


if __name__ == "__main__":
Expand Down
65 changes: 49 additions & 16 deletions llama/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ def __init__(self, model: Transformer, tokenizer: Tokenizer):
backend="torchxla_trace_once", fullgraph=True)

def _generate_one_token(self, tokens, input_tokens, input_text_mask, cur_pos_tensor,
input_pos_tensor, output_pos_tensor, cache_kvs, temperature, top_p):
input_pos_tensor, output_pos_tensor, cache_kvs,
temperature_tensor, top_p_tensor, with_temp):
logits, cache_kvs = self.model(input_tokens, input_pos_tensor, output_pos_tensor, cache_kvs)
if temperature > 0:
probs = torch.softmax(logits / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
if with_temp:
probs = torch.softmax(logits / temperature_tensor, dim=-1)
next_token = sample_top_p(probs, top_p_tensor)
else:
next_token = torch.argmax(logits, dim=-1)
next_token = next_token.reshape(-1)
Expand Down Expand Up @@ -58,33 +59,65 @@ def generate(

prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]

total_len = params.max_seq_len
min_prompt_size = min([len(t) for t in prompt_tokens])
max_prompt_size = max([len(t) for t in prompt_tokens])
assert min_prompt_size >= 1 and max_prompt_size < params.max_seq_len

tokens = torch.full((params.max_batch_size, total_len), self.tokenizer.pad_id).long()
total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)

tokens = torch.full((params.max_batch_size, params.max_seq_len), self.tokenizer.pad_id).long()
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t).long()
device = xm.xla_device()
tokens = tokens.to(device)
input_text_mask = tokens != self.tokenizer.pad_id

start_pos = 1
cur_pos_tensor = torch.tensor(start_pos).to(device)
input_pos_tensor = torch.arange(0, start_pos).to(device)
output_pos_tensor = cur_pos_tensor - 1
input_tokens = tokens.index_select(1, input_pos_tensor)
# Passing tensors instead of floats into self._generate_one_token_fn,
# so that different values would not trigger compilations of new graphs
temperature_tensor = torch.tensor(float(temperature)).to(device)
top_p_tensor = torch.tensor(float(top_p)).to(device)
with_temp = temperature > 0

cache_kvs = self.model.cache_kvs
xm.mark_step(wait=True)
xm.mark_step()

decoding_start_time = time.time()
for _ in range(start_pos, total_len):
prev_pos = 0
scale_factor = 8
while prev_pos < min_prompt_size:
section_len = 1
while prev_pos + section_len * scale_factor <= min_prompt_size:
section_len *= scale_factor
cur_pos = prev_pos + section_len
print(f"Processing prompt pos [{prev_pos}, {cur_pos}), section length {section_len}")
cur_pos_tensor = torch.tensor(cur_pos).to(device)
input_pos_tensor = torch.arange(prev_pos, cur_pos).to(device)
output_pos_tensor = cur_pos_tensor - 1
input_tokens = tokens.index_select(1, input_pos_tensor)
xm.mark_step()

tokens, input_tokens, cur_pos_tensor, input_pos_tensor, output_pos_tensor, cache_kvs \
= self._generate_one_token_fn(
tokens, input_tokens, input_text_mask, cur_pos_tensor,
input_pos_tensor, output_pos_tensor, cache_kvs,
temperature_tensor, top_p_tensor, with_temp
)
xm.mark_step()

prev_pos = cur_pos

assert cur_pos_tensor.item() == prev_pos + 1
for _ in range(prev_pos + 1, total_len):
tokens, input_tokens, cur_pos_tensor, input_pos_tensor, output_pos_tensor, cache_kvs \
= self._generate_one_token_fn(
tokens, input_tokens, input_text_mask, cur_pos_tensor,
input_pos_tensor, output_pos_tensor, cache_kvs, temperature, top_p
input_pos_tensor, output_pos_tensor, cache_kvs,
temperature_tensor, top_p_tensor, with_temp
)
xm.mark_step()
self.model.cache_kvs = cache_kvs
print(f"Decoded in {time.time() - decoding_start_time:.5f} seconds")
print(f"Processed prompts with {min_prompt_size} to {max_prompt_size} tokens, and generated {total_len - max_prompt_size} tokens")
print(f"Totally decoded {total_len - 1} tokens in {time.time() - decoding_start_time:.5f} seconds")

decoded = []
for i, t in enumerate(tokens.tolist()):
Expand All @@ -109,7 +142,7 @@ def generate(
def sample_top_p(probs, p):
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
mask = (probs_sum - probs_sort) > p
probs_sort = torch.where(mask, 0.0, probs_sort)
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
Expand Down