diff --git a/.gitignore b/.gitignore index bc6a38d79..4a8414e5a 100644 --- a/.gitignore +++ b/.gitignore @@ -169,7 +169,12 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +# XCode **/*.xcodeproj/* + +# Aider .aider* exo/tinychat/images/*.png +.vscode/ +build/ diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index 1020fdbc3..1b07ad19e 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -81,6 +81,9 @@ def generate_completion( }], } + if DEBUG >= 3: + print(f"completion: {completion}") + if not stream: completion["usage"] = { "prompt_tokens": len(tokenizer.encode(prompt)), diff --git a/exo/helpers.py b/exo/helpers.py index ff0205f00..634ffed2c 100644 --- a/exo/helpers.py +++ b/exo/helpers.py @@ -42,7 +42,6 @@ def get_system_info(): return "Linux" return "Non-Mac, non-Linux system" - def find_available_port(host: str = "", min_port: int = 49152, max_port: int = 65535) -> int: used_ports_file = os.path.join(tempfile.gettempdir(), "exo_used_ports") diff --git a/exo/inference/inference_engine.py b/exo/inference/inference_engine.py index b62000371..3c86cc681 100644 --- a/exo/inference/inference_engine.py +++ b/exo/inference/inference_engine.py @@ -55,6 +55,7 @@ async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inferen "mlx": "MLXDynamicShardInferenceEngine", "tinygrad": "TinygradDynamicShardInferenceEngine", "dummy": "DummyInferenceEngine", + "torch": "TorchDynamicShardInferenceEngine" } @@ -71,6 +72,10 @@ def get_inference_engine(inference_engine_name: str, shard_downloader: ShardDown tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0")) return TinygradDynamicShardInferenceEngine(shard_downloader) + elif inference_engine_name == "torch": + from exo.inference.torch.sharded_inference_engine import TorchDynamicShardInferenceEngine + + return TorchDynamicShardInferenceEngine(shard_downloader) elif inference_engine_name == "dummy": from exo.inference.dummy_inference_engine import DummyInferenceEngine return DummyInferenceEngine() diff --git a/exo/inference/torch/.gitignore b/exo/inference/torch/.gitignore new file mode 100644 index 000000000..6d76c24de --- /dev/null +++ b/exo/inference/torch/.gitignore @@ -0,0 +1,2 @@ +data/ +model/archive/ diff --git a/exo/inference/torch/README.md b/exo/inference/torch/README.md new file mode 100644 index 000000000..9cc6af50d --- /dev/null +++ b/exo/inference/torch/README.md @@ -0,0 +1,103 @@ +# PyTorch inference engine + +## Devs +- [Vincent Castro](https://x.com/t0kenl1mit) + +## ENV Vars +```bash +# Use the original max position embeddings amount, if present in the rope_scaling - default is False +TORCH_USE_ORG_SEQ = True or False + +# Use cache - default is True +TORCH_USE_CACHE = True or False +``` + +## Notes/Issues +### 10/10/2024 +- To select a pytorch device via environment variables, set the variable TORCH_DEVICE + - XLA is currently not installed and will need to be added to inference.py, looking into doing this on a TPU VM + - With pytorch, CUDA and ROCm are the same so specifying CUDA also enables ROCm support. See this [post](https://github.com/pytorch/pytorch/issues/55223#issuecomment-812587373) + - Looking into adding mobile device support properly +- If device is not CPU the data type defaults to float32 else float16. + +### 10/13/2024 +Still working on split model development (see test_split_model.py). Right now, it seems to do it but still transformers is loading more in the RAM and GPU as it loads up a larger models (causing an OOM). Will research and add to next update. Right now, tests are added and are in development. + +### 10/21/2024 +Working on removing transformers due to inference and VRAM usage [issues](https://github.com/exo-explore/exo/pull/139#issuecomment-2424953962). Creating a pure pytorch implementation of llama3 as using transformers wont work for exo. Using some code from meta but also implementing the use of torchtune. + +### 10/27/2024 +Still working on llama3 model but wanted to note that a better KVCache needs to be investigated. + +#### 11/17/2024 +Llama sharded model now working and next step is inference engine. Still testing on small llama 3.2 1B but will try larger models. + +### 01/16/2024 +Torchtune has replaced huggingface transformers except for the tokenizer. Inferencing on Meta 3.2 1B seems okay but my GPU runs into a wall quickly. Will be trying on a bigger VM and LAN server to split up model. + +## Tech +```bash +# Laptop/PC +Distributor ID: Ubuntu +Description: Ubuntu 24.04.1 LTS +Release: 24.04 +Codename: noble +CUDA Version: 12.4 +Nvidia Driver Version: 550.107.02 + +CPU: 11th Gen Intel® Core™ i7-11800H × 16 +RAM: 16GB +GPU 1: Nvidia GeForce RTX 3060 6GB Laptop +``` +```bash +# Server +Distributor ID: Pop +Description: Pop!_OS 22.04 LTS +Release: 22.04 +Codename: jammy +CUDA Version: 12.4 +Nvidia Driver Version: 550.90.07 + +GPU 1: NVIDIA T1000 8GB +GPU 2: NVIDIA Quadro M2000 4GB +GPU 3: NVIDIA Quadro M2000 4GB +GPU 4: NVIDIA Quadro P400 2GB +GPU 5: NVIDIA Quadro P400 2GB +``` + +## Current Model + +WIP pytorch llama model + +``` +# Llama-3.2-1B-Instruct # + +ShardedLlamaModel( + (model): ShardTransformerDecoder( + (tok_embeddings): Embedding(128256, 2048) + (layers): ModuleList( + (0-15): 16 x TransformerSelfAttentionLayer( + (attn): MultiHeadAttention( + (q_proj): Linear(in_features=2048, out_features=2048, bias=False) + (k_proj): Linear(in_features=2048, out_features=512, bias=False) + (v_proj): Linear(in_features=2048, out_features=512, bias=False) + (output_proj): Linear(in_features=2048, out_features=2048, bias=False) + (pos_embeddings): Llama3ScaledRoPE() + ) + (mlp): FeedForward( + (w1): Linear(in_features=2048, out_features=8192, bias=False) + (w2): Linear(in_features=8192, out_features=2048, bias=False) + (w3): Linear(in_features=2048, out_features=8192, bias=False) + (activation): SiLU() + ) + (sa_norm): RMSNorm() + (mlp_norm): RMSNorm() + (sa_scale): Identity() + (mlp_scale): Identity() + ) + ) + (norm): RMSNorm() + ) +) + +``` diff --git a/exo/inference/torch/__init__.py b/exo/inference/torch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/exo/inference/torch/models/__init__.py b/exo/inference/torch/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/exo/inference/torch/models/general_mha.py b/exo/inference/torch/models/general_mha.py new file mode 100644 index 000000000..8a9be51a1 --- /dev/null +++ b/exo/inference/torch/models/general_mha.py @@ -0,0 +1,252 @@ +""" +GeneralMHA class +Return transformer model with MHA +""" +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torchtune.modules as ttm + +from torchtune.modules import RMSNorm +from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE +from torchtune.models.qwen2._positional_embeddings import Qwen2RotaryPositionalEmbeddings +from torchtune.modules import RotaryPositionalEmbeddings +from exo.inference.shard import Shard +from exo.inference.torch.models.llm_utils import ( + layer_mlp, + ShardTransformerDecoder +) + +from exo.helpers import DEBUG + +def GeneralMHA( + config: dict, + shard: Shard +): + use_tied = False + attn_bias = config.get("attn_bias", False) + output_bias = config.get("attn_bias", False) + + if "llama" in shard.model_id or "Llama" in shard.model_id: + # rope scaling config + rope = Llama3ScaledRoPE( + dim=config["head_dim"], + max_seq_len=config["max_seq_len"], + base=config["rope_base"], + scale_factor=config["rope_scaling_factor"], + ) + + # tied needed for 3.2 llama models + if "3.2" in shard.model_id: + use_tied = True + elif "qwen" in shard.model_id or "Qwen" in shard.model_id: + # rope scaling config + rope = Qwen2RotaryPositionalEmbeddings( + dim=config["head_dim"], + max_seq_len=config["max_seq_len"], + base=config["rope_base"] + ) + attn_bias = True + output_bias = False + + # tied needed for 0.5B qwen models + if "0.5B" in shard.model_id or "0.5b" in shard.model_id: + use_tied = True + else: + rope = RotaryPositionalEmbeddings( + dim=config["head_dim"], + max_seq_len=config["max_seq_len"], + base=config["rope_base"] + ) + + if DEBUG >= 4: + print(f"model_id: {shard.model_id}") + print(f"rope: {rope}") + print(f"attn_bias: {attn_bias}") + print(f"output_bias: {output_bias}") + print(f"use_tied: {use_tied}") + + # hack to align sharded weights with layers + # fill unused layer positions with None + layers = [None for _ in range(shard.n_layers)] + + # build layers + for i in range(shard.start_layer, shard.end_layer + 1): + self_attn = ttm.MultiHeadAttention( + embed_dim=config["embed_dim"], + num_heads=config["num_heads"], + num_kv_heads=config["num_kv_heads"], + head_dim=config["head_dim"], + q_proj=nn.Linear( + config["embed_dim"], + config["num_heads"]*config["head_dim"], + bias=attn_bias, + ), + k_proj=nn.Linear( + config["embed_dim"], + config["num_kv_heads"]*config["head_dim"], + bias=attn_bias, + ), + v_proj=nn.Linear( + config["embed_dim"], + config["num_kv_heads"]*config["head_dim"], + bias=attn_bias, + ), + output_proj=nn.Linear( + config["embed_dim"], + config["embed_dim"], + bias=output_bias, + ), + max_seq_len=config["max_seq_len"], + attn_dropout=config["attn_dropout"], + pos_embeddings=rope, + ) + + mlp = layer_mlp( + dim=config["embed_dim"], + hidden_dim=config["intermediate_dim"], + ) + + layer = ttm.TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(config["embed_dim"], eps=config["norm_eps"]), + mlp_norm=RMSNorm(config["embed_dim"], eps=config["norm_eps"]), + ) + + layers[i] = layer + + layers = nn.ModuleList(layers) + + tok_embeddings = nn.Embedding(config["vocab_size"], config["embed_dim"]) + if use_tied: + output_proj = ttm.TiedLinear(tok_embeddings) + else: + output_proj = nn.Linear(config["embed_dim"], config["vocab_size"], bias=False) + + norm = RMSNorm(config["embed_dim"], eps=config["norm_eps"]) + + return ShardTransformerDecoder( + tok_embeddings=tok_embeddings, + shard=shard, + layers=layers, + max_seq_len=config["max_seq_len"], + num_heads=config["num_heads"], + head_dim=config["head_dim"], + norm=norm, + output=output_proj, + num_layers=config["num_layers"], + ) + +class ShardedGeneralModel(nn.Module): + def __init__( + self, + config: dict, + shard: Shard, + device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float16, + use_cache: Optional[bool] = False, + max_generated_tokens: int = 1024, + ): + super(ShardedGeneralModel, self).__init__() + + self.shard = shard + self.config = config + self.dtype = dtype + self.device = device if device is not None else torch.device("cpu") + self.max_seq_len = self.config["max_seq_len"] + self.use_cache = use_cache + + self.model = GeneralMHA( + config, + self.shard + ).to( + dtype=self.dtype, + device=self.device + ) + + if DEBUG >= 4: + print("ShardedGeneralModel called") + print(f"self.model {self.model}") + + # keep track of current position in generation + self.max_generated_tokens = max_generated_tokens + + def generate( + self, + tokens: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + hidden_state: Optional[torch.Tensor] = None, + curr_pos: Optional[int] = 0 + ) -> Tuple[ + Optional[torch.Tensor], + torch.Tensor, + ]: + """ + Generate logits and/or hidden_states from llama model + + Args + tokens (torch.Tensor) - tokens from prompt tokenization and generation + hidden_state (torch.Tensor, optional) - hidden state from last activated hidden layer, if any + """ + if DEBUG >= 4: + print("generate called") + print(f"tokens: {tokens}") + if mask is not None: + print(f"mask: {mask.size()}") + print(f"input_pos: {input_pos.size()}") + print(f"hidden_state: {hidden_state}") + print(f"curr_pos: {curr_pos}") + print(f"cached? {self.model.caches_are_enabled()}") + + model_hs = None + model_logits = None + + self.model.output_hidden_states = [self.shard.end_layer] + + if curr_pos > 0: + if self.model.caches_are_enabled(): + input_pos = input_pos[:, curr_pos].contiguous() + mask = mask[:, curr_pos, None, :].contiguous() + else: + input_pos = input_pos[:, :curr_pos + 1] + mask = mask[:, :curr_pos + 1, :curr_pos + 1] + else: + _, tklng = tokens.size() + + if self.model.caches_are_enabled(): + mask = mask[:, :tklng] + else: + mask = mask[:, :tklng, :tklng] + + input_pos = input_pos[:, :tklng].squeeze() + + if DEBUG >= 4: + print("model_input") + if tokens is not None: + print(f"tokens: {tokens}") + if hidden_state is not None: + print(f"hidden_state: {hidden_state}") + print(f"mask: {mask}") + print(f"input_pos: {input_pos}") + + + model_output = self.model( + tokens=tokens, + mask=mask, + input_pos=input_pos, + hidden_state=hidden_state, + dtype=self.dtype + ) + + if self.shard.is_last_layer(): + model_logits = model_output + else: + model_hs = model_output + + if DEBUG >= 4: + print(f"model_hs\n{model_hs}\nmodel_logits\n{model_logits}") + + return model_hs, model_logits diff --git a/exo/inference/torch/models/llm_utils.py b/exo/inference/torch/models/llm_utils.py new file mode 100644 index 000000000..0d3212c16 --- /dev/null +++ b/exo/inference/torch/models/llm_utils.py @@ -0,0 +1,520 @@ +""" +Utility methods used by LLMs +""" +import os +import re +import json +from pathlib import Path +from typing import Any, Dict, Optional, Union, List, Callable + +import torch +import torch.nn as nn +from torchtune.modules.attention_utils import _MaskType +from torchtune.modules import FeedForward, TransformerDecoder +from torchtune.models.convert_weights import hf_to_tune + +from safetensors.torch import load_file as load_safetensors + +from exo.helpers import DEBUG +from exo.inference.shard import Shard + +# dtype string to dtype from huggingface type config.json +HF_PRECISION_STR_TO_DTYPE: Dict[str, torch.dtype] = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + "float64": torch.float64, +} + + +def load_model_config(model_config_path: Path) -> dict: + """ + Loads the config.json of the model + + Args: + model_path (Path): local path to model config json + + Returns: + dict: The config as a dictionary + """ + model_config = {} + with open(model_config_path, "r") as f: + base_config = json.load(f) + + model_config = { + "rope_scaling": base_config.get("rope_scaling"), + "embed_dim": base_config["hidden_size"], + "num_heads": base_config["num_attention_heads"], + "head_dim": base_config.get( + "head_dim", + base_config["hidden_size"] // base_config["num_attention_heads"], + ), # Assuming embed_dim = hidden_size + "num_kv_heads": base_config["num_key_value_heads"], + "max_seq_len": base_config["max_position_embeddings"], + "intermediate_dim": base_config["intermediate_size"], + "attn_dropout": base_config.get("attention_dropout", 0.0), + "norm_eps": base_config["rms_norm_eps"], + "rope_base": base_config["rope_theta"], + "vocab_size": base_config["vocab_size"], + "num_layers": base_config["num_hidden_layers"], + "attn_bias": base_config.get("attention_bias", False), + "hidden_act": base_config.get("hidden_act", "silu"), + "torch_dtype": HF_PRECISION_STR_TO_DTYPE.get( + base_config.get("torch_dtype", "float16"), + torch.float16 + ) + } + + if model_config.get("rope_scaling", None) is not None: + model_config["rope_scaling_factor"] = model_config["rope_scaling"].get("rope_factor", 32) + + use_org_seq = bool(os.getenv("TORCH_USE_ORG_SEQ", "False").lower() == "true") + if use_org_seq and model_config.get("rope_scaling", None) is not None: + model_config["max_seq_len"] = model_config["rope_scaling"]["original_max_position_embeddings"] + + + + return model_config + + +def check_weights(model, state_dict): + """ + Verifies that the weights from the state dictionary are properly loaded into the model. + """ + model_state_dict = model.state_dict() + for name, param in model_state_dict.items(): + if name in state_dict: + loaded_param = state_dict[name] + if param.shape != loaded_param.shape: + print(f"Shape mismatch for {name}: expected {param.shape}, got {loaded_param.shape}") + else: + print(f"{name}: loaded correctly") + + for name in state_dict: + if name not in model_state_dict: + print(f"Unexpected weight {name} found in state_dict") + +def load_weights_torch(cache_dir: Path, model: Any, config: Dict): + # Load weights from safetensors files in the cache directory + safetensors_files = list(cache_dir.glob("*.safetensors")) + if not safetensors_files: + raise FileNotFoundError("No safetensors files found in the cache directory.") + + # Load weights from each found safetensors file + full_state_dict = {} + for safetensor_file in safetensors_files: + state_dict = load_safetensors(safetensor_file) + + if full_state_dict is not None: + full_state_dict = full_state_dict | state_dict + else: + full_state_dict = state_dict + + converted_sd = hf_to_tune( + state_dict=full_state_dict, + num_heads=config["num_heads"], + num_kv_heads=config["num_kv_heads"], + dim=config["embed_dim"], + head_dim=config["head_dim"] + ) + + model.load_state_dict(converted_sd, strict=True) + + print("\n--- checking weights ----\n") + check_weights(model, converted_sd) + +def _permute(t, n_heads: int, head_dim: int, dim: int): + """ + Reshape weight for torchtune + """ + return ( + t.view(n_heads, 2, head_dim // 2, dim) + .transpose(1, 2) + .reshape((head_dim * n_heads), dim) + ) + +def load_model_weights_torchtune( + cache_dir: Path, + shard: Shard, + model: Any, + num_heads: int = 32, + num_kv_heads: int = 32, + dim: int = 4096, + head_dim: int = None +): + """ + Loads weights from huggingface and changes it to match torchtune naming structure + """ + if head_dim is None: + head_dim = dim // num_heads + + model_state_dict = model.state_dict() + if DEBUG >= 8: + for name, _ in model_state_dict.items(): + print(f"name {name}") + # Load weights from safetensors files in the cache directory + safetensors_files = list(cache_dir.glob("*.safetensors")) + if not safetensors_files: + raise FileNotFoundError("No safetensors files found in the cache directory.") + + # Load weights from each found safetensors file + + full_state_dict = {} + for safetensor_file in safetensors_files: + state_dict = load_safetensors(safetensor_file) + + if full_state_dict is not None: + full_state_dict = full_state_dict | state_dict + else: + full_state_dict = state_dict + + # remap to work with our model + remapped_state_dict = {} + + is_llama = True if "llama" in shard.model_id or "Llama" in shard.model_id else False + + if DEBUG >= 8 and is_llama: + print("loading llama type weights") + elif DEBUG >= 8 and not is_llama: + print("loading weights") + + for key, value in full_state_dict.items(): + # load layer by shard + for layer_num in range(shard.start_layer, shard.end_layer + 1): + # change input layer norm to sa_norm for torchtune + re_iln = re.findall(rf"model.layers\.{layer_num}\.(input_layernorm)\.weight", key) + if len(re_iln) != 0: + new_key = f"model.layers.{layer_num}.sa_norm.scale" + remapped_state_dict[new_key] = value + if DEBUG >= 8: + print(f"{key} == {new_key}") + + # change post attention layernorm to mlp_norm for torchtune + re_pal = re.findall(rf"model.layers\.{layer_num}\.(post_attention_layernorm)\.weight", key) + if len(re_pal) != 0: + new_key = f"model.layers.{layer_num}.mlp_norm.scale" + remapped_state_dict[new_key] = value + if DEBUG >= 8: + print(f"{key} == {new_key}") + + # change self_attn to attn + # along with changing o_proj to output_proj + re_attn = re.findall(rf"model\.layers\.{layer_num}.(\w+)\.(\w+)\.(\w+)", key) + if len(re_attn) != 0 and re_attn[0][0] == "self_attn": + if re_attn[0][1] == "k_proj" and is_llama: + value = _permute( + t=value, + n_heads=num_kv_heads, + head_dim=head_dim, + dim=dim + ) + + new_key = f"model.layers.{layer_num}.attn.{re_attn[0][1]}.{re_attn[0][2]}" + remapped_state_dict[new_key] = value + elif re_attn[0][1] == "q_proj" and is_llama: + value = _permute( + t=value, + n_heads=num_heads, + head_dim=head_dim, + dim=dim + ) + new_key = f"model.layers.{layer_num}.attn.{re_attn[0][1]}.{re_attn[0][2]}" + remapped_state_dict[new_key] = value + + elif re_attn[0][1] == "o_proj": + new_key = f"model.layers.{layer_num}.attn.output_proj.weight" + remapped_state_dict[new_key] = value + else: + new_key = f"model.layers.{layer_num}.attn.{re_attn[0][1]}.{re_attn[0][2]}" + remapped_state_dict[new_key] = value + if DEBUG >= 8: + print(f"{key} == {new_key}") + + # set mlp weights + re_mlp = re.findall(rf"model\.layers\.{layer_num}.mlp.(\w+)\.(\w+)", key) + if len(re_mlp) != 0: + proj_name = re_mlp[0][0] + if proj_name == "up_proj": + proj_name = "w3" + elif proj_name == "down_proj": + proj_name = "w2" + elif proj_name == "gate_proj": + proj_name = "w1" + new_key = f"model.layers.{layer_num}.mlp.{proj_name}.weight" + remapped_state_dict[new_key] = value + if DEBUG >= 8: + print(f"{key} == {new_key}") + + # saving embed for paired weights + if key == "model.embed_tokens.weight": + remapped_state_dict["model.tok_embeddings.weight"] = value + if DEBUG >= 8: + print("model.embed_tokens.weight == model.tok_embeddings.weight") + + if key == "model.norm.weight": + remapped_state_dict["model.norm.scale"] = value + + if key == "lm_head.weight": + remapped_state_dict["model.output.weight"] = value + # saving embed for paired weights + if key == "model.embed_tokens.weight": + remapped_state_dict["model.tok_embeddings.weight"] = value + if DEBUG >= 8: + print("model.embed_tokens.weight == model.tok_embeddings.weight") + + if key == "model.norm.weight": + remapped_state_dict["model.norm.scale"] = value + + if key == "lm_head.weight": + remapped_state_dict["model.output.weight"] = value + + if not remapped_state_dict: + model.load_state_dict(full_state_dict, strict=True) + else: + if DEBUG >= 8: + print("\nRemapped state dict\n") + for rsdk in remapped_state_dict.keys(): + print(f"-- {rsdk}") + + # load new weight map + model.load_state_dict(remapped_state_dict, strict=False) + + if DEBUG >= 8: + print("\n--- checking weights ----\n") + check_weights(model, remapped_state_dict) + +class ShardTransformerDecoder(TransformerDecoder): + """ + ShardTransformerDecorder + Custom version of torchtune TransformerDecoder to allow for + sharding of models and passing of hidden layers between shards + """ + def __init__( + self, + *, + shard: Shard, + tok_embeddings: nn.Embedding, + layers: Union[nn.Module, List[nn.Module], nn.ModuleList], + max_seq_len: int, + num_heads: int, + head_dim: int, + norm: nn.Module, + output: Union[nn.Linear, Callable], + num_layers: Optional[int] = None, + output_hidden_states: Optional[List[int]] = None, + ): + super().__init__( + tok_embeddings=tok_embeddings, + layers=layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=head_dim, + norm=norm, + output=output, + num_layers=num_layers, + output_hidden_states=output_hidden_states, + ) + + self.shard = shard + + def setup_caches( + self, + batch_size: int, + dtype: torch.dtype, + *, + encoder_max_seq_len: Optional[int] = None, + decoder_max_seq_len: Optional[int] = None, + ): + """ + modified version for shard + + assume just decoder layers + """ + if decoder_max_seq_len is not None: + self.decoder_max_cache_seq_len = decoder_max_seq_len + else: + self.decoder_max_cache_seq_len = self.max_seq_len + + for layer in self.layers: + if layer is not None: + layer.setup_caches( + batch_size, + dtype, + encoder_max_seq_len=self.encoder_max_cache_seq_len, + decoder_max_seq_len=self.decoder_max_cache_seq_len, + ) + + def caches_are_enabled(self) -> bool: + """ + modified version for shard + """ + if self.layers[0] is not None: + return self.layers[0].caches_are_enabled() + else: + for layer in self.layers: + if layer is not None: + return layer.caches_are_enabled() + + return False + + def reset_caches(self): + torch.cuda.empty_cache() + + for layer in self.layers: + if layer is not None: + layer.reset_cache() + + def forward( + self, + tokens: torch.Tensor, + *, + mask: Optional[_MaskType] = None, + input_pos: Optional[torch.Tensor] = None, + hidden_state: Optional[torch.Tensor] = None, + dtype: torch.dtype = torch.float16 + ) -> Union[torch.Tensor, List[torch.Tensor]]: + # Determine the type of input and shape + if DEBUG >= 4: + print("forward called") + if tokens is not None: + print(f"tokens [{tokens.shape}]: {tokens}") + print(f"mask: {mask}") + print(f"input_pos: {input_pos}") + + if hidden_state is not None: + print(f"hidden_state [{hidden_state.shape}]: {hidden_state}") + + if hidden_state is not None: + h = hidden_state # Use directly as hidden states + else: + seq_len = tokens.shape[1] + + self._validate_inputs( + seq_len, + mask=mask, + input_pos=input_pos, + ) + + fl_tokens = tokens.clone() + h = self.tok_embeddings(fl_tokens).to(dtype=dtype) # Apply token tok_embeddings + + # Initialize a list to capture hidden states if requested + # for captured hidden states + hidden = [] + curr_layers = [self.layers[i] for i in range(self.shard.start_layer, self.shard.end_layer + 1)] + for i, layer in enumerate(curr_layers): + if DEBUG >= 8: + print(f"\nhidden layer in H[{self.shard.start_layer+i}]\n{h}") + print(f"\nmask\n{mask}\ninput_pos\n{input_pos}") + print(f"\noutput_hidden_states\n{self.output_hidden_states}\n") + + if i in self.output_hidden_states: + hidden.append(h) + + # Process through each transformer layer + h = layer( + h, + mask=mask, + input_pos=input_pos, + ) + + if DEBUG >= 8: + print(f"\nhidden layer out H[{self.shard.start_layer+i}]->H[{self.shard.start_layer+i+1}]\n{h}\n") + + if self.shard.is_last_layer(): + # Apply normalization + h = self.norm(h) + + # Handle chunked output if needed + output = self.output(h).float() + + if DEBUG >= 4: + print(f"\n\noutput {output}\n\n") + + return output + else: + if DEBUG >= 4: + print(f"\n\nhidden output {hidden[-1]}\n\n") + + return hidden[-1] + +class MultiLayerPreceptron(nn.Module): + def __init__(self, input_dim, hidden_dim, activation="silu", use_bias=False): + """ + General MLP (Multi-Layer Perceptron) module. + + Args: + input_dim (int): Dimensionality of the input. + hidden_dims (int): Hidden layer/intermediate dimensions. + output_dim (int): Dimensionality of the output. + activation (str): Activation function ('relu', 'gelu', 'tanh', 'sigmoid', etc.). + use_bias (bool): Use bias with linearization + """ + super(MultiLayerPreceptron, self).__init__() + + # Activation function mapping + activations = {"relu": nn.ReLU(), "gelu": nn.GELU(), "tanh": nn.Tanh(), "sigmoid": nn.Sigmoid(), "leaky_relu": nn.LeakyReLU(0.2), "silu": nn.SiLU()} + + # Ensure valid activation + if activation not in activations: + raise ValueError(f"Invalid activation: {activation}. Choose from {list(activations.keys())}") + + # Construct MLP layers + self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=use_bias) + self.up_proj = nn.Linear(input_dim, hidden_dim, bias=use_bias) + self.down_proj = nn.Linear(hidden_dim, input_dim, bias=use_bias) + self.act_fn = activations[activation] + + def forward(self, x) -> torch.Tensor: + return self.down_proj(self.act_fn(self.gate_proj(x))*self.up_proj(x)) + +class ShardInferenceState: + def __init__( + self, + tokens: Optional[torch.tensor] = None, + input_pos: Optional[torch.tensor] = None, + mask: Optional[torch.tensor] = None, + curr_pos: int = 0, + device: torch.device = torch.device("cpu") + ): + self.tokens = tokens + self.input_pos = input_pos + self.mask = mask + self.curr_pos = curr_pos + self.device = device + + def from_dict(self, state_dict): + """ + Data is stored as torch tensors on needed devices + """ + self.tokens = torch.tensor(state_dict["tokens"]).to(self.device) + self.input_pos = torch.tensor(state_dict["input_pos"]).to(self.device) + self.mask = torch.tensor(state_dict["mask"]).to(self.device) + self.curr_pos = state_dict["curr_pos"] + + def to_dict(self) -> dict: + return { + "tokens": self.tokens.numpy(force=True).tolist(), + "input_pos": self.input_pos.numpy(force=True).tolist(), + "mask": self.mask.numpy(force=True).tolist(), + "curr_pos": self.curr_pos + } + + def __str__(self) -> str: + return f""" + tokens: {self.tokens} + input_pos: {self.input_pos} + mask: {self.mask} + curr_pos: {self.curr_pos} + """ + +def layer_mlp(dim: int, hidden_dim: int) -> FeedForward: + """ + Generalized MLP layer + Ref: https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama3_1/_component_builders.py#L124 + Ref: https://github.com/pytorch/torchtune/blob/main/torchtune/models/qwen2/_component_builders.py#L127C1-L134C82 + """ + gate_proj = nn.Linear(dim, hidden_dim, bias=False) + down_proj = nn.Linear(hidden_dim, dim, bias=False) + up_proj = nn.Linear(dim, hidden_dim, bias=False) + return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj) \ No newline at end of file diff --git a/exo/inference/torch/sharded_inference_engine.py b/exo/inference/torch/sharded_inference_engine.py new file mode 100644 index 000000000..a21e2de77 --- /dev/null +++ b/exo/inference/torch/sharded_inference_engine.py @@ -0,0 +1,425 @@ +""" +TorchDynamicShardInferenceEngine +Sharded inference engine using PyTorch based torchtune models +""" + +import os +import functools +from concurrent.futures import ThreadPoolExecutor +import asyncio +import uuid +import re +from typing import Optional + +import numpy as np +import torch +import torchtune.generation as ttg +from transformers import AutoTokenizer + +from exo.inference.inference_engine import InferenceEngine +from exo.download.shard_download import ShardDownloader +from exo.inference.shard import Shard +from exo.inference.tokenizers import _resolve_tokenizer +from exo.helpers import DEBUG +from exo.inference.torch.models.llm_utils import ( + load_model_config, + load_model_weights_torchtune, + ShardInferenceState +) + +from exo.inference.torch.models.general_mha import ShardedGeneralModel + +# from torchtune generate recipe +# https://github.com/pytorch/torchtune/blob/main/recipes/configs/generation.yaml#L40 +TEMP = 0.6 +TOP_K = 35 + +class TorchDynamicShardInferenceEngine(InferenceEngine): + """ + Pytorch based inferece engine for sharded models + """ + def __init__(self, shard_downloader: ShardDownloader): + self.shard = None + self.shard_downloader = shard_downloader + self.sharded_model = None + self.request_id = None + self.executor = ThreadPoolExecutor(max_workers=1) + self.uuid = str(uuid.uuid4()) + self.model_path = None + self.model_config = None + self.state = None + self.oom_cnt = 0 + + # cache settings + self.use_cache = bool(os.getenv("TORCH_USE_CACHE", "True").lower() == "true") + self.cache_setup = False + + # device settings + if os.environ.get("TORCH_DEVICE"): + self.device = torch.device(os.environ["TORCH_DEVICE"]) + elif torch.cuda.is_available(): + self.device = torch.device("cuda") + elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): + self.device = torch.device("mps") + else: + self.device = torch.device("cpu") + + # rng setup for sampling + self.rng = torch.Generator(device=self.device) + self.rng.manual_seed(1234) + + def setup_cache(self, batch_size: int=1, total_response_length: int=1024): + # setup cache + # this is needed for a primary node that gets the initial encoding + if not self.sharded_model.model.caches_are_enabled() and self.use_cache: + with self.device: + self.sharded_model.model.setup_caches( + batch_size, + self.model_config["torch_dtype"], + decoder_max_seq_len=total_response_length + ) + + self.cache_setup = True + + + def clear_model(self): + """ + Clear out model and shard + A way to avoid OOM issues + + All prompts are stored in VRAM + while inference engine is up and using the same + model class instance, this will clear it for each prompt. + + OOM issue might occur in longer chats/contexts depending on your machine. + """ + if self.sharded_model.model.caches_are_enabled(): + self.sharded_model.model.reset_caches() + + del self.sharded_model + self.sharded_model = None + + if self.device == torch.device("cuda"): + torch.cuda.empty_cache() + + self.shard = None + self.state = None + + async def encode(self, shard: Shard, prompt: str) -> np.ndarray: + if DEBUG >= 4: + print("encode called") + print(f"shard: {shard}") + print(f"prompt: {prompt}") + + if self.sharded_model is not None: + print("CLEARING SHARD AND MODEL - ENCODING") + self.clear_model() + + await self.ensure_shard(shard) + + def encode_wrapper() -> np.ndarray: + """ + Encode the tensors from prompt along with the + initial input_pos and mask + """ + tokens = self.tokenizer.encode( + prompt, + return_tensors="pt" + ) + + # move to proper device, default is CPU + if tokens.device != self.device: + tokens = tokens.to(device=self.device) + + if DEBUG >= 4: + print("encoded_wrapper called") + print(f"tokens: {tokens}") + + # if going past max, just take from max onward + if len(tokens) > self.sharded_model.max_generated_tokens: + max_gen_tokens = self.sharded_model.max_generated_tokens + tokens = tokens[-max_gen_tokens:] + + self.state.tokens = tokens + + bsz, tklng = tokens.size() + total_response_length = tklng + self.sharded_model.max_generated_tokens + + self.setup_cache(bsz, total_response_length) + + # setup max sequence length + if not self.sharded_model.model.caches_are_enabled(): + max_seq_len = total_response_length + else: + max_seq_len = self.sharded_model.model.decoder_max_cache_seq_len + + # set pad_id + if hasattr(self.tokenizer, "pad_id"): + pad_id = self.tokenizer.pad_id + elif hasattr(self.tokenizer, "pad_token_id"): + print(f"pad_token_id: {self.tokenizer.pad_token_id}") + if self.tokenizer.pad_token_id is not None: + pad_id = self.tokenizer.pad_token_id + else: + pad_id = 0 + else: + pad_id = 0 + + padding_masks = tokens != pad_id + if not padding_masks.all(): + padding_masks = torch.nn.functional.pad( + padding_masks, + (0, self.sharded_model.max_generated_tokens), + value=True, + ) + + self.state.mask = ttg.get_causal_mask_from_padding_mask(padding_masks, target_seq_len=max_seq_len) + + self.state.input_pos = ttg.get_position_ids_from_padding_mask(padding_masks) + else: + self.state.mask = torch.tril(torch.ones( + total_response_length, + max_seq_len, + dtype=torch.bool, + device=self.device, + )).unsqueeze(0) + + self.state.input_pos = torch.arange(0, total_response_length, device=self.device).unsqueeze(0) + + return tokens + + return await asyncio.get_running_loop().run_in_executor( + self.executor, + functools.partial(encode_wrapper), + ) + + async def decode(self, shard: Shard, tokens: np.ndarray) -> str: + if DEBUG >= 4: + print("decode called") + print(f"shard: {shard}") + print(f"tokens: {tokens}") + + await self.ensure_shard(shard) + + return await asyncio.get_running_loop().run_in_executor( + self.executor, + functools.partial(self.tokenizer.decode, tokens.tolist()), + ) + + async def sample(self, x: np.ndarray, temp=TEMP, top_k=TOP_K) -> np.ndarray: + if DEBUG >= 4: + print("sample called") + print(f"x: {x}") + print(f"temp: {temp}") + print(f"top_k: {top_k}") + print(self.device) + + logits = torch.tensor(x).to(self.device) + + def sample_wrapper(): + q = torch.empty((logits.size(0), self.sharded_model.model.tok_embeddings.num_embeddings), device=logits.device).exponential_(1, generator=self.rng) + + tokens = ttg.sample(logits.clone(), temperature=temp, top_k=top_k, q=q.to(self.device)) + + if DEBUG >= 4: + print(f"tokens: {tokens}") + + return tokens.numpy(force=True) + + return await asyncio.get_running_loop().run_in_executor(self.executor, functools.partial(sample_wrapper)) + + async def infer_tensor( + self, + request_id: str, + shard: Shard, + input_data: np.ndarray, + inference_state: Optional[dict] = None + ) -> tuple[np.ndarray, Optional[dict]]: + + await self.ensure_shard(shard) + + # ensure shard + if DEBUG >= 4: + print("infer_tensor called") + print(f"shard: {shard}") + print(f"input_data: {input_data}") + + if inference_state.get("tokens") is not None: + self.state.from_dict(inference_state) + + self.request_id = request_id if not self.request_id else self.request_id + + hidden_state = None + input_tensor = None + if input_data.ndim == 3: + hidden_state = torch.tensor(input_data).to( + device=self.device, + dtype=self.model_config["torch_dtype"] + ) + elif input_data.ndim == 2: + input_tensor = torch.tensor(input_data).to( + device=self.device + ) + + if self.use_cache and not self.cache_setup: + if input_tensor is not None: + bsz, tklng = input_tensor.size() + self.setup_cache( + bsz, + tklng + self.sharded_model.max_generated_tokens + ) + else: + bsz, tklng = self.state.tokens.size() + self.setup_cache( + bsz, + tklng + self.sharded_model.max_generated_tokens + ) + + def infer_wrapper(): + if DEBUG >= 4: + print(f"infer_wrapper called [{self.oom_cnt} OOM]") + print(f"self.state.tokens: {self.state.tokens}") + print(f"hidden_state: {hidden_state}") + + model_cache = self.sharded_model.model.caches_are_enabled() + + if self.state.tokens is not None: + if input_data.ndim == 2 and input_tensor.size(-1) == 1: + self.state.tokens = torch.cat([ + self.state.tokens.to(self.device), + input_tensor.clone() + ], dim=-1).to(self.device) + else: + self.state.tokens = input_tensor.clone() + + try: + in_tokens = self.state.tokens.clone().to( + device=self.device + ) + + in_input_pos = self.state.input_pos.clone().to( + device=self.device + ) + + in_mask = self.state.mask.clone().to( + device=self.device + ) + + if hidden_state is not None: + model_hs, model_logits = self.sharded_model.generate( + tokens=in_tokens, + hidden_state=hidden_state, + input_pos=in_input_pos, + mask=in_mask, + curr_pos=self.state.curr_pos + ) + else: + if not model_cache: + model_hs, model_logits = self.sharded_model.generate( + tokens=in_tokens, + input_pos=in_input_pos, + mask=in_mask, + curr_pos=self.state.curr_pos + ) + else: + model_hs, model_logits = self.sharded_model.generate( + tokens=input_tensor, + input_pos=in_input_pos, + mask=in_mask, + curr_pos=self.state.curr_pos + ) + except torch.cuda.OutOfMemoryError: + print(f"OOM on cuda, clearing model and stopping") + self.oom_cnt += 1 + self.clear_model() + return + except Exception as err: + print(f"infer_tensor err\n{err}") + raise + + if model_hs is not None: + # numpy current no support for bf16 + if model_hs.dtype == torch.bfloat16: + model_hs = model_hs.float() + + if DEBUG >= 4: + print("sending hidden states") + print(f"model_hs: {model_hs.size()}") + print(f"state.tokens: {self.state.tokens}") + print(f"state.input_pos: {self.state.input_pos.size()}") + print(f"state.mask: {self.state.mask.size()}") + + return ( + model_hs.numpy(force=True), + self.state.to_dict(), + ) + + if self.state.curr_pos == 0: + self.state.curr_pos = self.state.tokens.size(-1) + else: + self.state.curr_pos += 1 + + # numpy current no support for bf16 + if model_logits.dtype == torch.bfloat16: + model_logits = model_logits.float() + + return ( + model_logits[:, -1].numpy(force=True), + self.state.to_dict(), + ) + + return await asyncio.get_running_loop().run_in_executor(self.executor, infer_wrapper) + + async def ensure_shard(self, shard: Shard): + if DEBUG >= 4: + print("shard ensured\n") + print(f"shard: {shard}") + print(f"class shard: {self.shard}") + print(f"uuid: {self.uuid}") + + # reset model after last layer to fix OOM + if self.shard == shard: + return + + self.shard = shard + + # Using CPU to store inference state + self.state = ShardInferenceState() + + # download model safetensors and shard + + self.model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__) + self.model_config = load_model_config(self.model_path/"config.json") + + # self.tokenizer = await _resolve_tokenizer(model_path) + self.tokenizer = await _resolve_tokenizer(self.model_path) + + def start_model(): + if DEBUG >= 4: + print("start_model called") + + self.sharded_model = ShardedGeneralModel( + config=self.model_config, + shard=shard, + device=self.device, + dtype=self.model_config["torch_dtype"], + use_cache=self.use_cache + ) + + load_model_weights_torchtune( + cache_dir=self.model_path, + shard=self.shard, + model=self.sharded_model, + num_heads=self.model_config["num_heads"], + num_kv_heads=self.model_config["num_kv_heads"], + dim=self.model_config["embed_dim"], + head_dim=self.model_config["head_dim"] + ) + + await asyncio.get_running_loop().run_in_executor( + self.executor, + functools.partial(start_model), + ) + + async def load_checkpoint(self, shard: Shard, path: str): + await self.ensure_shard(shard) diff --git a/exo/inference/torch/tests/__init__.py b/exo/inference/torch/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/exo/inference/torch/tests/test_inference_engine.py b/exo/inference/torch/tests/test_inference_engine.py new file mode 100644 index 000000000..cf1f06179 --- /dev/null +++ b/exo/inference/torch/tests/test_inference_engine.py @@ -0,0 +1,54 @@ +""" +Test inference engine and model sharding +""" +import pytest +import asyncio + +from exo.inference.shard import Shard +from exo.inference.torch.sharded_inference_engine import TorchDynamicShardInferenceEngine +from exo.download.hf.hf_shard_download import HFShardDownloader + +import numpy as np + +@pytest.mark.asyncio +async def test_inference_engine(): + prompt = "In a single word only, what is the last name of the current president of the USA?" + + shard = Shard( + model_id="llama-3.2-1b", + start_layer=0, + end_layer=8, + n_layers=16 + ) + + shard_2 = Shard( + model_id="llama-3.2-1b", + start_layer=9, + end_layer=15, + n_layers= 16 + ) + + inference_engine = TorchDynamicShardInferenceEngine(HFShardDownloader()) + + output_1 = await inference_engine.infer_prompt("test_id", shard, prompt) + print("\n------------inference_engine.infer_prompt output---------------\n") + print(output_1) + print("\n---------------------------\n") + + assert isinstance(output_1, np.ndarray), "Output should be numpy array" + + output_2 = await inference_engine.infer_tensor("test_id", shard, output_1) + print("\n------------inference_engine.infer_tensor output---------------\n") + print(output_2) + print("\n---------------------------\n") + + assert isinstance(output_2, np.ndarray), "Output should be numpy array" + +if __name__ == '__main__': + try: + print("\n\n -------- TEST llama-3.2-1b -------- \n\n") + asyncio.run(test_inference_engine()) + except Exception as err: + print(f"\n!!!! TEST FAILED \n{err}\n") + + diff --git a/exo/inference/torch/tests/test_llama3_full.py b/exo/inference/torch/tests/test_llama3_full.py new file mode 100644 index 000000000..ff1b62327 --- /dev/null +++ b/exo/inference/torch/tests/test_llama3_full.py @@ -0,0 +1,235 @@ +""" +Test of pytorch based llama3 models +full layer run +""" + +from pathlib import Path +import torch +from huggingface_hub import snapshot_download + +import torchtune.generation as ttg +from torchtune.models import llama3 +from torchtune.data import Message + +from transformers import AutoTokenizer + +from exo.inference.torch.models.general_mha import ShardedGeneralModel +from exo.inference.shard import Shard + +from exo.inference.torch.models.llm_utils import ( + load_model_config, + load_weights_torch, + load_model_weights_torchtune +) + +MODEL_NAME = "unsloth/Llama-3.2-1B-Instruct" +# MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" +TEMP = 0.85 +TOP_K = 35 +MAX_NEW_TOKENS = 200 +RAND_SEED = 42 + + +def main(model, prompt: str, device: torch.device = torch.device("cpu"), dtype: torch.dtype = torch.bfloat16): + messages = [{ + "role": "assistant", + "content": "", + }, { + "role": "user", + "content": prompt, + }] + + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + tok_out = tokenizer([text], return_tensors="pt") + print(f"tok_out: {tok_out}") + tokens = tok_out.input_ids.to(device=device, dtype=torch.int) + + rng = torch.Generator(device=device) + rng.manual_seed(RAND_SEED) + + generated_tokens = tokens.clone() + + print(f"tokens: {tokens}") + + bsz, tokens_length = tokens.size() + + # using self.max_seq_len will take up alot of VRAM + total_response_length = tokens_length + MAX_NEW_TOKENS + + # setup cache + if not model.model.caches_are_enabled(): + with device: + model.model.setup_caches( + bsz, + dtype, + decoder_max_seq_len=total_response_length + ) + + if not model.model.caches_are_enabled(): + max_seq_len = total_response_length + else: + max_seq_len = model.model.decoder_max_cache_seq_len + + # masking for proper attention + + # select correct pad_id + if hasattr(tokenizer, "pad_id"): + pad_id = tokenizer.pad_id + elif hasattr(tokenizer, "pad_token_id"): + print(f"pad_token_id: {tokenizer.pad_token_id}") + if tokenizer.pad_token_id is not None: + pad_id = tokenizer.pad_token_id + else: + pad_id = 0 + else: + pad_id = 0 + + print(f"pad_id: {pad_id}") + + padding_masks = tokens != pad_id + if not padding_masks.all(): + padding_masks = torch.nn.functional.pad( + padding_masks, + (0, MAX_NEW_TOKENS), + value=True, + ) + + mask = ttg.get_causal_mask_from_padding_mask(padding_masks, target_seq_len=max_seq_len) + + input_pos = ttg.get_position_ids_from_padding_mask(padding_masks) + else: + mask = torch.tril(torch.ones( + total_response_length, + max_seq_len, + dtype=torch.bool, + device=device, + )).unsqueeze(0) + + input_pos = torch.arange(0, total_response_length, device=device).unsqueeze(0) + + print(f"init mask: {mask}") + print(f"init input_pos: {input_pos}") + + curr_pos = 0 + + _, logits = model.generate( + tokens=tokens, + mask=mask, + input_pos=input_pos, + curr_pos=curr_pos + ) + + curr_pos = tokens_length + + q = torch.empty(( + logits.size(0), + model.model.tok_embeddings.num_embeddings + ), device=logits.device).exponential_(1, generator=rng) + + tokens = ttg.sample( + logits=logits[:, -1].clone(), + temperature=TEMP, + top_k=TOP_K, + q=q + ) + + print(f"tokens: {tokens}") + + for i in range(MAX_NEW_TOKENS - 1): + print(f"gen #{i+1}") + + if tokens.item() == tokenizer.eos_token_id: + print("stop token hit!") + break + + tokens = tokens.view(1, -1).to(device=device) if tokens.ndim == 1 else tokens + + _, logits = model.generate( + tokens=tokens, + input_pos=input_pos, + mask=mask, + curr_pos=curr_pos + ) + + curr_pos += 1 + + q = torch.empty( + ( + logits.size(0), + model.model.tok_embeddings.num_embeddings + ), device=logits.device).exponential_(1, generator=rng) + + tokens = ttg.sample( + logits=logits[:, -1].clone(), + temperature=TEMP, + top_k=TOP_K, + q=q, + ) + + print(f"tokens: {tokens}") + + generated_tokens = torch.cat([generated_tokens, tokens], dim=-1) + print(f"generated_tokens: {generated_tokens}") + + if not model.model.caches_are_enabled(): + tokens = generated_tokens.clone() + + print(f"\n\n[resp from model]\n\n{tokenizer.decode(generated_tokens.tolist()[0])}\n\n\n") + + +if __name__ == "__main__": + # prompt = "Hello, how are you?" + prompt = "Tell me a joke." + # prompt = "What is the meaning of exo?" + # prompt = "Tell me a short 4 line haiku" + # prompt = "In a single word only, what is the last name of the current president of the USA?" + + # Get the path to the model files from the Hugging Face cache + cache_dir = Path(snapshot_download(MODEL_NAME)) + print(f"Cache directory: {cache_dir}") + + # Load model configuration + config = load_model_config(cache_dir/"config.json") + + print(f"current config\n{config}") + + # Setup shard + n_layers = int(config["num_layers"]) + shard_1 = Shard( + model_id=MODEL_NAME, + start_layer=0, + end_layer=n_layers - 1, + n_layers=n_layers, + ) + + # Initialize tokenizer + tokenizer = AutoTokenizer.from_pretrained(cache_dir) + + # Initialize LlamaModel with config and tokenizer + # device = torch.device("cuda") + dtype = torch.bfloat16 + device = torch.device("cpu") + shard_model_1 = ShardedGeneralModel( + config=config, + shard=shard_1, + device=device, + dtype=config["torch_dtype"], + use_cache=True, + max_generated_tokens=MAX_NEW_TOKENS, + ) + + print(f"\nshard_model_1: {shard_model_1}") + + # load_model_weights_torchtune(cache_dir, shard_1, shard_model_1) + load_model_weights_torchtune( + cache_dir=cache_dir, + shard=shard_1, + model=shard_model_1, + num_heads=config["num_heads"], + num_kv_heads=config["num_kv_heads"], + dim=config["embed_dim"], + head_dim=config["head_dim"] + ) + + import time + main(shard_model_1, prompt, device, config["torch_dtype"]) diff --git a/exo/inference/torch/tests/test_mistral_full.py b/exo/inference/torch/tests/test_mistral_full.py new file mode 100644 index 000000000..54c0c8dee --- /dev/null +++ b/exo/inference/torch/tests/test_mistral_full.py @@ -0,0 +1,236 @@ +""" +Test of pytorch based mistral models +full layer run +""" + +from pathlib import Path +import torch +from huggingface_hub import snapshot_download + +import torchtune.generation as ttg + +from transformers import AutoTokenizer + +from exo.inference.torch.models.general_mha import ShardedGeneralModel +from exo.inference.shard import Shard + +from exo.inference.torch.models.llm_utils import ( + load_model_config, + load_model_weights_torchtune +) + +MODEL_NAME = "unsloth/mistral-7b-instruct-v0.3-bnb-4bit" +TEMP = 0.85 +TOP_K = 35 +MAX_NEW_TOKENS = 200 +RAND_SEED = 42 + + +def main( + model, + prompt: str, + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.bfloat16 +): + messages = [{ + "role": "assistant", + "content": "", + }, { + "role": "user", + "content": prompt, + }] + + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + tok_out = tokenizer([text], return_tensors="pt") + print(f"tok_out: {tok_out}") + tokens = tok_out.input_ids.to(device=device, dtype=torch.int) + + rng = torch.Generator(device=device) + rng.manual_seed(RAND_SEED) + + generated_tokens = tokens.clone() + + print(f"tokens: {tokens}") + + bsz, tokens_length = tokens.size() + + # using self.max_seq_len will take up alot of VRAM + total_response_length = tokens_length + MAX_NEW_TOKENS + + # setup cache + if not model.model.caches_are_enabled(): + with device: + model.model.setup_caches( + bsz, + dtype, + decoder_max_seq_len=total_response_length + ) + + if not model.model.caches_are_enabled(): + max_seq_len = total_response_length + else: + max_seq_len = model.model.decoder_max_cache_seq_len + + # masking for proper attention + + # select correct pad_id + if hasattr(tokenizer, "pad_id"): + pad_id = tokenizer.pad_id + elif hasattr(tokenizer, "pad_token_id"): + print(f"pad_token_id: {tokenizer.pad_token_id}") + if tokenizer.pad_token_id is not None: + pad_id = tokenizer.pad_token_id + else: + pad_id = 0 + else: + pad_id = 0 + + print(f"pad_id: {pad_id}") + + padding_masks = tokens != pad_id + if not padding_masks.all(): + padding_masks = torch.nn.functional.pad( + padding_masks, + (0, MAX_NEW_TOKENS), + value=True, + ) + + mask = ttg.get_causal_mask_from_padding_mask(padding_masks, target_seq_len=max_seq_len) + + input_pos = ttg.get_position_ids_from_padding_mask(padding_masks) + else: + mask = torch.tril(torch.ones( + total_response_length, + max_seq_len, + dtype=torch.bool, + device=device, + )).unsqueeze(0) + + input_pos = torch.arange(0, total_response_length, device=device).unsqueeze(0) + + print(f"init mask: {mask}") + print(f"init input_pos: {input_pos}") + + curr_pos = 0 + + _, logits = model.generate( + tokens=tokens, + mask=mask, + input_pos=input_pos, + curr_pos=curr_pos + ) + + curr_pos = tokens_length + + q = torch.empty(( + logits.size(0), + model.model.tok_embeddings.num_embeddings + ), device=logits.device).exponential_(1, generator=rng) + + tokens = ttg.sample( + logits=logits[:, -1].clone(), + temperature=TEMP, + top_k=TOP_K, + q=q + ) + + print(f"tokens: {tokens}") + + for i in range(MAX_NEW_TOKENS - 1): + print(f"gen #{i+1}") + + if tokens.item() == tokenizer.eos_token_id: + print("stop token hit!") + break + + tokens = tokens.view(1, -1).to(device=device) if tokens.ndim == 1 else tokens + + _, logits = model.generate( + tokens=tokens, + input_pos=input_pos, + mask=mask, + curr_pos=curr_pos + ) + + curr_pos += 1 + + q = torch.empty( + ( + logits.size(0), + model.model.tok_embeddings.num_embeddings + ), device=logits.device).exponential_(1, generator=rng) + + tokens = ttg.sample( + logits=logits[:, -1].clone(), + temperature=TEMP, + top_k=TOP_K, + q=q, + ) + + print(f"tokens: {tokens}") + + generated_tokens = torch.cat([generated_tokens, tokens], dim=-1) + print(f"generated_tokens: {generated_tokens}") + + if not model.model.caches_are_enabled(): + tokens = generated_tokens.clone() + + print(f"\n\n[resp from model]\n\n{tokenizer.decode(generated_tokens.tolist()[0])}\n\n\n") + + +if __name__ == "__main__": + # prompt = "Hello, how are you?" + prompt = "Tell me a joke." + # prompt = "What is the meaning of exo?" + # prompt = "Tell me a short 4 line haiku" + # prompt = "In a single word only, what is the last name of the current president of the USA?" + + # Get the path to the model files from the Hugging Face cache + cache_dir = Path(snapshot_download(MODEL_NAME)) + print(f"Cache directory: {cache_dir}") + + # Load model configuration + config = load_model_config(cache_dir/"config.json") + + print(f"current config\n{config}") + + # Setup shard + n_layers = int(config["num_layers"]) + shard_1 = Shard( + model_id=MODEL_NAME, + start_layer=0, + end_layer=n_layers - 1, + n_layers=n_layers, + ) + + # Initialize tokenizer + tokenizer = AutoTokenizer.from_pretrained(cache_dir) + + # Initialize LlamaModel with config and tokenizer +# device = torch.device("cuda") + dtype = torch.bfloat16 + device = torch.device("cpu") + + shard_model_1 = ShardedGeneralModel( + config=config, + shard=shard_1, + device=device, + dtype=config["torch_dtype"], + use_cache=True, + max_generated_tokens=MAX_NEW_TOKENS, + ) + + print(f"\nshard_model_1: {shard_model_1}") + + # load_model_weights_torchtune(cache_dir, shard_1, shard_model_1) + load_model_weights_torchtune( + cache_dir=cache_dir, + shard=shard_1, + model=shard_model_1, + num_heads=config["num_heads"], + num_kv_heads=config["num_kv_heads"], + dim=config["embed_dim"], + head_dim=config["head_dim"] + ) + + main(shard_model_1, prompt, device, config["torch_dtype"]) diff --git a/exo/inference/torch/tests/test_qwen_full.py b/exo/inference/torch/tests/test_qwen_full.py new file mode 100644 index 000000000..2e4f55d12 --- /dev/null +++ b/exo/inference/torch/tests/test_qwen_full.py @@ -0,0 +1,236 @@ +""" +Test of pytorch based qwen2 models +full layer run +""" + +from pathlib import Path +import torch +from huggingface_hub import snapshot_download + +import torchtune.generation as ttg + +from transformers import AutoTokenizer + +from exo.inference.torch.models.general_mha import ShardedGeneralModel +from exo.inference.shard import Shard + +from exo.inference.torch.models.llm_utils import ( + load_model_config, + load_model_weights_torchtune +) + +MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" +TEMP = 0.85 +TOP_K = 35 +MAX_NEW_TOKENS = 200 +RAND_SEED = 42 + + +def main( + model, + prompt: str, + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.bfloat16 +): + messages = [{ + "role": "assistant", + "content": "", + }, { + "role": "user", + "content": prompt, + }] + + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + tok_out = tokenizer([text], return_tensors="pt") + print(f"tok_out: {tok_out}") + tokens = tok_out.input_ids.to(device=device, dtype=torch.int) + + rng = torch.Generator(device=device) + rng.manual_seed(RAND_SEED) + + generated_tokens = tokens.clone() + + print(f"tokens: {tokens}") + + bsz, tokens_length = tokens.size() + + # using self.max_seq_len will take up alot of VRAM + total_response_length = tokens_length + MAX_NEW_TOKENS + + # setup cache + if not model.model.caches_are_enabled(): + with device: + model.model.setup_caches( + bsz, + dtype, + decoder_max_seq_len=total_response_length + ) + + if not model.model.caches_are_enabled(): + max_seq_len = total_response_length + else: + max_seq_len = model.model.decoder_max_cache_seq_len + + # masking for proper attention + + # select correct pad_id + if hasattr(tokenizer, "pad_id"): + pad_id = tokenizer.pad_id + elif hasattr(tokenizer, "pad_token_id"): + print(f"pad_token_id: {tokenizer.pad_token_id}") + if tokenizer.pad_token_id is not None: + pad_id = tokenizer.pad_token_id + else: + pad_id = 0 + else: + pad_id = 0 + + print(f"pad_id: {pad_id}") + + padding_masks = tokens != pad_id + if not padding_masks.all(): + padding_masks = torch.nn.functional.pad( + padding_masks, + (0, MAX_NEW_TOKENS), + value=True, + ) + + mask = ttg.get_causal_mask_from_padding_mask(padding_masks, target_seq_len=max_seq_len) + + input_pos = ttg.get_position_ids_from_padding_mask(padding_masks) + else: + mask = torch.tril(torch.ones( + total_response_length, + max_seq_len, + dtype=torch.bool, + device=device, + )).unsqueeze(0) + + input_pos = torch.arange(0, total_response_length, device=device).unsqueeze(0) + + print(f"init mask: {mask}") + print(f"init input_pos: {input_pos}") + + curr_pos = 0 + + _, logits = model.generate( + tokens=tokens, + mask=mask, + input_pos=input_pos, + curr_pos=curr_pos + ) + + curr_pos = tokens_length + + q = torch.empty(( + logits.size(0), + model.model.tok_embeddings.num_embeddings + ), device=logits.device).exponential_(1, generator=rng) + + tokens = ttg.sample( + logits=logits[:, -1].clone(), + temperature=TEMP, + top_k=TOP_K, + q=q + ) + + print(f"tokens: {tokens}") + + for i in range(MAX_NEW_TOKENS - 1): + print(f"gen #{i+1}") + + if tokens.item() == tokenizer.eos_token_id: + print("stop token hit!") + break + + tokens = tokens.view(1, -1).to(device=device) if tokens.ndim == 1 else tokens + + _, logits = model.generate( + tokens=tokens, + input_pos=input_pos, + mask=mask, + curr_pos=curr_pos + ) + + curr_pos += 1 + + q = torch.empty( + ( + logits.size(0), + model.model.tok_embeddings.num_embeddings + ), device=logits.device).exponential_(1, generator=rng) + + tokens = ttg.sample( + logits=logits[:, -1].clone(), + temperature=TEMP, + top_k=TOP_K, + q=q, + ) + + print(f"tokens: {tokens}") + + generated_tokens = torch.cat([generated_tokens, tokens], dim=-1) + print(f"generated_tokens: {generated_tokens}") + + if not model.model.caches_are_enabled(): + tokens = generated_tokens.clone() + + print(f"\n\n[resp from model]\n\n{tokenizer.decode(generated_tokens.tolist()[0])}\n\n\n") + + +if __name__ == "__main__": + # prompt = "Hello, how are you?" + prompt = "Tell me a joke." + # prompt = "What is the meaning of exo?" + # prompt = "Tell me a short 4 line haiku" + # prompt = "In a single word only, what is the last name of the current president of the USA?" + + # Get the path to the model files from the Hugging Face cache + cache_dir = Path(snapshot_download(MODEL_NAME)) + print(f"Cache directory: {cache_dir}") + + # Load model configuration + config = load_model_config(cache_dir/"config.json") + + print(f"current config\n{config}") + + # Setup shard + n_layers = int(config["num_layers"]) + shard_1 = Shard( + model_id=MODEL_NAME, + start_layer=0, + end_layer=n_layers - 1, + n_layers=n_layers, + ) + + # Initialize tokenizer + tokenizer = AutoTokenizer.from_pretrained(cache_dir) + + # Initialize LlamaModel with config and tokenizer +# device = torch.device("cuda") + dtype = torch.bfloat16 + device = torch.device("cpu") + + shard_model_1 = ShardedGeneralModel( + config=config, + shard=shard_1, + device=device, + dtype=config["torch_dtype"], + use_cache=True, + max_generated_tokens=MAX_NEW_TOKENS, + ) + + print(f"\nshard_model_1: {shard_model_1}") + + # load_model_weights_torchtune(cache_dir, shard_1, shard_model_1) + load_model_weights_torchtune( + cache_dir=cache_dir, + shard=shard_1, + model=shard_model_1, + num_heads=config["num_heads"], + num_kv_heads=config["num_kv_heads"], + dim=config["embed_dim"], + head_dim=config["head_dim"] + ) + + main(shard_model_1, prompt, device, config["torch_dtype"]) diff --git a/exo/main.py b/exo/main.py index 51abe858e..38094366e 100644 --- a/exo/main.py +++ b/exo/main.py @@ -81,7 +81,7 @@ def configure_uvloop(): parser.add_argument("--chatgpt-api-port", type=int, default=52415, help="ChatGPT API port") parser.add_argument("--chatgpt-api-response-timeout", type=int, default=900, help="ChatGPT API response timeout in seconds") parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max tokens to generate in each request") -parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (mlx, tinygrad, or dummy)") +parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (torch, mlx, tinygrad, or dummy)") parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI") parser.add_argument("--run-model", type=str, help="Specify a model to run directly") parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?") @@ -103,6 +103,8 @@ def configure_uvloop(): inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad") print(f"Inference engine name after selection: {inference_engine_name}") +os.environ["EXO_INFER_ENGINE"] = inference_engine_name + inference_engine = get_inference_engine(inference_engine_name, shard_downloader) print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}") @@ -149,7 +151,9 @@ def configure_uvloop(): elif args.discovery_module == "manual": if not args.discovery_config_path: raise ValueError(f"--discovery-config-path is required when using manual discovery. Please provide a path to a config json file.") - discovery = ManualDiscovery(args.discovery_config_path, args.node_id, create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities)) + discovery = ManualDiscovery( + args.discovery_config_path, args.node_id, create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities) + ) topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None node = Node( args.node_id, @@ -250,23 +254,26 @@ def on_token(_request_id, _tokens, _is_finished): finally: node.on_token.deregister(callback_id) + def clean_path(path): - """Clean and resolve path""" - if path.startswith("Optional("): - path = path.strip('Optional("').rstrip('")') - return os.path.expanduser(path) + """Clean and resolve path""" + if path.startswith("Optional("): + path = path.strip('Optional("').rstrip('")') + return os.path.expanduser(path) + async def hold_outstanding(node: Node): while node.outstanding_requests: await asyncio.sleep(.5) return + async def run_iter(node: Node, shard: Shard, train: bool, data, batch_size=1): losses = [] tokens = [] for batch in tqdm(iterate_batches(data, batch_size), total=len(data) // batch_size): _, _, lengths = batch - losses.append(np.sum(lengths * await node.enqueue_example(shard, *batch, train=train))) + losses.append(np.sum(lengths*await node.enqueue_example(shard, *batch, train=train))) tokens.append(np.sum(lengths)) total_tokens = np.sum(tokens) total_loss = np.sum(losses) / total_tokens @@ -333,7 +340,7 @@ async def main(): def restore_cursor(): if platform.system() != "Windows": - os.system("tput cnorm") # Show cursor + os.system("tput cnorm") # Show cursor # Restore the cursor when the program exits atexit.register(restore_cursor) @@ -356,8 +363,7 @@ def handle_exit(): await run_model_cli(node, model_name, args.prompt) elif args.command == "eval" or args.command == 'train': model_name = args.model_name - dataloader = lambda tok: load_dataset(args.data, preprocess=lambda item: tok(item) - , loadline=lambda line: json.loads(line).get("text","")) + dataloader = lambda tok: load_dataset(args.data, preprocess=lambda item: tok(item), loadline=lambda line: json.loads(line).get("text", "")) if args.command == 'eval': if not model_name: print("Error: Much like a human, I can't evaluate anything without a model") diff --git a/exo/models.py b/exo/models.py index b64bcca10..6b69cdab2 100644 --- a/exo/models.py +++ b/exo/models.py @@ -6,8 +6,8 @@ "llama-3.3-70b": { "layers": 80, "repo": { - "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.3-70B-Instruct-4bit", - "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.3-70B-Instruct", + "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.3-70B-Instruct-4bit", + "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.3-70B-Instruct", }, }, "llama-3.2-1b": { @@ -15,6 +15,7 @@ "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-4bit", "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct", + "TorchDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct" }, }, "llama-3.2-1b-8bit": { @@ -22,69 +23,93 @@ "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-8bit", "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct", + "TorchDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct" }, }, "llama-3.2-3b": { "layers": 28, "repo": { - "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct-4bit", - "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct", + "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct-4bit", + "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct", + "TorchDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct" }, }, "llama-3.2-3b-8bit": { "layers": 28, "repo": { - "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct-8bit", - "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct", + "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct-8bit", + "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct", + "TorchDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct", }, }, "llama-3.2-3b-bf16": { "layers": 28, "repo": { - "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct", - "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct", + "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct", + "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct", + "TorchDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct", }, }, "llama-3.1-8b": { "layers": 32, "repo": { - "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", - "TinygradDynamicShardInferenceEngine": "mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", + "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", + "TinygradDynamicShardInferenceEngine": "mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated", + "TorchDynamicShardInferenceEngine": "unsloth/Meta-Llama-3.1-8B-Instruct", }, }, "llama-3.1-70b": { "layers": 80, "repo": { - "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", - "TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct", + "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit", + "TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct", + "TorchDynamicShardInferenceEngine": "unsloth/Meta-Llama-3.1-70B-Instruct", }, }, "llama-3.1-70b-bf16": { "layers": 80, "repo": { - "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-bf16-CORRECTED", - "TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct", + "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-bf16-CORRECTED", + "TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct", + "TorchDynamicShardInferenceEngine": "unsloth/Meta-Llama-3.1-70B-Instruct", }, }, "llama-3-8b": { "layers": 32, "repo": { - "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-8B-Instruct-4bit", - "TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", + "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-8B-Instruct-4bit", + "TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", + "TorchDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R", }, }, "llama-3-70b": { "layers": 80, "repo": { - "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-70B-Instruct-4bit", - "TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", + "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-70B-Instruct-4bit", + "TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", + "TorchDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", }, }, - "llama-3.1-405b": { "layers": 126, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-405B-4bit", }, }, - "llama-3.1-405b-8bit": { "layers": 126, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", }, }, + "llama-3.1-405b": { + "layers": 126, + "repo": { + "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-405B-4bit", + "TorchDynamicShardInferenceEngine": "unsloth/Meta-Llama-3.1-405B-Instruct-bnb-4bit", + }, + }, + "llama-3.1-405b-8bit": { + "layers": 126, + "repo": {"MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit",}, + }, ### mistral - "mistral-nemo": { "layers": 40, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Mistral-Nemo-Instruct-2407-4bit", }, }, - "mistral-large": { "layers": 88, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Mistral-Large-Instruct-2407-4bit", }, }, + "mistral-nemo": { + "layers": 40, + "repo": {"MLXDynamicShardInferenceEngine": "mlx-community/Mistral-Nemo-Instruct-2407-4bit",}, + }, + "mistral-large": { + "layers": 88, + "repo": {"MLXDynamicShardInferenceEngine": "mlx-community/Mistral-Large-Instruct-2407-4bit",}, + }, ### deepseek "deepseek-coder-v2-lite": { "layers": 27, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", }, }, "deepseek-coder-v2.5": { "layers": 60, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", }, }, @@ -123,35 +148,143 @@ "deepseek-r1-distill-llama-70b-6bit": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-6bit", }, }, "deepseek-r1-distill-llama-70b-8bit": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-8bit", }, }, ### llava - "llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, }, + "llava-1.5-7b-hf": { + "layers": 32, + "repo": {"MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf",}, + }, ### qwen - "qwen-2.5-0.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-0.5B-Instruct-4bit", }, }, - "qwen-2.5-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-1.5B-Instruct-4bit", }, }, - "qwen-2.5-coder-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", }, }, - "qwen-2.5-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-3B-Instruct-4bit", }, }, - "qwen-2.5-coder-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", }, }, - "qwen-2.5-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-7B-Instruct-4bit", }, }, - "qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, }, - "qwen-2.5-math-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-7B-Instruct-4bit", }, }, - "qwen-2.5-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-14B-Instruct-4bit", }, }, - "qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, }, - "qwen-2.5-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-32B-Instruct-4bit", }, }, - "qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, }, - "qwen-2.5-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-72B-Instruct-4bit", }, }, - "qwen-2.5-math-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-72B-Instruct-4bit", }, }, + "qwen-2.5-0.5b": { + "layers": 28, + "repo": { + "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-0.5B-Instruct-4bit", + "TorchDynamicShardInferenceEngine": "Qwen/Qwen2.5-0.5B-Instruct" + }, + }, + "qwen-2.5-1.5b": { + "layers": 28, + "repo": { + "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-1.5B-Instruct-4bit", + "TorchDynamicShardInferenceEngine": "Qwen/Qwen2.5-1.5B-Instruct" + }, + }, + "qwen-2.5-coder-1.5b": { + "layers": 28, + "repo": { + "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", + "TorchDynamicShardInferenceEngine": "Qwen/Qwen2.5-Coder-1.5B-Instruct" + }, + }, + "qwen-2.5-3b": { + "layers": 36, + "repo": { + "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-3B-Instruct-4bit", + "TorchDynamicShardInferenceEngine": "Qwen/Qwen2.5-3B-Instruct" + }, + }, + "qwen-2.5-coder-3b": { + "layers": 36, + "repo": { + "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", + "TorchDynamicShardInferenceEngine": "Qwen/Qwen2.5-Coder-3B-Instruct" + }, + }, + "qwen-2.5-7b": { + "layers": 28, + "repo": { + "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-7B-Instruct-4bit", + "TorchDynamicShardInferenceEngine": "Qwen/Qwen2.5-7B-Instruct" + }, + }, + "qwen-2.5-coder-7b": { + "layers": 28, + "repo": { + "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", + "TorchDynamicShardInferenceEngine": "Qwen/Qwen2.5-Coder-7B-Instruct" + }, + }, + "qwen-2.5-math-7b": { + "layers": 28, + "repo": { + "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-7B-Instruct-4bit", + "TorchDynamicShardInferenceEngine": "Qwen/Qwen2.5-Math-7B-Instruct" + }, + }, + "qwen-2.5-14b": { + "layers": 48, + "repo": { + "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-14B-Instruct-4bit", + "TorchDynamicShardInferenceEngine": "Qwen/Qwen2.5-14B-Instruct" + }, + }, + "qwen-2.5-coder-14b": { + "layers": 48, + "repo": { + "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", + "TorchDynamicShardInferenceEngine": "Qwen/Qwen2.5-Coder-14B-Instruct" + }, + }, + "qwen-2.5-32b": { + "layers": 64, + "repo": { + "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-32B-Instruct-4bit", + "TorchDynamicShardInferenceEngine": "Qwen/Qwen2.5-32B-Instruct" + }, + }, + "qwen-2.5-coder-32b": { + "layers": 64, + "repo": { + "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", + "TorchDynamicShardInferenceEngine": "Qwen/Qwen2.5-Coder-32B-Instruct" + }, + }, + "qwen-2.5-72b": { + "layers": 80, + "repo": { + "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-72B-Instruct-4bit", + "TorchDynamicShardInferenceEngine": "Qwen/Qwen2.5-72B-Instruct" + }, + }, + "qwen-2.5-math-72b": { + "layers": 80, + "repo": { + "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-72B-Instruct-4bit", + "TorchDynamicShardInferenceEngine": "Qwen/Qwen2.5-Math-72B-Instruct" + }, + }, ### nemotron - "nemotron-70b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/nvidia_Llama-3.1-Nemotron-70B-Instruct-HF_4bit", }, }, - "nemotron-70b-bf16": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16", }, }, + "nemotron-70b": { + "layers": 80, + "repo": {"MLXDynamicShardInferenceEngine": "mlx-community/nvidia_Llama-3.1-Nemotron-70B-Instruct-HF_4bit",}, + }, + "nemotron-70b-bf16": { + "layers": 80, + "repo": {"MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16",}, + }, # gemma - "gemma2-9b": { "layers": 42, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-9b-it-4bit", }, }, - "gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, }, + "gemma2-9b": { + "layers": 42, + "repo": {"MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-9b-it-4bit",}, + }, + "gemma2-27b": { + "layers": 46, + "repo": {"MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit",}, + }, # stable diffusion - "stable-diffusion-2-1-base": { "layers": 31, "repo": { "MLXDynamicShardInferenceEngine": "stabilityai/stable-diffusion-2-1-base" } }, + "stable-diffusion-2-1-base": {"layers": 31, "repo": {"MLXDynamicShardInferenceEngine": "stabilityai/stable-diffusion-2-1-base"}}, # phi - "phi-3.5-mini": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Phi-3.5-mini-instruct-4bit", }, }, - "phi-4": { "layers": 40, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/phi-4-4bit", }, }, + "phi-3.5-mini": { + "layers": 32, + "repo": {"MLXDynamicShardInferenceEngine": "mlx-community/Phi-3.5-mini-instruct-4bit",}, + }, + "phi-4": { + "layers": 40, + "repo": {"MLXDynamicShardInferenceEngine": "mlx-community/phi-4-4bit",}, + }, # dummy - "dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, }, + "dummy": { + "layers": 8, + "repo": {"DummyInferenceEngine": "dummy",}, + }, } pretty_name = { @@ -255,19 +388,13 @@ def get_supported_models(supported_inference_engine_lists: Optional[List[List[st return list(model_cards.keys()) from exo.inference.inference_engine import inference_engine_classes - supported_inference_engine_lists = [ - [inference_engine_classes[engine] if engine in inference_engine_classes else engine for engine in engine_list] - for engine_list in supported_inference_engine_lists - ] + supported_inference_engine_lists = [[inference_engine_classes[engine] if engine in inference_engine_classes else engine for engine in engine_list] + for engine_list in supported_inference_engine_lists] def has_any_engine(model_info: dict, engine_list: List[str]) -> bool: return any(engine in model_info.get("repo", {}) for engine in engine_list) def supports_all_engine_lists(model_info: dict) -> bool: - return all(has_any_engine(model_info, engine_list) - for engine_list in supported_inference_engine_lists) + return all(has_any_engine(model_info, engine_list) for engine_list in supported_inference_engine_lists) - return [ - model_id for model_id, model_info in model_cards.items() - if supports_all_engine_lists(model_info) - ] + return [model_id for model_id, model_info in model_cards.items() if supports_all_engine_lists(model_info)] diff --git a/exo/networking/grpc/grpc_peer_handle.py b/exo/networking/grpc/grpc_peer_handle.py index 58ef7e71b..e7999637e 100644 --- a/exo/networking/grpc/grpc_peer_handle.py +++ b/exo/networking/grpc/grpc_peer_handle.py @@ -16,8 +16,10 @@ if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64": import mlx.core as mx + IS_APPLE = True else: import numpy as mx + IS_APPLE = False class GRPCPeerHandle(PeerHandle): @@ -123,6 +125,7 @@ async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: O request_id=request_id, inference_state=None if inference_state is None else self.serialize_inference_state(inference_state) ) + response = await self.stub.SendTensor(request) if not response.tensor_data or not response.shape or not response.dtype: @@ -200,11 +203,12 @@ def serialize_inference_state(self, inference_state: dict) -> node_service_pb2.I proto_inference_state = node_service_pb2.InferenceState() other_data = {} for k, v in inference_state.items(): - if isinstance(v, mx.array): + mx_array_type = mx.array if IS_APPLE else mx.ndarray + if isinstance(v, mx_array_type): np_array = np.array(v) tensor_data = node_service_pb2.Tensor(tensor_data=np_array.tobytes(), shape=list(np_array.shape), dtype=str(np_array.dtype)) proto_inference_state.tensor_data[k].CopyFrom(tensor_data) - elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v): + elif isinstance(v, list) and all(isinstance(item, mx_array_type) for item in v): tensor_list = node_service_pb2.TensorList() for tensor in v: np_array = np.array(tensor) diff --git a/exo/orchestration/node.py b/exo/orchestration/node.py index 1281aa8ae..e86563702 100644 --- a/exo/orchestration/node.py +++ b/exo/orchestration/node.py @@ -1,3 +1,4 @@ +import os import numpy as np import json import asyncio @@ -17,6 +18,7 @@ from exo.inference.inference_engine import get_inference_engine, InferenceEngine from exo.download.shard_download import ShardDownloader + class Node: def __init__( self, @@ -44,7 +46,7 @@ def __init__( self.buffered_inputs: Dict[str, List[np.ndarray]] = {} self.buffered_partials: Dict[str, List[np.ndarray]] = {} self.checkpoints: Dict[str, Dict[str, int]] = {} - + self.max_generate_tokens = max_generate_tokens self.topology_viz = topology_viz self.default_sample_temperature = default_sample_temperature @@ -101,7 +103,8 @@ def get_supported_inference_engines(self): supported_engine_names.append('mlx') supported_engine_names.append('tinygrad') else: - supported_engine_names.append('tinygrad') + supported_engine_names.append(os.environ.get("EXO_INFER_ENGINE", 'tinygrad')) + return supported_engine_names async def broadcast_supported_engines(self, supported_engines_names: List[str]): @@ -149,10 +152,9 @@ async def process_inference_result( self.outstanding_requests.pop(request_id) else: self.outstanding_requests[request_id] = "waiting" - asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1), inference_state)) - - return np.array(self.buffered_token_output[request_id][0]) if shard.model_id != 'stable-diffusion-2-1-base' else intermediate_result + asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset=1), inference_state)) + return np.array(self.buffered_token_output[request_id][0]) if shard.model_id != 'stable-diffusion-2-1-base' else intermediate_result async def process_prompt( self, @@ -219,7 +221,7 @@ async def enqueue_example( self, base_shard: Shard, example: np.ndarray, - target: np.ndarray, + target: np.ndarray, length: np.ndarray, request_id: Optional[str] = None, train: bool = False, @@ -232,7 +234,7 @@ async def enqueue_example( if request_id is None: request_id = str(uuid.uuid4()) self.outstanding_requests[request_id] = "waiting" - loss = await self.forward_example(shard, example, target, length, train, request_id, 0) + loss = await self.forward_example(shard, example, target, length, train, request_id, 0) return loss async def coordinate_save( @@ -263,7 +265,7 @@ async def process_example( self, base_shard: Shard, example: np.ndarray, - target: np.ndarray, + target: np.ndarray, length: np.ndarray, train: bool = False, request_id: Optional[str] = None, @@ -308,7 +310,7 @@ async def _process_example( self, base_shard: Shard, example: np.ndarray, - target: np.ndarray, + target: np.ndarray, length: np.ndarray, train: bool = False, request_id: Optional[str] = None, @@ -327,7 +329,7 @@ async def _process_example( self.outstanding_requests[request_id] = "preprocessing" step, _ = await self.inference_engine.infer_tensor(request_id, shard, example) self.outstanding_requests[request_id] = "waiting" - loss, backgrad = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1)) + loss, backgrad = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset=1)) self.outstanding_requests[request_id] = "training" partial_loss, grad = await self.inference_engine.train(request_id, shard, example, backgrad, length, loss="back_gradient") self.outstanding_requests.pop(request_id) @@ -343,7 +345,7 @@ async def _process_example( self.outstanding_requests[request_id] = "preprocessing" step, _ = await self.inference_engine.infer_tensor(request_id, shard, example) self.outstanding_requests[request_id] = "waiting" - loss = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1)) + loss = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset=1)) self.outstanding_requests.pop(request_id) return loss except Exception as e: @@ -351,7 +353,7 @@ async def _process_example( print(f"Error processing example for shard {shard}: {e}") traceback.print_exc() return None - + async def process_tensor( self, base_shard: Shard, @@ -380,7 +382,7 @@ async def _process_tensor( try: self.outstanding_requests[request_id] = "processing" result, inference_state = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state) - ret = await self.process_inference_result(shard, result, request_id, inference_state) + ret = await self.process_inference_result(shard, result, request_id, inference_state) return ret except Exception as e: self.outstanding_requests.pop(request_id) @@ -428,7 +430,7 @@ async def forward_prompt( raise ValueError(f"Peer for {target_index} not found") if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}") await target_peer.send_prompt(next_shard, prompt, request_id=request_id, inference_state=inference_state) - + async def forward_tensor( self, base_shard: Shard, @@ -458,7 +460,7 @@ def get_partition_index(self, offset: int = 0): current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None) if current_partition_index is None: raise ValueError(f"No current partition found for node: {self.id}") - return (current_partition_index + offset) % len(partitions) + return (current_partition_index+offset) % len(partitions) def get_current_shard(self, base_shard: Shard, index: Optional[int] = None) -> Shard: if index is None: @@ -584,7 +586,7 @@ def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]: def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None: if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} {tokens=} {is_finished=}") self.on_token.trigger_all(request_id, tokens, is_finished) - + async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None: if DEBUG >= 2: print(f"Broadcasting result: {request_id=} {result=} {is_finished=}") async def send_result_to_peer(peer): @@ -620,8 +622,8 @@ def current_topology(self) -> Topology: def handle_stable_diffusion(self, inference_state, result): if inference_state['is_step_finished']: - inference_state['step']+=1 - progress = [inference_state['step'],inference_state['total_steps']] + inference_state['step'] += 1 + progress = [inference_state['step'], inference_state['total_steps']] intermediate_result = result if progress[0] == progress[1]: intermediate_result = result diff --git a/exo/orchestration/standard_node.py b/exo/orchestration/standard_node.py new file mode 100644 index 000000000..3a612a365 --- /dev/null +++ b/exo/orchestration/standard_node.py @@ -0,0 +1,469 @@ +import numpy as np +import json +import asyncio +import uuid +import time +import traceback +from typing import List, Dict, Optional, Tuple, Union, Set +from exo.networking import Discovery, PeerHandle, Server +from exo.inference.inference_engine import InferenceEngine, Shard +from .node import Node +from exo.topology.topology import Topology +from exo.topology.device_capabilities import device_capabilities +from exo.topology.partitioning_strategy import Partition, PartitioningStrategy, map_partitions_to_shards +from exo import DEBUG +from exo.helpers import AsyncCallbackSystem +from exo.viz.topology_viz import TopologyViz +from exo.download.hf.hf_helpers import RepoProgressEvent +from exo.inference.inference_engine import get_inference_engine, InferenceEngine +from exo.download.hf.hf_shard_download import HFShardDownloader + +class StandardNode(Node): + def __init__( + self, + _id: str, + server: Server, + inference_engine: InferenceEngine, + discovery: Discovery, + partitioning_strategy: PartitioningStrategy = None, + max_generate_tokens: int = 1024, + default_sample_temperature: float = 0.0, + topology_viz: Optional[TopologyViz] = None, + shard_downloader: Optional[HFShardDownloader] = None, + ): + self.id = _id + self.inference_engine = inference_engine + self.server = server + self.discovery = discovery + self.partitioning_strategy = partitioning_strategy + self.peers: List[PeerHandle] = {} + self.topology: Topology = Topology() + self.device_capabilities = device_capabilities() + self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {} + self.buffered_logits: Dict[str, List[np.ndarray]] = {} + self.buffered_inputs: Dict[str, List[np.ndarray]] = {} + self.max_generate_tokens = max_generate_tokens + self.topology_viz = topology_viz + self.default_sample_temperature = default_sample_temperature + self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]() + self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]() + self._on_opaque_status.register("node_status").on_next(self.on_node_status) + self.node_download_progress: Dict[str, RepoProgressEvent] = {} + self.topology_inference_engines_pool: List[List[str]] = [] + self.shard_downloader = shard_downloader + + async def start(self, wait_for_peers: int = 0) -> None: + await self.server.start() + await self.discovery.start() + await self.update_peers(wait_for_peers) + await self.collect_topology() + if DEBUG >= 2: print(f"Collected topology: {self.topology}") + asyncio.create_task(self.periodic_topology_collection(1.0)) + + async def stop(self) -> None: + await self.discovery.stop() + await self.server.stop() + + def on_node_status(self, request_id, opaque_status): + try: + status_data = json.loads(opaque_status) + if status_data.get("type", "") == "supported_inference_engines": + node_id = status_data.get("node_id") + engines = status_data.get("engines", []) + self.topology_inference_engines_pool.append(engines) + if status_data.get("type", "") == "node_status": + if status_data.get("status", "").startswith("start_"): + self.current_topology.active_node_id = status_data.get("node_id") + elif status_data.get("status", "").startswith("end_"): + if status_data.get("node_id") == self.current_topology.active_node_id: + self.current_topology.active_node_id = None + download_progress = None + if status_data.get("type", "") == "download_progress": + if DEBUG >= 8: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('progress')}") + download_progress = RepoProgressEvent.from_dict(status_data.get('progress')) + self.node_download_progress[status_data.get('node_id')] = download_progress + if self.topology_viz: + self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id, self.node_download_progress) + except Exception as e: + if DEBUG >= 1: print(f"Error updating visualization: {e}") + if DEBUG >= 1: traceback.print_exc() + + def get_supported_inference_engines(self): + supported_engine_names = [] + if self.inference_engine.__class__.__name__ == 'MLXDynamicShardInferenceEngine': + supported_engine_names.append('mlx') + supported_engine_names.append('tinygrad') + elif self.inference_engine.__class__.__name__ == 'TorchDynamicShardInferenceEngine': + supported_engine_names.append('torch') + supported_engine_names.append('tinygrad') + else: + supported_engine_names.append('tinygrad') + return supported_engine_names + + async def broadcast_supported_engines(self, supported_engines_names: List[str]): + status_message = json.dumps({"type": "supported_inference_engines", "node_id": self.id, "engines": supported_engines_names}) + await self.broadcast_opaque_status("", status_message) + + def get_topology_inference_engines(self) -> List[List[str]]: + return self.topology_inference_engines_pool + + async def process_inference_result( + self, + shard, + result: np.ndarray, + request_id: Optional[str] = None, + ): + if request_id not in self.buffered_token_output: + self.buffered_token_output[request_id] = ([], False) + is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens + if shard.is_last_layer() and not is_finished: + token = await self.inference_engine.sample(result, temp=self.default_sample_temperature) + await self.inference_engine.ensure_shard(shard) + self.buffered_token_output[request_id][0].append(token.item()) + if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}") + is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id + forward = token.reshape(1, -1) + self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished) + asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished)) + else: + forward = result + + if is_finished: + self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True) + else: + asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1))) + + return np.array(self.buffered_token_output[request_id][0]) + + async def process_prompt( + self, + base_shard: Shard, + prompt: str, + request_id: Optional[str] = None, + ) -> Optional[np.ndarray]: + shard = self.get_current_shard(base_shard) + asyncio.create_task( + self.broadcast_opaque_status( + request_id, + json.dumps({ + "type": "node_status", + "node_id": self.id, + "status": "start_process_prompt", + "base_shard": base_shard.to_dict(), + "shard": shard.to_dict(), + "prompt": prompt, + "request_id": request_id, + }), + ) + ) + start_time = time.perf_counter_ns() + resp = await self._process_prompt(base_shard, prompt, request_id) + end_time = time.perf_counter_ns() + elapsed_time_ns = end_time - start_time + asyncio.create_task( + self.broadcast_opaque_status( + request_id, + json.dumps({ + "type": "node_status", + "node_id": self.id, + "status": "end_process_prompt", + "base_shard": base_shard.to_dict(), + "shard": shard.to_dict(), + "prompt": prompt, + "request_id": request_id, + "elapsed_time_ns": elapsed_time_ns, + "result_size": resp.size if resp is not None else 0, + }), + ) + ) + return resp + + async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]: + if request_id is None: + request_id = str(uuid.uuid4()) + shard = self.get_current_shard(base_shard) + + if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}") + if not shard.is_first_layer(): + if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}") + resp = await self.forward_prompt(shard, prompt, request_id, 0) + return None + else: + result = await self.inference_engine.infer_prompt(request_id, shard, prompt) + ret = await self.process_inference_result(shard, result, request_id) + return result + + async def process_tensor( + self, + base_shard: Shard, + tensor: np.ndarray, + request_id: Optional[str] = None, + ) -> Optional[np.ndarray]: + shard = self.get_current_shard(base_shard) + asyncio.create_task( + self.broadcast_opaque_status( + request_id, + json.dumps({ + "type": "node_status", + "node_id": self.id, + "status": "start_process_tensor", + "base_shard": base_shard.to_dict(), + "shard": shard.to_dict(), + "tensor_size": tensor.size, + "tensor_shape": tensor.shape, + "request_id": request_id, + }), + ) + ) + start_time = time.perf_counter_ns() + resp = await self._process_tensor(shard, tensor, request_id) + end_time = time.perf_counter_ns() + elapsed_time_ns = end_time - start_time + asyncio.create_task( + self.broadcast_opaque_status( + request_id, + json.dumps({ + "type": "node_status", + "node_id": self.id, + "status": "end_process_tensor", + "base_shard": base_shard.to_dict(), + "shard": shard.to_dict(), + "request_id": request_id, + "elapsed_time_ns": elapsed_time_ns, + "result_size": resp.size if resp is not None else 0, + }), + ) + ) + return resp + + async def _process_tensor( + self, + base_shard: Shard, + tensor: np.ndarray, + request_id: Optional[str] = None, + ) -> Optional[np.ndarray]: + if request_id is None: + request_id = str(uuid.uuid4()) + shard = self.get_current_shard(base_shard) + + if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}") + try: + result = await self.inference_engine.infer_tensor(request_id, shard, tensor) + ret = await self.process_inference_result(shard, result, request_id) + return ret + except Exception as e: + print(f"Error processing tensor for shard {shard}: {e}") + traceback.print_exc() + return None + + async def forward_prompt( + self, + base_shard: Shard, + prompt: str, + request_id: str, + target_index: int, + ) -> None: + if DEBUG >= 1: print(f"target partition index: {target_index}") + target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id + next_shard = self.get_current_shard(base_shard, target_index) + if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}") + if target_id == self.id: + await self.process_prompt(next_shard, prompt, request_id) + else: + target_peer = next((p for p in self.peers if p.id() == target_id), None) + if not target_peer: + raise ValueError(f"Peer for {target_index} not found") + if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}") + await target_peer.send_prompt(next_shard, prompt, request_id=request_id) + + async def forward_tensor( + self, + base_shard: Shard, + tensor: np.ndarray, + request_id: str, + target_index: int, + ) -> None: + if DEBUG >= 1: print(f"target partition index: {target_index}") + target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id + next_shard = self.get_current_shard(base_shard, target_index) + if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {next_shard}") + if target_id == self.id: + await self.process_tensor(next_shard, tensor, request_id) + else: + target_peer = next((p for p in self.peers if p.id() == target_id), None) + if not target_peer: + raise ValueError(f"Peer for {target_index} not found") + if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}") + await target_peer.send_tensor(next_shard, tensor, request_id=request_id) + + def get_partition_index(self, offset: int = 0): + if not self.partitioning_strategy: + if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.") + return None + partitions = self.partitioning_strategy.partition(self.topology) + current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None) + if current_partition_index is None: + raise ValueError(f"No current partition found for node: {self.id}") + return (current_partition_index + offset) % len(partitions) + + def get_current_shard(self, base_shard: Shard, index: Optional[int] = None) -> Shard: + if index is None: + index = self.get_partition_index() + partitions = self.partitioning_strategy.partition(self.topology) + shards = map_partitions_to_shards(partitions, base_shard.n_layers, base_shard.model_id) + return shards[index] + + async def update_peers(self, wait_for_peers: int = 0) -> bool: + next_peers = await self.discovery.discover_peers(wait_for_peers) + current_peer_ids = {peer.id() for peer in self.peers} + next_peer_ids = {peer.id() for peer in next_peers} + peers_added = [peer for peer in next_peers if peer.id() not in current_peer_ids] + peers_removed = [peer for peer in self.peers if peer.id() not in next_peer_ids] + peers_updated = [peer for peer in next_peers if peer.id() in current_peer_ids and any(p.addr() != peer.addr() for p in self.peers if p.id() == peer.id())] + peers_unchanged = [peer for peer in next_peers if peer.id() in current_peer_ids and all(p.addr() == peer.addr() for p in self.peers if p.id() == peer.id())] + peers_to_disconnect = [peer for peer in peers_removed if await peer.is_connected()] + peers_to_connect = [peer for peer in peers_added + peers_updated + peers_unchanged if not await peer.is_connected()] + + def _pretty(peers: List[PeerHandle]) -> List[str]: + return [f"{peer.id()}@{peer.addr()}" for peer in peers] + + if DEBUG >= 2: + print(f"update_peers: added={peers_added} removed={peers_removed} updated={peers_updated} unchanged={peers_unchanged} to_disconnect={peers_to_disconnect} to_connect={peers_to_connect}") + + async def disconnect_with_timeout(peer, timeout=5): + try: + await asyncio.wait_for(peer.disconnect(), timeout) + return True + except Exception as e: + print(f"Error disconnecting peer {peer.id()}@{peer.addr()}: {e}") + traceback.print_exc() + return False + + async def connect_with_timeout(peer, timeout=5): + try: + await asyncio.wait_for(peer.connect(), timeout) + return True + except Exception as e: + print(f"Error connecting peer {peer.id()}@{peer.addr()}: {e}") + traceback.print_exc() + return False + + disconnect_results = await asyncio.gather(*(disconnect_with_timeout(peer) for peer in peers_to_disconnect), return_exceptions=True) + connect_results = await asyncio.gather(*(connect_with_timeout(peer) for peer in peers_to_connect), return_exceptions=True) + + successful_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is True] + failed_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is False] + successful_connects = [peer for peer, result in zip(peers_to_connect, connect_results) if result is True] + failed_connects = [peer for peer, result in zip(peers_to_connect, connect_results) if result is False] + if DEBUG >= 1: + if successful_disconnects: print(f"Successfully disconnected peers: {_pretty(successful_disconnects)}") + if failed_disconnects: print(f"Failed to disconnect peers: {_pretty(failed_disconnects)}") + if successful_connects: print(f"Successfully connected peers: {_pretty(successful_connects)}") + if failed_connects: print(f"Failed to connect peers: {_pretty(failed_connects)}") + + self.peers = next_peers + return len(peers_added) > 0 or len(peers_removed) > 0 or len(peers_updated) > 0 + + async def select_best_inference_engine(self): + if self.inference_engine.__class__.__name__ == 'DummyInferenceEngine': return + supported_engines = self.get_supported_inference_engines() + await self.broadcast_supported_engines(supported_engines) + if len(self.get_topology_inference_engines()): + self.inference_engine = get_inference_engine(supported_engines[0], self.shard_downloader) + + async def periodic_topology_collection(self, interval: int): + while True: + await asyncio.sleep(interval) + try: + did_peers_change = await self.update_peers() + if DEBUG >= 2: print(f"{did_peers_change=}") + if did_peers_change: + await self.collect_topology() + await self.select_best_inference_engine() + except Exception as e: + print(f"Error collecting topology: {e}") + traceback.print_exc() + + async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]: + if request_id not in self.buffered_token_output: + return None, False + return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1] + + async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) -> Topology: + next_topology = Topology() + next_topology.update_node(self.id, self.device_capabilities) + + if DEBUG >= 2: print(f"Collecting topology {max_depth=} {visited=}") + + prev_visited = visited.copy() + visited.add(self.id) + visited.update(p.id() for p in self.peers) + + for peer in self.peers: + next_topology.update_node(peer.id(), peer.device_capabilities()) + next_topology.add_edge(self.id, peer.id()) + + if peer.id() in prev_visited: + continue + + if max_depth <= 0: + if DEBUG >= 2: print("Max depth reached. Skipping...") + continue + + try: + other_topology = await asyncio.wait_for(peer.collect_topology(visited, max_depth=max_depth - 1), timeout=5.0) + if DEBUG >= 2: print(f"Collected topology from: {peer.id()}: {other_topology}") + self.topology.merge(other_topology) + except Exception as e: + print(f"Error collecting topology from {peer.id()}: {e}") + traceback.print_exc() + + next_topology.active_node_id = self.topology.active_node_id # this is not so clean. + self.topology = next_topology + if self.topology_viz: + self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id) + return next_topology + + @property + def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]: + return self._on_token + + @property + def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]: + return self._on_opaque_status + + def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None: + if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}") + self.on_token.trigger_all(request_id, tokens, is_finished) + + async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None: + async def send_result_to_peer(peer): + try: + await asyncio.wait_for(peer.send_result(request_id, result, is_finished), timeout=15.0) + except asyncio.TimeoutError: + print(f"Timeout broadcasting result to {peer.id()}") + except Exception as e: + print(f"Error broadcasting result to {peer.id()}: {e}") + traceback.print_exc() + + await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True) + + async def broadcast_opaque_status(self, request_id: str, status: str) -> None: + if DEBUG >= 8: print(f"Broadcasting opaque status: {request_id=} {status=}") + + async def send_status_to_peer(peer): + try: + await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0) + except asyncio.TimeoutError: + print(f"Timeout sending opaque status to {peer.id()}") + except Exception as e: + print(f"Error sending opaque status to {peer.id()}: {e}") + traceback.print_exc() + + await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True) + # in the case of opaque status, we also want to receive our own opaque statuses + self.on_opaque_status.trigger_all(request_id, status) + + @property + def current_topology(self) -> Topology: + return self.topology diff --git a/exo/tinychat/index.html b/exo/tinychat/index.html index 65451068f..543a47abf 100644 --- a/exo/tinychat/index.html +++ b/exo/tinychat/index.html @@ -18,7 +18,7 @@ - + diff --git a/exo/topology/device_capabilities.py b/exo/topology/device_capabilities.py index ae28e757f..17372262b 100644 --- a/exo/topology/device_capabilities.py +++ b/exo/topology/device_capabilities.py @@ -119,6 +119,10 @@ def to_dict(self): "NVIDIA A800 80GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS), "NVIDIA A100 80GB SXM": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS), "NVIDIA A800 80GB SXM": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS), + "NVIDIA T1000 8GB": DeviceFlops(fp32=2.5 * TFLOPS, fp16=5.0 * TFLOPS, int8=10.0 * TFLOPS), + "Quadro M2000": DeviceFlops(fp32=0.5 * TFLOPS, fp16=1.0 * TFLOPS, int8=2.0 * TFLOPS), + "Quadro P400": DeviceFlops(fp32=0.641 * TFLOPS, fp16=1.282 * TFLOPS, int8=2.564 * TFLOPS), + "NVIDIA A10": DeviceFlops(fp32=31.2 * TFLOPS, fp16=62.5 * TFLOPS, int8=2.5 * TFLOPS), # ... add more devices if needed ... ### AMD GPUs # RX 6000 series diff --git a/install.ps1 b/install.ps1 new file mode 100644 index 000000000..c766cdd5b --- /dev/null +++ b/install.ps1 @@ -0,0 +1,8 @@ +# Create a virtual environment +python3 -m venv .venv + +# Activate the virtual environment +& .\.venv\Scripts\Activate.ps1 + +# Install the package in the virtual environment +pip install . diff --git a/setup.py b/setup.py index de242f544..530e85976 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,15 @@ "uuid==1.30", "uvloop==0.21.0", "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@ec120ce6b9ce8e4ff4b5692566a683ef240e8bc8", + "torch==2.6.0", + "accelerate==0.34.2", + "torchtune==0.5.0", + "torchao==0.8.0", + "pytest==8.3.3", + "pytest-asyncio==0.24.0", + "scapy==2.6.1", + "uvloop==0.21.0" + ] extras_require = {