diff --git a/.gitignore b/.gitignore
index bc6a38d79..4a8414e5a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -169,7 +169,12 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
+# XCode
**/*.xcodeproj/*
+
+# Aider
.aider*
exo/tinychat/images/*.png
+.vscode/
+build/
diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py
index 1020fdbc3..1b07ad19e 100644
--- a/exo/api/chatgpt_api.py
+++ b/exo/api/chatgpt_api.py
@@ -81,6 +81,9 @@ def generate_completion(
}],
}
+ if DEBUG >= 3:
+ print(f"completion: {completion}")
+
if not stream:
completion["usage"] = {
"prompt_tokens": len(tokenizer.encode(prompt)),
diff --git a/exo/helpers.py b/exo/helpers.py
index ff0205f00..634ffed2c 100644
--- a/exo/helpers.py
+++ b/exo/helpers.py
@@ -42,7 +42,6 @@ def get_system_info():
return "Linux"
return "Non-Mac, non-Linux system"
-
def find_available_port(host: str = "", min_port: int = 49152, max_port: int = 65535) -> int:
used_ports_file = os.path.join(tempfile.gettempdir(), "exo_used_ports")
diff --git a/exo/inference/inference_engine.py b/exo/inference/inference_engine.py
index b62000371..3c86cc681 100644
--- a/exo/inference/inference_engine.py
+++ b/exo/inference/inference_engine.py
@@ -55,6 +55,7 @@ async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inferen
"mlx": "MLXDynamicShardInferenceEngine",
"tinygrad": "TinygradDynamicShardInferenceEngine",
"dummy": "DummyInferenceEngine",
+ "torch": "TorchDynamicShardInferenceEngine"
}
@@ -71,6 +72,10 @@ def get_inference_engine(inference_engine_name: str, shard_downloader: ShardDown
tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
return TinygradDynamicShardInferenceEngine(shard_downloader)
+ elif inference_engine_name == "torch":
+ from exo.inference.torch.sharded_inference_engine import TorchDynamicShardInferenceEngine
+
+ return TorchDynamicShardInferenceEngine(shard_downloader)
elif inference_engine_name == "dummy":
from exo.inference.dummy_inference_engine import DummyInferenceEngine
return DummyInferenceEngine()
diff --git a/exo/inference/torch/.gitignore b/exo/inference/torch/.gitignore
new file mode 100644
index 000000000..6d76c24de
--- /dev/null
+++ b/exo/inference/torch/.gitignore
@@ -0,0 +1,2 @@
+data/
+model/archive/
diff --git a/exo/inference/torch/README.md b/exo/inference/torch/README.md
new file mode 100644
index 000000000..9cc6af50d
--- /dev/null
+++ b/exo/inference/torch/README.md
@@ -0,0 +1,103 @@
+# PyTorch inference engine
+
+## Devs
+- [Vincent Castro](https://x.com/t0kenl1mit)
+
+## ENV Vars
+```bash
+# Use the original max position embeddings amount, if present in the rope_scaling - default is False
+TORCH_USE_ORG_SEQ = True or False
+
+# Use cache - default is True
+TORCH_USE_CACHE = True or False
+```
+
+## Notes/Issues
+### 10/10/2024
+- To select a pytorch device via environment variables, set the variable TORCH_DEVICE
+ - XLA is currently not installed and will need to be added to inference.py, looking into doing this on a TPU VM
+ - With pytorch, CUDA and ROCm are the same so specifying CUDA also enables ROCm support. See this [post](https://github.com/pytorch/pytorch/issues/55223#issuecomment-812587373)
+ - Looking into adding mobile device support properly
+- If device is not CPU the data type defaults to float32 else float16.
+
+### 10/13/2024
+Still working on split model development (see test_split_model.py). Right now, it seems to do it but still transformers is loading more in the RAM and GPU as it loads up a larger models (causing an OOM). Will research and add to next update. Right now, tests are added and are in development.
+
+### 10/21/2024
+Working on removing transformers due to inference and VRAM usage [issues](https://github.com/exo-explore/exo/pull/139#issuecomment-2424953962). Creating a pure pytorch implementation of llama3 as using transformers wont work for exo. Using some code from meta but also implementing the use of torchtune.
+
+### 10/27/2024
+Still working on llama3 model but wanted to note that a better KVCache needs to be investigated.
+
+#### 11/17/2024
+Llama sharded model now working and next step is inference engine. Still testing on small llama 3.2 1B but will try larger models.
+
+### 01/16/2024
+Torchtune has replaced huggingface transformers except for the tokenizer. Inferencing on Meta 3.2 1B seems okay but my GPU runs into a wall quickly. Will be trying on a bigger VM and LAN server to split up model.
+
+## Tech
+```bash
+# Laptop/PC
+Distributor ID: Ubuntu
+Description: Ubuntu 24.04.1 LTS
+Release: 24.04
+Codename: noble
+CUDA Version: 12.4
+Nvidia Driver Version: 550.107.02
+
+CPU: 11th Gen Intel® Core™ i7-11800H × 16
+RAM: 16GB
+GPU 1: Nvidia GeForce RTX 3060 6GB Laptop
+```
+```bash
+# Server
+Distributor ID: Pop
+Description: Pop!_OS 22.04 LTS
+Release: 22.04
+Codename: jammy
+CUDA Version: 12.4
+Nvidia Driver Version: 550.90.07
+
+GPU 1: NVIDIA T1000 8GB
+GPU 2: NVIDIA Quadro M2000 4GB
+GPU 3: NVIDIA Quadro M2000 4GB
+GPU 4: NVIDIA Quadro P400 2GB
+GPU 5: NVIDIA Quadro P400 2GB
+```
+
+## Current Model
+
+WIP pytorch llama model
+
+```
+# Llama-3.2-1B-Instruct #
+
+ShardedLlamaModel(
+ (model): ShardTransformerDecoder(
+ (tok_embeddings): Embedding(128256, 2048)
+ (layers): ModuleList(
+ (0-15): 16 x TransformerSelfAttentionLayer(
+ (attn): MultiHeadAttention(
+ (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
+ (k_proj): Linear(in_features=2048, out_features=512, bias=False)
+ (v_proj): Linear(in_features=2048, out_features=512, bias=False)
+ (output_proj): Linear(in_features=2048, out_features=2048, bias=False)
+ (pos_embeddings): Llama3ScaledRoPE()
+ )
+ (mlp): FeedForward(
+ (w1): Linear(in_features=2048, out_features=8192, bias=False)
+ (w2): Linear(in_features=8192, out_features=2048, bias=False)
+ (w3): Linear(in_features=2048, out_features=8192, bias=False)
+ (activation): SiLU()
+ )
+ (sa_norm): RMSNorm()
+ (mlp_norm): RMSNorm()
+ (sa_scale): Identity()
+ (mlp_scale): Identity()
+ )
+ )
+ (norm): RMSNorm()
+ )
+)
+
+```
diff --git a/exo/inference/torch/__init__.py b/exo/inference/torch/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/exo/inference/torch/models/__init__.py b/exo/inference/torch/models/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/exo/inference/torch/models/general_mha.py b/exo/inference/torch/models/general_mha.py
new file mode 100644
index 000000000..8a9be51a1
--- /dev/null
+++ b/exo/inference/torch/models/general_mha.py
@@ -0,0 +1,252 @@
+"""
+GeneralMHA class
+Return transformer model with MHA
+"""
+from typing import Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torchtune.modules as ttm
+
+from torchtune.modules import RMSNorm
+from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
+from torchtune.models.qwen2._positional_embeddings import Qwen2RotaryPositionalEmbeddings
+from torchtune.modules import RotaryPositionalEmbeddings
+from exo.inference.shard import Shard
+from exo.inference.torch.models.llm_utils import (
+ layer_mlp,
+ ShardTransformerDecoder
+)
+
+from exo.helpers import DEBUG
+
+def GeneralMHA(
+ config: dict,
+ shard: Shard
+):
+ use_tied = False
+ attn_bias = config.get("attn_bias", False)
+ output_bias = config.get("attn_bias", False)
+
+ if "llama" in shard.model_id or "Llama" in shard.model_id:
+ # rope scaling config
+ rope = Llama3ScaledRoPE(
+ dim=config["head_dim"],
+ max_seq_len=config["max_seq_len"],
+ base=config["rope_base"],
+ scale_factor=config["rope_scaling_factor"],
+ )
+
+ # tied needed for 3.2 llama models
+ if "3.2" in shard.model_id:
+ use_tied = True
+ elif "qwen" in shard.model_id or "Qwen" in shard.model_id:
+ # rope scaling config
+ rope = Qwen2RotaryPositionalEmbeddings(
+ dim=config["head_dim"],
+ max_seq_len=config["max_seq_len"],
+ base=config["rope_base"]
+ )
+ attn_bias = True
+ output_bias = False
+
+ # tied needed for 0.5B qwen models
+ if "0.5B" in shard.model_id or "0.5b" in shard.model_id:
+ use_tied = True
+ else:
+ rope = RotaryPositionalEmbeddings(
+ dim=config["head_dim"],
+ max_seq_len=config["max_seq_len"],
+ base=config["rope_base"]
+ )
+
+ if DEBUG >= 4:
+ print(f"model_id: {shard.model_id}")
+ print(f"rope: {rope}")
+ print(f"attn_bias: {attn_bias}")
+ print(f"output_bias: {output_bias}")
+ print(f"use_tied: {use_tied}")
+
+ # hack to align sharded weights with layers
+ # fill unused layer positions with None
+ layers = [None for _ in range(shard.n_layers)]
+
+ # build layers
+ for i in range(shard.start_layer, shard.end_layer + 1):
+ self_attn = ttm.MultiHeadAttention(
+ embed_dim=config["embed_dim"],
+ num_heads=config["num_heads"],
+ num_kv_heads=config["num_kv_heads"],
+ head_dim=config["head_dim"],
+ q_proj=nn.Linear(
+ config["embed_dim"],
+ config["num_heads"]*config["head_dim"],
+ bias=attn_bias,
+ ),
+ k_proj=nn.Linear(
+ config["embed_dim"],
+ config["num_kv_heads"]*config["head_dim"],
+ bias=attn_bias,
+ ),
+ v_proj=nn.Linear(
+ config["embed_dim"],
+ config["num_kv_heads"]*config["head_dim"],
+ bias=attn_bias,
+ ),
+ output_proj=nn.Linear(
+ config["embed_dim"],
+ config["embed_dim"],
+ bias=output_bias,
+ ),
+ max_seq_len=config["max_seq_len"],
+ attn_dropout=config["attn_dropout"],
+ pos_embeddings=rope,
+ )
+
+ mlp = layer_mlp(
+ dim=config["embed_dim"],
+ hidden_dim=config["intermediate_dim"],
+ )
+
+ layer = ttm.TransformerSelfAttentionLayer(
+ attn=self_attn,
+ mlp=mlp,
+ sa_norm=RMSNorm(config["embed_dim"], eps=config["norm_eps"]),
+ mlp_norm=RMSNorm(config["embed_dim"], eps=config["norm_eps"]),
+ )
+
+ layers[i] = layer
+
+ layers = nn.ModuleList(layers)
+
+ tok_embeddings = nn.Embedding(config["vocab_size"], config["embed_dim"])
+ if use_tied:
+ output_proj = ttm.TiedLinear(tok_embeddings)
+ else:
+ output_proj = nn.Linear(config["embed_dim"], config["vocab_size"], bias=False)
+
+ norm = RMSNorm(config["embed_dim"], eps=config["norm_eps"])
+
+ return ShardTransformerDecoder(
+ tok_embeddings=tok_embeddings,
+ shard=shard,
+ layers=layers,
+ max_seq_len=config["max_seq_len"],
+ num_heads=config["num_heads"],
+ head_dim=config["head_dim"],
+ norm=norm,
+ output=output_proj,
+ num_layers=config["num_layers"],
+ )
+
+class ShardedGeneralModel(nn.Module):
+ def __init__(
+ self,
+ config: dict,
+ shard: Shard,
+ device: Optional[torch.device] = None,
+ dtype: torch.dtype = torch.float16,
+ use_cache: Optional[bool] = False,
+ max_generated_tokens: int = 1024,
+ ):
+ super(ShardedGeneralModel, self).__init__()
+
+ self.shard = shard
+ self.config = config
+ self.dtype = dtype
+ self.device = device if device is not None else torch.device("cpu")
+ self.max_seq_len = self.config["max_seq_len"]
+ self.use_cache = use_cache
+
+ self.model = GeneralMHA(
+ config,
+ self.shard
+ ).to(
+ dtype=self.dtype,
+ device=self.device
+ )
+
+ if DEBUG >= 4:
+ print("ShardedGeneralModel called")
+ print(f"self.model {self.model}")
+
+ # keep track of current position in generation
+ self.max_generated_tokens = max_generated_tokens
+
+ def generate(
+ self,
+ tokens: Optional[torch.Tensor] = None,
+ mask: Optional[torch.Tensor] = None,
+ input_pos: Optional[torch.Tensor] = None,
+ hidden_state: Optional[torch.Tensor] = None,
+ curr_pos: Optional[int] = 0
+ ) -> Tuple[
+ Optional[torch.Tensor],
+ torch.Tensor,
+ ]:
+ """
+ Generate logits and/or hidden_states from llama model
+
+ Args
+ tokens (torch.Tensor) - tokens from prompt tokenization and generation
+ hidden_state (torch.Tensor, optional) - hidden state from last activated hidden layer, if any
+ """
+ if DEBUG >= 4:
+ print("generate called")
+ print(f"tokens: {tokens}")
+ if mask is not None:
+ print(f"mask: {mask.size()}")
+ print(f"input_pos: {input_pos.size()}")
+ print(f"hidden_state: {hidden_state}")
+ print(f"curr_pos: {curr_pos}")
+ print(f"cached? {self.model.caches_are_enabled()}")
+
+ model_hs = None
+ model_logits = None
+
+ self.model.output_hidden_states = [self.shard.end_layer]
+
+ if curr_pos > 0:
+ if self.model.caches_are_enabled():
+ input_pos = input_pos[:, curr_pos].contiguous()
+ mask = mask[:, curr_pos, None, :].contiguous()
+ else:
+ input_pos = input_pos[:, :curr_pos + 1]
+ mask = mask[:, :curr_pos + 1, :curr_pos + 1]
+ else:
+ _, tklng = tokens.size()
+
+ if self.model.caches_are_enabled():
+ mask = mask[:, :tklng]
+ else:
+ mask = mask[:, :tklng, :tklng]
+
+ input_pos = input_pos[:, :tklng].squeeze()
+
+ if DEBUG >= 4:
+ print("model_input")
+ if tokens is not None:
+ print(f"tokens: {tokens}")
+ if hidden_state is not None:
+ print(f"hidden_state: {hidden_state}")
+ print(f"mask: {mask}")
+ print(f"input_pos: {input_pos}")
+
+
+ model_output = self.model(
+ tokens=tokens,
+ mask=mask,
+ input_pos=input_pos,
+ hidden_state=hidden_state,
+ dtype=self.dtype
+ )
+
+ if self.shard.is_last_layer():
+ model_logits = model_output
+ else:
+ model_hs = model_output
+
+ if DEBUG >= 4:
+ print(f"model_hs\n{model_hs}\nmodel_logits\n{model_logits}")
+
+ return model_hs, model_logits
diff --git a/exo/inference/torch/models/llm_utils.py b/exo/inference/torch/models/llm_utils.py
new file mode 100644
index 000000000..0d3212c16
--- /dev/null
+++ b/exo/inference/torch/models/llm_utils.py
@@ -0,0 +1,520 @@
+"""
+Utility methods used by LLMs
+"""
+import os
+import re
+import json
+from pathlib import Path
+from typing import Any, Dict, Optional, Union, List, Callable
+
+import torch
+import torch.nn as nn
+from torchtune.modules.attention_utils import _MaskType
+from torchtune.modules import FeedForward, TransformerDecoder
+from torchtune.models.convert_weights import hf_to_tune
+
+from safetensors.torch import load_file as load_safetensors
+
+from exo.helpers import DEBUG
+from exo.inference.shard import Shard
+
+# dtype string to dtype from huggingface type config.json
+HF_PRECISION_STR_TO_DTYPE: Dict[str, torch.dtype] = {
+ "float16": torch.float16,
+ "bfloat16": torch.bfloat16,
+ "float32": torch.float32,
+ "float64": torch.float64,
+}
+
+
+def load_model_config(model_config_path: Path) -> dict:
+ """
+ Loads the config.json of the model
+
+ Args:
+ model_path (Path): local path to model config json
+
+ Returns:
+ dict: The config as a dictionary
+ """
+ model_config = {}
+ with open(model_config_path, "r") as f:
+ base_config = json.load(f)
+
+ model_config = {
+ "rope_scaling": base_config.get("rope_scaling"),
+ "embed_dim": base_config["hidden_size"],
+ "num_heads": base_config["num_attention_heads"],
+ "head_dim": base_config.get(
+ "head_dim",
+ base_config["hidden_size"] // base_config["num_attention_heads"],
+ ), # Assuming embed_dim = hidden_size
+ "num_kv_heads": base_config["num_key_value_heads"],
+ "max_seq_len": base_config["max_position_embeddings"],
+ "intermediate_dim": base_config["intermediate_size"],
+ "attn_dropout": base_config.get("attention_dropout", 0.0),
+ "norm_eps": base_config["rms_norm_eps"],
+ "rope_base": base_config["rope_theta"],
+ "vocab_size": base_config["vocab_size"],
+ "num_layers": base_config["num_hidden_layers"],
+ "attn_bias": base_config.get("attention_bias", False),
+ "hidden_act": base_config.get("hidden_act", "silu"),
+ "torch_dtype": HF_PRECISION_STR_TO_DTYPE.get(
+ base_config.get("torch_dtype", "float16"),
+ torch.float16
+ )
+ }
+
+ if model_config.get("rope_scaling", None) is not None:
+ model_config["rope_scaling_factor"] = model_config["rope_scaling"].get("rope_factor", 32)
+
+ use_org_seq = bool(os.getenv("TORCH_USE_ORG_SEQ", "False").lower() == "true")
+ if use_org_seq and model_config.get("rope_scaling", None) is not None:
+ model_config["max_seq_len"] = model_config["rope_scaling"]["original_max_position_embeddings"]
+
+
+
+ return model_config
+
+
+def check_weights(model, state_dict):
+ """
+ Verifies that the weights from the state dictionary are properly loaded into the model.
+ """
+ model_state_dict = model.state_dict()
+ for name, param in model_state_dict.items():
+ if name in state_dict:
+ loaded_param = state_dict[name]
+ if param.shape != loaded_param.shape:
+ print(f"Shape mismatch for {name}: expected {param.shape}, got {loaded_param.shape}")
+ else:
+ print(f"{name}: loaded correctly")
+
+ for name in state_dict:
+ if name not in model_state_dict:
+ print(f"Unexpected weight {name} found in state_dict")
+
+def load_weights_torch(cache_dir: Path, model: Any, config: Dict):
+ # Load weights from safetensors files in the cache directory
+ safetensors_files = list(cache_dir.glob("*.safetensors"))
+ if not safetensors_files:
+ raise FileNotFoundError("No safetensors files found in the cache directory.")
+
+ # Load weights from each found safetensors file
+ full_state_dict = {}
+ for safetensor_file in safetensors_files:
+ state_dict = load_safetensors(safetensor_file)
+
+ if full_state_dict is not None:
+ full_state_dict = full_state_dict | state_dict
+ else:
+ full_state_dict = state_dict
+
+ converted_sd = hf_to_tune(
+ state_dict=full_state_dict,
+ num_heads=config["num_heads"],
+ num_kv_heads=config["num_kv_heads"],
+ dim=config["embed_dim"],
+ head_dim=config["head_dim"]
+ )
+
+ model.load_state_dict(converted_sd, strict=True)
+
+ print("\n--- checking weights ----\n")
+ check_weights(model, converted_sd)
+
+def _permute(t, n_heads: int, head_dim: int, dim: int):
+ """
+ Reshape weight for torchtune
+ """
+ return (
+ t.view(n_heads, 2, head_dim // 2, dim)
+ .transpose(1, 2)
+ .reshape((head_dim * n_heads), dim)
+ )
+
+def load_model_weights_torchtune(
+ cache_dir: Path,
+ shard: Shard,
+ model: Any,
+ num_heads: int = 32,
+ num_kv_heads: int = 32,
+ dim: int = 4096,
+ head_dim: int = None
+):
+ """
+ Loads weights from huggingface and changes it to match torchtune naming structure
+ """
+ if head_dim is None:
+ head_dim = dim // num_heads
+
+ model_state_dict = model.state_dict()
+ if DEBUG >= 8:
+ for name, _ in model_state_dict.items():
+ print(f"name {name}")
+ # Load weights from safetensors files in the cache directory
+ safetensors_files = list(cache_dir.glob("*.safetensors"))
+ if not safetensors_files:
+ raise FileNotFoundError("No safetensors files found in the cache directory.")
+
+ # Load weights from each found safetensors file
+
+ full_state_dict = {}
+ for safetensor_file in safetensors_files:
+ state_dict = load_safetensors(safetensor_file)
+
+ if full_state_dict is not None:
+ full_state_dict = full_state_dict | state_dict
+ else:
+ full_state_dict = state_dict
+
+ # remap to work with our model
+ remapped_state_dict = {}
+
+ is_llama = True if "llama" in shard.model_id or "Llama" in shard.model_id else False
+
+ if DEBUG >= 8 and is_llama:
+ print("loading llama type weights")
+ elif DEBUG >= 8 and not is_llama:
+ print("loading weights")
+
+ for key, value in full_state_dict.items():
+ # load layer by shard
+ for layer_num in range(shard.start_layer, shard.end_layer + 1):
+ # change input layer norm to sa_norm for torchtune
+ re_iln = re.findall(rf"model.layers\.{layer_num}\.(input_layernorm)\.weight", key)
+ if len(re_iln) != 0:
+ new_key = f"model.layers.{layer_num}.sa_norm.scale"
+ remapped_state_dict[new_key] = value
+ if DEBUG >= 8:
+ print(f"{key} == {new_key}")
+
+ # change post attention layernorm to mlp_norm for torchtune
+ re_pal = re.findall(rf"model.layers\.{layer_num}\.(post_attention_layernorm)\.weight", key)
+ if len(re_pal) != 0:
+ new_key = f"model.layers.{layer_num}.mlp_norm.scale"
+ remapped_state_dict[new_key] = value
+ if DEBUG >= 8:
+ print(f"{key} == {new_key}")
+
+ # change self_attn to attn
+ # along with changing o_proj to output_proj
+ re_attn = re.findall(rf"model\.layers\.{layer_num}.(\w+)\.(\w+)\.(\w+)", key)
+ if len(re_attn) != 0 and re_attn[0][0] == "self_attn":
+ if re_attn[0][1] == "k_proj" and is_llama:
+ value = _permute(
+ t=value,
+ n_heads=num_kv_heads,
+ head_dim=head_dim,
+ dim=dim
+ )
+
+ new_key = f"model.layers.{layer_num}.attn.{re_attn[0][1]}.{re_attn[0][2]}"
+ remapped_state_dict[new_key] = value
+ elif re_attn[0][1] == "q_proj" and is_llama:
+ value = _permute(
+ t=value,
+ n_heads=num_heads,
+ head_dim=head_dim,
+ dim=dim
+ )
+ new_key = f"model.layers.{layer_num}.attn.{re_attn[0][1]}.{re_attn[0][2]}"
+ remapped_state_dict[new_key] = value
+
+ elif re_attn[0][1] == "o_proj":
+ new_key = f"model.layers.{layer_num}.attn.output_proj.weight"
+ remapped_state_dict[new_key] = value
+ else:
+ new_key = f"model.layers.{layer_num}.attn.{re_attn[0][1]}.{re_attn[0][2]}"
+ remapped_state_dict[new_key] = value
+ if DEBUG >= 8:
+ print(f"{key} == {new_key}")
+
+ # set mlp weights
+ re_mlp = re.findall(rf"model\.layers\.{layer_num}.mlp.(\w+)\.(\w+)", key)
+ if len(re_mlp) != 0:
+ proj_name = re_mlp[0][0]
+ if proj_name == "up_proj":
+ proj_name = "w3"
+ elif proj_name == "down_proj":
+ proj_name = "w2"
+ elif proj_name == "gate_proj":
+ proj_name = "w1"
+ new_key = f"model.layers.{layer_num}.mlp.{proj_name}.weight"
+ remapped_state_dict[new_key] = value
+ if DEBUG >= 8:
+ print(f"{key} == {new_key}")
+
+ # saving embed for paired weights
+ if key == "model.embed_tokens.weight":
+ remapped_state_dict["model.tok_embeddings.weight"] = value
+ if DEBUG >= 8:
+ print("model.embed_tokens.weight == model.tok_embeddings.weight")
+
+ if key == "model.norm.weight":
+ remapped_state_dict["model.norm.scale"] = value
+
+ if key == "lm_head.weight":
+ remapped_state_dict["model.output.weight"] = value
+ # saving embed for paired weights
+ if key == "model.embed_tokens.weight":
+ remapped_state_dict["model.tok_embeddings.weight"] = value
+ if DEBUG >= 8:
+ print("model.embed_tokens.weight == model.tok_embeddings.weight")
+
+ if key == "model.norm.weight":
+ remapped_state_dict["model.norm.scale"] = value
+
+ if key == "lm_head.weight":
+ remapped_state_dict["model.output.weight"] = value
+
+ if not remapped_state_dict:
+ model.load_state_dict(full_state_dict, strict=True)
+ else:
+ if DEBUG >= 8:
+ print("\nRemapped state dict\n")
+ for rsdk in remapped_state_dict.keys():
+ print(f"-- {rsdk}")
+
+ # load new weight map
+ model.load_state_dict(remapped_state_dict, strict=False)
+
+ if DEBUG >= 8:
+ print("\n--- checking weights ----\n")
+ check_weights(model, remapped_state_dict)
+
+class ShardTransformerDecoder(TransformerDecoder):
+ """
+ ShardTransformerDecorder
+ Custom version of torchtune TransformerDecoder to allow for
+ sharding of models and passing of hidden layers between shards
+ """
+ def __init__(
+ self,
+ *,
+ shard: Shard,
+ tok_embeddings: nn.Embedding,
+ layers: Union[nn.Module, List[nn.Module], nn.ModuleList],
+ max_seq_len: int,
+ num_heads: int,
+ head_dim: int,
+ norm: nn.Module,
+ output: Union[nn.Linear, Callable],
+ num_layers: Optional[int] = None,
+ output_hidden_states: Optional[List[int]] = None,
+ ):
+ super().__init__(
+ tok_embeddings=tok_embeddings,
+ layers=layers,
+ max_seq_len=max_seq_len,
+ num_heads=num_heads,
+ head_dim=head_dim,
+ norm=norm,
+ output=output,
+ num_layers=num_layers,
+ output_hidden_states=output_hidden_states,
+ )
+
+ self.shard = shard
+
+ def setup_caches(
+ self,
+ batch_size: int,
+ dtype: torch.dtype,
+ *,
+ encoder_max_seq_len: Optional[int] = None,
+ decoder_max_seq_len: Optional[int] = None,
+ ):
+ """
+ modified version for shard
+
+ assume just decoder layers
+ """
+ if decoder_max_seq_len is not None:
+ self.decoder_max_cache_seq_len = decoder_max_seq_len
+ else:
+ self.decoder_max_cache_seq_len = self.max_seq_len
+
+ for layer in self.layers:
+ if layer is not None:
+ layer.setup_caches(
+ batch_size,
+ dtype,
+ encoder_max_seq_len=self.encoder_max_cache_seq_len,
+ decoder_max_seq_len=self.decoder_max_cache_seq_len,
+ )
+
+ def caches_are_enabled(self) -> bool:
+ """
+ modified version for shard
+ """
+ if self.layers[0] is not None:
+ return self.layers[0].caches_are_enabled()
+ else:
+ for layer in self.layers:
+ if layer is not None:
+ return layer.caches_are_enabled()
+
+ return False
+
+ def reset_caches(self):
+ torch.cuda.empty_cache()
+
+ for layer in self.layers:
+ if layer is not None:
+ layer.reset_cache()
+
+ def forward(
+ self,
+ tokens: torch.Tensor,
+ *,
+ mask: Optional[_MaskType] = None,
+ input_pos: Optional[torch.Tensor] = None,
+ hidden_state: Optional[torch.Tensor] = None,
+ dtype: torch.dtype = torch.float16
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
+ # Determine the type of input and shape
+ if DEBUG >= 4:
+ print("forward called")
+ if tokens is not None:
+ print(f"tokens [{tokens.shape}]: {tokens}")
+ print(f"mask: {mask}")
+ print(f"input_pos: {input_pos}")
+
+ if hidden_state is not None:
+ print(f"hidden_state [{hidden_state.shape}]: {hidden_state}")
+
+ if hidden_state is not None:
+ h = hidden_state # Use directly as hidden states
+ else:
+ seq_len = tokens.shape[1]
+
+ self._validate_inputs(
+ seq_len,
+ mask=mask,
+ input_pos=input_pos,
+ )
+
+ fl_tokens = tokens.clone()
+ h = self.tok_embeddings(fl_tokens).to(dtype=dtype) # Apply token tok_embeddings
+
+ # Initialize a list to capture hidden states if requested
+ # for captured hidden states
+ hidden = []
+ curr_layers = [self.layers[i] for i in range(self.shard.start_layer, self.shard.end_layer + 1)]
+ for i, layer in enumerate(curr_layers):
+ if DEBUG >= 8:
+ print(f"\nhidden layer in H[{self.shard.start_layer+i}]\n{h}")
+ print(f"\nmask\n{mask}\ninput_pos\n{input_pos}")
+ print(f"\noutput_hidden_states\n{self.output_hidden_states}\n")
+
+ if i in self.output_hidden_states:
+ hidden.append(h)
+
+ # Process through each transformer layer
+ h = layer(
+ h,
+ mask=mask,
+ input_pos=input_pos,
+ )
+
+ if DEBUG >= 8:
+ print(f"\nhidden layer out H[{self.shard.start_layer+i}]->H[{self.shard.start_layer+i+1}]\n{h}\n")
+
+ if self.shard.is_last_layer():
+ # Apply normalization
+ h = self.norm(h)
+
+ # Handle chunked output if needed
+ output = self.output(h).float()
+
+ if DEBUG >= 4:
+ print(f"\n\noutput {output}\n\n")
+
+ return output
+ else:
+ if DEBUG >= 4:
+ print(f"\n\nhidden output {hidden[-1]}\n\n")
+
+ return hidden[-1]
+
+class MultiLayerPreceptron(nn.Module):
+ def __init__(self, input_dim, hidden_dim, activation="silu", use_bias=False):
+ """
+ General MLP (Multi-Layer Perceptron) module.
+
+ Args:
+ input_dim (int): Dimensionality of the input.
+ hidden_dims (int): Hidden layer/intermediate dimensions.
+ output_dim (int): Dimensionality of the output.
+ activation (str): Activation function ('relu', 'gelu', 'tanh', 'sigmoid', etc.).
+ use_bias (bool): Use bias with linearization
+ """
+ super(MultiLayerPreceptron, self).__init__()
+
+ # Activation function mapping
+ activations = {"relu": nn.ReLU(), "gelu": nn.GELU(), "tanh": nn.Tanh(), "sigmoid": nn.Sigmoid(), "leaky_relu": nn.LeakyReLU(0.2), "silu": nn.SiLU()}
+
+ # Ensure valid activation
+ if activation not in activations:
+ raise ValueError(f"Invalid activation: {activation}. Choose from {list(activations.keys())}")
+
+ # Construct MLP layers
+ self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=use_bias)
+ self.up_proj = nn.Linear(input_dim, hidden_dim, bias=use_bias)
+ self.down_proj = nn.Linear(hidden_dim, input_dim, bias=use_bias)
+ self.act_fn = activations[activation]
+
+ def forward(self, x) -> torch.Tensor:
+ return self.down_proj(self.act_fn(self.gate_proj(x))*self.up_proj(x))
+
+class ShardInferenceState:
+ def __init__(
+ self,
+ tokens: Optional[torch.tensor] = None,
+ input_pos: Optional[torch.tensor] = None,
+ mask: Optional[torch.tensor] = None,
+ curr_pos: int = 0,
+ device: torch.device = torch.device("cpu")
+ ):
+ self.tokens = tokens
+ self.input_pos = input_pos
+ self.mask = mask
+ self.curr_pos = curr_pos
+ self.device = device
+
+ def from_dict(self, state_dict):
+ """
+ Data is stored as torch tensors on needed devices
+ """
+ self.tokens = torch.tensor(state_dict["tokens"]).to(self.device)
+ self.input_pos = torch.tensor(state_dict["input_pos"]).to(self.device)
+ self.mask = torch.tensor(state_dict["mask"]).to(self.device)
+ self.curr_pos = state_dict["curr_pos"]
+
+ def to_dict(self) -> dict:
+ return {
+ "tokens": self.tokens.numpy(force=True).tolist(),
+ "input_pos": self.input_pos.numpy(force=True).tolist(),
+ "mask": self.mask.numpy(force=True).tolist(),
+ "curr_pos": self.curr_pos
+ }
+
+ def __str__(self) -> str:
+ return f"""
+ tokens: {self.tokens}
+ input_pos: {self.input_pos}
+ mask: {self.mask}
+ curr_pos: {self.curr_pos}
+ """
+
+def layer_mlp(dim: int, hidden_dim: int) -> FeedForward:
+ """
+ Generalized MLP layer
+ Ref: https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama3_1/_component_builders.py#L124
+ Ref: https://github.com/pytorch/torchtune/blob/main/torchtune/models/qwen2/_component_builders.py#L127C1-L134C82
+ """
+ gate_proj = nn.Linear(dim, hidden_dim, bias=False)
+ down_proj = nn.Linear(hidden_dim, dim, bias=False)
+ up_proj = nn.Linear(dim, hidden_dim, bias=False)
+ return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj)
\ No newline at end of file
diff --git a/exo/inference/torch/sharded_inference_engine.py b/exo/inference/torch/sharded_inference_engine.py
new file mode 100644
index 000000000..a21e2de77
--- /dev/null
+++ b/exo/inference/torch/sharded_inference_engine.py
@@ -0,0 +1,425 @@
+"""
+TorchDynamicShardInferenceEngine
+Sharded inference engine using PyTorch based torchtune models
+"""
+
+import os
+import functools
+from concurrent.futures import ThreadPoolExecutor
+import asyncio
+import uuid
+import re
+from typing import Optional
+
+import numpy as np
+import torch
+import torchtune.generation as ttg
+from transformers import AutoTokenizer
+
+from exo.inference.inference_engine import InferenceEngine
+from exo.download.shard_download import ShardDownloader
+from exo.inference.shard import Shard
+from exo.inference.tokenizers import _resolve_tokenizer
+from exo.helpers import DEBUG
+from exo.inference.torch.models.llm_utils import (
+ load_model_config,
+ load_model_weights_torchtune,
+ ShardInferenceState
+)
+
+from exo.inference.torch.models.general_mha import ShardedGeneralModel
+
+# from torchtune generate recipe
+# https://github.com/pytorch/torchtune/blob/main/recipes/configs/generation.yaml#L40
+TEMP = 0.6
+TOP_K = 35
+
+class TorchDynamicShardInferenceEngine(InferenceEngine):
+ """
+ Pytorch based inferece engine for sharded models
+ """
+ def __init__(self, shard_downloader: ShardDownloader):
+ self.shard = None
+ self.shard_downloader = shard_downloader
+ self.sharded_model = None
+ self.request_id = None
+ self.executor = ThreadPoolExecutor(max_workers=1)
+ self.uuid = str(uuid.uuid4())
+ self.model_path = None
+ self.model_config = None
+ self.state = None
+ self.oom_cnt = 0
+
+ # cache settings
+ self.use_cache = bool(os.getenv("TORCH_USE_CACHE", "True").lower() == "true")
+ self.cache_setup = False
+
+ # device settings
+ if os.environ.get("TORCH_DEVICE"):
+ self.device = torch.device(os.environ["TORCH_DEVICE"])
+ elif torch.cuda.is_available():
+ self.device = torch.device("cuda")
+ elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
+ self.device = torch.device("mps")
+ else:
+ self.device = torch.device("cpu")
+
+ # rng setup for sampling
+ self.rng = torch.Generator(device=self.device)
+ self.rng.manual_seed(1234)
+
+ def setup_cache(self, batch_size: int=1, total_response_length: int=1024):
+ # setup cache
+ # this is needed for a primary node that gets the initial encoding
+ if not self.sharded_model.model.caches_are_enabled() and self.use_cache:
+ with self.device:
+ self.sharded_model.model.setup_caches(
+ batch_size,
+ self.model_config["torch_dtype"],
+ decoder_max_seq_len=total_response_length
+ )
+
+ self.cache_setup = True
+
+
+ def clear_model(self):
+ """
+ Clear out model and shard
+ A way to avoid OOM issues
+
+ All prompts are stored in VRAM
+ while inference engine is up and using the same
+ model class instance, this will clear it for each prompt.
+
+ OOM issue might occur in longer chats/contexts depending on your machine.
+ """
+ if self.sharded_model.model.caches_are_enabled():
+ self.sharded_model.model.reset_caches()
+
+ del self.sharded_model
+ self.sharded_model = None
+
+ if self.device == torch.device("cuda"):
+ torch.cuda.empty_cache()
+
+ self.shard = None
+ self.state = None
+
+ async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
+ if DEBUG >= 4:
+ print("encode called")
+ print(f"shard: {shard}")
+ print(f"prompt: {prompt}")
+
+ if self.sharded_model is not None:
+ print("CLEARING SHARD AND MODEL - ENCODING")
+ self.clear_model()
+
+ await self.ensure_shard(shard)
+
+ def encode_wrapper() -> np.ndarray:
+ """
+ Encode the tensors from prompt along with the
+ initial input_pos and mask
+ """
+ tokens = self.tokenizer.encode(
+ prompt,
+ return_tensors="pt"
+ )
+
+ # move to proper device, default is CPU
+ if tokens.device != self.device:
+ tokens = tokens.to(device=self.device)
+
+ if DEBUG >= 4:
+ print("encoded_wrapper called")
+ print(f"tokens: {tokens}")
+
+ # if going past max, just take from max onward
+ if len(tokens) > self.sharded_model.max_generated_tokens:
+ max_gen_tokens = self.sharded_model.max_generated_tokens
+ tokens = tokens[-max_gen_tokens:]
+
+ self.state.tokens = tokens
+
+ bsz, tklng = tokens.size()
+ total_response_length = tklng + self.sharded_model.max_generated_tokens
+
+ self.setup_cache(bsz, total_response_length)
+
+ # setup max sequence length
+ if not self.sharded_model.model.caches_are_enabled():
+ max_seq_len = total_response_length
+ else:
+ max_seq_len = self.sharded_model.model.decoder_max_cache_seq_len
+
+ # set pad_id
+ if hasattr(self.tokenizer, "pad_id"):
+ pad_id = self.tokenizer.pad_id
+ elif hasattr(self.tokenizer, "pad_token_id"):
+ print(f"pad_token_id: {self.tokenizer.pad_token_id}")
+ if self.tokenizer.pad_token_id is not None:
+ pad_id = self.tokenizer.pad_token_id
+ else:
+ pad_id = 0
+ else:
+ pad_id = 0
+
+ padding_masks = tokens != pad_id
+ if not padding_masks.all():
+ padding_masks = torch.nn.functional.pad(
+ padding_masks,
+ (0, self.sharded_model.max_generated_tokens),
+ value=True,
+ )
+
+ self.state.mask = ttg.get_causal_mask_from_padding_mask(padding_masks, target_seq_len=max_seq_len)
+
+ self.state.input_pos = ttg.get_position_ids_from_padding_mask(padding_masks)
+ else:
+ self.state.mask = torch.tril(torch.ones(
+ total_response_length,
+ max_seq_len,
+ dtype=torch.bool,
+ device=self.device,
+ )).unsqueeze(0)
+
+ self.state.input_pos = torch.arange(0, total_response_length, device=self.device).unsqueeze(0)
+
+ return tokens
+
+ return await asyncio.get_running_loop().run_in_executor(
+ self.executor,
+ functools.partial(encode_wrapper),
+ )
+
+ async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
+ if DEBUG >= 4:
+ print("decode called")
+ print(f"shard: {shard}")
+ print(f"tokens: {tokens}")
+
+ await self.ensure_shard(shard)
+
+ return await asyncio.get_running_loop().run_in_executor(
+ self.executor,
+ functools.partial(self.tokenizer.decode, tokens.tolist()),
+ )
+
+ async def sample(self, x: np.ndarray, temp=TEMP, top_k=TOP_K) -> np.ndarray:
+ if DEBUG >= 4:
+ print("sample called")
+ print(f"x: {x}")
+ print(f"temp: {temp}")
+ print(f"top_k: {top_k}")
+ print(self.device)
+
+ logits = torch.tensor(x).to(self.device)
+
+ def sample_wrapper():
+ q = torch.empty((logits.size(0), self.sharded_model.model.tok_embeddings.num_embeddings), device=logits.device).exponential_(1, generator=self.rng)
+
+ tokens = ttg.sample(logits.clone(), temperature=temp, top_k=top_k, q=q.to(self.device))
+
+ if DEBUG >= 4:
+ print(f"tokens: {tokens}")
+
+ return tokens.numpy(force=True)
+
+ return await asyncio.get_running_loop().run_in_executor(self.executor, functools.partial(sample_wrapper))
+
+ async def infer_tensor(
+ self,
+ request_id: str,
+ shard: Shard,
+ input_data: np.ndarray,
+ inference_state: Optional[dict] = None
+ ) -> tuple[np.ndarray, Optional[dict]]:
+
+ await self.ensure_shard(shard)
+
+ # ensure shard
+ if DEBUG >= 4:
+ print("infer_tensor called")
+ print(f"shard: {shard}")
+ print(f"input_data: {input_data}")
+
+ if inference_state.get("tokens") is not None:
+ self.state.from_dict(inference_state)
+
+ self.request_id = request_id if not self.request_id else self.request_id
+
+ hidden_state = None
+ input_tensor = None
+ if input_data.ndim == 3:
+ hidden_state = torch.tensor(input_data).to(
+ device=self.device,
+ dtype=self.model_config["torch_dtype"]
+ )
+ elif input_data.ndim == 2:
+ input_tensor = torch.tensor(input_data).to(
+ device=self.device
+ )
+
+ if self.use_cache and not self.cache_setup:
+ if input_tensor is not None:
+ bsz, tklng = input_tensor.size()
+ self.setup_cache(
+ bsz,
+ tklng + self.sharded_model.max_generated_tokens
+ )
+ else:
+ bsz, tklng = self.state.tokens.size()
+ self.setup_cache(
+ bsz,
+ tklng + self.sharded_model.max_generated_tokens
+ )
+
+ def infer_wrapper():
+ if DEBUG >= 4:
+ print(f"infer_wrapper called [{self.oom_cnt} OOM]")
+ print(f"self.state.tokens: {self.state.tokens}")
+ print(f"hidden_state: {hidden_state}")
+
+ model_cache = self.sharded_model.model.caches_are_enabled()
+
+ if self.state.tokens is not None:
+ if input_data.ndim == 2 and input_tensor.size(-1) == 1:
+ self.state.tokens = torch.cat([
+ self.state.tokens.to(self.device),
+ input_tensor.clone()
+ ], dim=-1).to(self.device)
+ else:
+ self.state.tokens = input_tensor.clone()
+
+ try:
+ in_tokens = self.state.tokens.clone().to(
+ device=self.device
+ )
+
+ in_input_pos = self.state.input_pos.clone().to(
+ device=self.device
+ )
+
+ in_mask = self.state.mask.clone().to(
+ device=self.device
+ )
+
+ if hidden_state is not None:
+ model_hs, model_logits = self.sharded_model.generate(
+ tokens=in_tokens,
+ hidden_state=hidden_state,
+ input_pos=in_input_pos,
+ mask=in_mask,
+ curr_pos=self.state.curr_pos
+ )
+ else:
+ if not model_cache:
+ model_hs, model_logits = self.sharded_model.generate(
+ tokens=in_tokens,
+ input_pos=in_input_pos,
+ mask=in_mask,
+ curr_pos=self.state.curr_pos
+ )
+ else:
+ model_hs, model_logits = self.sharded_model.generate(
+ tokens=input_tensor,
+ input_pos=in_input_pos,
+ mask=in_mask,
+ curr_pos=self.state.curr_pos
+ )
+ except torch.cuda.OutOfMemoryError:
+ print(f"OOM on cuda, clearing model and stopping")
+ self.oom_cnt += 1
+ self.clear_model()
+ return
+ except Exception as err:
+ print(f"infer_tensor err\n{err}")
+ raise
+
+ if model_hs is not None:
+ # numpy current no support for bf16
+ if model_hs.dtype == torch.bfloat16:
+ model_hs = model_hs.float()
+
+ if DEBUG >= 4:
+ print("sending hidden states")
+ print(f"model_hs: {model_hs.size()}")
+ print(f"state.tokens: {self.state.tokens}")
+ print(f"state.input_pos: {self.state.input_pos.size()}")
+ print(f"state.mask: {self.state.mask.size()}")
+
+ return (
+ model_hs.numpy(force=True),
+ self.state.to_dict(),
+ )
+
+ if self.state.curr_pos == 0:
+ self.state.curr_pos = self.state.tokens.size(-1)
+ else:
+ self.state.curr_pos += 1
+
+ # numpy current no support for bf16
+ if model_logits.dtype == torch.bfloat16:
+ model_logits = model_logits.float()
+
+ return (
+ model_logits[:, -1].numpy(force=True),
+ self.state.to_dict(),
+ )
+
+ return await asyncio.get_running_loop().run_in_executor(self.executor, infer_wrapper)
+
+ async def ensure_shard(self, shard: Shard):
+ if DEBUG >= 4:
+ print("shard ensured\n")
+ print(f"shard: {shard}")
+ print(f"class shard: {self.shard}")
+ print(f"uuid: {self.uuid}")
+
+ # reset model after last layer to fix OOM
+ if self.shard == shard:
+ return
+
+ self.shard = shard
+
+ # Using CPU to store inference state
+ self.state = ShardInferenceState()
+
+ # download model safetensors and shard
+
+ self.model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
+ self.model_config = load_model_config(self.model_path/"config.json")
+
+ # self.tokenizer = await _resolve_tokenizer(model_path)
+ self.tokenizer = await _resolve_tokenizer(self.model_path)
+
+ def start_model():
+ if DEBUG >= 4:
+ print("start_model called")
+
+ self.sharded_model = ShardedGeneralModel(
+ config=self.model_config,
+ shard=shard,
+ device=self.device,
+ dtype=self.model_config["torch_dtype"],
+ use_cache=self.use_cache
+ )
+
+ load_model_weights_torchtune(
+ cache_dir=self.model_path,
+ shard=self.shard,
+ model=self.sharded_model,
+ num_heads=self.model_config["num_heads"],
+ num_kv_heads=self.model_config["num_kv_heads"],
+ dim=self.model_config["embed_dim"],
+ head_dim=self.model_config["head_dim"]
+ )
+
+ await asyncio.get_running_loop().run_in_executor(
+ self.executor,
+ functools.partial(start_model),
+ )
+
+ async def load_checkpoint(self, shard: Shard, path: str):
+ await self.ensure_shard(shard)
diff --git a/exo/inference/torch/tests/__init__.py b/exo/inference/torch/tests/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/exo/inference/torch/tests/test_inference_engine.py b/exo/inference/torch/tests/test_inference_engine.py
new file mode 100644
index 000000000..cf1f06179
--- /dev/null
+++ b/exo/inference/torch/tests/test_inference_engine.py
@@ -0,0 +1,54 @@
+"""
+Test inference engine and model sharding
+"""
+import pytest
+import asyncio
+
+from exo.inference.shard import Shard
+from exo.inference.torch.sharded_inference_engine import TorchDynamicShardInferenceEngine
+from exo.download.hf.hf_shard_download import HFShardDownloader
+
+import numpy as np
+
+@pytest.mark.asyncio
+async def test_inference_engine():
+ prompt = "In a single word only, what is the last name of the current president of the USA?"
+
+ shard = Shard(
+ model_id="llama-3.2-1b",
+ start_layer=0,
+ end_layer=8,
+ n_layers=16
+ )
+
+ shard_2 = Shard(
+ model_id="llama-3.2-1b",
+ start_layer=9,
+ end_layer=15,
+ n_layers= 16
+ )
+
+ inference_engine = TorchDynamicShardInferenceEngine(HFShardDownloader())
+
+ output_1 = await inference_engine.infer_prompt("test_id", shard, prompt)
+ print("\n------------inference_engine.infer_prompt output---------------\n")
+ print(output_1)
+ print("\n---------------------------\n")
+
+ assert isinstance(output_1, np.ndarray), "Output should be numpy array"
+
+ output_2 = await inference_engine.infer_tensor("test_id", shard, output_1)
+ print("\n------------inference_engine.infer_tensor output---------------\n")
+ print(output_2)
+ print("\n---------------------------\n")
+
+ assert isinstance(output_2, np.ndarray), "Output should be numpy array"
+
+if __name__ == '__main__':
+ try:
+ print("\n\n -------- TEST llama-3.2-1b -------- \n\n")
+ asyncio.run(test_inference_engine())
+ except Exception as err:
+ print(f"\n!!!! TEST FAILED \n{err}\n")
+
+
diff --git a/exo/inference/torch/tests/test_llama3_full.py b/exo/inference/torch/tests/test_llama3_full.py
new file mode 100644
index 000000000..ff1b62327
--- /dev/null
+++ b/exo/inference/torch/tests/test_llama3_full.py
@@ -0,0 +1,235 @@
+"""
+Test of pytorch based llama3 models
+full layer run
+"""
+
+from pathlib import Path
+import torch
+from huggingface_hub import snapshot_download
+
+import torchtune.generation as ttg
+from torchtune.models import llama3
+from torchtune.data import Message
+
+from transformers import AutoTokenizer
+
+from exo.inference.torch.models.general_mha import ShardedGeneralModel
+from exo.inference.shard import Shard
+
+from exo.inference.torch.models.llm_utils import (
+ load_model_config,
+ load_weights_torch,
+ load_model_weights_torchtune
+)
+
+MODEL_NAME = "unsloth/Llama-3.2-1B-Instruct"
+# MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
+TEMP = 0.85
+TOP_K = 35
+MAX_NEW_TOKENS = 200
+RAND_SEED = 42
+
+
+def main(model, prompt: str, device: torch.device = torch.device("cpu"), dtype: torch.dtype = torch.bfloat16):
+ messages = [{
+ "role": "assistant",
+ "content": "",
+ }, {
+ "role": "user",
+ "content": prompt,
+ }]
+
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+ tok_out = tokenizer([text], return_tensors="pt")
+ print(f"tok_out: {tok_out}")
+ tokens = tok_out.input_ids.to(device=device, dtype=torch.int)
+
+ rng = torch.Generator(device=device)
+ rng.manual_seed(RAND_SEED)
+
+ generated_tokens = tokens.clone()
+
+ print(f"tokens: {tokens}")
+
+ bsz, tokens_length = tokens.size()
+
+ # using self.max_seq_len will take up alot of VRAM
+ total_response_length = tokens_length + MAX_NEW_TOKENS
+
+ # setup cache
+ if not model.model.caches_are_enabled():
+ with device:
+ model.model.setup_caches(
+ bsz,
+ dtype,
+ decoder_max_seq_len=total_response_length
+ )
+
+ if not model.model.caches_are_enabled():
+ max_seq_len = total_response_length
+ else:
+ max_seq_len = model.model.decoder_max_cache_seq_len
+
+ # masking for proper attention
+
+ # select correct pad_id
+ if hasattr(tokenizer, "pad_id"):
+ pad_id = tokenizer.pad_id
+ elif hasattr(tokenizer, "pad_token_id"):
+ print(f"pad_token_id: {tokenizer.pad_token_id}")
+ if tokenizer.pad_token_id is not None:
+ pad_id = tokenizer.pad_token_id
+ else:
+ pad_id = 0
+ else:
+ pad_id = 0
+
+ print(f"pad_id: {pad_id}")
+
+ padding_masks = tokens != pad_id
+ if not padding_masks.all():
+ padding_masks = torch.nn.functional.pad(
+ padding_masks,
+ (0, MAX_NEW_TOKENS),
+ value=True,
+ )
+
+ mask = ttg.get_causal_mask_from_padding_mask(padding_masks, target_seq_len=max_seq_len)
+
+ input_pos = ttg.get_position_ids_from_padding_mask(padding_masks)
+ else:
+ mask = torch.tril(torch.ones(
+ total_response_length,
+ max_seq_len,
+ dtype=torch.bool,
+ device=device,
+ )).unsqueeze(0)
+
+ input_pos = torch.arange(0, total_response_length, device=device).unsqueeze(0)
+
+ print(f"init mask: {mask}")
+ print(f"init input_pos: {input_pos}")
+
+ curr_pos = 0
+
+ _, logits = model.generate(
+ tokens=tokens,
+ mask=mask,
+ input_pos=input_pos,
+ curr_pos=curr_pos
+ )
+
+ curr_pos = tokens_length
+
+ q = torch.empty((
+ logits.size(0),
+ model.model.tok_embeddings.num_embeddings
+ ), device=logits.device).exponential_(1, generator=rng)
+
+ tokens = ttg.sample(
+ logits=logits[:, -1].clone(),
+ temperature=TEMP,
+ top_k=TOP_K,
+ q=q
+ )
+
+ print(f"tokens: {tokens}")
+
+ for i in range(MAX_NEW_TOKENS - 1):
+ print(f"gen #{i+1}")
+
+ if tokens.item() == tokenizer.eos_token_id:
+ print("stop token hit!")
+ break
+
+ tokens = tokens.view(1, -1).to(device=device) if tokens.ndim == 1 else tokens
+
+ _, logits = model.generate(
+ tokens=tokens,
+ input_pos=input_pos,
+ mask=mask,
+ curr_pos=curr_pos
+ )
+
+ curr_pos += 1
+
+ q = torch.empty(
+ (
+ logits.size(0),
+ model.model.tok_embeddings.num_embeddings
+ ), device=logits.device).exponential_(1, generator=rng)
+
+ tokens = ttg.sample(
+ logits=logits[:, -1].clone(),
+ temperature=TEMP,
+ top_k=TOP_K,
+ q=q,
+ )
+
+ print(f"tokens: {tokens}")
+
+ generated_tokens = torch.cat([generated_tokens, tokens], dim=-1)
+ print(f"generated_tokens: {generated_tokens}")
+
+ if not model.model.caches_are_enabled():
+ tokens = generated_tokens.clone()
+
+ print(f"\n\n[resp from model]\n\n{tokenizer.decode(generated_tokens.tolist()[0])}\n\n\n")
+
+
+if __name__ == "__main__":
+ # prompt = "Hello, how are you?"
+ prompt = "Tell me a joke."
+ # prompt = "What is the meaning of exo?"
+ # prompt = "Tell me a short 4 line haiku"
+ # prompt = "In a single word only, what is the last name of the current president of the USA?"
+
+ # Get the path to the model files from the Hugging Face cache
+ cache_dir = Path(snapshot_download(MODEL_NAME))
+ print(f"Cache directory: {cache_dir}")
+
+ # Load model configuration
+ config = load_model_config(cache_dir/"config.json")
+
+ print(f"current config\n{config}")
+
+ # Setup shard
+ n_layers = int(config["num_layers"])
+ shard_1 = Shard(
+ model_id=MODEL_NAME,
+ start_layer=0,
+ end_layer=n_layers - 1,
+ n_layers=n_layers,
+ )
+
+ # Initialize tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(cache_dir)
+
+ # Initialize LlamaModel with config and tokenizer
+ # device = torch.device("cuda")
+ dtype = torch.bfloat16
+ device = torch.device("cpu")
+ shard_model_1 = ShardedGeneralModel(
+ config=config,
+ shard=shard_1,
+ device=device,
+ dtype=config["torch_dtype"],
+ use_cache=True,
+ max_generated_tokens=MAX_NEW_TOKENS,
+ )
+
+ print(f"\nshard_model_1: {shard_model_1}")
+
+ # load_model_weights_torchtune(cache_dir, shard_1, shard_model_1)
+ load_model_weights_torchtune(
+ cache_dir=cache_dir,
+ shard=shard_1,
+ model=shard_model_1,
+ num_heads=config["num_heads"],
+ num_kv_heads=config["num_kv_heads"],
+ dim=config["embed_dim"],
+ head_dim=config["head_dim"]
+ )
+
+ import time
+ main(shard_model_1, prompt, device, config["torch_dtype"])
diff --git a/exo/inference/torch/tests/test_mistral_full.py b/exo/inference/torch/tests/test_mistral_full.py
new file mode 100644
index 000000000..54c0c8dee
--- /dev/null
+++ b/exo/inference/torch/tests/test_mistral_full.py
@@ -0,0 +1,236 @@
+"""
+Test of pytorch based mistral models
+full layer run
+"""
+
+from pathlib import Path
+import torch
+from huggingface_hub import snapshot_download
+
+import torchtune.generation as ttg
+
+from transformers import AutoTokenizer
+
+from exo.inference.torch.models.general_mha import ShardedGeneralModel
+from exo.inference.shard import Shard
+
+from exo.inference.torch.models.llm_utils import (
+ load_model_config,
+ load_model_weights_torchtune
+)
+
+MODEL_NAME = "unsloth/mistral-7b-instruct-v0.3-bnb-4bit"
+TEMP = 0.85
+TOP_K = 35
+MAX_NEW_TOKENS = 200
+RAND_SEED = 42
+
+
+def main(
+ model,
+ prompt: str,
+ device: torch.device = torch.device("cpu"),
+ dtype: torch.dtype = torch.bfloat16
+):
+ messages = [{
+ "role": "assistant",
+ "content": "",
+ }, {
+ "role": "user",
+ "content": prompt,
+ }]
+
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+ tok_out = tokenizer([text], return_tensors="pt")
+ print(f"tok_out: {tok_out}")
+ tokens = tok_out.input_ids.to(device=device, dtype=torch.int)
+
+ rng = torch.Generator(device=device)
+ rng.manual_seed(RAND_SEED)
+
+ generated_tokens = tokens.clone()
+
+ print(f"tokens: {tokens}")
+
+ bsz, tokens_length = tokens.size()
+
+ # using self.max_seq_len will take up alot of VRAM
+ total_response_length = tokens_length + MAX_NEW_TOKENS
+
+ # setup cache
+ if not model.model.caches_are_enabled():
+ with device:
+ model.model.setup_caches(
+ bsz,
+ dtype,
+ decoder_max_seq_len=total_response_length
+ )
+
+ if not model.model.caches_are_enabled():
+ max_seq_len = total_response_length
+ else:
+ max_seq_len = model.model.decoder_max_cache_seq_len
+
+ # masking for proper attention
+
+ # select correct pad_id
+ if hasattr(tokenizer, "pad_id"):
+ pad_id = tokenizer.pad_id
+ elif hasattr(tokenizer, "pad_token_id"):
+ print(f"pad_token_id: {tokenizer.pad_token_id}")
+ if tokenizer.pad_token_id is not None:
+ pad_id = tokenizer.pad_token_id
+ else:
+ pad_id = 0
+ else:
+ pad_id = 0
+
+ print(f"pad_id: {pad_id}")
+
+ padding_masks = tokens != pad_id
+ if not padding_masks.all():
+ padding_masks = torch.nn.functional.pad(
+ padding_masks,
+ (0, MAX_NEW_TOKENS),
+ value=True,
+ )
+
+ mask = ttg.get_causal_mask_from_padding_mask(padding_masks, target_seq_len=max_seq_len)
+
+ input_pos = ttg.get_position_ids_from_padding_mask(padding_masks)
+ else:
+ mask = torch.tril(torch.ones(
+ total_response_length,
+ max_seq_len,
+ dtype=torch.bool,
+ device=device,
+ )).unsqueeze(0)
+
+ input_pos = torch.arange(0, total_response_length, device=device).unsqueeze(0)
+
+ print(f"init mask: {mask}")
+ print(f"init input_pos: {input_pos}")
+
+ curr_pos = 0
+
+ _, logits = model.generate(
+ tokens=tokens,
+ mask=mask,
+ input_pos=input_pos,
+ curr_pos=curr_pos
+ )
+
+ curr_pos = tokens_length
+
+ q = torch.empty((
+ logits.size(0),
+ model.model.tok_embeddings.num_embeddings
+ ), device=logits.device).exponential_(1, generator=rng)
+
+ tokens = ttg.sample(
+ logits=logits[:, -1].clone(),
+ temperature=TEMP,
+ top_k=TOP_K,
+ q=q
+ )
+
+ print(f"tokens: {tokens}")
+
+ for i in range(MAX_NEW_TOKENS - 1):
+ print(f"gen #{i+1}")
+
+ if tokens.item() == tokenizer.eos_token_id:
+ print("stop token hit!")
+ break
+
+ tokens = tokens.view(1, -1).to(device=device) if tokens.ndim == 1 else tokens
+
+ _, logits = model.generate(
+ tokens=tokens,
+ input_pos=input_pos,
+ mask=mask,
+ curr_pos=curr_pos
+ )
+
+ curr_pos += 1
+
+ q = torch.empty(
+ (
+ logits.size(0),
+ model.model.tok_embeddings.num_embeddings
+ ), device=logits.device).exponential_(1, generator=rng)
+
+ tokens = ttg.sample(
+ logits=logits[:, -1].clone(),
+ temperature=TEMP,
+ top_k=TOP_K,
+ q=q,
+ )
+
+ print(f"tokens: {tokens}")
+
+ generated_tokens = torch.cat([generated_tokens, tokens], dim=-1)
+ print(f"generated_tokens: {generated_tokens}")
+
+ if not model.model.caches_are_enabled():
+ tokens = generated_tokens.clone()
+
+ print(f"\n\n[resp from model]\n\n{tokenizer.decode(generated_tokens.tolist()[0])}\n\n\n")
+
+
+if __name__ == "__main__":
+ # prompt = "Hello, how are you?"
+ prompt = "Tell me a joke."
+ # prompt = "What is the meaning of exo?"
+ # prompt = "Tell me a short 4 line haiku"
+ # prompt = "In a single word only, what is the last name of the current president of the USA?"
+
+ # Get the path to the model files from the Hugging Face cache
+ cache_dir = Path(snapshot_download(MODEL_NAME))
+ print(f"Cache directory: {cache_dir}")
+
+ # Load model configuration
+ config = load_model_config(cache_dir/"config.json")
+
+ print(f"current config\n{config}")
+
+ # Setup shard
+ n_layers = int(config["num_layers"])
+ shard_1 = Shard(
+ model_id=MODEL_NAME,
+ start_layer=0,
+ end_layer=n_layers - 1,
+ n_layers=n_layers,
+ )
+
+ # Initialize tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(cache_dir)
+
+ # Initialize LlamaModel with config and tokenizer
+# device = torch.device("cuda")
+ dtype = torch.bfloat16
+ device = torch.device("cpu")
+
+ shard_model_1 = ShardedGeneralModel(
+ config=config,
+ shard=shard_1,
+ device=device,
+ dtype=config["torch_dtype"],
+ use_cache=True,
+ max_generated_tokens=MAX_NEW_TOKENS,
+ )
+
+ print(f"\nshard_model_1: {shard_model_1}")
+
+ # load_model_weights_torchtune(cache_dir, shard_1, shard_model_1)
+ load_model_weights_torchtune(
+ cache_dir=cache_dir,
+ shard=shard_1,
+ model=shard_model_1,
+ num_heads=config["num_heads"],
+ num_kv_heads=config["num_kv_heads"],
+ dim=config["embed_dim"],
+ head_dim=config["head_dim"]
+ )
+
+ main(shard_model_1, prompt, device, config["torch_dtype"])
diff --git a/exo/inference/torch/tests/test_qwen_full.py b/exo/inference/torch/tests/test_qwen_full.py
new file mode 100644
index 000000000..2e4f55d12
--- /dev/null
+++ b/exo/inference/torch/tests/test_qwen_full.py
@@ -0,0 +1,236 @@
+"""
+Test of pytorch based qwen2 models
+full layer run
+"""
+
+from pathlib import Path
+import torch
+from huggingface_hub import snapshot_download
+
+import torchtune.generation as ttg
+
+from transformers import AutoTokenizer
+
+from exo.inference.torch.models.general_mha import ShardedGeneralModel
+from exo.inference.shard import Shard
+
+from exo.inference.torch.models.llm_utils import (
+ load_model_config,
+ load_model_weights_torchtune
+)
+
+MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
+TEMP = 0.85
+TOP_K = 35
+MAX_NEW_TOKENS = 200
+RAND_SEED = 42
+
+
+def main(
+ model,
+ prompt: str,
+ device: torch.device = torch.device("cpu"),
+ dtype: torch.dtype = torch.bfloat16
+):
+ messages = [{
+ "role": "assistant",
+ "content": "",
+ }, {
+ "role": "user",
+ "content": prompt,
+ }]
+
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+ tok_out = tokenizer([text], return_tensors="pt")
+ print(f"tok_out: {tok_out}")
+ tokens = tok_out.input_ids.to(device=device, dtype=torch.int)
+
+ rng = torch.Generator(device=device)
+ rng.manual_seed(RAND_SEED)
+
+ generated_tokens = tokens.clone()
+
+ print(f"tokens: {tokens}")
+
+ bsz, tokens_length = tokens.size()
+
+ # using self.max_seq_len will take up alot of VRAM
+ total_response_length = tokens_length + MAX_NEW_TOKENS
+
+ # setup cache
+ if not model.model.caches_are_enabled():
+ with device:
+ model.model.setup_caches(
+ bsz,
+ dtype,
+ decoder_max_seq_len=total_response_length
+ )
+
+ if not model.model.caches_are_enabled():
+ max_seq_len = total_response_length
+ else:
+ max_seq_len = model.model.decoder_max_cache_seq_len
+
+ # masking for proper attention
+
+ # select correct pad_id
+ if hasattr(tokenizer, "pad_id"):
+ pad_id = tokenizer.pad_id
+ elif hasattr(tokenizer, "pad_token_id"):
+ print(f"pad_token_id: {tokenizer.pad_token_id}")
+ if tokenizer.pad_token_id is not None:
+ pad_id = tokenizer.pad_token_id
+ else:
+ pad_id = 0
+ else:
+ pad_id = 0
+
+ print(f"pad_id: {pad_id}")
+
+ padding_masks = tokens != pad_id
+ if not padding_masks.all():
+ padding_masks = torch.nn.functional.pad(
+ padding_masks,
+ (0, MAX_NEW_TOKENS),
+ value=True,
+ )
+
+ mask = ttg.get_causal_mask_from_padding_mask(padding_masks, target_seq_len=max_seq_len)
+
+ input_pos = ttg.get_position_ids_from_padding_mask(padding_masks)
+ else:
+ mask = torch.tril(torch.ones(
+ total_response_length,
+ max_seq_len,
+ dtype=torch.bool,
+ device=device,
+ )).unsqueeze(0)
+
+ input_pos = torch.arange(0, total_response_length, device=device).unsqueeze(0)
+
+ print(f"init mask: {mask}")
+ print(f"init input_pos: {input_pos}")
+
+ curr_pos = 0
+
+ _, logits = model.generate(
+ tokens=tokens,
+ mask=mask,
+ input_pos=input_pos,
+ curr_pos=curr_pos
+ )
+
+ curr_pos = tokens_length
+
+ q = torch.empty((
+ logits.size(0),
+ model.model.tok_embeddings.num_embeddings
+ ), device=logits.device).exponential_(1, generator=rng)
+
+ tokens = ttg.sample(
+ logits=logits[:, -1].clone(),
+ temperature=TEMP,
+ top_k=TOP_K,
+ q=q
+ )
+
+ print(f"tokens: {tokens}")
+
+ for i in range(MAX_NEW_TOKENS - 1):
+ print(f"gen #{i+1}")
+
+ if tokens.item() == tokenizer.eos_token_id:
+ print("stop token hit!")
+ break
+
+ tokens = tokens.view(1, -1).to(device=device) if tokens.ndim == 1 else tokens
+
+ _, logits = model.generate(
+ tokens=tokens,
+ input_pos=input_pos,
+ mask=mask,
+ curr_pos=curr_pos
+ )
+
+ curr_pos += 1
+
+ q = torch.empty(
+ (
+ logits.size(0),
+ model.model.tok_embeddings.num_embeddings
+ ), device=logits.device).exponential_(1, generator=rng)
+
+ tokens = ttg.sample(
+ logits=logits[:, -1].clone(),
+ temperature=TEMP,
+ top_k=TOP_K,
+ q=q,
+ )
+
+ print(f"tokens: {tokens}")
+
+ generated_tokens = torch.cat([generated_tokens, tokens], dim=-1)
+ print(f"generated_tokens: {generated_tokens}")
+
+ if not model.model.caches_are_enabled():
+ tokens = generated_tokens.clone()
+
+ print(f"\n\n[resp from model]\n\n{tokenizer.decode(generated_tokens.tolist()[0])}\n\n\n")
+
+
+if __name__ == "__main__":
+ # prompt = "Hello, how are you?"
+ prompt = "Tell me a joke."
+ # prompt = "What is the meaning of exo?"
+ # prompt = "Tell me a short 4 line haiku"
+ # prompt = "In a single word only, what is the last name of the current president of the USA?"
+
+ # Get the path to the model files from the Hugging Face cache
+ cache_dir = Path(snapshot_download(MODEL_NAME))
+ print(f"Cache directory: {cache_dir}")
+
+ # Load model configuration
+ config = load_model_config(cache_dir/"config.json")
+
+ print(f"current config\n{config}")
+
+ # Setup shard
+ n_layers = int(config["num_layers"])
+ shard_1 = Shard(
+ model_id=MODEL_NAME,
+ start_layer=0,
+ end_layer=n_layers - 1,
+ n_layers=n_layers,
+ )
+
+ # Initialize tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(cache_dir)
+
+ # Initialize LlamaModel with config and tokenizer
+# device = torch.device("cuda")
+ dtype = torch.bfloat16
+ device = torch.device("cpu")
+
+ shard_model_1 = ShardedGeneralModel(
+ config=config,
+ shard=shard_1,
+ device=device,
+ dtype=config["torch_dtype"],
+ use_cache=True,
+ max_generated_tokens=MAX_NEW_TOKENS,
+ )
+
+ print(f"\nshard_model_1: {shard_model_1}")
+
+ # load_model_weights_torchtune(cache_dir, shard_1, shard_model_1)
+ load_model_weights_torchtune(
+ cache_dir=cache_dir,
+ shard=shard_1,
+ model=shard_model_1,
+ num_heads=config["num_heads"],
+ num_kv_heads=config["num_kv_heads"],
+ dim=config["embed_dim"],
+ head_dim=config["head_dim"]
+ )
+
+ main(shard_model_1, prompt, device, config["torch_dtype"])
diff --git a/exo/main.py b/exo/main.py
index 51abe858e..38094366e 100644
--- a/exo/main.py
+++ b/exo/main.py
@@ -81,7 +81,7 @@ def configure_uvloop():
parser.add_argument("--chatgpt-api-port", type=int, default=52415, help="ChatGPT API port")
parser.add_argument("--chatgpt-api-response-timeout", type=int, default=900, help="ChatGPT API response timeout in seconds")
parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max tokens to generate in each request")
-parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (mlx, tinygrad, or dummy)")
+parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (torch, mlx, tinygrad, or dummy)")
parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
@@ -103,6 +103,8 @@ def configure_uvloop():
inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
print(f"Inference engine name after selection: {inference_engine_name}")
+os.environ["EXO_INFER_ENGINE"] = inference_engine_name
+
inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
@@ -149,7 +151,9 @@ def configure_uvloop():
elif args.discovery_module == "manual":
if not args.discovery_config_path:
raise ValueError(f"--discovery-config-path is required when using manual discovery. Please provide a path to a config json file.")
- discovery = ManualDiscovery(args.discovery_config_path, args.node_id, create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
+ discovery = ManualDiscovery(
+ args.discovery_config_path, args.node_id, create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities)
+ )
topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
node = Node(
args.node_id,
@@ -250,23 +254,26 @@ def on_token(_request_id, _tokens, _is_finished):
finally:
node.on_token.deregister(callback_id)
+
def clean_path(path):
- """Clean and resolve path"""
- if path.startswith("Optional("):
- path = path.strip('Optional("').rstrip('")')
- return os.path.expanduser(path)
+ """Clean and resolve path"""
+ if path.startswith("Optional("):
+ path = path.strip('Optional("').rstrip('")')
+ return os.path.expanduser(path)
+
async def hold_outstanding(node: Node):
while node.outstanding_requests:
await asyncio.sleep(.5)
return
+
async def run_iter(node: Node, shard: Shard, train: bool, data, batch_size=1):
losses = []
tokens = []
for batch in tqdm(iterate_batches(data, batch_size), total=len(data) // batch_size):
_, _, lengths = batch
- losses.append(np.sum(lengths * await node.enqueue_example(shard, *batch, train=train)))
+ losses.append(np.sum(lengths*await node.enqueue_example(shard, *batch, train=train)))
tokens.append(np.sum(lengths))
total_tokens = np.sum(tokens)
total_loss = np.sum(losses) / total_tokens
@@ -333,7 +340,7 @@ async def main():
def restore_cursor():
if platform.system() != "Windows":
- os.system("tput cnorm") # Show cursor
+ os.system("tput cnorm") # Show cursor
# Restore the cursor when the program exits
atexit.register(restore_cursor)
@@ -356,8 +363,7 @@ def handle_exit():
await run_model_cli(node, model_name, args.prompt)
elif args.command == "eval" or args.command == 'train':
model_name = args.model_name
- dataloader = lambda tok: load_dataset(args.data, preprocess=lambda item: tok(item)
- , loadline=lambda line: json.loads(line).get("text",""))
+ dataloader = lambda tok: load_dataset(args.data, preprocess=lambda item: tok(item), loadline=lambda line: json.loads(line).get("text", ""))
if args.command == 'eval':
if not model_name:
print("Error: Much like a human, I can't evaluate anything without a model")
diff --git a/exo/models.py b/exo/models.py
index b64bcca10..6b69cdab2 100644
--- a/exo/models.py
+++ b/exo/models.py
@@ -6,8 +6,8 @@
"llama-3.3-70b": {
"layers": 80,
"repo": {
- "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.3-70B-Instruct-4bit",
- "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.3-70B-Instruct",
+ "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.3-70B-Instruct-4bit",
+ "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.3-70B-Instruct",
},
},
"llama-3.2-1b": {
@@ -15,6 +15,7 @@
"repo": {
"MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-4bit",
"TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct",
+ "TorchDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct"
},
},
"llama-3.2-1b-8bit": {
@@ -22,69 +23,93 @@
"repo": {
"MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-8bit",
"TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct",
+ "TorchDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct"
},
},
"llama-3.2-3b": {
"layers": 28,
"repo": {
- "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct-4bit",
- "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct",
+ "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct-4bit",
+ "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct",
+ "TorchDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct"
},
},
"llama-3.2-3b-8bit": {
"layers": 28,
"repo": {
- "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct-8bit",
- "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct",
+ "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct-8bit",
+ "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct",
+ "TorchDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct",
},
},
"llama-3.2-3b-bf16": {
"layers": 28,
"repo": {
- "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct",
- "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct",
+ "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct",
+ "TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct",
+ "TorchDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct",
},
},
"llama-3.1-8b": {
"layers": 32,
"repo": {
- "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
- "TinygradDynamicShardInferenceEngine": "mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated",
+ "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
+ "TinygradDynamicShardInferenceEngine": "mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated",
+ "TorchDynamicShardInferenceEngine": "unsloth/Meta-Llama-3.1-8B-Instruct",
},
},
"llama-3.1-70b": {
"layers": 80,
"repo": {
- "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit",
- "TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct",
+ "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit",
+ "TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct",
+ "TorchDynamicShardInferenceEngine": "unsloth/Meta-Llama-3.1-70B-Instruct",
},
},
"llama-3.1-70b-bf16": {
"layers": 80,
"repo": {
- "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-bf16-CORRECTED",
- "TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct",
+ "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-bf16-CORRECTED",
+ "TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct",
+ "TorchDynamicShardInferenceEngine": "unsloth/Meta-Llama-3.1-70B-Instruct",
},
},
"llama-3-8b": {
"layers": 32,
"repo": {
- "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
- "TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
+ "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
+ "TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
+ "TorchDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
},
},
"llama-3-70b": {
"layers": 80,
"repo": {
- "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-70B-Instruct-4bit",
- "TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R",
+ "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-70B-Instruct-4bit",
+ "TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R",
+ "TorchDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R",
},
},
- "llama-3.1-405b": { "layers": 126, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-405B-4bit", }, },
- "llama-3.1-405b-8bit": { "layers": 126, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", }, },
+ "llama-3.1-405b": {
+ "layers": 126,
+ "repo": {
+ "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-405B-4bit",
+ "TorchDynamicShardInferenceEngine": "unsloth/Meta-Llama-3.1-405B-Instruct-bnb-4bit",
+ },
+ },
+ "llama-3.1-405b-8bit": {
+ "layers": 126,
+ "repo": {"MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit",},
+ },
### mistral
- "mistral-nemo": { "layers": 40, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Mistral-Nemo-Instruct-2407-4bit", }, },
- "mistral-large": { "layers": 88, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Mistral-Large-Instruct-2407-4bit", }, },
+ "mistral-nemo": {
+ "layers": 40,
+ "repo": {"MLXDynamicShardInferenceEngine": "mlx-community/Mistral-Nemo-Instruct-2407-4bit",},
+ },
+ "mistral-large": {
+ "layers": 88,
+ "repo": {"MLXDynamicShardInferenceEngine": "mlx-community/Mistral-Large-Instruct-2407-4bit",},
+ },
### deepseek
"deepseek-coder-v2-lite": { "layers": 27, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", }, },
"deepseek-coder-v2.5": { "layers": 60, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", }, },
@@ -123,35 +148,143 @@
"deepseek-r1-distill-llama-70b-6bit": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-6bit", }, },
"deepseek-r1-distill-llama-70b-8bit": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-8bit", }, },
### llava
- "llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, },
+ "llava-1.5-7b-hf": {
+ "layers": 32,
+ "repo": {"MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf",},
+ },
### qwen
- "qwen-2.5-0.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-0.5B-Instruct-4bit", }, },
- "qwen-2.5-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-1.5B-Instruct-4bit", }, },
- "qwen-2.5-coder-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", }, },
- "qwen-2.5-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-3B-Instruct-4bit", }, },
- "qwen-2.5-coder-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", }, },
- "qwen-2.5-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-7B-Instruct-4bit", }, },
- "qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
- "qwen-2.5-math-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-7B-Instruct-4bit", }, },
- "qwen-2.5-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-14B-Instruct-4bit", }, },
- "qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, },
- "qwen-2.5-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-32B-Instruct-4bit", }, },
- "qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, },
- "qwen-2.5-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-72B-Instruct-4bit", }, },
- "qwen-2.5-math-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-72B-Instruct-4bit", }, },
+ "qwen-2.5-0.5b": {
+ "layers": 28,
+ "repo": {
+ "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-0.5B-Instruct-4bit",
+ "TorchDynamicShardInferenceEngine": "Qwen/Qwen2.5-0.5B-Instruct"
+ },
+ },
+ "qwen-2.5-1.5b": {
+ "layers": 28,
+ "repo": {
+ "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-1.5B-Instruct-4bit",
+ "TorchDynamicShardInferenceEngine": "Qwen/Qwen2.5-1.5B-Instruct"
+ },
+ },
+ "qwen-2.5-coder-1.5b": {
+ "layers": 28,
+ "repo": {
+ "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit",
+ "TorchDynamicShardInferenceEngine": "Qwen/Qwen2.5-Coder-1.5B-Instruct"
+ },
+ },
+ "qwen-2.5-3b": {
+ "layers": 36,
+ "repo": {
+ "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-3B-Instruct-4bit",
+ "TorchDynamicShardInferenceEngine": "Qwen/Qwen2.5-3B-Instruct"
+ },
+ },
+ "qwen-2.5-coder-3b": {
+ "layers": 36,
+ "repo": {
+ "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit",
+ "TorchDynamicShardInferenceEngine": "Qwen/Qwen2.5-Coder-3B-Instruct"
+ },
+ },
+ "qwen-2.5-7b": {
+ "layers": 28,
+ "repo": {
+ "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-7B-Instruct-4bit",
+ "TorchDynamicShardInferenceEngine": "Qwen/Qwen2.5-7B-Instruct"
+ },
+ },
+ "qwen-2.5-coder-7b": {
+ "layers": 28,
+ "repo": {
+ "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit",
+ "TorchDynamicShardInferenceEngine": "Qwen/Qwen2.5-Coder-7B-Instruct"
+ },
+ },
+ "qwen-2.5-math-7b": {
+ "layers": 28,
+ "repo": {
+ "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-7B-Instruct-4bit",
+ "TorchDynamicShardInferenceEngine": "Qwen/Qwen2.5-Math-7B-Instruct"
+ },
+ },
+ "qwen-2.5-14b": {
+ "layers": 48,
+ "repo": {
+ "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-14B-Instruct-4bit",
+ "TorchDynamicShardInferenceEngine": "Qwen/Qwen2.5-14B-Instruct"
+ },
+ },
+ "qwen-2.5-coder-14b": {
+ "layers": 48,
+ "repo": {
+ "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit",
+ "TorchDynamicShardInferenceEngine": "Qwen/Qwen2.5-Coder-14B-Instruct"
+ },
+ },
+ "qwen-2.5-32b": {
+ "layers": 64,
+ "repo": {
+ "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-32B-Instruct-4bit",
+ "TorchDynamicShardInferenceEngine": "Qwen/Qwen2.5-32B-Instruct"
+ },
+ },
+ "qwen-2.5-coder-32b": {
+ "layers": 64,
+ "repo": {
+ "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit",
+ "TorchDynamicShardInferenceEngine": "Qwen/Qwen2.5-Coder-32B-Instruct"
+ },
+ },
+ "qwen-2.5-72b": {
+ "layers": 80,
+ "repo": {
+ "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-72B-Instruct-4bit",
+ "TorchDynamicShardInferenceEngine": "Qwen/Qwen2.5-72B-Instruct"
+ },
+ },
+ "qwen-2.5-math-72b": {
+ "layers": 80,
+ "repo": {
+ "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-72B-Instruct-4bit",
+ "TorchDynamicShardInferenceEngine": "Qwen/Qwen2.5-Math-72B-Instruct"
+ },
+ },
### nemotron
- "nemotron-70b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/nvidia_Llama-3.1-Nemotron-70B-Instruct-HF_4bit", }, },
- "nemotron-70b-bf16": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16", }, },
+ "nemotron-70b": {
+ "layers": 80,
+ "repo": {"MLXDynamicShardInferenceEngine": "mlx-community/nvidia_Llama-3.1-Nemotron-70B-Instruct-HF_4bit",},
+ },
+ "nemotron-70b-bf16": {
+ "layers": 80,
+ "repo": {"MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16",},
+ },
# gemma
- "gemma2-9b": { "layers": 42, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-9b-it-4bit", }, },
- "gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, },
+ "gemma2-9b": {
+ "layers": 42,
+ "repo": {"MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-9b-it-4bit",},
+ },
+ "gemma2-27b": {
+ "layers": 46,
+ "repo": {"MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit",},
+ },
# stable diffusion
- "stable-diffusion-2-1-base": { "layers": 31, "repo": { "MLXDynamicShardInferenceEngine": "stabilityai/stable-diffusion-2-1-base" } },
+ "stable-diffusion-2-1-base": {"layers": 31, "repo": {"MLXDynamicShardInferenceEngine": "stabilityai/stable-diffusion-2-1-base"}},
# phi
- "phi-3.5-mini": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Phi-3.5-mini-instruct-4bit", }, },
- "phi-4": { "layers": 40, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/phi-4-4bit", }, },
+ "phi-3.5-mini": {
+ "layers": 32,
+ "repo": {"MLXDynamicShardInferenceEngine": "mlx-community/Phi-3.5-mini-instruct-4bit",},
+ },
+ "phi-4": {
+ "layers": 40,
+ "repo": {"MLXDynamicShardInferenceEngine": "mlx-community/phi-4-4bit",},
+ },
# dummy
- "dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, },
+ "dummy": {
+ "layers": 8,
+ "repo": {"DummyInferenceEngine": "dummy",},
+ },
}
pretty_name = {
@@ -255,19 +388,13 @@ def get_supported_models(supported_inference_engine_lists: Optional[List[List[st
return list(model_cards.keys())
from exo.inference.inference_engine import inference_engine_classes
- supported_inference_engine_lists = [
- [inference_engine_classes[engine] if engine in inference_engine_classes else engine for engine in engine_list]
- for engine_list in supported_inference_engine_lists
- ]
+ supported_inference_engine_lists = [[inference_engine_classes[engine] if engine in inference_engine_classes else engine for engine in engine_list]
+ for engine_list in supported_inference_engine_lists]
def has_any_engine(model_info: dict, engine_list: List[str]) -> bool:
return any(engine in model_info.get("repo", {}) for engine in engine_list)
def supports_all_engine_lists(model_info: dict) -> bool:
- return all(has_any_engine(model_info, engine_list)
- for engine_list in supported_inference_engine_lists)
+ return all(has_any_engine(model_info, engine_list) for engine_list in supported_inference_engine_lists)
- return [
- model_id for model_id, model_info in model_cards.items()
- if supports_all_engine_lists(model_info)
- ]
+ return [model_id for model_id, model_info in model_cards.items() if supports_all_engine_lists(model_info)]
diff --git a/exo/networking/grpc/grpc_peer_handle.py b/exo/networking/grpc/grpc_peer_handle.py
index 58ef7e71b..e7999637e 100644
--- a/exo/networking/grpc/grpc_peer_handle.py
+++ b/exo/networking/grpc/grpc_peer_handle.py
@@ -16,8 +16,10 @@
if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
import mlx.core as mx
+ IS_APPLE = True
else:
import numpy as mx
+ IS_APPLE = False
class GRPCPeerHandle(PeerHandle):
@@ -123,6 +125,7 @@ async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: O
request_id=request_id,
inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
)
+
response = await self.stub.SendTensor(request)
if not response.tensor_data or not response.shape or not response.dtype:
@@ -200,11 +203,12 @@ def serialize_inference_state(self, inference_state: dict) -> node_service_pb2.I
proto_inference_state = node_service_pb2.InferenceState()
other_data = {}
for k, v in inference_state.items():
- if isinstance(v, mx.array):
+ mx_array_type = mx.array if IS_APPLE else mx.ndarray
+ if isinstance(v, mx_array_type):
np_array = np.array(v)
tensor_data = node_service_pb2.Tensor(tensor_data=np_array.tobytes(), shape=list(np_array.shape), dtype=str(np_array.dtype))
proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
- elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
+ elif isinstance(v, list) and all(isinstance(item, mx_array_type) for item in v):
tensor_list = node_service_pb2.TensorList()
for tensor in v:
np_array = np.array(tensor)
diff --git a/exo/orchestration/node.py b/exo/orchestration/node.py
index 1281aa8ae..e86563702 100644
--- a/exo/orchestration/node.py
+++ b/exo/orchestration/node.py
@@ -1,3 +1,4 @@
+import os
import numpy as np
import json
import asyncio
@@ -17,6 +18,7 @@
from exo.inference.inference_engine import get_inference_engine, InferenceEngine
from exo.download.shard_download import ShardDownloader
+
class Node:
def __init__(
self,
@@ -44,7 +46,7 @@ def __init__(
self.buffered_inputs: Dict[str, List[np.ndarray]] = {}
self.buffered_partials: Dict[str, List[np.ndarray]] = {}
self.checkpoints: Dict[str, Dict[str, int]] = {}
-
+
self.max_generate_tokens = max_generate_tokens
self.topology_viz = topology_viz
self.default_sample_temperature = default_sample_temperature
@@ -101,7 +103,8 @@ def get_supported_inference_engines(self):
supported_engine_names.append('mlx')
supported_engine_names.append('tinygrad')
else:
- supported_engine_names.append('tinygrad')
+ supported_engine_names.append(os.environ.get("EXO_INFER_ENGINE", 'tinygrad'))
+
return supported_engine_names
async def broadcast_supported_engines(self, supported_engines_names: List[str]):
@@ -149,10 +152,9 @@ async def process_inference_result(
self.outstanding_requests.pop(request_id)
else:
self.outstanding_requests[request_id] = "waiting"
- asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1), inference_state))
-
- return np.array(self.buffered_token_output[request_id][0]) if shard.model_id != 'stable-diffusion-2-1-base' else intermediate_result
+ asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset=1), inference_state))
+ return np.array(self.buffered_token_output[request_id][0]) if shard.model_id != 'stable-diffusion-2-1-base' else intermediate_result
async def process_prompt(
self,
@@ -219,7 +221,7 @@ async def enqueue_example(
self,
base_shard: Shard,
example: np.ndarray,
- target: np.ndarray,
+ target: np.ndarray,
length: np.ndarray,
request_id: Optional[str] = None,
train: bool = False,
@@ -232,7 +234,7 @@ async def enqueue_example(
if request_id is None:
request_id = str(uuid.uuid4())
self.outstanding_requests[request_id] = "waiting"
- loss = await self.forward_example(shard, example, target, length, train, request_id, 0)
+ loss = await self.forward_example(shard, example, target, length, train, request_id, 0)
return loss
async def coordinate_save(
@@ -263,7 +265,7 @@ async def process_example(
self,
base_shard: Shard,
example: np.ndarray,
- target: np.ndarray,
+ target: np.ndarray,
length: np.ndarray,
train: bool = False,
request_id: Optional[str] = None,
@@ -308,7 +310,7 @@ async def _process_example(
self,
base_shard: Shard,
example: np.ndarray,
- target: np.ndarray,
+ target: np.ndarray,
length: np.ndarray,
train: bool = False,
request_id: Optional[str] = None,
@@ -327,7 +329,7 @@ async def _process_example(
self.outstanding_requests[request_id] = "preprocessing"
step, _ = await self.inference_engine.infer_tensor(request_id, shard, example)
self.outstanding_requests[request_id] = "waiting"
- loss, backgrad = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
+ loss, backgrad = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset=1))
self.outstanding_requests[request_id] = "training"
partial_loss, grad = await self.inference_engine.train(request_id, shard, example, backgrad, length, loss="back_gradient")
self.outstanding_requests.pop(request_id)
@@ -343,7 +345,7 @@ async def _process_example(
self.outstanding_requests[request_id] = "preprocessing"
step, _ = await self.inference_engine.infer_tensor(request_id, shard, example)
self.outstanding_requests[request_id] = "waiting"
- loss = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
+ loss = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset=1))
self.outstanding_requests.pop(request_id)
return loss
except Exception as e:
@@ -351,7 +353,7 @@ async def _process_example(
print(f"Error processing example for shard {shard}: {e}")
traceback.print_exc()
return None
-
+
async def process_tensor(
self,
base_shard: Shard,
@@ -380,7 +382,7 @@ async def _process_tensor(
try:
self.outstanding_requests[request_id] = "processing"
result, inference_state = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state)
- ret = await self.process_inference_result(shard, result, request_id, inference_state)
+ ret = await self.process_inference_result(shard, result, request_id, inference_state)
return ret
except Exception as e:
self.outstanding_requests.pop(request_id)
@@ -428,7 +430,7 @@ async def forward_prompt(
raise ValueError(f"Peer for {target_index} not found")
if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
await target_peer.send_prompt(next_shard, prompt, request_id=request_id, inference_state=inference_state)
-
+
async def forward_tensor(
self,
base_shard: Shard,
@@ -458,7 +460,7 @@ def get_partition_index(self, offset: int = 0):
current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
if current_partition_index is None:
raise ValueError(f"No current partition found for node: {self.id}")
- return (current_partition_index + offset) % len(partitions)
+ return (current_partition_index+offset) % len(partitions)
def get_current_shard(self, base_shard: Shard, index: Optional[int] = None) -> Shard:
if index is None:
@@ -584,7 +586,7 @@ def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]:
def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} {tokens=} {is_finished=}")
self.on_token.trigger_all(request_id, tokens, is_finished)
-
+
async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
if DEBUG >= 2: print(f"Broadcasting result: {request_id=} {result=} {is_finished=}")
async def send_result_to_peer(peer):
@@ -620,8 +622,8 @@ def current_topology(self) -> Topology:
def handle_stable_diffusion(self, inference_state, result):
if inference_state['is_step_finished']:
- inference_state['step']+=1
- progress = [inference_state['step'],inference_state['total_steps']]
+ inference_state['step'] += 1
+ progress = [inference_state['step'], inference_state['total_steps']]
intermediate_result = result
if progress[0] == progress[1]:
intermediate_result = result
diff --git a/exo/orchestration/standard_node.py b/exo/orchestration/standard_node.py
new file mode 100644
index 000000000..3a612a365
--- /dev/null
+++ b/exo/orchestration/standard_node.py
@@ -0,0 +1,469 @@
+import numpy as np
+import json
+import asyncio
+import uuid
+import time
+import traceback
+from typing import List, Dict, Optional, Tuple, Union, Set
+from exo.networking import Discovery, PeerHandle, Server
+from exo.inference.inference_engine import InferenceEngine, Shard
+from .node import Node
+from exo.topology.topology import Topology
+from exo.topology.device_capabilities import device_capabilities
+from exo.topology.partitioning_strategy import Partition, PartitioningStrategy, map_partitions_to_shards
+from exo import DEBUG
+from exo.helpers import AsyncCallbackSystem
+from exo.viz.topology_viz import TopologyViz
+from exo.download.hf.hf_helpers import RepoProgressEvent
+from exo.inference.inference_engine import get_inference_engine, InferenceEngine
+from exo.download.hf.hf_shard_download import HFShardDownloader
+
+class StandardNode(Node):
+ def __init__(
+ self,
+ _id: str,
+ server: Server,
+ inference_engine: InferenceEngine,
+ discovery: Discovery,
+ partitioning_strategy: PartitioningStrategy = None,
+ max_generate_tokens: int = 1024,
+ default_sample_temperature: float = 0.0,
+ topology_viz: Optional[TopologyViz] = None,
+ shard_downloader: Optional[HFShardDownloader] = None,
+ ):
+ self.id = _id
+ self.inference_engine = inference_engine
+ self.server = server
+ self.discovery = discovery
+ self.partitioning_strategy = partitioning_strategy
+ self.peers: List[PeerHandle] = {}
+ self.topology: Topology = Topology()
+ self.device_capabilities = device_capabilities()
+ self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
+ self.buffered_logits: Dict[str, List[np.ndarray]] = {}
+ self.buffered_inputs: Dict[str, List[np.ndarray]] = {}
+ self.max_generate_tokens = max_generate_tokens
+ self.topology_viz = topology_viz
+ self.default_sample_temperature = default_sample_temperature
+ self._on_token = AsyncCallbackSystem[str, Tuple[str, List[int], bool]]()
+ self._on_opaque_status = AsyncCallbackSystem[str, Tuple[str, str]]()
+ self._on_opaque_status.register("node_status").on_next(self.on_node_status)
+ self.node_download_progress: Dict[str, RepoProgressEvent] = {}
+ self.topology_inference_engines_pool: List[List[str]] = []
+ self.shard_downloader = shard_downloader
+
+ async def start(self, wait_for_peers: int = 0) -> None:
+ await self.server.start()
+ await self.discovery.start()
+ await self.update_peers(wait_for_peers)
+ await self.collect_topology()
+ if DEBUG >= 2: print(f"Collected topology: {self.topology}")
+ asyncio.create_task(self.periodic_topology_collection(1.0))
+
+ async def stop(self) -> None:
+ await self.discovery.stop()
+ await self.server.stop()
+
+ def on_node_status(self, request_id, opaque_status):
+ try:
+ status_data = json.loads(opaque_status)
+ if status_data.get("type", "") == "supported_inference_engines":
+ node_id = status_data.get("node_id")
+ engines = status_data.get("engines", [])
+ self.topology_inference_engines_pool.append(engines)
+ if status_data.get("type", "") == "node_status":
+ if status_data.get("status", "").startswith("start_"):
+ self.current_topology.active_node_id = status_data.get("node_id")
+ elif status_data.get("status", "").startswith("end_"):
+ if status_data.get("node_id") == self.current_topology.active_node_id:
+ self.current_topology.active_node_id = None
+ download_progress = None
+ if status_data.get("type", "") == "download_progress":
+ if DEBUG >= 8: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('progress')}")
+ download_progress = RepoProgressEvent.from_dict(status_data.get('progress'))
+ self.node_download_progress[status_data.get('node_id')] = download_progress
+ if self.topology_viz:
+ self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id, self.node_download_progress)
+ except Exception as e:
+ if DEBUG >= 1: print(f"Error updating visualization: {e}")
+ if DEBUG >= 1: traceback.print_exc()
+
+ def get_supported_inference_engines(self):
+ supported_engine_names = []
+ if self.inference_engine.__class__.__name__ == 'MLXDynamicShardInferenceEngine':
+ supported_engine_names.append('mlx')
+ supported_engine_names.append('tinygrad')
+ elif self.inference_engine.__class__.__name__ == 'TorchDynamicShardInferenceEngine':
+ supported_engine_names.append('torch')
+ supported_engine_names.append('tinygrad')
+ else:
+ supported_engine_names.append('tinygrad')
+ return supported_engine_names
+
+ async def broadcast_supported_engines(self, supported_engines_names: List[str]):
+ status_message = json.dumps({"type": "supported_inference_engines", "node_id": self.id, "engines": supported_engines_names})
+ await self.broadcast_opaque_status("", status_message)
+
+ def get_topology_inference_engines(self) -> List[List[str]]:
+ return self.topology_inference_engines_pool
+
+ async def process_inference_result(
+ self,
+ shard,
+ result: np.ndarray,
+ request_id: Optional[str] = None,
+ ):
+ if request_id not in self.buffered_token_output:
+ self.buffered_token_output[request_id] = ([], False)
+ is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
+ if shard.is_last_layer() and not is_finished:
+ token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
+ await self.inference_engine.ensure_shard(shard)
+ self.buffered_token_output[request_id][0].append(token.item())
+ if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
+ is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id
+ forward = token.reshape(1, -1)
+ self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
+ asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))
+ else:
+ forward = result
+
+ if is_finished:
+ self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
+ else:
+ asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1)))
+
+ return np.array(self.buffered_token_output[request_id][0])
+
+ async def process_prompt(
+ self,
+ base_shard: Shard,
+ prompt: str,
+ request_id: Optional[str] = None,
+ ) -> Optional[np.ndarray]:
+ shard = self.get_current_shard(base_shard)
+ asyncio.create_task(
+ self.broadcast_opaque_status(
+ request_id,
+ json.dumps({
+ "type": "node_status",
+ "node_id": self.id,
+ "status": "start_process_prompt",
+ "base_shard": base_shard.to_dict(),
+ "shard": shard.to_dict(),
+ "prompt": prompt,
+ "request_id": request_id,
+ }),
+ )
+ )
+ start_time = time.perf_counter_ns()
+ resp = await self._process_prompt(base_shard, prompt, request_id)
+ end_time = time.perf_counter_ns()
+ elapsed_time_ns = end_time - start_time
+ asyncio.create_task(
+ self.broadcast_opaque_status(
+ request_id,
+ json.dumps({
+ "type": "node_status",
+ "node_id": self.id,
+ "status": "end_process_prompt",
+ "base_shard": base_shard.to_dict(),
+ "shard": shard.to_dict(),
+ "prompt": prompt,
+ "request_id": request_id,
+ "elapsed_time_ns": elapsed_time_ns,
+ "result_size": resp.size if resp is not None else 0,
+ }),
+ )
+ )
+ return resp
+
+ async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
+ if request_id is None:
+ request_id = str(uuid.uuid4())
+ shard = self.get_current_shard(base_shard)
+
+ if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
+ if not shard.is_first_layer():
+ if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
+ resp = await self.forward_prompt(shard, prompt, request_id, 0)
+ return None
+ else:
+ result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
+ ret = await self.process_inference_result(shard, result, request_id)
+ return result
+
+ async def process_tensor(
+ self,
+ base_shard: Shard,
+ tensor: np.ndarray,
+ request_id: Optional[str] = None,
+ ) -> Optional[np.ndarray]:
+ shard = self.get_current_shard(base_shard)
+ asyncio.create_task(
+ self.broadcast_opaque_status(
+ request_id,
+ json.dumps({
+ "type": "node_status",
+ "node_id": self.id,
+ "status": "start_process_tensor",
+ "base_shard": base_shard.to_dict(),
+ "shard": shard.to_dict(),
+ "tensor_size": tensor.size,
+ "tensor_shape": tensor.shape,
+ "request_id": request_id,
+ }),
+ )
+ )
+ start_time = time.perf_counter_ns()
+ resp = await self._process_tensor(shard, tensor, request_id)
+ end_time = time.perf_counter_ns()
+ elapsed_time_ns = end_time - start_time
+ asyncio.create_task(
+ self.broadcast_opaque_status(
+ request_id,
+ json.dumps({
+ "type": "node_status",
+ "node_id": self.id,
+ "status": "end_process_tensor",
+ "base_shard": base_shard.to_dict(),
+ "shard": shard.to_dict(),
+ "request_id": request_id,
+ "elapsed_time_ns": elapsed_time_ns,
+ "result_size": resp.size if resp is not None else 0,
+ }),
+ )
+ )
+ return resp
+
+ async def _process_tensor(
+ self,
+ base_shard: Shard,
+ tensor: np.ndarray,
+ request_id: Optional[str] = None,
+ ) -> Optional[np.ndarray]:
+ if request_id is None:
+ request_id = str(uuid.uuid4())
+ shard = self.get_current_shard(base_shard)
+
+ if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
+ try:
+ result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
+ ret = await self.process_inference_result(shard, result, request_id)
+ return ret
+ except Exception as e:
+ print(f"Error processing tensor for shard {shard}: {e}")
+ traceback.print_exc()
+ return None
+
+ async def forward_prompt(
+ self,
+ base_shard: Shard,
+ prompt: str,
+ request_id: str,
+ target_index: int,
+ ) -> None:
+ if DEBUG >= 1: print(f"target partition index: {target_index}")
+ target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
+ next_shard = self.get_current_shard(base_shard, target_index)
+ if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}")
+ if target_id == self.id:
+ await self.process_prompt(next_shard, prompt, request_id)
+ else:
+ target_peer = next((p for p in self.peers if p.id() == target_id), None)
+ if not target_peer:
+ raise ValueError(f"Peer for {target_index} not found")
+ if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
+ await target_peer.send_prompt(next_shard, prompt, request_id=request_id)
+
+ async def forward_tensor(
+ self,
+ base_shard: Shard,
+ tensor: np.ndarray,
+ request_id: str,
+ target_index: int,
+ ) -> None:
+ if DEBUG >= 1: print(f"target partition index: {target_index}")
+ target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
+ next_shard = self.get_current_shard(base_shard, target_index)
+ if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {next_shard}")
+ if target_id == self.id:
+ await self.process_tensor(next_shard, tensor, request_id)
+ else:
+ target_peer = next((p for p in self.peers if p.id() == target_id), None)
+ if not target_peer:
+ raise ValueError(f"Peer for {target_index} not found")
+ if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}")
+ await target_peer.send_tensor(next_shard, tensor, request_id=request_id)
+
+ def get_partition_index(self, offset: int = 0):
+ if not self.partitioning_strategy:
+ if DEBUG >= 1: print("No partitioning strategy found. Skipping forward.")
+ return None
+ partitions = self.partitioning_strategy.partition(self.topology)
+ current_partition_index = next((i for i, p in enumerate(partitions) if p.node_id == self.id), None)
+ if current_partition_index is None:
+ raise ValueError(f"No current partition found for node: {self.id}")
+ return (current_partition_index + offset) % len(partitions)
+
+ def get_current_shard(self, base_shard: Shard, index: Optional[int] = None) -> Shard:
+ if index is None:
+ index = self.get_partition_index()
+ partitions = self.partitioning_strategy.partition(self.topology)
+ shards = map_partitions_to_shards(partitions, base_shard.n_layers, base_shard.model_id)
+ return shards[index]
+
+ async def update_peers(self, wait_for_peers: int = 0) -> bool:
+ next_peers = await self.discovery.discover_peers(wait_for_peers)
+ current_peer_ids = {peer.id() for peer in self.peers}
+ next_peer_ids = {peer.id() for peer in next_peers}
+ peers_added = [peer for peer in next_peers if peer.id() not in current_peer_ids]
+ peers_removed = [peer for peer in self.peers if peer.id() not in next_peer_ids]
+ peers_updated = [peer for peer in next_peers if peer.id() in current_peer_ids and any(p.addr() != peer.addr() for p in self.peers if p.id() == peer.id())]
+ peers_unchanged = [peer for peer in next_peers if peer.id() in current_peer_ids and all(p.addr() == peer.addr() for p in self.peers if p.id() == peer.id())]
+ peers_to_disconnect = [peer for peer in peers_removed if await peer.is_connected()]
+ peers_to_connect = [peer for peer in peers_added + peers_updated + peers_unchanged if not await peer.is_connected()]
+
+ def _pretty(peers: List[PeerHandle]) -> List[str]:
+ return [f"{peer.id()}@{peer.addr()}" for peer in peers]
+
+ if DEBUG >= 2:
+ print(f"update_peers: added={peers_added} removed={peers_removed} updated={peers_updated} unchanged={peers_unchanged} to_disconnect={peers_to_disconnect} to_connect={peers_to_connect}")
+
+ async def disconnect_with_timeout(peer, timeout=5):
+ try:
+ await asyncio.wait_for(peer.disconnect(), timeout)
+ return True
+ except Exception as e:
+ print(f"Error disconnecting peer {peer.id()}@{peer.addr()}: {e}")
+ traceback.print_exc()
+ return False
+
+ async def connect_with_timeout(peer, timeout=5):
+ try:
+ await asyncio.wait_for(peer.connect(), timeout)
+ return True
+ except Exception as e:
+ print(f"Error connecting peer {peer.id()}@{peer.addr()}: {e}")
+ traceback.print_exc()
+ return False
+
+ disconnect_results = await asyncio.gather(*(disconnect_with_timeout(peer) for peer in peers_to_disconnect), return_exceptions=True)
+ connect_results = await asyncio.gather(*(connect_with_timeout(peer) for peer in peers_to_connect), return_exceptions=True)
+
+ successful_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is True]
+ failed_disconnects = [peer for peer, result in zip(peers_to_disconnect, disconnect_results) if result is False]
+ successful_connects = [peer for peer, result in zip(peers_to_connect, connect_results) if result is True]
+ failed_connects = [peer for peer, result in zip(peers_to_connect, connect_results) if result is False]
+ if DEBUG >= 1:
+ if successful_disconnects: print(f"Successfully disconnected peers: {_pretty(successful_disconnects)}")
+ if failed_disconnects: print(f"Failed to disconnect peers: {_pretty(failed_disconnects)}")
+ if successful_connects: print(f"Successfully connected peers: {_pretty(successful_connects)}")
+ if failed_connects: print(f"Failed to connect peers: {_pretty(failed_connects)}")
+
+ self.peers = next_peers
+ return len(peers_added) > 0 or len(peers_removed) > 0 or len(peers_updated) > 0
+
+ async def select_best_inference_engine(self):
+ if self.inference_engine.__class__.__name__ == 'DummyInferenceEngine': return
+ supported_engines = self.get_supported_inference_engines()
+ await self.broadcast_supported_engines(supported_engines)
+ if len(self.get_topology_inference_engines()):
+ self.inference_engine = get_inference_engine(supported_engines[0], self.shard_downloader)
+
+ async def periodic_topology_collection(self, interval: int):
+ while True:
+ await asyncio.sleep(interval)
+ try:
+ did_peers_change = await self.update_peers()
+ if DEBUG >= 2: print(f"{did_peers_change=}")
+ if did_peers_change:
+ await self.collect_topology()
+ await self.select_best_inference_engine()
+ except Exception as e:
+ print(f"Error collecting topology: {e}")
+ traceback.print_exc()
+
+ async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
+ if request_id not in self.buffered_token_output:
+ return None, False
+ return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
+
+ async def collect_topology(self, visited: set[str] = set(), max_depth: int = 4) -> Topology:
+ next_topology = Topology()
+ next_topology.update_node(self.id, self.device_capabilities)
+
+ if DEBUG >= 2: print(f"Collecting topology {max_depth=} {visited=}")
+
+ prev_visited = visited.copy()
+ visited.add(self.id)
+ visited.update(p.id() for p in self.peers)
+
+ for peer in self.peers:
+ next_topology.update_node(peer.id(), peer.device_capabilities())
+ next_topology.add_edge(self.id, peer.id())
+
+ if peer.id() in prev_visited:
+ continue
+
+ if max_depth <= 0:
+ if DEBUG >= 2: print("Max depth reached. Skipping...")
+ continue
+
+ try:
+ other_topology = await asyncio.wait_for(peer.collect_topology(visited, max_depth=max_depth - 1), timeout=5.0)
+ if DEBUG >= 2: print(f"Collected topology from: {peer.id()}: {other_topology}")
+ self.topology.merge(other_topology)
+ except Exception as e:
+ print(f"Error collecting topology from {peer.id()}: {e}")
+ traceback.print_exc()
+
+ next_topology.active_node_id = self.topology.active_node_id # this is not so clean.
+ self.topology = next_topology
+ if self.topology_viz:
+ self.topology_viz.update_visualization(self.current_topology, self.partitioning_strategy.partition(self.current_topology), self.id)
+ return next_topology
+
+ @property
+ def on_token(self) -> AsyncCallbackSystem[str, Tuple[str, List[int], bool]]:
+ return self._on_token
+
+ @property
+ def on_opaque_status(self) -> AsyncCallbackSystem[str, Tuple[str, str]]:
+ return self._on_opaque_status
+
+ def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
+ if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
+ self.on_token.trigger_all(request_id, tokens, is_finished)
+
+ async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
+ async def send_result_to_peer(peer):
+ try:
+ await asyncio.wait_for(peer.send_result(request_id, result, is_finished), timeout=15.0)
+ except asyncio.TimeoutError:
+ print(f"Timeout broadcasting result to {peer.id()}")
+ except Exception as e:
+ print(f"Error broadcasting result to {peer.id()}: {e}")
+ traceback.print_exc()
+
+ await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True)
+
+ async def broadcast_opaque_status(self, request_id: str, status: str) -> None:
+ if DEBUG >= 8: print(f"Broadcasting opaque status: {request_id=} {status=}")
+
+ async def send_status_to_peer(peer):
+ try:
+ await asyncio.wait_for(peer.send_opaque_status(request_id, status), timeout=15.0)
+ except asyncio.TimeoutError:
+ print(f"Timeout sending opaque status to {peer.id()}")
+ except Exception as e:
+ print(f"Error sending opaque status to {peer.id()}: {e}")
+ traceback.print_exc()
+
+ await asyncio.gather(*[send_status_to_peer(peer) for peer in self.peers], return_exceptions=True)
+ # in the case of opaque status, we also want to receive our own opaque statuses
+ self.on_opaque_status.trigger_all(request_id, status)
+
+ @property
+ def current_topology(self) -> Topology:
+ return self.topology
diff --git a/exo/tinychat/index.html b/exo/tinychat/index.html
index 65451068f..543a47abf 100644
--- a/exo/tinychat/index.html
+++ b/exo/tinychat/index.html
@@ -18,7 +18,7 @@
-
+
diff --git a/exo/topology/device_capabilities.py b/exo/topology/device_capabilities.py
index ae28e757f..17372262b 100644
--- a/exo/topology/device_capabilities.py
+++ b/exo/topology/device_capabilities.py
@@ -119,6 +119,10 @@ def to_dict(self):
"NVIDIA A800 80GB PCIE": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
"NVIDIA A100 80GB SXM": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
"NVIDIA A800 80GB SXM": DeviceFlops(fp32=19.5*TFLOPS, fp16=312.0*TFLOPS, int8=624.0*TFLOPS),
+ "NVIDIA T1000 8GB": DeviceFlops(fp32=2.5 * TFLOPS, fp16=5.0 * TFLOPS, int8=10.0 * TFLOPS),
+ "Quadro M2000": DeviceFlops(fp32=0.5 * TFLOPS, fp16=1.0 * TFLOPS, int8=2.0 * TFLOPS),
+ "Quadro P400": DeviceFlops(fp32=0.641 * TFLOPS, fp16=1.282 * TFLOPS, int8=2.564 * TFLOPS),
+ "NVIDIA A10": DeviceFlops(fp32=31.2 * TFLOPS, fp16=62.5 * TFLOPS, int8=2.5 * TFLOPS),
# ... add more devices if needed ...
### AMD GPUs
# RX 6000 series
diff --git a/install.ps1 b/install.ps1
new file mode 100644
index 000000000..c766cdd5b
--- /dev/null
+++ b/install.ps1
@@ -0,0 +1,8 @@
+# Create a virtual environment
+python3 -m venv .venv
+
+# Activate the virtual environment
+& .\.venv\Scripts\Activate.ps1
+
+# Install the package in the virtual environment
+pip install .
diff --git a/setup.py b/setup.py
index de242f544..530e85976 100644
--- a/setup.py
+++ b/setup.py
@@ -29,6 +29,15 @@
"uuid==1.30",
"uvloop==0.21.0",
"tinygrad @ git+https://github.com/tinygrad/tinygrad.git@ec120ce6b9ce8e4ff4b5692566a683ef240e8bc8",
+ "torch==2.6.0",
+ "accelerate==0.34.2",
+ "torchtune==0.5.0",
+ "torchao==0.8.0",
+ "pytest==8.3.3",
+ "pytest-asyncio==0.24.0",
+ "scapy==2.6.1",
+ "uvloop==0.21.0"
+
]
extras_require = {