diff --git a/example_text_completion.py b/example_text_completion.py index fc084b019..da89d443e 100755 --- a/example_text_completion.py +++ b/example_text_completion.py @@ -25,6 +25,7 @@ def main( max_gen_len: int = 64, max_batch_size: int = 4, dynamo: bool = True, + spmd: bool = True, ): if not USE_CUDA: server = xp.start_server(9012, only_on_master=False) @@ -34,6 +35,7 @@ def main( max_seq_len=max_seq_len, max_batch_size=max_batch_size, dynamo=dynamo, + spmd=spmd, ) prompts = [ @@ -77,12 +79,13 @@ def _fn( max_gen_len: int = 64, max_batch_size: int = 4, dynamo: bool = True, + spmd: bool = True, ): if USE_CUDA: os.environ['WORLD_SIZE'] = torch.cuda.device_count() os.environ['RANK'] = idx os.environ['LOCAL_RANK'] = idx - main(ckpt_dir, tokenizer_path, temperature, top_p, max_seq_len, max_gen_len, max_batch_size, dynamo) + main(ckpt_dir, tokenizer_path, temperature, top_p, max_seq_len, max_gen_len, max_batch_size, dynamo, spmd) def mp_main( @@ -95,6 +98,7 @@ def mp_main( max_gen_len: int = 64, max_batch_size: int = 4, dynamo: bool = True, + spmd: bool = True, ): if mp: if USE_CUDA: @@ -103,9 +107,9 @@ def mp_main( else: kwargs = {} xmp.spawn(_fn, - args=(ckpt_dir, tokenizer_path, temperature, top_p, max_seq_len, max_gen_len, max_batch_size, dynamo), **kwargs) + args=(ckpt_dir, tokenizer_path, temperature, top_p, max_seq_len, max_gen_len, max_batch_size, dynamo, spmd), **kwargs) else: - main(ckpt_dir, tokenizer_path, temperature, top_p, max_seq_len, max_gen_len, max_batch_size, dynamo) + main(ckpt_dir, tokenizer_path, temperature, top_p, max_seq_len, max_gen_len, max_batch_size, dynamo, spmd) if __name__ == "__main__": diff --git a/llama/generation.py b/llama/generation.py index 043f188c7..fadd4fa44 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -20,6 +20,9 @@ # Some how xla init will slow down the CUDA speed. if not USE_CUDA: import torch_xla.core.xla_model as xm + import torch_xla.experimental.xla_sharding as xs + from torch_xla import runtime as xr + import numpy as np Role = Literal["system", "user", "assistant"] @@ -60,6 +63,7 @@ def build( max_batch_size: int, model_parallel_size: Optional[int] = None, dynamo: bool = True, + spmd: bool = True, ) -> "Llama": # if not model_parallel_is_initialized(): # if model_parallel_size is None: @@ -118,14 +122,47 @@ def build( model = model.to(device) print(f"Loaded in {time.time() - start_time:.2f} seconds") - return Llama(model, tokenizer, device, dynamo) + return Llama(model, tokenizer, device, dynamo, spmd) - def __init__(self, model: Transformer, tokenizer: Tokenizer, device: torch.device, dynamo: bool = True): + def __init__(self, model: Transformer, tokenizer: Tokenizer, device: torch.device, dynamo: bool = True, spmd: bool = True): self.model = model self.tokenizer = tokenizer self.device = device - self._generate_one_token_fn = self._generate_one_token + + if spmd: + num_devices = xr.global_runtime_device_count() # updated way to get device count + device_ids = np.arange(num_devices) + x_dim = 2 # hard-coded for v5 + yz_dim = 4 # hard-coded for v5 + + # manually shard the kv cache + four_d_mesh = xs.Mesh(device_ids, (1, 1, x_dim, yz_dim)) + for layer in model.layers: + xs.mark_sharding(layer.attention.cache_k, four_d_mesh, (0, 1, 2, None)) + xs.mark_sharding(layer.attention.cache_v, four_d_mesh, (0, 1, 2, None)) + + col_mesh = xs.Mesh(device_ids, (1, num_devices)) + row_mesh = xs.Mesh(device_ids, (num_devices, 1)) + two_d_mesh = xs.Mesh(device_ids, (x_dim, yz_dim)) + two_d_mesh_transpose = xs.Mesh(device_ids, (yz_dim, x_dim)) + + for name, layer in model.named_modules(): + if 'tok_embeddings' in name: + xs.mark_sharding(layer.weight, row_mesh, (0, 1)) + if 'attention.' in name: + if 'wo' in name: + xs.mark_sharding(layer.weight, two_d_mesh_transpose, (0, 1)) + else: + xs.mark_sharding(layer.weight, two_d_mesh, (0, 1)) + if 'feed_forward.' in name: + if 'w2' in name: + xs.mark_sharding(layer.weight, two_d_mesh_transpose, (0, 1)) + else: + xs.mark_sharding(layer.weight, two_d_mesh, (0, 1)) + if 'output' in name: + xs.mark_sharding(layer.weight, col_mesh, (0, 1)) + if dynamo: if USE_CUDA: # Inductor errors out when compiles _generate_one_token_fn.