From b89495f4441527476673d32dd2f179ab805250df Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Mon, 27 Jan 2025 00:37:57 +0000 Subject: [PATCH 01/16] rewrite ShardDownloader, simplify significantly --- exo/api/chatgpt_api.py | 51 ++---- exo/download/download_progress.py | 3 +- exo/download/hf/hf_helpers.py | 1 - exo/download/hf/new_shard_download.py | 180 +++++++++++++++++++++ exo/download/hf/test_new_shard_download.py | 20 +++ exo/download/shard_download.py | 4 +- exo/main.py | 7 +- exo/models.py | 11 +- exo/orchestration/node.py | 6 +- exo/orchestration/test_node.py | 4 +- exo/tinychat/index.js | 6 +- 11 files changed, 236 insertions(+), 57 deletions(-) create mode 100644 exo/download/hf/new_shard_download.py create mode 100644 exo/download/hf/test_new_shard_download.py diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index 30cec8c40..3d36717c7 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -11,17 +11,17 @@ import traceback import signal from exo import DEBUG, VERSION -from exo.download.download_progress import RepoProgressEvent from exo.helpers import PrefixDict, shutdown, get_exo_images_dir from exo.inference.tokenizers import resolve_tokenizer from exo.orchestration import Node -from exo.models import build_base_shard, model_cards, get_repo, pretty_name +from exo.models import build_base_shard, model_cards, get_repo, get_model_id, get_supported_models, get_pretty_name from typing import Callable, Optional from PIL import Image import numpy as np import base64 from io import BytesIO import platform +from exo.download.shard_download import RepoProgressEvent if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64": import mlx.core as mx @@ -29,7 +29,6 @@ import numpy as mx import tempfile -from exo.download.hf.hf_shard_download import HFShardDownloader import shutil from exo.download.hf.hf_helpers import get_hf_home, get_repo_root from exo.apputil import create_animation_mp4 @@ -277,41 +276,12 @@ async def handle_healthcheck(self, request): async def handle_model_support(self, request): try: - response = web.StreamResponse(status=200, reason='OK', headers={ - 'Content-Type': 'text/event-stream', - 'Cache-Control': 'no-cache', - 'Connection': 'keep-alive', - }) + response = web.StreamResponse(status=200, reason='OK', headers={ 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', 'Connection': 'keep-alive' }) await response.prepare(request) - - async def process_model(model_name, pretty): - if model_name in model_cards: - model_info = model_cards[model_name] - - if self.inference_engine_classname in model_info.get("repo", {}): - shard = build_base_shard(model_name, self.inference_engine_classname) - if shard: - downloader = HFShardDownloader(quick_check=True) - downloader.current_shard = shard - downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname) - status = await downloader.get_shard_download_status() - - download_percentage = status.get("overall") if status else None - total_size = status.get("total_size") if status else None - total_downloaded = status.get("total_downloaded") if status else False - - model_data = { - model_name: { - "name": pretty, "downloaded": download_percentage == 100 if download_percentage is not None else False, "download_percentage": download_percentage, "total_size": total_size, - "total_downloaded": total_downloaded - } - } - - await response.write(f"data: {json.dumps(model_data)}\n\n".encode()) - - # Process all models in parallel - await asyncio.gather(*[process_model(model_name, pretty) for model_name, pretty in pretty_name.items()]) - + downloads = await self.node.shard_downloader.get_shard_download_status(self.inference_engine_classname) + for (path, d) in downloads: + model_data = { get_model_id(d.repo_id): { "downloaded": d.downloaded_bytes == d.total_bytes, "download_percentage": 100 if d.downloaded_bytes == d.total_bytes else 100 * float(d.downloaded_bytes) / float(d.total_bytes), "total_size": d.total_bytes, "total_downloaded": d.downloaded_bytes } } + await response.write(f"data: {json.dumps(model_data)}\n\n".encode()) await response.write(b"data: [DONE]\n\n") return response @@ -348,6 +318,7 @@ async def handle_get_download_progress(self, request): progress_data = {} for node_id, progress_event in self.node.node_download_progress.items(): if isinstance(progress_event, RepoProgressEvent): + if progress_event.status != "in_progress": continue progress_data[node_id] = progress_event.to_dict() else: print(f"Unknown progress event type: {type(progress_event)}. {progress_event}") @@ -611,9 +582,9 @@ async def handle_delete_model(self, request): async def handle_get_initial_models(self, request): model_data = {} - for model_name, pretty in pretty_name.items(): - model_data[model_name] = { - "name": pretty, + for model_id in get_supported_models([[self.inference_engine_classname]]): + model_data[model_id] = { + "name": get_pretty_name(model_id), "downloaded": None, # Initially unknown "download_percentage": None, # Change from 0 to null "total_size": None, diff --git a/exo/download/download_progress.py b/exo/download/download_progress.py index 779e53287..340cc3217 100644 --- a/exo/download/download_progress.py +++ b/exo/download/download_progress.py @@ -14,11 +14,12 @@ class RepoFileProgressEvent: speed: int eta: timedelta status: Literal["not_started", "in_progress", "complete"] + start_time: float def to_dict(self): return { "repo_id": self.repo_id, "repo_revision": self.repo_revision, "file_path": self.file_path, "downloaded": self.downloaded, "downloaded_this_session": self.downloaded_this_session, - "total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status + "total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status, "start_time": self.start_time } @classmethod diff --git a/exo/download/hf/hf_helpers.py b/exo/download/hf/hf_helpers.py index 8be12b957..c0fa44e22 100644 --- a/exo/download/hf/hf_helpers.py +++ b/exo/download/hf/hf_helpers.py @@ -209,7 +209,6 @@ async def download_file( raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}") if downloaded_size == total_size: - print(f"File already downloaded: {file_path}") if progress_callback: await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete")) return diff --git a/exo/download/hf/new_shard_download.py b/exo/download/hf/new_shard_download.py new file mode 100644 index 000000000..28ca8cefa --- /dev/null +++ b/exo/download/hf/new_shard_download.py @@ -0,0 +1,180 @@ +from exo.inference.shard import Shard +from exo.models import get_repo +from pathlib import Path +from exo.download.hf.hf_helpers import get_hf_endpoint, get_auth_headers, filter_repo_objects, get_allow_patterns +from exo.download.shard_download import ShardDownloader +from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent +from exo.helpers import AsyncCallbackSystem, DEBUG +from exo.models import get_supported_models, build_base_shard +import os +import aiofiles.os as aios +import aiohttp +import aiofiles +from urllib.parse import urljoin +from typing import Optional, Callable, Union, Tuple, Dict +import time +from datetime import timedelta +import asyncio +import json + +def exo_home() -> Path: + return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo")) + +async def ensure_downloads_dir() -> Path: + downloads_dir = exo_home()/"downloads" + await aios.makedirs(downloads_dir, exist_ok=True) + return downloads_dir + +async def fetch_file_list(session, repo_id, revision, path=""): + api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}" + url = f"{api_url}/{path}" if path else api_url + + headers = await get_auth_headers() + async with session.get(url, headers=headers) as response: + if response.status == 200: + data = await response.json() + files = [] + for item in data: + if item["type"] == "file": + files.append({"path": item["path"], "size": item["size"]}) + elif item["type"] == "directory": + subfiles = await fetch_file_list(session, repo_id, revision, item["path"]) + files.extend(subfiles) + return files + else: + raise Exception(f"Failed to fetch file list: {response.status}") + +async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Optional[Callable[[int, int], None]] = None) -> Path: + if (target_dir/path).exists(): return target_dir/path + base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/" + url = urljoin(base_url, path) + headers = await get_auth_headers() + async with session.get(url, headers=headers) as r: + assert r.status == 200, r.status + length = int(r.headers.get('content-length', 0)) + n_read = 0 + async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file: + while chunk := await r.content.read(16384): on_progress(n_read := n_read + await temp_file.write(chunk), length) + await aios.rename(temp_file.name, target_dir/path) + return target_dir/path + +def calculate_repo_progress(repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent: + all_total_bytes = sum([p.total for p in file_progress.values()]) + all_downloaded_bytes = sum([p.downloaded for p in file_progress.values()]) + elapsed_time = time.time() - all_start_time + all_speed = all_downloaded_bytes / elapsed_time if elapsed_time > 0 else 0 + all_eta = timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed) if all_speed > 0 else timedelta(seconds=0) + status = "not_started" if all_downloaded_bytes == 0 else "complete" if all_downloaded_bytes == all_total_bytes else "in_progress" + return RepoProgressEvent(repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes, all_total_bytes, all_speed, all_eta, file_progress, status) + +async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]: + target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--") + async with aiohttp.ClientSession() as session: + index_file = await download_file(session, repo_id, revision, "model.safetensors.index.json", target_dir) + async with aiofiles.open(index_file, 'r') as f: + index_data = json.loads(await f.read()) + return index_data.get("weight_map") + +async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str) -> list[str]: + weight_map = await get_weight_map(get_repo(shard.model_id, inference_engine_classname)) + return get_allow_patterns(weight_map, shard) + +async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 6, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]: + if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}") + repo_id = get_repo(shard.model_id, inference_engine_classname) + revision = "main" + target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--") + if not skip_download: await aios.makedirs(target_dir, exist_ok=True) + + if repo_id is None: + raise ValueError(f"No repo found for {shard.model_id=} and inference engine {inference_engine_classname}") + + allow_patterns = await resolve_allow_patterns(shard, inference_engine_classname) if not skip_download else None + if DEBUG >= 3: print(f"Downloading {shard.model_id=} with {allow_patterns=}") + + all_start_time = time.time() + async with aiohttp.ClientSession() as session: + file_list = await fetch_file_list(session, repo_id, revision) + filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, key=lambda x: x["path"])) + file_progress: Dict[str, RepoFileProgressEvent] = {} + def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int): + start_time = file_progress[file["path"]].start_time if file["path"] in file_progress else time.time() + speed = curr_bytes / (time.time() - start_time) + eta = timedelta(seconds=(total_bytes - curr_bytes) / speed) + file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, curr_bytes, total_bytes, speed, eta, "in_progress", start_time) + on_progress.trigger_all(shard, calculate_repo_progress(repo_id, revision, file_progress, all_start_time)) + if DEBUG >= 2: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}") + for file in filtered_file_list: + downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0 + file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "not_started" if downloaded_bytes == 0 else "complete" if downloaded_bytes == file["size"] else "in_progress", time.time()) + + semaphore = asyncio.Semaphore(max_parallel_downloads) + async def download_with_semaphore(file): + async with semaphore: + await download_file(session, repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes)) + if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list]) + final_repo_progress = calculate_repo_progress(repo_id, revision, file_progress, all_start_time) + on_progress.trigger_all(shard, final_repo_progress) + return target_dir, final_repo_progress + +def new_shard_downloader() -> ShardDownloader: + return SingletonShardDownloader(CachedShardDownloader(NewShardDownloader())) + +class SingletonShardDownloader(ShardDownloader): + def __init__(self, shard_downloader: ShardDownloader): + self.shard_downloader = shard_downloader + self.active_downloads: Dict[Shard, asyncio.Task] = {} + + @property + def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]: + return self.shard_downloader.on_progress + + async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path: + if shard not in self.active_downloads: self.active_downloads[shard] = asyncio.create_task(self.shard_downloader.ensure_shard(shard, inference_engine_name)) + try: return await self.active_downloads[shard] + finally: + if shard in self.active_downloads and self.active_downloads[shard].done(): del self.active_downloads[shard] + + async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]: + return await self.shard_downloader.get_shard_download_status(inference_engine_name) + +class CachedShardDownloader(ShardDownloader): + def __init__(self, shard_downloader: ShardDownloader): + self.shard_downloader = shard_downloader + self.cache: Dict[tuple[str, Shard], Path] = {} + + @property + def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]: + return self.shard_downloader.on_progress + + async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path: + if (inference_engine_name, shard) in self.cache: + if DEBUG >= 2: print(f"ensure_shard cache hit {shard=} for {inference_engine_name}") + return self.cache[(inference_engine_name, shard)] + if DEBUG >= 2: print(f"ensure_shard cache miss {shard=} for {inference_engine_name}") + target_dir = await self.shard_downloader.ensure_shard(shard, inference_engine_name) + self.cache[(inference_engine_name, shard)] = target_dir + return target_dir + + async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]: + return await self.shard_downloader.get_shard_download_status(inference_engine_name) + +class NewShardDownloader(ShardDownloader): + def __init__(self): + self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]() + + @property + def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]: + return self._on_progress + + async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path: + target_dir, _ = await download_shard(shard, inference_engine_name, self.on_progress) + return target_dir + + async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]: + if DEBUG >= 2: print("Getting shard download status for", inference_engine_name) + downloads = await asyncio.gather(*[download_shard(build_base_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])], return_exceptions=True) + if DEBUG >= 6: print("Downloaded shards:", downloads) + if any(isinstance(d, Exception) for d in downloads) and DEBUG >= 1: print("Error downloading shards:", [d for d in downloads if isinstance(d, Exception)]) + return [d for d in downloads if not isinstance(d, Exception)] + diff --git a/exo/download/hf/test_new_shard_download.py b/exo/download/hf/test_new_shard_download.py new file mode 100644 index 000000000..17a999f15 --- /dev/null +++ b/exo/download/hf/test_new_shard_download.py @@ -0,0 +1,20 @@ +from exo.download.hf.new_shard_download import download_shard, NewShardDownloader +from exo.inference.shard import Shard +from exo.models import get_model_id +from pathlib import Path +import asyncio + +async def test_new_shard_download(): + shard_downloader = NewShardDownloader() + shard_downloader.on_progress.register("test").on_next(lambda shard, event: print(shard, event)) + await shard_downloader.ensure_shard(Shard(model_id="llama-3.2-1b", start_layer=0, end_layer=0, n_layers=16), "MLXDynamicShardInferenceEngine") + download_statuses = await shard_downloader.get_shard_download_status("MLXDynamicShardInferenceEngine") + print({k: v for k, v in download_statuses if v.downloaded_bytes > 0}) + +async def test_helpers(): + print(get_model_id("mlx-community/Llama-3.3-70B-Instruct-4bit")) + print(get_model_id("fjsekljfd")) + +if __name__ == "__main__": + asyncio.run(test_helpers()) + diff --git a/exo/download/shard_download.py b/exo/download/shard_download.py index 89d3582e4..21ea412c2 100644 --- a/exo/download/shard_download.py +++ b/exo/download/shard_download.py @@ -27,7 +27,7 @@ def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent pass @abstractmethod - async def get_shard_download_status(self) -> Optional[Dict[str, float]]: + async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]: """Get the download status of shards. Returns: @@ -45,5 +45,5 @@ async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path: def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]: return AsyncCallbackSystem() - async def get_shard_download_status(self) -> Optional[Dict[str, float]]: + async def get_shard_download_status(self, inference_engine_name: str) -> Optional[Dict[str, float]]: return None diff --git a/exo/main.py b/exo/main.py index d58224ca7..1ea527704 100644 --- a/exo/main.py +++ b/exo/main.py @@ -24,7 +24,7 @@ from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy from exo.api import ChatGPTAPI from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader -from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.download.hf.new_shard_download import new_shard_downloader from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses_and_interfaces, terminal_link, shutdown from exo.inference.shard import Shard from exo.inference.inference_engine import get_inference_engine, InferenceEngine @@ -117,8 +117,7 @@ def configure_uvloop(): system_info = get_system_info() print(f"Detected system: {system_info}") -shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check, - max_parallel_downloads=args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader() +shard_downloader: ShardDownloader = new_shard_downloader() if args.inference_engine != "dummy" else NoopShardDownloader() inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad") print(f"Inference engine name after selection: {inference_engine_name}") @@ -175,10 +174,10 @@ def configure_uvloop(): None, inference_engine, discovery, + shard_downloader, partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), max_generate_tokens=args.max_generate_tokens, topology_viz=topology_viz, - shard_downloader=shard_downloader, default_sample_temperature=args.default_temp ) server = GRPCServer(node, args.node_host, args.node_port) diff --git a/exo/models.py b/exo/models.py index 29019e91b..9880c64a6 100644 --- a/exo/models.py +++ b/exo/models.py @@ -175,8 +175,11 @@ "deepseek-coder-v2-lite": "Deepseek Coder V2 Lite", "deepseek-coder-v2.5": "Deepseek Coder V2.5", "deepseek-v3": "Deepseek V3", + "deepseek-v3-3bit": "Deepseek V3 (3-bit)", "deepseek-r1": "Deepseek R1", + "deepseek-r1-3bit": "Deepseek R1 (3-bit)", "llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)", + "qwen-2.5-0.5b": "Qwen 2.5 0.5B", "qwen-2.5-1.5b": "Qwen 2.5 1.5B", "qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B", "qwen-2.5-3b": "Qwen 2.5 3B", @@ -232,6 +235,12 @@ def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]: return model_cards.get(model_id, {}).get("repo", {}).get(inference_engine_classname, None) +def get_model_id(repo_id: str) -> Optional[str]: + return next((model_id for model_id, card in model_cards.items() if repo_id in card["repo"].values()), None) + +def get_pretty_name(model_id: str) -> Optional[str]: + return pretty_name.get(model_id, None) + def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]: repo = get_repo(model_id, inference_engine_classname) n_layers = model_cards.get(model_id, {}).get("layers", 0) @@ -239,7 +248,7 @@ def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional return None return Shard(model_id, 0, 0, n_layers) -def get_supported_models(supported_inference_engine_lists: List[List[str]]) -> List[str]: +def get_supported_models(supported_inference_engine_lists: Optional[List[List[str]]] = None) -> List[str]: if not supported_inference_engine_lists: return list(model_cards.keys()) diff --git a/exo/orchestration/node.py b/exo/orchestration/node.py index 36c771b0b..d90754e2d 100644 --- a/exo/orchestration/node.py +++ b/exo/orchestration/node.py @@ -15,7 +15,7 @@ from exo.viz.topology_viz import TopologyViz from exo.download.hf.hf_helpers import RepoProgressEvent from exo.inference.inference_engine import get_inference_engine, InferenceEngine -from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.download.shard_download import ShardDownloader class Node: def __init__( @@ -24,16 +24,17 @@ def __init__( server: Server, inference_engine: InferenceEngine, discovery: Discovery, + shard_downloader: ShardDownloader, partitioning_strategy: PartitioningStrategy = None, max_generate_tokens: int = 1024, default_sample_temperature: float = 0.0, topology_viz: Optional[TopologyViz] = None, - shard_downloader: Optional[HFShardDownloader] = None, ): self.id = _id self.inference_engine = inference_engine self.server = server self.discovery = discovery + self.shard_downloader = shard_downloader self.partitioning_strategy = partitioning_strategy self.peers: List[PeerHandle] = {} self.topology: Topology = Topology() @@ -52,7 +53,6 @@ def __init__( self._on_opaque_status.register("node_status").on_next(self.on_node_status) self.node_download_progress: Dict[str, RepoProgressEvent] = {} self.topology_inference_engines_pool: List[List[str]] = [] - self.shard_downloader = shard_downloader self.outstanding_requests = {} async def start(self, wait_for_peers: int = 0) -> None: diff --git a/exo/orchestration/test_node.py b/exo/orchestration/test_node.py index 8bee4b2ba..6f992bbeb 100644 --- a/exo/orchestration/test_node.py +++ b/exo/orchestration/test_node.py @@ -5,7 +5,7 @@ from .node import Node from exo.networking.peer_handle import PeerHandle - +from exo.download.shard_download import NoopShardDownloader class TestNode(unittest.IsolatedAsyncioTestCase): def setUp(self): @@ -22,7 +22,7 @@ def setUp(self): mock_peer2.id.return_value = "peer2" self.mock_discovery.discover_peers = AsyncMock(return_value=[mock_peer1, mock_peer2]) - self.node = Node("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery) + self.node = Node("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery, NoopShardDownloader()) async def asyncSetUp(self): await self.node.start() diff --git a/exo/tinychat/index.js b/exo/tinychat/index.js index dff5fc97a..b0eea9175 100644 --- a/exo/tinychat/index.js +++ b/exo/tinychat/index.js @@ -75,12 +75,12 @@ document.addEventListener("alpine:init", () => { while (true) { try { await this.populateSelector(); - // Wait 5 seconds before next poll - await new Promise(resolve => setTimeout(resolve, 5000)); + // Wait 15 seconds before next poll + await new Promise(resolve => setTimeout(resolve, 15000)); } catch (error) { console.error('Model polling error:', error); // If there's an error, wait before retrying - await new Promise(resolve => setTimeout(resolve, 5000)); + await new Promise(resolve => setTimeout(resolve, 15000)); } } }, From 1df023023eeeeff8ed0f7ad0319622199293efd0 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Mon, 27 Jan 2025 01:06:47 +0000 Subject: [PATCH 02/16] remove a lot of hf bloat --- exo/api/chatgpt_api.py | 46 +- exo/download/hf/hf_helpers.py | 414 +----------------- exo/download/hf/hf_shard_download.py | 172 -------- exo/download/{hf => }/new_shard_download.py | 37 +- .../{hf => }/test_new_shard_download.py | 0 exo/inference/tokenizers.py | 33 +- exo/main.py | 19 +- exo/orchestration/node.py | 2 +- exo/viz/test_topology_viz.py | 2 +- exo/viz/topology_viz.py | 2 +- 10 files changed, 76 insertions(+), 651 deletions(-) delete mode 100644 exo/download/hf/hf_shard_download.py rename exo/download/{hf => }/new_shard_download.py (87%) rename exo/download/{hf => }/test_new_shard_download.py (100%) diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index 3d36717c7..905338406 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -21,19 +21,17 @@ import base64 from io import BytesIO import platform -from exo.download.shard_download import RepoProgressEvent +from exo.download.download_progress import RepoProgressEvent +from exo.download.new_shard_download import ensure_downloads_dir, delete_model +import tempfile +from exo.apputil import create_animation_mp4 +from collections import defaultdict if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64": import mlx.core as mx else: import numpy as mx -import tempfile -import shutil -from exo.download.hf.hf_helpers import get_hf_home, get_repo_root -from exo.apputil import create_animation_mp4 -from collections import defaultdict - class Message: def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None): @@ -545,35 +543,13 @@ def on_result(_request_id: str, result, is_finished: bool): return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500) async def handle_delete_model(self, request): + model_id = request.match_info.get('model_name') try: - model_name = request.match_info.get('model_name') - if DEBUG >= 2: print(f"Attempting to delete model: {model_name}") - - if not model_name or model_name not in model_cards: - return web.json_response({"detail": f"Invalid model name: {model_name}"}, status=400) - - shard = build_base_shard(model_name, self.inference_engine_classname) - if not shard: - return web.json_response({"detail": "Could not build shard for model"}, status=400) - - repo_id = get_repo(shard.model_id, self.inference_engine_classname) - if DEBUG >= 2: print(f"Repo ID for model: {repo_id}") - - # Get the HF cache directory using the helper function - hf_home = get_hf_home() - cache_dir = get_repo_root(repo_id) - - if DEBUG >= 2: print(f"Looking for model files in: {cache_dir}") - - if os.path.exists(cache_dir): - if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...") - try: - shutil.rmtree(cache_dir) - return web.json_response({"status": "success", "message": f"Model {model_name} deleted successfully", "path": str(cache_dir)}) - except Exception as e: - return web.json_response({"detail": f"Failed to delete model files: {str(e)}"}, status=500) - else: - return web.json_response({"detail": f"Model files not found at {cache_dir}"}, status=404) + if await delete_model(model_id, self.inference_engine_classname): return web.json_response({"status": "success", "message": f"Model {model_id} deleted successfully"}) + else: return web.json_response({"detail": f"Model {model_id} files not found"}, status=404) + except Exception as e: + if DEBUG >= 2: traceback.print_exc() + return web.json_response({"detail": f"Error deleting model: {str(e)}"}, status=500) except Exception as e: print(f"Error in handle_delete_model: {str(e)}") diff --git a/exo/download/hf/hf_helpers.py b/exo/download/hf/hf_helpers.py index c0fa44e22..01cd86508 100644 --- a/exo/download/hf/hf_helpers.py +++ b/exo/download/hf/hf_helpers.py @@ -1,36 +1,16 @@ import aiofiles.os as aios from typing import Union -import asyncio -import aiohttp -import json import os -import sys -import shutil -from urllib.parse import urljoin -from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal -from datetime import datetime, timedelta +from typing import Callable, Optional, Dict, List, Union from fnmatch import fnmatch from pathlib import Path -from typing import Generator, Iterable, TypeVar, TypedDict -from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type -from exo.helpers import DEBUG, is_frozen -from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback +from typing import Generator, Iterable, TypeVar +from exo.helpers import DEBUG from exo.inference.shard import Shard import aiofiles T = TypeVar("T") -async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]: - refs_dir = get_repo_root(repo_id)/"refs" - refs_file = refs_dir/revision - if await aios.path.exists(refs_file): - async with aiofiles.open(refs_file, 'r') as f: - commit_hash = (await f.read()).strip() - snapshot_dir = get_repo_root(repo_id)/"snapshots"/commit_hash - return snapshot_dir - return None - - def filter_repo_objects( items: Iterable[T], *, @@ -48,14 +28,12 @@ def filter_repo_objects( ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns] if key is None: - def _identity(item: T) -> str: if isinstance(item, str): return item if isinstance(item, Path): return str(item) raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.") - key = _identity for item in items: @@ -66,22 +44,18 @@ def _identity(item: T) -> str: continue yield item - def _add_wildcard_to_directories(pattern: str) -> str: if pattern[-1] == "/": return pattern + "*" return pattern - def get_hf_endpoint() -> str: return os.environ.get('HF_ENDPOINT', "https://huggingface.co") - def get_hf_home() -> Path: """Get the Hugging Face home directory.""" return Path(os.environ.get("HF_HOME", Path.home()/".cache"/"huggingface")) - async def get_hf_token(): """Retrieve the Hugging Face token from the user's HF_HOME directory.""" token_path = get_hf_home()/"token" @@ -90,7 +64,6 @@ async def get_hf_token(): return (await f.read()).strip() return None - async def get_auth_headers(): """Get authentication headers if a token is available.""" token = await get_hf_token() @@ -98,324 +71,6 @@ async def get_auth_headers(): return {"Authorization": f"Bearer {token}"} return {} - -def get_repo_root(repo_id: str) -> Path: - """Get the root directory for a given repo ID in the Hugging Face cache.""" - sanitized_repo_id = str(repo_id).replace("/", "--") - return get_hf_home()/"hub"/f"models--{sanitized_repo_id}" - -async def move_models_to_hf(seed_dir: Union[str, Path]): - """Move model in resources folder of app to .cache/huggingface/hub""" - source_dir = Path(seed_dir) - dest_dir = get_hf_home()/"hub" - await aios.makedirs(dest_dir, exist_ok=True) - for path in source_dir.iterdir(): - if path.is_dir() and path.name.startswith("models--"): - dest_path = dest_dir / path.name - if await aios.path.exists(dest_path): - print('Skipping moving model to .cache directory') - else: - try: - await aios.rename(str(path), str(dest_path)) - except Exception as e: - print(f'Error moving model to .cache: {e}') - - - -async def fetch_file_list(session, repo_id, revision, path=""): - api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}" - url = f"{api_url}/{path}" if path else api_url - - headers = await get_auth_headers() - async with session.get(url, headers=headers) as response: - if response.status == 200: - data = await response.json() - files = [] - for item in data: - if item["type"] == "file": - files.append({"path": item["path"], "size": item["size"]}) - elif item["type"] == "directory": - subfiles = await fetch_file_list(session, repo_id, revision, item["path"]) - files.extend(subfiles) - return files - else: - raise Exception(f"Failed to fetch file list: {response.status}") - - -@retry( - stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60), retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)), reraise=True -) -async def download_file( - session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True -): - base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/" - url = urljoin(base_url, file_path) - local_path = os.path.join(save_directory, file_path) - - await aios.makedirs(os.path.dirname(local_path), exist_ok=True) - - # Check if file already exists and get its size - local_file_size = await aios.path.getsize(local_path) if await aios.path.exists(local_path) else 0 - - headers = await get_auth_headers() - if use_range_request: - headers["Range"] = f"bytes={local_file_size}-" - - async with session.get(url, headers=headers) as response: - total_size = int(response.headers.get('Content-Length', 0)) - downloaded_size = local_file_size - downloaded_this_session = 0 - mode = 'ab' if use_range_request else 'wb' - percentage = await get_file_download_percentage( - session, - repo_id, - revision, - file_path, - Path(save_directory) - ) - - if percentage == 100: - if DEBUG >= 2: print(f"File already downloaded: {file_path}") - if progress_callback: - await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, total_size, 0, total_size, 0, timedelta(0), "complete")) - return - - if response.status == 200: - # File doesn't support range requests or we're not using them, start from beginning - mode = 'wb' - downloaded_size = 0 - elif response.status == 206: - # Partial content, resume download - content_range = response.headers.get('Content-Range', '') - try: - total_size = int(content_range.split('/')[-1]) - except ValueError: - if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...") - return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False) - elif response.status == 416: - # Range not satisfiable, get the actual file size - content_range = response.headers.get('Content-Range', '') - try: - total_size = int(content_range.split('/')[-1]) - if downloaded_size == total_size: - if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}") - if progress_callback: - await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete")) - return - except ValueError: - if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...") - return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False) - else: - raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}") - - if downloaded_size == total_size: - if progress_callback: - await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete")) - return - - DOWNLOAD_CHUNK_SIZE = 32768 - start_time = datetime.now() - async with aiofiles.open(local_path, mode) as f: - async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE): - await f.write(chunk) - downloaded_size += len(chunk) - downloaded_this_session += len(chunk) - if progress_callback and total_size: - elapsed_time = (datetime.now() - start_time).total_seconds() - speed = int(downloaded_this_session/elapsed_time) if elapsed_time > 0 else 0 - remaining_size = total_size - downloaded_size - eta = timedelta(seconds=remaining_size/speed) if speed > 0 else timedelta(0) - status = "in_progress" if downloaded_size < total_size else "complete" - if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}") - await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status)) - if DEBUG >= 2: print(f"Downloaded: {file_path}") - - -async def resolve_revision_to_commit_hash(repo_id: str, revision: str) -> str: - repo_root = get_repo_root(repo_id) - refs_dir = repo_root/"refs" - refs_file = refs_dir/revision - - # Check if we have a cached commit hash - if await aios.path.exists(refs_file): - async with aiofiles.open(refs_file, 'r') as f: - commit_hash = (await f.read()).strip() - if DEBUG >= 2: print(f"Commit hash is already cached at {refs_file}: {commit_hash}") - return commit_hash - - # Fetch the commit hash for the given revision - async with aiohttp.ClientSession() as session: - api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/revision/{revision}" - headers = await get_auth_headers() - async with session.get(api_url, headers=headers) as response: - if response.status != 200: - raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}") - revision_info = await response.json() - commit_hash = revision_info['sha'] - - # Cache the commit hash - await aios.makedirs(refs_dir, exist_ok=True) - async with aiofiles.open(refs_file, 'w') as f: - await f.write(commit_hash) - - return commit_hash - - -async def download_repo_files( - repo_id: str, - revision: str = "main", - progress_callback: Optional[RepoProgressCallback] = None, - allow_patterns: Optional[Union[List[str], str]] = None, - ignore_patterns: Optional[Union[List[str], str]] = None, - max_parallel_downloads: int = 4 -) -> Path: - repo_root = get_repo_root(repo_id) - snapshots_dir = repo_root/"snapshots" - cachedreqs_dir = repo_root/"cachedreqs" - - # Ensure directories exist - await aios.makedirs(snapshots_dir, exist_ok=True) - await aios.makedirs(cachedreqs_dir, exist_ok=True) - - # Resolve revision to commit hash - commit_hash = await resolve_revision_to_commit_hash(repo_id, revision) - - # Set up the snapshot directory - snapshot_dir = snapshots_dir/commit_hash - await aios.makedirs(snapshot_dir, exist_ok=True) - - # Set up the cached file list directory - cached_file_list_dir = cachedreqs_dir/commit_hash - await aios.makedirs(cached_file_list_dir, exist_ok=True) - cached_file_list_path = cached_file_list_dir/"fetch_file_list.json" - - async with aiohttp.ClientSession() as session: - # Check if we have a cached file list - if await aios.path.exists(cached_file_list_path): - async with aiofiles.open(cached_file_list_path, 'r') as f: - file_list = json.loads(await f.read()) - if DEBUG >= 2: print(f"Using cached file list from {cached_file_list_path}") - else: - file_list = await fetch_file_list(session, repo_id, revision) - # Cache the file list - async with aiofiles.open(cached_file_list_path, 'w') as f: - await f.write(json.dumps(file_list)) - if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}") - - model_index_exists = any(file["path"] == "model_index.json" for file in file_list) - if model_index_exists: - allow_patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"] - - filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"])) - total_files = len(filtered_file_list) - total_bytes = sum(file["size"] for file in filtered_file_list) - file_progress: Dict[str, RepoFileProgressEvent] = { - file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started") - for file in filtered_file_list - } - start_time = datetime.now() - - async def download_with_progress(file_info, progress_state): - local_path = snapshot_dir/file_info["path"] - if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]: - if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}") - progress_state['completed_files'] += 1 - progress_state['downloaded_bytes'] += file_info["size"] - file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete") - if progress_callback: - elapsed_time = (datetime.now() - start_time).total_seconds() - overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0 - remaining_bytes = total_bytes - progress_state['downloaded_bytes'] - overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0) - status = "in_progress" if progress_state['completed_files'] < total_files else "complete" - await progress_callback( - RepoProgressEvent( - repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, - overall_eta, file_progress, status - ) - ) - return - - async def file_progress_callback(event: RepoFileProgressEvent): - progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded - progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session - file_progress[event.file_path] = event - if progress_callback: - elapsed_time = (datetime.now() - start_time).total_seconds() - overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0 - remaining_bytes = total_bytes - progress_state['downloaded_bytes'] - overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0) - status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete" - await progress_callback( - RepoProgressEvent( - repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, - overall_eta, file_progress, status - ) - ) - - await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback) - progress_state['completed_files'] += 1 - file_progress[ - file_info["path"] - ] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete") - if progress_callback: - elapsed_time = (datetime.now() - start_time).total_seconds() - overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0 - remaining_bytes = total_bytes - progress_state['downloaded_bytes'] - overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0) - status = "in_progress" if progress_state['completed_files'] < total_files else "complete" - await progress_callback( - RepoProgressEvent( - repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, - overall_eta, file_progress, status - ) - ) - - progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0} - - semaphore = asyncio.Semaphore(max_parallel_downloads) - - async def download_with_semaphore(file_info): - async with semaphore: - await download_with_progress(file_info, progress_state) - - tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list] - await asyncio.gather(*tasks) - - return snapshot_dir - - -async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[str, str]]: - """ - Retrieve the weight map from the model.safetensors.index.json file. - - Args: - repo_id (str): The Hugging Face repository ID. - revision (str): The revision of the repository to use. - - Returns: - Optional[Dict[str, str]]: The weight map if it exists, otherwise None. - """ - - # Download the index file - await download_repo_files(repo_id=repo_id, revision=revision, allow_patterns="model.safetensors.index.json") - - # Check if the file exists - repo_root = get_repo_root(repo_id) - commit_hash = await resolve_revision_to_commit_hash(repo_id, revision) - snapshot_dir = repo_root/"snapshots"/commit_hash - index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None) - - if index_file: - index_file_path = snapshot_dir/index_file - if await aios.path.exists(index_file_path): - async with aiofiles.open(index_file_path, 'r') as f: - index_data = json.loads(await f.read()) - return index_data.get("weight_map") - - return None - - def extract_layer_num(tensor_name: str) -> Optional[int]: # This is a simple example and might need to be adjusted based on the actual naming convention parts = tensor_name.split('.') @@ -424,7 +79,6 @@ def extract_layer_num(tensor_name: str) -> Optional[int]: return int(part) return None - def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]: default_patterns = set(["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"]) shard_specific_patterns = set() @@ -442,65 +96,3 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]: shard_specific_patterns = set(["*.safetensors"]) if DEBUG >= 3: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}") return list(default_patterns | shard_specific_patterns) - -async def get_file_download_percentage( - session: aiohttp.ClientSession, - repo_id: str, - revision: str, - file_path: str, - snapshot_dir: Path, -) -> float: - """ - Calculate the download percentage for a file by comparing local and remote sizes. - """ - try: - local_path = snapshot_dir / file_path - if not await aios.path.exists(local_path): - return 0 - - # Get local file size first - local_size = await aios.path.getsize(local_path) - if local_size == 0: - return 0 - - # Check remote size - base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/" - url = urljoin(base_url, file_path) - headers = await get_auth_headers() - - # Use HEAD request with redirect following for all files - async with session.head(url, headers=headers, allow_redirects=True) as response: - if response.status != 200: - if DEBUG >= 2: - print(f"Failed to get remote file info for {file_path}: {response.status}") - return 0 - - remote_size = int(response.headers.get('Content-Length', 0)) - - if remote_size == 0: - if DEBUG >= 2: - print(f"Remote size is 0 for {file_path}") - return 0 - - # Only return 100% if sizes match exactly - if local_size == remote_size: - return 100.0 - - # Calculate percentage based on sizes - return (local_size / remote_size) * 100 if remote_size > 0 else 0 - - except Exception as e: - if DEBUG >= 2: - print(f"Error checking file download status for {file_path}: {e}") - return 0 - -async def has_hf_home_read_access() -> bool: - hf_home = get_hf_home() - try: return await aios.access(hf_home, os.R_OK) - except OSError: return False - -async def has_hf_home_write_access() -> bool: - hf_home = get_hf_home() - try: return await aios.access(hf_home, os.W_OK) - except OSError: return False - diff --git a/exo/download/hf/hf_shard_download.py b/exo/download/hf/hf_shard_download.py deleted file mode 100644 index 705bab843..000000000 --- a/exo/download/hf/hf_shard_download.py +++ /dev/null @@ -1,172 +0,0 @@ -import asyncio -import traceback -from pathlib import Path -from typing import Dict, List, Tuple, Optional, Union -from exo.inference.shard import Shard -from exo.download.shard_download import ShardDownloader -from exo.download.download_progress import RepoProgressEvent -from exo.download.hf.hf_helpers import ( - download_repo_files, RepoProgressEvent, get_weight_map, - get_allow_patterns, get_repo_root, fetch_file_list, - get_local_snapshot_dir, get_file_download_percentage, - filter_repo_objects -) -from exo.helpers import AsyncCallbackSystem, DEBUG -from exo.models import model_cards, get_repo -import aiohttp -from aiofiles import os as aios - - -class HFShardDownloader(ShardDownloader): - def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4): - self.quick_check = quick_check - self.max_parallel_downloads = max_parallel_downloads - self.active_downloads: Dict[Shard, asyncio.Task] = {} - self.completed_downloads: Dict[Shard, Path] = {} - self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]() - self.current_shard: Optional[Shard] = None - self.current_repo_id: Optional[str] = None - self.revision: str = "main" - - async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path: - self.current_shard = shard - self.current_repo_id = get_repo(shard.model_id, inference_engine_name) - repo_name = get_repo(shard.model_id, inference_engine_name) - if shard in self.completed_downloads: - return self.completed_downloads[shard] - if self.quick_check: - repo_root = get_repo_root(repo_name) - snapshots_dir = repo_root/"snapshots" - if snapshots_dir.exists(): - visible_dirs = [d for d in snapshots_dir.iterdir() if not d.name.startswith('.')] - if visible_dirs: - most_recent_dir = max(visible_dirs, key=lambda x: x.stat().st_mtime) - return most_recent_dir - - # If a download on this shard is already in progress, keep that one - for active_shard in self.active_downloads: - if active_shard == shard: - if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.") - return await self.active_downloads[shard] - - # Cancel any downloads for this model_id on a different shard - existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id] - for active_shard in existing_active_shards: - if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})") - task = self.active_downloads[active_shard] - task.cancel() - try: - await task - except asyncio.CancelledError: - pass # This is expected when cancelling a task - except Exception as e: - if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}") - traceback.print_exc() - self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id} - - # Start new download - download_task = asyncio.create_task(self._download_shard(shard, repo_name)) - self.active_downloads[shard] = download_task - try: - path = await download_task - self.completed_downloads[shard] = path - return path - finally: - # Ensure the task is removed even if an exception occurs - print(f"Removing download task for {shard}: {shard in self.active_downloads}") - if shard in self.active_downloads: - self.active_downloads.pop(shard) - - async def _download_shard(self, shard: Shard, repo_name: str) -> Path: - async def wrapped_progress_callback(event: RepoProgressEvent): - self._on_progress.trigger_all(shard, event) - - weight_map = await get_weight_map(repo_name) - allow_patterns = get_allow_patterns(weight_map, shard) - - return await download_repo_files(repo_name, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads) - - @property - def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]: - return self._on_progress - - async def get_shard_download_status(self) -> Optional[Dict[str, Union[float, int]]]: - if not self.current_shard or not self.current_repo_id: - if DEBUG >= 2: - print(f"No current shard or repo_id set: {self.current_shard=} {self.current_repo_id=}") - return None - - try: - # If no snapshot directory exists, return None - no need to check remote files - snapshot_dir = await get_local_snapshot_dir(self.current_repo_id, self.revision) - if not snapshot_dir: - if DEBUG >= 2: - print(f"No snapshot directory found for {self.current_repo_id}") - return None - - if not await aios.path.exists(snapshot_dir/"model_index.json"): - # Get the weight map to know what files we need - weight_map = await get_weight_map(self.current_repo_id, self.revision) - if not weight_map: - if DEBUG >= 2: - print(f"No weight map found for {self.current_repo_id}") - return None - - # Get all files needed for this shard - patterns = get_allow_patterns(weight_map, self.current_shard) - else: - patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"] - - - # Check download status for all relevant files - status = {} - total_bytes = 0 - downloaded_bytes = 0 - - async with aiohttp.ClientSession() as session: - file_list = await fetch_file_list(session, self.current_repo_id, self.revision) - relevant_files = list( - filter_repo_objects( - file_list, allow_patterns=patterns, key=lambda x: x["path"])) - - for file in relevant_files: - file_size = file["size"] - total_bytes += file_size - - percentage = await get_file_download_percentage( - session, - self.current_repo_id, - self.revision, - file["path"], - snapshot_dir, - ) - status[file["path"]] = percentage - downloaded_bytes += (file_size * (percentage / 100)) - - # Add overall progress weighted by file size - if total_bytes > 0: - status["overall"] = (downloaded_bytes / total_bytes) * 100 - else: - status["overall"] = 0 - - # Add total size in bytes - status["total_size"] = total_bytes - if status["overall"] != 100: - status["total_downloaded"] = downloaded_bytes - - - if DEBUG >= 2: - print(f"Download calculation for {self.current_repo_id}:") - print(f"Total bytes: {total_bytes}") - print(f"Downloaded bytes: {downloaded_bytes}") - if DEBUG >= 3: - for file in relevant_files: - print(f"File {file['path']}: size={file['size']}, percentage={status[file['path']]}") - - return status - - except Exception as e: - if DEBUG >= 3: - print(f"Error getting shard download status: {e}") - traceback.print_exc() - return None diff --git a/exo/download/hf/new_shard_download.py b/exo/download/new_shard_download.py similarity index 87% rename from exo/download/hf/new_shard_download.py rename to exo/download/new_shard_download.py index 28ca8cefa..dccd3b725 100644 --- a/exo/download/hf/new_shard_download.py +++ b/exo/download/new_shard_download.py @@ -16,15 +16,46 @@ from datetime import timedelta import asyncio import json +import traceback +import shutil def exo_home() -> Path: return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo")) +async def has_exo_home_read_access() -> bool: + try: return await aios.access(exo_home(), os.R_OK) + except OSError: return False + +async def has_exo_home_write_access() -> bool: + try: return await aios.access(exo_home(), os.W_OK) + except OSError: return False + async def ensure_downloads_dir() -> Path: downloads_dir = exo_home()/"downloads" await aios.makedirs(downloads_dir, exist_ok=True) return downloads_dir +async def delete_model(model_id: str, inference_engine_name: str) -> bool: + repo_id = get_repo(model_id, inference_engine_name) + model_dir = await ensure_downloads_dir()/repo_id.replace("/", "--") + if not await aios.path.exists(model_dir): return False + await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False) + return True + +async def seed_models(seed_dir: Union[str, Path]): + """Move model in resources folder of app to .cache/huggingface/hub""" + source_dir = Path(seed_dir) + dest_dir = await ensure_downloads_dir() + for path in source_dir.iterdir(): + if path.is_dir() and path.name.startswith("models--"): + dest_path = dest_dir/path.name + if await aios.path.exists(dest_path): print('Skipping moving model to .cache directory') + else: + try: await aios.rename(str(path), str(dest_path)) + except: + print(f"Error seeding model {path} to {dest_path}") + traceback.print_exc() + async def fetch_file_list(session, repo_id, revision, path=""): api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}" url = f"{api_url}/{path}" if path else api_url @@ -44,8 +75,9 @@ async def fetch_file_list(session, repo_id, revision, path=""): else: raise Exception(f"Failed to fetch file list: {response.status}") -async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Optional[Callable[[int, int], None]] = None) -> Path: +async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path: if (target_dir/path).exists(): return target_dir/path + await aios.makedirs((target_dir/path).parent, exist_ok=True) base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/" url = urljoin(base_url, path) headers = await get_auth_headers() @@ -71,8 +103,7 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str] target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--") async with aiohttp.ClientSession() as session: index_file = await download_file(session, repo_id, revision, "model.safetensors.index.json", target_dir) - async with aiofiles.open(index_file, 'r') as f: - index_data = json.loads(await f.read()) + async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read()) return index_data.get("weight_map") async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str) -> list[str]: diff --git a/exo/download/hf/test_new_shard_download.py b/exo/download/test_new_shard_download.py similarity index 100% rename from exo/download/hf/test_new_shard_download.py rename to exo/download/test_new_shard_download.py diff --git a/exo/inference/tokenizers.py b/exo/inference/tokenizers.py index 4dccaf664..865765d5f 100644 --- a/exo/inference/tokenizers.py +++ b/exo/inference/tokenizers.py @@ -1,12 +1,11 @@ import traceback -from aiofiles import os as aios from os import PathLike -from pathlib import Path +from aiofiles import os as aios from typing import Union from transformers import AutoTokenizer, AutoProcessor import numpy as np -from exo.download.hf.hf_helpers import get_local_snapshot_dir from exo.helpers import DEBUG +from exo.download.new_shard_download import ensure_downloads_dir class DummyTokenizer: @@ -24,25 +23,25 @@ def decode(self, tokens): return "dummy" * len(tokens) -async def resolve_tokenizer(model_id: str): - if model_id == "dummy": +async def resolve_tokenizer(repo_id: Union[str, PathLike]): + if repo_id == "dummy": return DummyTokenizer() - local_path = await get_local_snapshot_dir(model_id) + local_path = await ensure_downloads_dir()/str(repo_id).replace("/", "--") if DEBUG >= 2: print(f"Checking if local path exists to load tokenizer from local {local_path=}") try: if local_path and await aios.path.exists(local_path): - if DEBUG >= 2: print(f"Resolving tokenizer for {model_id=} from {local_path=}") + if DEBUG >= 2: print(f"Resolving tokenizer for {repo_id=} from {local_path=}") return await _resolve_tokenizer(local_path) except: - if DEBUG >= 5: print(f"Local check for {local_path=} failed. Resolving tokenizer for {model_id=} normally...") + if DEBUG >= 5: print(f"Local check for {local_path=} failed. Resolving tokenizer for {repo_id=} normally...") if DEBUG >= 5: traceback.print_exc() - return await _resolve_tokenizer(model_id) + return await _resolve_tokenizer(repo_id) -async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]): +async def _resolve_tokenizer(repo_id_or_local_path: Union[str, PathLike]): try: - if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}") - processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=True if "Mistral-Large" in f"{model_id_or_local_path}" else False, trust_remote_code=True) + if DEBUG >= 4: print(f"Trying AutoProcessor for {repo_id_or_local_path}") + processor = AutoProcessor.from_pretrained(repo_id_or_local_path, use_fast=True if "Mistral-Large" in f"{repo_id_or_local_path}" else False, trust_remote_code=True) if not hasattr(processor, 'eos_token_id'): processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id if not hasattr(processor, 'encode'): @@ -51,14 +50,14 @@ async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]): processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode return processor except Exception as e: - if DEBUG >= 4: print(f"Failed to load processor for {model_id_or_local_path}. Error: {e}") + if DEBUG >= 4: print(f"Failed to load processor for {repo_id_or_local_path}. Error: {e}") if DEBUG >= 4: print(traceback.format_exc()) try: - if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id_or_local_path}") - return AutoTokenizer.from_pretrained(model_id_or_local_path, trust_remote_code=True) + if DEBUG >= 4: print(f"Trying AutoTokenizer for {repo_id_or_local_path}") + return AutoTokenizer.from_pretrained(repo_id_or_local_path, trust_remote_code=True) except Exception as e: - if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}") + if DEBUG >= 4: print(f"Failed to load tokenizer for {repo_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}") if DEBUG >= 4: print(traceback.format_exc()) - raise ValueError(f"[TODO] Unsupported model: {model_id_or_local_path}") + raise ValueError(f"[TODO] Unsupported model: {repo_id_or_local_path}") diff --git a/exo/main.py b/exo/main.py index 1ea527704..8f6c5e570 100644 --- a/exo/main.py +++ b/exo/main.py @@ -23,19 +23,18 @@ from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy from exo.api import ChatGPTAPI -from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader -from exo.download.hf.new_shard_download import new_shard_downloader +from exo.download.shard_download import ShardDownloader, NoopShardDownloader +from exo.download.download_progress import RepoProgressEvent +from exo.download.new_shard_download import new_shard_downloader, has_exo_home_read_access, has_exo_home_write_access, exo_home, seed_models from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses_and_interfaces, terminal_link, shutdown from exo.inference.shard import Shard from exo.inference.inference_engine import get_inference_engine, InferenceEngine from exo.inference.tokenizers import resolve_tokenizer from exo.models import build_base_shard, get_repo from exo.viz.topology_viz import TopologyViz -from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf import uvloop from contextlib import asynccontextmanager import concurrent.futures -import socket import resource import psutil @@ -321,13 +320,13 @@ async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_n async def main(): loop = asyncio.get_running_loop() - # Check HuggingFace directory permissions - hf_home, has_read, has_write = get_hf_home(), await has_hf_home_read_access(), await has_hf_home_write_access() - if DEBUG >= 1: print(f"Model storage directory: {hf_home}") + # Check exo directory permissions + home, has_read, has_write = exo_home(), await has_exo_home_read_access(), await has_exo_home_write_access() + if DEBUG >= 1: print(f"exo home directory: {home}") print(f"{has_read=}, {has_write=}") if not has_read or not has_write: print(f""" - WARNING: Limited permissions for model storage directory: {hf_home}. + WARNING: Limited permissions for exo home directory: {home}. This may prevent model downloads from working correctly. {"❌ No read access" if not has_read else ""} {"❌ No write access" if not has_write else ""} @@ -336,9 +335,9 @@ async def main(): if not args.models_seed_dir is None: try: models_seed_dir = clean_path(args.models_seed_dir) - await move_models_to_hf(models_seed_dir) + await seed_models(models_seed_dir) except Exception as e: - print(f"Error moving models to .cache/huggingface: {e}") + print(f"Error seeding models: {e}") def restore_cursor(): if platform.system() != "Windows": diff --git a/exo/orchestration/node.py b/exo/orchestration/node.py index d90754e2d..cd60dd0fa 100644 --- a/exo/orchestration/node.py +++ b/exo/orchestration/node.py @@ -13,7 +13,7 @@ from exo import DEBUG from exo.helpers import AsyncCallbackSystem from exo.viz.topology_viz import TopologyViz -from exo.download.hf.hf_helpers import RepoProgressEvent +from exo.download.download_progress import RepoProgressEvent from exo.inference.inference_engine import get_inference_engine, InferenceEngine from exo.download.shard_download import ShardDownloader diff --git a/exo/viz/test_topology_viz.py b/exo/viz/test_topology_viz.py index e57de1ae3..a10ac43b7 100644 --- a/exo/viz/test_topology_viz.py +++ b/exo/viz/test_topology_viz.py @@ -5,7 +5,7 @@ from exo.topology.topology import Topology from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops from exo.topology.partitioning_strategy import Partition -from exo.download.hf.hf_helpers import RepoProgressEvent, RepoFileProgressEvent +from exo.download.download_progress import RepoProgressEvent def create_hf_repo_progress_event( diff --git a/exo/viz/topology_viz.py b/exo/viz/topology_viz.py index 734fe69dd..0a9f15971 100644 --- a/exo/viz/topology_viz.py +++ b/exo/viz/topology_viz.py @@ -4,7 +4,7 @@ from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second from exo.topology.topology import Topology from exo.topology.partitioning_strategy import Partition -from exo.download.hf.hf_helpers import RepoProgressEvent +from exo.download.download_progress import RepoProgressEvent from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES from rich.console import Console, Group from rich.text import Text From 3c7bd48aa33ef69d4ffe32051ee065eaeb0b175f Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Mon, 27 Jan 2025 01:08:46 +0000 Subject: [PATCH 03/16] get rid of some more hf bloat --- exo/download/test_new_shard_download.py | 2 +- exo/inference/mlx/test_non_blocking.py | 4 +- exo/inference/test_inference_engine.py | 6 +-- extra/download_hf.py | 50 ------------------------- test/test_hf.py | 26 ------------- 5 files changed, 6 insertions(+), 82 deletions(-) delete mode 100644 extra/download_hf.py delete mode 100644 test/test_hf.py diff --git a/exo/download/test_new_shard_download.py b/exo/download/test_new_shard_download.py index 17a999f15..d3303f87c 100644 --- a/exo/download/test_new_shard_download.py +++ b/exo/download/test_new_shard_download.py @@ -1,4 +1,4 @@ -from exo.download.hf.new_shard_download import download_shard, NewShardDownloader +from exo.download.new_shard_download import download_shard, NewShardDownloader from exo.inference.shard import Shard from exo.models import get_model_id from pathlib import Path diff --git a/exo/inference/mlx/test_non_blocking.py b/exo/inference/mlx/test_non_blocking.py index 64eedfdde..4eb2e9fb3 100644 --- a/exo/inference/mlx/test_non_blocking.py +++ b/exo/inference/mlx/test_non_blocking.py @@ -2,7 +2,7 @@ import time import numpy as np from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine -from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.download.new_shard_download import NewShardDownloader from exo.inference.shard import Shard from exo.models import build_base_shard from collections import deque @@ -10,7 +10,7 @@ async def test_non_blocking(): # Setup - shard_downloader = HFShardDownloader() + shard_downloader = NewShardDownloader() engine = MLXDynamicShardInferenceEngine(shard_downloader) _shard = build_base_shard("llama-3.1-8b", "MLXDynamicShardInferenceEngine") shard = Shard(_shard.model_id, _shard.start_layer, _shard.n_layers - 1, _shard.n_layers) diff --git a/exo/inference/test_inference_engine.py b/exo/inference/test_inference_engine.py index 7a69eafce..75021cce9 100644 --- a/exo/inference/test_inference_engine.py +++ b/exo/inference/test_inference_engine.py @@ -1,6 +1,6 @@ from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine -from exo.download.hf.hf_shard_download import HFShardDownloader from exo.inference.inference_engine import InferenceEngine +from exo.download.new_shard_download import NewShardDownloader from exo.inference.shard import Shard from exo.helpers import DEBUG import os @@ -44,7 +44,7 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e assert np.array_equal(next_resp_full, resp4) -asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(HFShardDownloader()), MLXDynamicShardInferenceEngine(HFShardDownloader()), "llama-3.2-1b", 16)) +asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(NewShardDownloader()), MLXDynamicShardInferenceEngine(NewShardDownloader()), "llama-3.2-1b", 16)) if os.getenv("RUN_TINYGRAD", default="0") == "1": import tinygrad @@ -52,5 +52,5 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0")) asyncio.run( - test_inference_engine(TinygradDynamicShardInferenceEngine(HFShardDownloader()), TinygradDynamicShardInferenceEngine(HFShardDownloader()), "llama-3-8b", 32) + test_inference_engine(TinygradDynamicShardInferenceEngine(NewShardDownloader()), TinygradDynamicShardInferenceEngine(NewShardDownloader()), "llama-3-8b", 32) ) diff --git a/extra/download_hf.py b/extra/download_hf.py deleted file mode 100644 index 8cf712e55..000000000 --- a/extra/download_hf.py +++ /dev/null @@ -1,50 +0,0 @@ -import argparse -import asyncio -from exo.download.hf.hf_helpers import download_all_files, RepoProgressEvent - -DEFAULT_ALLOW_PATTERNS = [ - "*.json", - "*.py", - "tokenizer.model", - "*.tiktoken", - "*.txt", - "*.safetensors", -] -# Always ignore `.git` and `.cache/huggingface` folders in commits -DEFAULT_IGNORE_PATTERNS = [ - ".git", - ".git/*", - "*/.git", - "**/.git/**", - ".cache/huggingface", - ".cache/huggingface/*", - "*/.cache/huggingface", - "**/.cache/huggingface/**", -] - - -async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None): - async def progress_callback(event: RepoProgressEvent): - print(f"Overall Progress: {event.completed_files}/{event.total_files} files, {event.downloaded_bytes}/{event.total_bytes} bytes") - print(f"Estimated time remaining: {event.overall_eta}") - print("File Progress:") - for file_path, progress in event.file_progress.items(): - status_icon = {'not_started': '⚪', 'in_progress': '🔵', 'complete': '✅'}[progress.status] - eta_str = str(progress.eta) - print(f"{status_icon} {file_path}: {progress.downloaded}/{progress.total} bytes, " - f"Speed: {progress.speed:.2f} B/s, ETA: {eta_str}") - print("\n") - - await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.") - parser.add_argument("--repo-id", required=True, help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')") - parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)") - parser.add_argument("--allow-patterns", nargs="*", default=None, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')") - parser.add_argument("--ignore-patterns", nargs="*", default=None, help="Patterns of files to ignore (e.g., '.*')") - - args = parser.parse_args() - - asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns)) diff --git a/test/test_hf.py b/test/test_hf.py deleted file mode 100644 index 0477d132b..000000000 --- a/test/test_hf.py +++ /dev/null @@ -1,26 +0,0 @@ -import os -import sys - -# Add the project root to the Python path -project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.insert(0, project_root) - -import asyncio -from exo.download.hf.hf_helpers import get_weight_map - -async def test_get_weight_map(): - repo_ids = [ - "mlx-community/quantized-gemma-2b", - "mlx-community/Meta-Llama-3.1-8B-4bit", - "mlx-community/Meta-Llama-3.1-70B-4bit", - "mlx-community/Meta-Llama-3.1-405B-4bit", - ] - for repo_id in repo_ids: - weight_map = await get_weight_map(repo_id) - assert weight_map is not None, "Weight map should not be None" - assert isinstance(weight_map, dict), "Weight map should be a dictionary" - assert len(weight_map) > 0, "Weight map should not be empty" - print(f"OK: {repo_id}") - -if __name__ == "__main__": - asyncio.run(test_get_weight_map()) From 74379ef6717d3b3a4f7924c4cef180e9f5d49f5f Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Mon, 27 Jan 2025 01:11:54 +0000 Subject: [PATCH 04/16] log download logs with DEBUG>=6 very verbose --- exo/download/new_shard_download.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exo/download/new_shard_download.py b/exo/download/new_shard_download.py index dccd3b725..0dfff2f2d 100644 --- a/exo/download/new_shard_download.py +++ b/exo/download/new_shard_download.py @@ -11,7 +11,7 @@ import aiohttp import aiofiles from urllib.parse import urljoin -from typing import Optional, Callable, Union, Tuple, Dict +from typing import Callable, Union, Tuple, Dict import time from datetime import timedelta import asyncio @@ -111,7 +111,7 @@ async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str) return get_allow_patterns(weight_map, shard) async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 6, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]: - if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}") + if DEBUG >= 6 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}") repo_id = get_repo(shard.model_id, inference_engine_classname) revision = "main" target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--") From 277d63d86083076c9a7a64479779133817bba071 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Mon, 27 Jan 2025 01:26:15 +0000 Subject: [PATCH 05/16] special case when a model doesnt have a model index file, then use wildcard for allow_patterns --- exo/download/new_shard_download.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/exo/download/new_shard_download.py b/exo/download/new_shard_download.py index 0dfff2f2d..0c66c2902 100644 --- a/exo/download/new_shard_download.py +++ b/exo/download/new_shard_download.py @@ -11,7 +11,7 @@ import aiohttp import aiofiles from urllib.parse import urljoin -from typing import Callable, Union, Tuple, Dict +from typing import Callable, Union, Tuple, Dict, List import time from datetime import timedelta import asyncio @@ -82,7 +82,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: url = urljoin(base_url, path) headers = await get_auth_headers() async with session.get(url, headers=headers) as r: - assert r.status == 200, r.status + assert r.status == 200, f"Failed to download {path} from {url}: {r.status}" length = int(r.headers.get('content-length', 0)) n_read = 0 async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file: @@ -106,9 +106,13 @@ async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str] async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read()) return index_data.get("weight_map") -async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str) -> list[str]: - weight_map = await get_weight_map(get_repo(shard.model_id, inference_engine_classname)) - return get_allow_patterns(weight_map, shard) +async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str) -> List[str]: + try: + weight_map = await get_weight_map(get_repo(shard.model_id, inference_engine_classname)) + return get_allow_patterns(weight_map, shard) + except Exception as e: + if DEBUG >= 1: print(f"Error getting weight map for {shard.model_id=} and inference engine {inference_engine_classname}: {e}") + return ["*"] async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 6, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]: if DEBUG >= 6 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}") From 21586063f69d3c75dcf3fac9dfa43a966dd1d0d0 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Mon, 27 Jan 2025 01:35:33 +0000 Subject: [PATCH 06/16] use llama-3.2-1b in tinygrad test --- exo/inference/test_inference_engine.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/exo/inference/test_inference_engine.py b/exo/inference/test_inference_engine.py index 75021cce9..956a11629 100644 --- a/exo/inference/test_inference_engine.py +++ b/exo/inference/test_inference_engine.py @@ -51,6 +51,4 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e import os from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0")) - asyncio.run( - test_inference_engine(TinygradDynamicShardInferenceEngine(NewShardDownloader()), TinygradDynamicShardInferenceEngine(NewShardDownloader()), "llama-3-8b", 32) - ) + asyncio.run(test_inference_engine(TinygradDynamicShardInferenceEngine(NewShardDownloader()), TinygradDynamicShardInferenceEngine(NewShardDownloader()), "llama-3.2-1b", 32)) From b349e48b0db48be4b0247d25deec76f086bccde8 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Mon, 27 Jan 2025 02:13:05 +0000 Subject: [PATCH 07/16] fix visual bug where frontend would show the full hf repo size, but in some cases that includes redundant files so we should use the model index in those cases too --- exo/api/chatgpt_api.py | 13 +++------ exo/download/download_progress.py | 5 +++- exo/download/new_shard_download.py | 35 ++++++++++++++++--------- exo/download/test_new_shard_download.py | 7 +---- exo/models.py | 8 +++--- 5 files changed, 36 insertions(+), 32 deletions(-) diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index 905338406..af4459095 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -14,7 +14,7 @@ from exo.helpers import PrefixDict, shutdown, get_exo_images_dir from exo.inference.tokenizers import resolve_tokenizer from exo.orchestration import Node -from exo.models import build_base_shard, model_cards, get_repo, get_model_id, get_supported_models, get_pretty_name +from exo.models import build_base_shard, build_full_shard, model_cards, get_repo, get_supported_models, get_pretty_name from typing import Callable, Optional from PIL import Image import numpy as np @@ -22,7 +22,7 @@ from io import BytesIO import platform from exo.download.download_progress import RepoProgressEvent -from exo.download.new_shard_download import ensure_downloads_dir, delete_model +from exo.download.new_shard_download import delete_model import tempfile from exo.apputil import create_animation_mp4 from collections import defaultdict @@ -278,7 +278,7 @@ async def handle_model_support(self, request): await response.prepare(request) downloads = await self.node.shard_downloader.get_shard_download_status(self.inference_engine_classname) for (path, d) in downloads: - model_data = { get_model_id(d.repo_id): { "downloaded": d.downloaded_bytes == d.total_bytes, "download_percentage": 100 if d.downloaded_bytes == d.total_bytes else 100 * float(d.downloaded_bytes) / float(d.total_bytes), "total_size": d.total_bytes, "total_downloaded": d.downloaded_bytes } } + model_data = { d.shard.model_id: { "downloaded": d.downloaded_bytes == d.total_bytes, "download_percentage": 100 if d.downloaded_bytes == d.total_bytes else 100 * float(d.downloaded_bytes) / float(d.total_bytes), "total_size": d.total_bytes, "total_downloaded": d.downloaded_bytes } } await response.write(f"data: {json.dumps(model_data)}\n\n".encode()) await response.write(b"data: [DONE]\n\n") return response @@ -551,11 +551,6 @@ async def handle_delete_model(self, request): if DEBUG >= 2: traceback.print_exc() return web.json_response({"detail": f"Error deleting model: {str(e)}"}, status=500) - except Exception as e: - print(f"Error in handle_delete_model: {str(e)}") - traceback.print_exc() - return web.json_response({"detail": f"Server error: {str(e)}"}, status=500) - async def handle_get_initial_models(self, request): model_data = {} for model_id in get_supported_models([[self.inference_engine_classname]]): @@ -606,7 +601,7 @@ async def handle_post_download(self, request): model_name = data.get("model") if not model_name: return web.json_response({"error": "model parameter is required"}, status=400) if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400) - shard = build_base_shard(model_name, self.inference_engine_classname) + shard = build_full_shard(model_name, self.inference_engine_classname) if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400) asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname)) diff --git a/exo/download/download_progress.py b/exo/download/download_progress.py index 340cc3217..c02733a3d 100644 --- a/exo/download/download_progress.py +++ b/exo/download/download_progress.py @@ -1,4 +1,5 @@ from typing import Dict, Callable, Coroutine, Any, Literal +from exo.inference.shard import Shard from dataclasses import dataclass from datetime import timedelta @@ -30,6 +31,7 @@ def from_dict(cls, data): @dataclass class RepoProgressEvent: + shard: Shard repo_id: str repo_revision: str completed_files: int @@ -44,7 +46,7 @@ class RepoProgressEvent: def to_dict(self): return { - "repo_id": self.repo_id, "repo_revision": self.repo_revision, "completed_files": self.completed_files, "total_files": self.total_files, "downloaded_bytes": self.downloaded_bytes, + "shard": self.shard.to_dict(), "repo_id": self.repo_id, "repo_revision": self.repo_revision, "completed_files": self.completed_files, "total_files": self.total_files, "downloaded_bytes": self.downloaded_bytes, "downloaded_bytes_this_session": self.downloaded_bytes_this_session, "total_bytes": self.total_bytes, "overall_speed": self.overall_speed, "overall_eta": self.overall_eta.total_seconds(), "file_progress": {k: v.to_dict() for k, v in self.file_progress.items()}, "status": self.status @@ -54,6 +56,7 @@ def to_dict(self): def from_dict(cls, data): if 'overall_eta' in data: data['overall_eta'] = timedelta(seconds=data['overall_eta']) if 'file_progress' in data: data['file_progress'] = {k: RepoFileProgressEvent.from_dict(v) for k, v in data['file_progress'].items()} + if 'shard' in data: data['shard'] = Shard.from_dict(data['shard']) return cls(**data) diff --git a/exo/download/new_shard_download.py b/exo/download/new_shard_download.py index 0c66c2902..2fd6629b1 100644 --- a/exo/download/new_shard_download.py +++ b/exo/download/new_shard_download.py @@ -5,7 +5,7 @@ from exo.download.shard_download import ShardDownloader from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent from exo.helpers import AsyncCallbackSystem, DEBUG -from exo.models import get_supported_models, build_base_shard +from exo.models import get_supported_models, build_full_shard import os import aiofiles.os as aios import aiohttp @@ -18,10 +18,18 @@ import json import traceback import shutil +import tempfile def exo_home() -> Path: return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo")) +def exo_tmp() -> Path: + return Path(tempfile.gettempdir())/"exo" + +async def ensure_exo_tmp() -> Path: + await aios.makedirs(exo_tmp(), exist_ok=True) + return exo_tmp() + async def has_exo_home_read_access() -> bool: try: return await aios.access(exo_home(), os.R_OK) except OSError: return False @@ -90,17 +98,17 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: await aios.rename(temp_file.name, target_dir/path) return target_dir/path -def calculate_repo_progress(repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent: +def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent: all_total_bytes = sum([p.total for p in file_progress.values()]) all_downloaded_bytes = sum([p.downloaded for p in file_progress.values()]) elapsed_time = time.time() - all_start_time all_speed = all_downloaded_bytes / elapsed_time if elapsed_time > 0 else 0 all_eta = timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed) if all_speed > 0 else timedelta(seconds=0) status = "not_started" if all_downloaded_bytes == 0 else "complete" if all_downloaded_bytes == all_total_bytes else "in_progress" - return RepoProgressEvent(repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes, all_total_bytes, all_speed, all_eta, file_progress, status) + return RepoProgressEvent(shard, repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes, all_total_bytes, all_speed, all_eta, file_progress, status) async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]: - target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--") + target_dir = await ensure_exo_tmp()/repo_id.replace("/", "--") async with aiohttp.ClientSession() as session: index_file = await download_file(session, repo_id, revision, "model.safetensors.index.json", target_dir) async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read()) @@ -110,12 +118,13 @@ async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str) try: weight_map = await get_weight_map(get_repo(shard.model_id, inference_engine_classname)) return get_allow_patterns(weight_map, shard) - except Exception as e: - if DEBUG >= 1: print(f"Error getting weight map for {shard.model_id=} and inference engine {inference_engine_classname}: {e}") + except: + if DEBUG >= 1: print(f"Error getting weight map for {shard.model_id=} and inference engine {inference_engine_classname}") + if DEBUG >= 1: traceback.print_exc() return ["*"] async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 6, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]: - if DEBUG >= 6 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}") + if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}") repo_id = get_repo(shard.model_id, inference_engine_classname) revision = "main" target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--") @@ -124,8 +133,8 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr if repo_id is None: raise ValueError(f"No repo found for {shard.model_id=} and inference engine {inference_engine_classname}") - allow_patterns = await resolve_allow_patterns(shard, inference_engine_classname) if not skip_download else None - if DEBUG >= 3: print(f"Downloading {shard.model_id=} with {allow_patterns=}") + allow_patterns = await resolve_allow_patterns(shard, inference_engine_classname) + if DEBUG >= 2: print(f"Downloading {shard.model_id=} with {allow_patterns=}") all_start_time = time.time() async with aiohttp.ClientSession() as session: @@ -137,8 +146,8 @@ def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int): speed = curr_bytes / (time.time() - start_time) eta = timedelta(seconds=(total_bytes - curr_bytes) / speed) file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, curr_bytes, total_bytes, speed, eta, "in_progress", start_time) - on_progress.trigger_all(shard, calculate_repo_progress(repo_id, revision, file_progress, all_start_time)) - if DEBUG >= 2: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}") + on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time)) + if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}") for file in filtered_file_list: downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0 file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "not_started" if downloaded_bytes == 0 else "complete" if downloaded_bytes == file["size"] else "in_progress", time.time()) @@ -148,7 +157,7 @@ async def download_with_semaphore(file): async with semaphore: await download_file(session, repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes)) if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list]) - final_repo_progress = calculate_repo_progress(repo_id, revision, file_progress, all_start_time) + final_repo_progress = calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time) on_progress.trigger_all(shard, final_repo_progress) return target_dir, final_repo_progress @@ -208,7 +217,7 @@ async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path: async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]: if DEBUG >= 2: print("Getting shard download status for", inference_engine_name) - downloads = await asyncio.gather(*[download_shard(build_base_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])], return_exceptions=True) + downloads = await asyncio.gather(*[download_shard(build_full_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])], return_exceptions=True) if DEBUG >= 6: print("Downloaded shards:", downloads) if any(isinstance(d, Exception) for d in downloads) and DEBUG >= 1: print("Error downloading shards:", [d for d in downloads if isinstance(d, Exception)]) return [d for d in downloads if not isinstance(d, Exception)] diff --git a/exo/download/test_new_shard_download.py b/exo/download/test_new_shard_download.py index d3303f87c..15357a0a2 100644 --- a/exo/download/test_new_shard_download.py +++ b/exo/download/test_new_shard_download.py @@ -1,6 +1,5 @@ from exo.download.new_shard_download import download_shard, NewShardDownloader from exo.inference.shard import Shard -from exo.models import get_model_id from pathlib import Path import asyncio @@ -11,10 +10,6 @@ async def test_new_shard_download(): download_statuses = await shard_downloader.get_shard_download_status("MLXDynamicShardInferenceEngine") print({k: v for k, v in download_statuses if v.downloaded_bytes > 0}) -async def test_helpers(): - print(get_model_id("mlx-community/Llama-3.3-70B-Instruct-4bit")) - print(get_model_id("fjsekljfd")) - if __name__ == "__main__": - asyncio.run(test_helpers()) + asyncio.run(test_new_shard_download()) diff --git a/exo/models.py b/exo/models.py index 9880c64a6..74865aca3 100644 --- a/exo/models.py +++ b/exo/models.py @@ -235,9 +235,6 @@ def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]: return model_cards.get(model_id, {}).get("repo", {}).get(inference_engine_classname, None) -def get_model_id(repo_id: str) -> Optional[str]: - return next((model_id for model_id, card in model_cards.items() if repo_id in card["repo"].values()), None) - def get_pretty_name(model_id: str) -> Optional[str]: return pretty_name.get(model_id, None) @@ -248,6 +245,11 @@ def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional return None return Shard(model_id, 0, 0, n_layers) +def build_full_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]: + base_shard = build_base_shard(model_id, inference_engine_classname) + if base_shard is None: return None + return Shard(base_shard.model_id, 0, base_shard.n_layers - 1, base_shard.n_layers) + def get_supported_models(supported_inference_engine_lists: Optional[List[List[str]]] = None) -> List[str]: if not supported_inference_engine_lists: return list(model_cards.keys()) From d7ca9b7732e61a703edee4642487ed940c0b0350 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Mon, 27 Jan 2025 02:20:22 +0000 Subject: [PATCH 08/16] show each node id in the tinychat topology viz --- exo/tinychat/index.css | 2 +- exo/tinychat/index.js | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/exo/tinychat/index.css b/exo/tinychat/index.css index d94a5a236..1971795af 100644 --- a/exo/tinychat/index.css +++ b/exo/tinychat/index.css @@ -843,4 +843,4 @@ main { font-size: 0.8em; color: var(--secondary-color-transparent); font-family: monospace; -} \ No newline at end of file +} diff --git a/exo/tinychat/index.js b/exo/tinychat/index.js index b0eea9175..c38876673 100644 --- a/exo/tinychat/index.js +++ b/exo/tinychat/index.js @@ -637,6 +637,9 @@ document.addEventListener("alpine:init", () => { const vizElement = this.$refs.topologyViz; vizElement.innerHTML = ''; // Clear existing visualization + // Helper function to truncate node ID + const truncateNodeId = (id) => id.substring(0, 8); + // Create nodes from object Object.entries(topologyData.nodes).forEach(([nodeId, node]) => { const nodeElement = document.createElement('div'); @@ -647,14 +650,14 @@ document.addEventListener("alpine:init", () => { const peerConnectionsHtml = peerConnections.map(peer => `
- To ${peer.to_id}: ${peer.description} + To ${truncateNodeId(peer.to_id)}: ${peer.description}
`).join(''); nodeElement.innerHTML = `
- ${node.model} + ${node.model} [${truncateNodeId(nodeId)}]
${node.chip} From 19a27c5bfd6083b3a9916cbb1c3eb200b5ce3f6b Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Mon, 27 Jan 2025 02:59:23 +0000 Subject: [PATCH 09/16] HF_HOME -> EXO_HOME --- .circleci/config.yml | 4 ++-- README.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index ba9c6983b..9f2b46040 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -27,7 +27,7 @@ commands: fi # Start first instance - HF_HOME="$(pwd)/.hf_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <> \ + EXO_HOME="$(pwd)/.exo_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <> \ --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 \ --chatgpt-api-response-timeout 900 --disable-tui > output1.log & PID1=$! @@ -35,7 +35,7 @@ commands: TAIL1=$! # Start second instance - HF_HOME="$(pwd)/.hf_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <> \ + EXO_HOME="$(pwd)/.exo_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <> \ --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 \ --chatgpt-api-response-timeout 900 --disable-tui > output2.log & PID2=$! diff --git a/README.md b/README.md index 25224f742..5e645abba 100644 --- a/README.md +++ b/README.md @@ -212,9 +212,9 @@ exo run llama-3.2-3b --prompt "What is the meaning of exo?" ### Model Storage -Models by default are stored in `~/.cache/huggingface/hub`. +Models by default are stored in `~/.cache/exo/downloads`. -You can set a different model storage location by setting the `HF_HOME` env var. +You can set a different model storage location by setting the `EXO_HOME` env var. ## Debugging From 295f41c5cc84cefc1af5ddd20d2b45ad09b363b0 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Mon, 27 Jan 2025 05:03:35 +0000 Subject: [PATCH 10/16] increase bench job timeout to give enough time to download --- .github/workflows/bench_job.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/bench_job.yml b/.github/workflows/bench_job.yml index 842f2527e..5e089283f 100644 --- a/.github/workflows/bench_job.yml +++ b/.github/workflows/bench_job.yml @@ -128,6 +128,7 @@ jobs: --interface-type-filter="${{ inputs.network_interface }}" \ --disable-tui \ --max-generate-tokens 250 \ + --chatgpt-api-response-timeout 900 \ --chatgpt-api-port 52415 > output1.log 2>&1 & PID1=$! From 82f75d0ccf95cbccda89e72c66893dbcd49d407e Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Mon, 27 Jan 2025 05:20:30 +0000 Subject: [PATCH 11/16] increase hf download http timeout 15 mins for large downloads --- exo/download/new_shard_download.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/download/new_shard_download.py b/exo/download/new_shard_download.py index 2fd6629b1..b4dbea416 100644 --- a/exo/download/new_shard_download.py +++ b/exo/download/new_shard_download.py @@ -109,7 +109,7 @@ def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_prog async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]: target_dir = await ensure_exo_tmp()/repo_id.replace("/", "--") - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=900, connect=10, sock_read=900, sock_connect=10)) as session: index_file = await download_file(session, repo_id, revision, "model.safetensors.index.json", target_dir) async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read()) return index_data.get("weight_map") From ae770db4f3184ed551b81138c77d590ac6ebeffe Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Mon, 27 Jan 2025 05:37:50 +0000 Subject: [PATCH 12/16] increase download chunks to 1MB --- exo/download/new_shard_download.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/download/new_shard_download.py b/exo/download/new_shard_download.py index b4dbea416..cfc22d382 100644 --- a/exo/download/new_shard_download.py +++ b/exo/download/new_shard_download.py @@ -94,7 +94,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: length = int(r.headers.get('content-length', 0)) n_read = 0 async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file: - while chunk := await r.content.read(16384): on_progress(n_read := n_read + await temp_file.write(chunk), length) + while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await temp_file.write(chunk), length) await aios.rename(temp_file.name, target_dir/path) return target_dir/path From 4748bb7dc74657a4a71d6c5f4edd51f3c538e63a Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Mon, 27 Jan 2025 05:49:17 +0000 Subject: [PATCH 13/16] increase file download timeout to 30min --- exo/download/new_shard_download.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/download/new_shard_download.py b/exo/download/new_shard_download.py index cfc22d382..2093cbcc9 100644 --- a/exo/download/new_shard_download.py +++ b/exo/download/new_shard_download.py @@ -109,7 +109,7 @@ def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_prog async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]: target_dir = await ensure_exo_tmp()/repo_id.replace("/", "--") - async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=900, connect=10, sock_read=900, sock_connect=10)) as session: + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=10, sock_read=1800, sock_connect=10)) as session: index_file = await download_file(session, repo_id, revision, "model.safetensors.index.json", target_dir) async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read()) return index_data.get("weight_map") From 265586f7b45b46b3d5a45d12401930923ac6f517 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Mon, 27 Jan 2025 06:05:40 +0000 Subject: [PATCH 14/16] set timeout on get too --- exo/download/new_shard_download.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exo/download/new_shard_download.py b/exo/download/new_shard_download.py index 2093cbcc9..114353a16 100644 --- a/exo/download/new_shard_download.py +++ b/exo/download/new_shard_download.py @@ -89,7 +89,7 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/" url = urljoin(base_url, path) headers = await get_auth_headers() - async with session.get(url, headers=headers) as r: + async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r: assert r.status == 200, f"Failed to download {path} from {url}: {r.status}" length = int(r.headers.get('content-length', 0)) n_read = 0 From 90e0e2761f4f202bf63541e9606a4fea3da97fbd Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Mon, 27 Jan 2025 06:05:59 +0000 Subject: [PATCH 15/16] ignore not_started progress updates --- exo/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/exo/main.py b/exo/main.py index 8f6c5e570..2dd75bf2e 100644 --- a/exo/main.py +++ b/exo/main.py @@ -221,6 +221,7 @@ def preemptively_start_download(request_id: str, opaque_status: str): def throttled_broadcast(shard: Shard, event: RepoProgressEvent): global last_broadcast_time current_time = time.time() + if event.status == "not_started": return if event.status == "complete" or current_time - last_broadcast_time >= 0.1: last_broadcast_time = current_time asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()}))) From 7c649085a1e244a1ed94d307e4d4f2e33241c6d3 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Mon, 27 Jan 2025 19:23:18 +0000 Subject: [PATCH 16/16] fix eta/speed for resuming an existing download, using the session downloaded bytes --- exo/download/new_shard_download.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/exo/download/new_shard_download.py b/exo/download/new_shard_download.py index 114353a16..be5a2d22a 100644 --- a/exo/download/new_shard_download.py +++ b/exo/download/new_shard_download.py @@ -101,11 +101,12 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent: all_total_bytes = sum([p.total for p in file_progress.values()]) all_downloaded_bytes = sum([p.downloaded for p in file_progress.values()]) + all_downloaded_bytes_this_session = sum([p.downloaded_this_session for p in file_progress.values()]) elapsed_time = time.time() - all_start_time - all_speed = all_downloaded_bytes / elapsed_time if elapsed_time > 0 else 0 + all_speed = all_downloaded_bytes_this_session / elapsed_time if elapsed_time > 0 else 0 all_eta = timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed) if all_speed > 0 else timedelta(seconds=0) status = "not_started" if all_downloaded_bytes == 0 else "complete" if all_downloaded_bytes == all_total_bytes else "in_progress" - return RepoProgressEvent(shard, repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes, all_total_bytes, all_speed, all_eta, file_progress, status) + return RepoProgressEvent(shard, repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes_this_session, all_total_bytes, all_speed, all_eta, file_progress, status) async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]: target_dir = await ensure_exo_tmp()/repo_id.replace("/", "--") @@ -143,9 +144,10 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr file_progress: Dict[str, RepoFileProgressEvent] = {} def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int): start_time = file_progress[file["path"]].start_time if file["path"] in file_progress else time.time() - speed = curr_bytes / (time.time() - start_time) + downloaded_this_session = file_progress[file["path"]].downloaded_this_session + (curr_bytes - file_progress[file["path"]].downloaded) if file["path"] in file_progress else curr_bytes + speed = downloaded_this_session / (time.time() - start_time) eta = timedelta(seconds=(total_bytes - curr_bytes) / speed) - file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, curr_bytes, total_bytes, speed, eta, "in_progress", start_time) + file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, downloaded_this_session, total_bytes, speed, eta, "in_progress", start_time) on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time)) if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}") for file in filtered_file_list: