diff --git a/example_xla.py b/example_xla.py index dd111e58b..4cbb4f5d4 100644 --- a/example_xla.py +++ b/example_xla.py @@ -81,12 +81,14 @@ def main( tokenizer_path: str, temperature: float = 0.8, top_p: float = 0.95, - max_seq_len: int = 512, + max_seq_len: int = 2048, max_batch_size: int = 32, ckpt_dir: str = '', dim: int = 4096, n_layers: int = 32, n_heads: int = 32, + prompt_len: int = 6, + max_gen_len: int = 256, ): rank, world_size = setup_model_parallel() if rank > 0: @@ -96,9 +98,10 @@ def main( ckpt_dir, tokenizer_path, rank, world_size, max_seq_len, max_batch_size, dim, n_layers, n_heads ) - prompts = [ + prompts = [generator.tokenizer.decode([8]*prompt_len) for _ in range(max_batch_size)] + # prompts = [ # For these prompts, the expected answer is the natural continuation of the prompt - "I believe the meaning of life is", + # "I believe the meaning of life is", # "Simply put, the theory of relativity states that ", # "Building a website can be done in 10 simple steps:\n", # Few shot prompts: https://huggingface.co/blog/few-shot-learning-gpt-neo-and-inference-api @@ -122,11 +125,11 @@ def main( #plush girafe => girafe peluche # #cheese =>""", - ] + # ] for _ in range(2): with torch.no_grad(): results = generator.generate( - prompts, max_gen_len=256, temperature=temperature, top_p=top_p + prompts, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p ) for result in results: @@ -139,31 +142,35 @@ def _fn( tokenizer_path: str, temperature: float = 0.8, top_p: float = 0.95, - max_seq_len: int = 512, + max_seq_len: int = 2048, max_batch_size: int = 32, ckpt_dir: str = '', dim: int = 4096, n_layers: int = 32, n_heads: int = 32, + prompt_len: int = 6, + max_gen_len: int = 256, ): - main(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads) + main(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads, prompt_len, max_gen_len) def mp_main( mp: bool, tokenizer_path: str, temperature: float = 0.8, top_p: float = 0.95, - max_seq_len: int = 512, + max_seq_len: int = 2048, max_batch_size: int = 32, ckpt_dir: str = '', dim: int = 4096, n_layers: int = 32, n_heads: int = 32, + prompt_len: int = 6, + max_gen_len: int = 256, ): if mp: - xmp.spawn(_fn, args=(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads)) + xmp.spawn(_fn, args=(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads, prompt_len, max_gen_len)) else: - main(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads) + main(tokenizer_path, temperature, top_p, max_seq_len, max_batch_size, ckpt_dir, dim, n_layers, n_heads, prompt_len, max_gen_len) if __name__ == "__main__": diff --git a/llama/generation.py b/llama/generation.py index b0ea81e2d..925f82223 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -20,11 +20,12 @@ def __init__(self, model: Transformer, tokenizer: Tokenizer): backend="torchxla_trace_once", fullgraph=True) def _generate_one_token(self, tokens, input_tokens, input_text_mask, cur_pos_tensor, - input_pos_tensor, output_pos_tensor, cache_kvs, temperature, top_p): + input_pos_tensor, output_pos_tensor, cache_kvs, + temperature_tensor, top_p_tensor, with_temp): logits, cache_kvs = self.model(input_tokens, input_pos_tensor, output_pos_tensor, cache_kvs) - if temperature > 0: - probs = torch.softmax(logits / temperature, dim=-1) - next_token = sample_top_p(probs, top_p) + if with_temp: + probs = torch.softmax(logits / temperature_tensor, dim=-1) + next_token = sample_top_p(probs, top_p_tensor) else: next_token = torch.argmax(logits, dim=-1) next_token = next_token.reshape(-1) @@ -58,33 +59,65 @@ def generate( prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] - total_len = params.max_seq_len + min_prompt_size = min([len(t) for t in prompt_tokens]) + max_prompt_size = max([len(t) for t in prompt_tokens]) + assert min_prompt_size >= 1 and max_prompt_size < params.max_seq_len - tokens = torch.full((params.max_batch_size, total_len), self.tokenizer.pad_id).long() + total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) + + tokens = torch.full((params.max_batch_size, params.max_seq_len), self.tokenizer.pad_id).long() for k, t in enumerate(prompt_tokens): tokens[k, : len(t)] = torch.tensor(t).long() device = xm.xla_device() tokens = tokens.to(device) input_text_mask = tokens != self.tokenizer.pad_id - start_pos = 1 - cur_pos_tensor = torch.tensor(start_pos).to(device) - input_pos_tensor = torch.arange(0, start_pos).to(device) - output_pos_tensor = cur_pos_tensor - 1 - input_tokens = tokens.index_select(1, input_pos_tensor) + # Passing tensors instead of floats into self._generate_one_token_fn, + # so that different values would not trigger compilations of new graphs + temperature_tensor = torch.tensor(float(temperature)).to(device) + top_p_tensor = torch.tensor(float(top_p)).to(device) + with_temp = temperature > 0 + cache_kvs = self.model.cache_kvs - xm.mark_step(wait=True) + xm.mark_step() decoding_start_time = time.time() - for _ in range(start_pos, total_len): + prev_pos = 0 + scale_factor = 8 + while prev_pos < min_prompt_size: + section_len = 1 + while prev_pos + section_len * scale_factor <= min_prompt_size: + section_len *= scale_factor + cur_pos = prev_pos + section_len + print(f"Processing prompt pos [{prev_pos}, {cur_pos}), section length {section_len}") + cur_pos_tensor = torch.tensor(cur_pos).to(device) + input_pos_tensor = torch.arange(prev_pos, cur_pos).to(device) + output_pos_tensor = cur_pos_tensor - 1 + input_tokens = tokens.index_select(1, input_pos_tensor) + xm.mark_step() + + tokens, input_tokens, cur_pos_tensor, input_pos_tensor, output_pos_tensor, cache_kvs \ + = self._generate_one_token_fn( + tokens, input_tokens, input_text_mask, cur_pos_tensor, + input_pos_tensor, output_pos_tensor, cache_kvs, + temperature_tensor, top_p_tensor, with_temp + ) + xm.mark_step() + + prev_pos = cur_pos + + assert cur_pos_tensor.item() == prev_pos + 1 + for _ in range(prev_pos + 1, total_len): tokens, input_tokens, cur_pos_tensor, input_pos_tensor, output_pos_tensor, cache_kvs \ = self._generate_one_token_fn( tokens, input_tokens, input_text_mask, cur_pos_tensor, - input_pos_tensor, output_pos_tensor, cache_kvs, temperature, top_p + input_pos_tensor, output_pos_tensor, cache_kvs, + temperature_tensor, top_p_tensor, with_temp ) xm.mark_step() self.model.cache_kvs = cache_kvs - print(f"Decoded in {time.time() - decoding_start_time:.5f} seconds") + print(f"Processed prompts with {min_prompt_size} to {max_prompt_size} tokens, and generated {total_len - max_prompt_size} tokens") + print(f"Totally decoded {total_len - 1} tokens in {time.time() - decoding_start_time:.5f} seconds") decoded = [] for i, t in enumerate(tokens.tolist()): @@ -109,7 +142,7 @@ def generate( def sample_top_p(probs, p): probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) probs_sum = torch.cumsum(probs_sort, dim=-1) - mask = probs_sum - probs_sort > p + mask = (probs_sum - probs_sort) > p probs_sort = torch.where(mask, 0.0, probs_sort) probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) next_token = torch.multinomial(probs_sort, num_samples=1)