diff --git a/.circleci/config.yml b/.circleci/config.yml index ba9c6983b..9f2b46040 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -27,7 +27,7 @@ commands: fi # Start first instance - HF_HOME="$(pwd)/.hf_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <> \ + EXO_HOME="$(pwd)/.exo_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <> \ --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 \ --chatgpt-api-response-timeout 900 --disable-tui > output1.log & PID1=$! @@ -35,7 +35,7 @@ commands: TAIL1=$! # Start second instance - HF_HOME="$(pwd)/.hf_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <> \ + EXO_HOME="$(pwd)/.exo_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <> \ --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 \ --chatgpt-api-response-timeout 900 --disable-tui > output2.log & PID2=$! diff --git a/.github/workflows/bench_job.yml b/.github/workflows/bench_job.yml index 842f2527e..5e089283f 100644 --- a/.github/workflows/bench_job.yml +++ b/.github/workflows/bench_job.yml @@ -128,6 +128,7 @@ jobs: --interface-type-filter="${{ inputs.network_interface }}" \ --disable-tui \ --max-generate-tokens 250 \ + --chatgpt-api-response-timeout 900 \ --chatgpt-api-port 52415 > output1.log 2>&1 & PID1=$! diff --git a/README.md b/README.md index 25224f742..5e645abba 100644 --- a/README.md +++ b/README.md @@ -212,9 +212,9 @@ exo run llama-3.2-3b --prompt "What is the meaning of exo?" ### Model Storage -Models by default are stored in `~/.cache/huggingface/hub`. +Models by default are stored in `~/.cache/exo/downloads`. -You can set a different model storage location by setting the `HF_HOME` env var. +You can set a different model storage location by setting the `EXO_HOME` env var. ## Debugging diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index 30cec8c40..af4459095 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -11,30 +11,27 @@ import traceback import signal from exo import DEBUG, VERSION -from exo.download.download_progress import RepoProgressEvent from exo.helpers import PrefixDict, shutdown, get_exo_images_dir from exo.inference.tokenizers import resolve_tokenizer from exo.orchestration import Node -from exo.models import build_base_shard, model_cards, get_repo, pretty_name +from exo.models import build_base_shard, build_full_shard, model_cards, get_repo, get_supported_models, get_pretty_name from typing import Callable, Optional from PIL import Image import numpy as np import base64 from io import BytesIO import platform +from exo.download.download_progress import RepoProgressEvent +from exo.download.new_shard_download import delete_model +import tempfile +from exo.apputil import create_animation_mp4 +from collections import defaultdict if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64": import mlx.core as mx else: import numpy as mx -import tempfile -from exo.download.hf.hf_shard_download import HFShardDownloader -import shutil -from exo.download.hf.hf_helpers import get_hf_home, get_repo_root -from exo.apputil import create_animation_mp4 -from collections import defaultdict - class Message: def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None): @@ -277,41 +274,12 @@ async def handle_healthcheck(self, request): async def handle_model_support(self, request): try: - response = web.StreamResponse(status=200, reason='OK', headers={ - 'Content-Type': 'text/event-stream', - 'Cache-Control': 'no-cache', - 'Connection': 'keep-alive', - }) + response = web.StreamResponse(status=200, reason='OK', headers={ 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', 'Connection': 'keep-alive' }) await response.prepare(request) - - async def process_model(model_name, pretty): - if model_name in model_cards: - model_info = model_cards[model_name] - - if self.inference_engine_classname in model_info.get("repo", {}): - shard = build_base_shard(model_name, self.inference_engine_classname) - if shard: - downloader = HFShardDownloader(quick_check=True) - downloader.current_shard = shard - downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname) - status = await downloader.get_shard_download_status() - - download_percentage = status.get("overall") if status else None - total_size = status.get("total_size") if status else None - total_downloaded = status.get("total_downloaded") if status else False - - model_data = { - model_name: { - "name": pretty, "downloaded": download_percentage == 100 if download_percentage is not None else False, "download_percentage": download_percentage, "total_size": total_size, - "total_downloaded": total_downloaded - } - } - - await response.write(f"data: {json.dumps(model_data)}\n\n".encode()) - - # Process all models in parallel - await asyncio.gather(*[process_model(model_name, pretty) for model_name, pretty in pretty_name.items()]) - + downloads = await self.node.shard_downloader.get_shard_download_status(self.inference_engine_classname) + for (path, d) in downloads: + model_data = { d.shard.model_id: { "downloaded": d.downloaded_bytes == d.total_bytes, "download_percentage": 100 if d.downloaded_bytes == d.total_bytes else 100 * float(d.downloaded_bytes) / float(d.total_bytes), "total_size": d.total_bytes, "total_downloaded": d.downloaded_bytes } } + await response.write(f"data: {json.dumps(model_data)}\n\n".encode()) await response.write(b"data: [DONE]\n\n") return response @@ -348,6 +316,7 @@ async def handle_get_download_progress(self, request): progress_data = {} for node_id, progress_event in self.node.node_download_progress.items(): if isinstance(progress_event, RepoProgressEvent): + if progress_event.status != "in_progress": continue progress_data[node_id] = progress_event.to_dict() else: print(f"Unknown progress event type: {type(progress_event)}. {progress_event}") @@ -574,46 +543,19 @@ def on_result(_request_id: str, result, is_finished: bool): return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500) async def handle_delete_model(self, request): + model_id = request.match_info.get('model_name') try: - model_name = request.match_info.get('model_name') - if DEBUG >= 2: print(f"Attempting to delete model: {model_name}") - - if not model_name or model_name not in model_cards: - return web.json_response({"detail": f"Invalid model name: {model_name}"}, status=400) - - shard = build_base_shard(model_name, self.inference_engine_classname) - if not shard: - return web.json_response({"detail": "Could not build shard for model"}, status=400) - - repo_id = get_repo(shard.model_id, self.inference_engine_classname) - if DEBUG >= 2: print(f"Repo ID for model: {repo_id}") - - # Get the HF cache directory using the helper function - hf_home = get_hf_home() - cache_dir = get_repo_root(repo_id) - - if DEBUG >= 2: print(f"Looking for model files in: {cache_dir}") - - if os.path.exists(cache_dir): - if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...") - try: - shutil.rmtree(cache_dir) - return web.json_response({"status": "success", "message": f"Model {model_name} deleted successfully", "path": str(cache_dir)}) - except Exception as e: - return web.json_response({"detail": f"Failed to delete model files: {str(e)}"}, status=500) - else: - return web.json_response({"detail": f"Model files not found at {cache_dir}"}, status=404) - + if await delete_model(model_id, self.inference_engine_classname): return web.json_response({"status": "success", "message": f"Model {model_id} deleted successfully"}) + else: return web.json_response({"detail": f"Model {model_id} files not found"}, status=404) except Exception as e: - print(f"Error in handle_delete_model: {str(e)}") - traceback.print_exc() - return web.json_response({"detail": f"Server error: {str(e)}"}, status=500) + if DEBUG >= 2: traceback.print_exc() + return web.json_response({"detail": f"Error deleting model: {str(e)}"}, status=500) async def handle_get_initial_models(self, request): model_data = {} - for model_name, pretty in pretty_name.items(): - model_data[model_name] = { - "name": pretty, + for model_id in get_supported_models([[self.inference_engine_classname]]): + model_data[model_id] = { + "name": get_pretty_name(model_id), "downloaded": None, # Initially unknown "download_percentage": None, # Change from 0 to null "total_size": None, @@ -659,7 +601,7 @@ async def handle_post_download(self, request): model_name = data.get("model") if not model_name: return web.json_response({"error": "model parameter is required"}, status=400) if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400) - shard = build_base_shard(model_name, self.inference_engine_classname) + shard = build_full_shard(model_name, self.inference_engine_classname) if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400) asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname)) diff --git a/exo/download/download_progress.py b/exo/download/download_progress.py index 779e53287..c02733a3d 100644 --- a/exo/download/download_progress.py +++ b/exo/download/download_progress.py @@ -1,4 +1,5 @@ from typing import Dict, Callable, Coroutine, Any, Literal +from exo.inference.shard import Shard from dataclasses import dataclass from datetime import timedelta @@ -14,11 +15,12 @@ class RepoFileProgressEvent: speed: int eta: timedelta status: Literal["not_started", "in_progress", "complete"] + start_time: float def to_dict(self): return { "repo_id": self.repo_id, "repo_revision": self.repo_revision, "file_path": self.file_path, "downloaded": self.downloaded, "downloaded_this_session": self.downloaded_this_session, - "total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status + "total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status, "start_time": self.start_time } @classmethod @@ -29,6 +31,7 @@ def from_dict(cls, data): @dataclass class RepoProgressEvent: + shard: Shard repo_id: str repo_revision: str completed_files: int @@ -43,7 +46,7 @@ class RepoProgressEvent: def to_dict(self): return { - "repo_id": self.repo_id, "repo_revision": self.repo_revision, "completed_files": self.completed_files, "total_files": self.total_files, "downloaded_bytes": self.downloaded_bytes, + "shard": self.shard.to_dict(), "repo_id": self.repo_id, "repo_revision": self.repo_revision, "completed_files": self.completed_files, "total_files": self.total_files, "downloaded_bytes": self.downloaded_bytes, "downloaded_bytes_this_session": self.downloaded_bytes_this_session, "total_bytes": self.total_bytes, "overall_speed": self.overall_speed, "overall_eta": self.overall_eta.total_seconds(), "file_progress": {k: v.to_dict() for k, v in self.file_progress.items()}, "status": self.status @@ -53,6 +56,7 @@ def to_dict(self): def from_dict(cls, data): if 'overall_eta' in data: data['overall_eta'] = timedelta(seconds=data['overall_eta']) if 'file_progress' in data: data['file_progress'] = {k: RepoFileProgressEvent.from_dict(v) for k, v in data['file_progress'].items()} + if 'shard' in data: data['shard'] = Shard.from_dict(data['shard']) return cls(**data) diff --git a/exo/download/hf/hf_helpers.py b/exo/download/hf/hf_helpers.py index 8be12b957..01cd86508 100644 --- a/exo/download/hf/hf_helpers.py +++ b/exo/download/hf/hf_helpers.py @@ -1,36 +1,16 @@ import aiofiles.os as aios from typing import Union -import asyncio -import aiohttp -import json import os -import sys -import shutil -from urllib.parse import urljoin -from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal -from datetime import datetime, timedelta +from typing import Callable, Optional, Dict, List, Union from fnmatch import fnmatch from pathlib import Path -from typing import Generator, Iterable, TypeVar, TypedDict -from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type -from exo.helpers import DEBUG, is_frozen -from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback +from typing import Generator, Iterable, TypeVar +from exo.helpers import DEBUG from exo.inference.shard import Shard import aiofiles T = TypeVar("T") -async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]: - refs_dir = get_repo_root(repo_id)/"refs" - refs_file = refs_dir/revision - if await aios.path.exists(refs_file): - async with aiofiles.open(refs_file, 'r') as f: - commit_hash = (await f.read()).strip() - snapshot_dir = get_repo_root(repo_id)/"snapshots"/commit_hash - return snapshot_dir - return None - - def filter_repo_objects( items: Iterable[T], *, @@ -48,14 +28,12 @@ def filter_repo_objects( ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns] if key is None: - def _identity(item: T) -> str: if isinstance(item, str): return item if isinstance(item, Path): return str(item) raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.") - key = _identity for item in items: @@ -66,22 +44,18 @@ def _identity(item: T) -> str: continue yield item - def _add_wildcard_to_directories(pattern: str) -> str: if pattern[-1] == "/": return pattern + "*" return pattern - def get_hf_endpoint() -> str: return os.environ.get('HF_ENDPOINT', "https://huggingface.co") - def get_hf_home() -> Path: """Get the Hugging Face home directory.""" return Path(os.environ.get("HF_HOME", Path.home()/".cache"/"huggingface")) - async def get_hf_token(): """Retrieve the Hugging Face token from the user's HF_HOME directory.""" token_path = get_hf_home()/"token" @@ -90,7 +64,6 @@ async def get_hf_token(): return (await f.read()).strip() return None - async def get_auth_headers(): """Get authentication headers if a token is available.""" token = await get_hf_token() @@ -98,325 +71,6 @@ async def get_auth_headers(): return {"Authorization": f"Bearer {token}"} return {} - -def get_repo_root(repo_id: str) -> Path: - """Get the root directory for a given repo ID in the Hugging Face cache.""" - sanitized_repo_id = str(repo_id).replace("/", "--") - return get_hf_home()/"hub"/f"models--{sanitized_repo_id}" - -async def move_models_to_hf(seed_dir: Union[str, Path]): - """Move model in resources folder of app to .cache/huggingface/hub""" - source_dir = Path(seed_dir) - dest_dir = get_hf_home()/"hub" - await aios.makedirs(dest_dir, exist_ok=True) - for path in source_dir.iterdir(): - if path.is_dir() and path.name.startswith("models--"): - dest_path = dest_dir / path.name - if await aios.path.exists(dest_path): - print('Skipping moving model to .cache directory') - else: - try: - await aios.rename(str(path), str(dest_path)) - except Exception as e: - print(f'Error moving model to .cache: {e}') - - - -async def fetch_file_list(session, repo_id, revision, path=""): - api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}" - url = f"{api_url}/{path}" if path else api_url - - headers = await get_auth_headers() - async with session.get(url, headers=headers) as response: - if response.status == 200: - data = await response.json() - files = [] - for item in data: - if item["type"] == "file": - files.append({"path": item["path"], "size": item["size"]}) - elif item["type"] == "directory": - subfiles = await fetch_file_list(session, repo_id, revision, item["path"]) - files.extend(subfiles) - return files - else: - raise Exception(f"Failed to fetch file list: {response.status}") - - -@retry( - stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60), retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)), reraise=True -) -async def download_file( - session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True -): - base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/" - url = urljoin(base_url, file_path) - local_path = os.path.join(save_directory, file_path) - - await aios.makedirs(os.path.dirname(local_path), exist_ok=True) - - # Check if file already exists and get its size - local_file_size = await aios.path.getsize(local_path) if await aios.path.exists(local_path) else 0 - - headers = await get_auth_headers() - if use_range_request: - headers["Range"] = f"bytes={local_file_size}-" - - async with session.get(url, headers=headers) as response: - total_size = int(response.headers.get('Content-Length', 0)) - downloaded_size = local_file_size - downloaded_this_session = 0 - mode = 'ab' if use_range_request else 'wb' - percentage = await get_file_download_percentage( - session, - repo_id, - revision, - file_path, - Path(save_directory) - ) - - if percentage == 100: - if DEBUG >= 2: print(f"File already downloaded: {file_path}") - if progress_callback: - await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, total_size, 0, total_size, 0, timedelta(0), "complete")) - return - - if response.status == 200: - # File doesn't support range requests or we're not using them, start from beginning - mode = 'wb' - downloaded_size = 0 - elif response.status == 206: - # Partial content, resume download - content_range = response.headers.get('Content-Range', '') - try: - total_size = int(content_range.split('/')[-1]) - except ValueError: - if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...") - return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False) - elif response.status == 416: - # Range not satisfiable, get the actual file size - content_range = response.headers.get('Content-Range', '') - try: - total_size = int(content_range.split('/')[-1]) - if downloaded_size == total_size: - if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}") - if progress_callback: - await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete")) - return - except ValueError: - if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...") - return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False) - else: - raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}") - - if downloaded_size == total_size: - print(f"File already downloaded: {file_path}") - if progress_callback: - await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete")) - return - - DOWNLOAD_CHUNK_SIZE = 32768 - start_time = datetime.now() - async with aiofiles.open(local_path, mode) as f: - async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE): - await f.write(chunk) - downloaded_size += len(chunk) - downloaded_this_session += len(chunk) - if progress_callback and total_size: - elapsed_time = (datetime.now() - start_time).total_seconds() - speed = int(downloaded_this_session/elapsed_time) if elapsed_time > 0 else 0 - remaining_size = total_size - downloaded_size - eta = timedelta(seconds=remaining_size/speed) if speed > 0 else timedelta(0) - status = "in_progress" if downloaded_size < total_size else "complete" - if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}") - await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status)) - if DEBUG >= 2: print(f"Downloaded: {file_path}") - - -async def resolve_revision_to_commit_hash(repo_id: str, revision: str) -> str: - repo_root = get_repo_root(repo_id) - refs_dir = repo_root/"refs" - refs_file = refs_dir/revision - - # Check if we have a cached commit hash - if await aios.path.exists(refs_file): - async with aiofiles.open(refs_file, 'r') as f: - commit_hash = (await f.read()).strip() - if DEBUG >= 2: print(f"Commit hash is already cached at {refs_file}: {commit_hash}") - return commit_hash - - # Fetch the commit hash for the given revision - async with aiohttp.ClientSession() as session: - api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/revision/{revision}" - headers = await get_auth_headers() - async with session.get(api_url, headers=headers) as response: - if response.status != 200: - raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}") - revision_info = await response.json() - commit_hash = revision_info['sha'] - - # Cache the commit hash - await aios.makedirs(refs_dir, exist_ok=True) - async with aiofiles.open(refs_file, 'w') as f: - await f.write(commit_hash) - - return commit_hash - - -async def download_repo_files( - repo_id: str, - revision: str = "main", - progress_callback: Optional[RepoProgressCallback] = None, - allow_patterns: Optional[Union[List[str], str]] = None, - ignore_patterns: Optional[Union[List[str], str]] = None, - max_parallel_downloads: int = 4 -) -> Path: - repo_root = get_repo_root(repo_id) - snapshots_dir = repo_root/"snapshots" - cachedreqs_dir = repo_root/"cachedreqs" - - # Ensure directories exist - await aios.makedirs(snapshots_dir, exist_ok=True) - await aios.makedirs(cachedreqs_dir, exist_ok=True) - - # Resolve revision to commit hash - commit_hash = await resolve_revision_to_commit_hash(repo_id, revision) - - # Set up the snapshot directory - snapshot_dir = snapshots_dir/commit_hash - await aios.makedirs(snapshot_dir, exist_ok=True) - - # Set up the cached file list directory - cached_file_list_dir = cachedreqs_dir/commit_hash - await aios.makedirs(cached_file_list_dir, exist_ok=True) - cached_file_list_path = cached_file_list_dir/"fetch_file_list.json" - - async with aiohttp.ClientSession() as session: - # Check if we have a cached file list - if await aios.path.exists(cached_file_list_path): - async with aiofiles.open(cached_file_list_path, 'r') as f: - file_list = json.loads(await f.read()) - if DEBUG >= 2: print(f"Using cached file list from {cached_file_list_path}") - else: - file_list = await fetch_file_list(session, repo_id, revision) - # Cache the file list - async with aiofiles.open(cached_file_list_path, 'w') as f: - await f.write(json.dumps(file_list)) - if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}") - - model_index_exists = any(file["path"] == "model_index.json" for file in file_list) - if model_index_exists: - allow_patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"] - - filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"])) - total_files = len(filtered_file_list) - total_bytes = sum(file["size"] for file in filtered_file_list) - file_progress: Dict[str, RepoFileProgressEvent] = { - file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started") - for file in filtered_file_list - } - start_time = datetime.now() - - async def download_with_progress(file_info, progress_state): - local_path = snapshot_dir/file_info["path"] - if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]: - if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}") - progress_state['completed_files'] += 1 - progress_state['downloaded_bytes'] += file_info["size"] - file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete") - if progress_callback: - elapsed_time = (datetime.now() - start_time).total_seconds() - overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0 - remaining_bytes = total_bytes - progress_state['downloaded_bytes'] - overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0) - status = "in_progress" if progress_state['completed_files'] < total_files else "complete" - await progress_callback( - RepoProgressEvent( - repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, - overall_eta, file_progress, status - ) - ) - return - - async def file_progress_callback(event: RepoFileProgressEvent): - progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded - progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session - file_progress[event.file_path] = event - if progress_callback: - elapsed_time = (datetime.now() - start_time).total_seconds() - overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0 - remaining_bytes = total_bytes - progress_state['downloaded_bytes'] - overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0) - status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete" - await progress_callback( - RepoProgressEvent( - repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, - overall_eta, file_progress, status - ) - ) - - await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback) - progress_state['completed_files'] += 1 - file_progress[ - file_info["path"] - ] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete") - if progress_callback: - elapsed_time = (datetime.now() - start_time).total_seconds() - overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0 - remaining_bytes = total_bytes - progress_state['downloaded_bytes'] - overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0) - status = "in_progress" if progress_state['completed_files'] < total_files else "complete" - await progress_callback( - RepoProgressEvent( - repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed, - overall_eta, file_progress, status - ) - ) - - progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0} - - semaphore = asyncio.Semaphore(max_parallel_downloads) - - async def download_with_semaphore(file_info): - async with semaphore: - await download_with_progress(file_info, progress_state) - - tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list] - await asyncio.gather(*tasks) - - return snapshot_dir - - -async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[str, str]]: - """ - Retrieve the weight map from the model.safetensors.index.json file. - - Args: - repo_id (str): The Hugging Face repository ID. - revision (str): The revision of the repository to use. - - Returns: - Optional[Dict[str, str]]: The weight map if it exists, otherwise None. - """ - - # Download the index file - await download_repo_files(repo_id=repo_id, revision=revision, allow_patterns="model.safetensors.index.json") - - # Check if the file exists - repo_root = get_repo_root(repo_id) - commit_hash = await resolve_revision_to_commit_hash(repo_id, revision) - snapshot_dir = repo_root/"snapshots"/commit_hash - index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None) - - if index_file: - index_file_path = snapshot_dir/index_file - if await aios.path.exists(index_file_path): - async with aiofiles.open(index_file_path, 'r') as f: - index_data = json.loads(await f.read()) - return index_data.get("weight_map") - - return None - - def extract_layer_num(tensor_name: str) -> Optional[int]: # This is a simple example and might need to be adjusted based on the actual naming convention parts = tensor_name.split('.') @@ -425,7 +79,6 @@ def extract_layer_num(tensor_name: str) -> Optional[int]: return int(part) return None - def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]: default_patterns = set(["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"]) shard_specific_patterns = set() @@ -443,65 +96,3 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]: shard_specific_patterns = set(["*.safetensors"]) if DEBUG >= 3: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}") return list(default_patterns | shard_specific_patterns) - -async def get_file_download_percentage( - session: aiohttp.ClientSession, - repo_id: str, - revision: str, - file_path: str, - snapshot_dir: Path, -) -> float: - """ - Calculate the download percentage for a file by comparing local and remote sizes. - """ - try: - local_path = snapshot_dir / file_path - if not await aios.path.exists(local_path): - return 0 - - # Get local file size first - local_size = await aios.path.getsize(local_path) - if local_size == 0: - return 0 - - # Check remote size - base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/" - url = urljoin(base_url, file_path) - headers = await get_auth_headers() - - # Use HEAD request with redirect following for all files - async with session.head(url, headers=headers, allow_redirects=True) as response: - if response.status != 200: - if DEBUG >= 2: - print(f"Failed to get remote file info for {file_path}: {response.status}") - return 0 - - remote_size = int(response.headers.get('Content-Length', 0)) - - if remote_size == 0: - if DEBUG >= 2: - print(f"Remote size is 0 for {file_path}") - return 0 - - # Only return 100% if sizes match exactly - if local_size == remote_size: - return 100.0 - - # Calculate percentage based on sizes - return (local_size / remote_size) * 100 if remote_size > 0 else 0 - - except Exception as e: - if DEBUG >= 2: - print(f"Error checking file download status for {file_path}: {e}") - return 0 - -async def has_hf_home_read_access() -> bool: - hf_home = get_hf_home() - try: return await aios.access(hf_home, os.R_OK) - except OSError: return False - -async def has_hf_home_write_access() -> bool: - hf_home = get_hf_home() - try: return await aios.access(hf_home, os.W_OK) - except OSError: return False - diff --git a/exo/download/hf/hf_shard_download.py b/exo/download/hf/hf_shard_download.py deleted file mode 100644 index 705bab843..000000000 --- a/exo/download/hf/hf_shard_download.py +++ /dev/null @@ -1,172 +0,0 @@ -import asyncio -import traceback -from pathlib import Path -from typing import Dict, List, Tuple, Optional, Union -from exo.inference.shard import Shard -from exo.download.shard_download import ShardDownloader -from exo.download.download_progress import RepoProgressEvent -from exo.download.hf.hf_helpers import ( - download_repo_files, RepoProgressEvent, get_weight_map, - get_allow_patterns, get_repo_root, fetch_file_list, - get_local_snapshot_dir, get_file_download_percentage, - filter_repo_objects -) -from exo.helpers import AsyncCallbackSystem, DEBUG -from exo.models import model_cards, get_repo -import aiohttp -from aiofiles import os as aios - - -class HFShardDownloader(ShardDownloader): - def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4): - self.quick_check = quick_check - self.max_parallel_downloads = max_parallel_downloads - self.active_downloads: Dict[Shard, asyncio.Task] = {} - self.completed_downloads: Dict[Shard, Path] = {} - self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]() - self.current_shard: Optional[Shard] = None - self.current_repo_id: Optional[str] = None - self.revision: str = "main" - - async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path: - self.current_shard = shard - self.current_repo_id = get_repo(shard.model_id, inference_engine_name) - repo_name = get_repo(shard.model_id, inference_engine_name) - if shard in self.completed_downloads: - return self.completed_downloads[shard] - if self.quick_check: - repo_root = get_repo_root(repo_name) - snapshots_dir = repo_root/"snapshots" - if snapshots_dir.exists(): - visible_dirs = [d for d in snapshots_dir.iterdir() if not d.name.startswith('.')] - if visible_dirs: - most_recent_dir = max(visible_dirs, key=lambda x: x.stat().st_mtime) - return most_recent_dir - - # If a download on this shard is already in progress, keep that one - for active_shard in self.active_downloads: - if active_shard == shard: - if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.") - return await self.active_downloads[shard] - - # Cancel any downloads for this model_id on a different shard - existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id] - for active_shard in existing_active_shards: - if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})") - task = self.active_downloads[active_shard] - task.cancel() - try: - await task - except asyncio.CancelledError: - pass # This is expected when cancelling a task - except Exception as e: - if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}") - traceback.print_exc() - self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id} - - # Start new download - download_task = asyncio.create_task(self._download_shard(shard, repo_name)) - self.active_downloads[shard] = download_task - try: - path = await download_task - self.completed_downloads[shard] = path - return path - finally: - # Ensure the task is removed even if an exception occurs - print(f"Removing download task for {shard}: {shard in self.active_downloads}") - if shard in self.active_downloads: - self.active_downloads.pop(shard) - - async def _download_shard(self, shard: Shard, repo_name: str) -> Path: - async def wrapped_progress_callback(event: RepoProgressEvent): - self._on_progress.trigger_all(shard, event) - - weight_map = await get_weight_map(repo_name) - allow_patterns = get_allow_patterns(weight_map, shard) - - return await download_repo_files(repo_name, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads) - - @property - def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]: - return self._on_progress - - async def get_shard_download_status(self) -> Optional[Dict[str, Union[float, int]]]: - if not self.current_shard or not self.current_repo_id: - if DEBUG >= 2: - print(f"No current shard or repo_id set: {self.current_shard=} {self.current_repo_id=}") - return None - - try: - # If no snapshot directory exists, return None - no need to check remote files - snapshot_dir = await get_local_snapshot_dir(self.current_repo_id, self.revision) - if not snapshot_dir: - if DEBUG >= 2: - print(f"No snapshot directory found for {self.current_repo_id}") - return None - - if not await aios.path.exists(snapshot_dir/"model_index.json"): - # Get the weight map to know what files we need - weight_map = await get_weight_map(self.current_repo_id, self.revision) - if not weight_map: - if DEBUG >= 2: - print(f"No weight map found for {self.current_repo_id}") - return None - - # Get all files needed for this shard - patterns = get_allow_patterns(weight_map, self.current_shard) - else: - patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"] - - - # Check download status for all relevant files - status = {} - total_bytes = 0 - downloaded_bytes = 0 - - async with aiohttp.ClientSession() as session: - file_list = await fetch_file_list(session, self.current_repo_id, self.revision) - relevant_files = list( - filter_repo_objects( - file_list, allow_patterns=patterns, key=lambda x: x["path"])) - - for file in relevant_files: - file_size = file["size"] - total_bytes += file_size - - percentage = await get_file_download_percentage( - session, - self.current_repo_id, - self.revision, - file["path"], - snapshot_dir, - ) - status[file["path"]] = percentage - downloaded_bytes += (file_size * (percentage / 100)) - - # Add overall progress weighted by file size - if total_bytes > 0: - status["overall"] = (downloaded_bytes / total_bytes) * 100 - else: - status["overall"] = 0 - - # Add total size in bytes - status["total_size"] = total_bytes - if status["overall"] != 100: - status["total_downloaded"] = downloaded_bytes - - - if DEBUG >= 2: - print(f"Download calculation for {self.current_repo_id}:") - print(f"Total bytes: {total_bytes}") - print(f"Downloaded bytes: {downloaded_bytes}") - if DEBUG >= 3: - for file in relevant_files: - print(f"File {file['path']}: size={file['size']}, percentage={status[file['path']]}") - - return status - - except Exception as e: - if DEBUG >= 3: - print(f"Error getting shard download status: {e}") - traceback.print_exc() - return None diff --git a/exo/download/new_shard_download.py b/exo/download/new_shard_download.py new file mode 100644 index 000000000..be5a2d22a --- /dev/null +++ b/exo/download/new_shard_download.py @@ -0,0 +1,226 @@ +from exo.inference.shard import Shard +from exo.models import get_repo +from pathlib import Path +from exo.download.hf.hf_helpers import get_hf_endpoint, get_auth_headers, filter_repo_objects, get_allow_patterns +from exo.download.shard_download import ShardDownloader +from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent +from exo.helpers import AsyncCallbackSystem, DEBUG +from exo.models import get_supported_models, build_full_shard +import os +import aiofiles.os as aios +import aiohttp +import aiofiles +from urllib.parse import urljoin +from typing import Callable, Union, Tuple, Dict, List +import time +from datetime import timedelta +import asyncio +import json +import traceback +import shutil +import tempfile + +def exo_home() -> Path: + return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo")) + +def exo_tmp() -> Path: + return Path(tempfile.gettempdir())/"exo" + +async def ensure_exo_tmp() -> Path: + await aios.makedirs(exo_tmp(), exist_ok=True) + return exo_tmp() + +async def has_exo_home_read_access() -> bool: + try: return await aios.access(exo_home(), os.R_OK) + except OSError: return False + +async def has_exo_home_write_access() -> bool: + try: return await aios.access(exo_home(), os.W_OK) + except OSError: return False + +async def ensure_downloads_dir() -> Path: + downloads_dir = exo_home()/"downloads" + await aios.makedirs(downloads_dir, exist_ok=True) + return downloads_dir + +async def delete_model(model_id: str, inference_engine_name: str) -> bool: + repo_id = get_repo(model_id, inference_engine_name) + model_dir = await ensure_downloads_dir()/repo_id.replace("/", "--") + if not await aios.path.exists(model_dir): return False + await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False) + return True + +async def seed_models(seed_dir: Union[str, Path]): + """Move model in resources folder of app to .cache/huggingface/hub""" + source_dir = Path(seed_dir) + dest_dir = await ensure_downloads_dir() + for path in source_dir.iterdir(): + if path.is_dir() and path.name.startswith("models--"): + dest_path = dest_dir/path.name + if await aios.path.exists(dest_path): print('Skipping moving model to .cache directory') + else: + try: await aios.rename(str(path), str(dest_path)) + except: + print(f"Error seeding model {path} to {dest_path}") + traceback.print_exc() + +async def fetch_file_list(session, repo_id, revision, path=""): + api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}" + url = f"{api_url}/{path}" if path else api_url + + headers = await get_auth_headers() + async with session.get(url, headers=headers) as response: + if response.status == 200: + data = await response.json() + files = [] + for item in data: + if item["type"] == "file": + files.append({"path": item["path"], "size": item["size"]}) + elif item["type"] == "directory": + subfiles = await fetch_file_list(session, repo_id, revision, item["path"]) + files.extend(subfiles) + return files + else: + raise Exception(f"Failed to fetch file list: {response.status}") + +async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path: + if (target_dir/path).exists(): return target_dir/path + await aios.makedirs((target_dir/path).parent, exist_ok=True) + base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/" + url = urljoin(base_url, path) + headers = await get_auth_headers() + async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r: + assert r.status == 200, f"Failed to download {path} from {url}: {r.status}" + length = int(r.headers.get('content-length', 0)) + n_read = 0 + async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file: + while chunk := await r.content.read(1024 * 1024): on_progress(n_read := n_read + await temp_file.write(chunk), length) + await aios.rename(temp_file.name, target_dir/path) + return target_dir/path + +def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent: + all_total_bytes = sum([p.total for p in file_progress.values()]) + all_downloaded_bytes = sum([p.downloaded for p in file_progress.values()]) + all_downloaded_bytes_this_session = sum([p.downloaded_this_session for p in file_progress.values()]) + elapsed_time = time.time() - all_start_time + all_speed = all_downloaded_bytes_this_session / elapsed_time if elapsed_time > 0 else 0 + all_eta = timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed) if all_speed > 0 else timedelta(seconds=0) + status = "not_started" if all_downloaded_bytes == 0 else "complete" if all_downloaded_bytes == all_total_bytes else "in_progress" + return RepoProgressEvent(shard, repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes_this_session, all_total_bytes, all_speed, all_eta, file_progress, status) + +async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]: + target_dir = await ensure_exo_tmp()/repo_id.replace("/", "--") + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=10, sock_read=1800, sock_connect=10)) as session: + index_file = await download_file(session, repo_id, revision, "model.safetensors.index.json", target_dir) + async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read()) + return index_data.get("weight_map") + +async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str) -> List[str]: + try: + weight_map = await get_weight_map(get_repo(shard.model_id, inference_engine_classname)) + return get_allow_patterns(weight_map, shard) + except: + if DEBUG >= 1: print(f"Error getting weight map for {shard.model_id=} and inference engine {inference_engine_classname}") + if DEBUG >= 1: traceback.print_exc() + return ["*"] + +async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 6, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]: + if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}") + repo_id = get_repo(shard.model_id, inference_engine_classname) + revision = "main" + target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--") + if not skip_download: await aios.makedirs(target_dir, exist_ok=True) + + if repo_id is None: + raise ValueError(f"No repo found for {shard.model_id=} and inference engine {inference_engine_classname}") + + allow_patterns = await resolve_allow_patterns(shard, inference_engine_classname) + if DEBUG >= 2: print(f"Downloading {shard.model_id=} with {allow_patterns=}") + + all_start_time = time.time() + async with aiohttp.ClientSession() as session: + file_list = await fetch_file_list(session, repo_id, revision) + filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, key=lambda x: x["path"])) + file_progress: Dict[str, RepoFileProgressEvent] = {} + def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int): + start_time = file_progress[file["path"]].start_time if file["path"] in file_progress else time.time() + downloaded_this_session = file_progress[file["path"]].downloaded_this_session + (curr_bytes - file_progress[file["path"]].downloaded) if file["path"] in file_progress else curr_bytes + speed = downloaded_this_session / (time.time() - start_time) + eta = timedelta(seconds=(total_bytes - curr_bytes) / speed) + file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, downloaded_this_session, total_bytes, speed, eta, "in_progress", start_time) + on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time)) + if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}") + for file in filtered_file_list: + downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0 + file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "not_started" if downloaded_bytes == 0 else "complete" if downloaded_bytes == file["size"] else "in_progress", time.time()) + + semaphore = asyncio.Semaphore(max_parallel_downloads) + async def download_with_semaphore(file): + async with semaphore: + await download_file(session, repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes)) + if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list]) + final_repo_progress = calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time) + on_progress.trigger_all(shard, final_repo_progress) + return target_dir, final_repo_progress + +def new_shard_downloader() -> ShardDownloader: + return SingletonShardDownloader(CachedShardDownloader(NewShardDownloader())) + +class SingletonShardDownloader(ShardDownloader): + def __init__(self, shard_downloader: ShardDownloader): + self.shard_downloader = shard_downloader + self.active_downloads: Dict[Shard, asyncio.Task] = {} + + @property + def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]: + return self.shard_downloader.on_progress + + async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path: + if shard not in self.active_downloads: self.active_downloads[shard] = asyncio.create_task(self.shard_downloader.ensure_shard(shard, inference_engine_name)) + try: return await self.active_downloads[shard] + finally: + if shard in self.active_downloads and self.active_downloads[shard].done(): del self.active_downloads[shard] + + async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]: + return await self.shard_downloader.get_shard_download_status(inference_engine_name) + +class CachedShardDownloader(ShardDownloader): + def __init__(self, shard_downloader: ShardDownloader): + self.shard_downloader = shard_downloader + self.cache: Dict[tuple[str, Shard], Path] = {} + + @property + def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]: + return self.shard_downloader.on_progress + + async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path: + if (inference_engine_name, shard) in self.cache: + if DEBUG >= 2: print(f"ensure_shard cache hit {shard=} for {inference_engine_name}") + return self.cache[(inference_engine_name, shard)] + if DEBUG >= 2: print(f"ensure_shard cache miss {shard=} for {inference_engine_name}") + target_dir = await self.shard_downloader.ensure_shard(shard, inference_engine_name) + self.cache[(inference_engine_name, shard)] = target_dir + return target_dir + + async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]: + return await self.shard_downloader.get_shard_download_status(inference_engine_name) + +class NewShardDownloader(ShardDownloader): + def __init__(self): + self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]() + + @property + def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]: + return self._on_progress + + async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path: + target_dir, _ = await download_shard(shard, inference_engine_name, self.on_progress) + return target_dir + + async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]: + if DEBUG >= 2: print("Getting shard download status for", inference_engine_name) + downloads = await asyncio.gather(*[download_shard(build_full_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])], return_exceptions=True) + if DEBUG >= 6: print("Downloaded shards:", downloads) + if any(isinstance(d, Exception) for d in downloads) and DEBUG >= 1: print("Error downloading shards:", [d for d in downloads if isinstance(d, Exception)]) + return [d for d in downloads if not isinstance(d, Exception)] + diff --git a/exo/download/shard_download.py b/exo/download/shard_download.py index 89d3582e4..21ea412c2 100644 --- a/exo/download/shard_download.py +++ b/exo/download/shard_download.py @@ -27,7 +27,7 @@ def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent pass @abstractmethod - async def get_shard_download_status(self) -> Optional[Dict[str, float]]: + async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]: """Get the download status of shards. Returns: @@ -45,5 +45,5 @@ async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path: def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]: return AsyncCallbackSystem() - async def get_shard_download_status(self) -> Optional[Dict[str, float]]: + async def get_shard_download_status(self, inference_engine_name: str) -> Optional[Dict[str, float]]: return None diff --git a/exo/download/test_new_shard_download.py b/exo/download/test_new_shard_download.py new file mode 100644 index 000000000..15357a0a2 --- /dev/null +++ b/exo/download/test_new_shard_download.py @@ -0,0 +1,15 @@ +from exo.download.new_shard_download import download_shard, NewShardDownloader +from exo.inference.shard import Shard +from pathlib import Path +import asyncio + +async def test_new_shard_download(): + shard_downloader = NewShardDownloader() + shard_downloader.on_progress.register("test").on_next(lambda shard, event: print(shard, event)) + await shard_downloader.ensure_shard(Shard(model_id="llama-3.2-1b", start_layer=0, end_layer=0, n_layers=16), "MLXDynamicShardInferenceEngine") + download_statuses = await shard_downloader.get_shard_download_status("MLXDynamicShardInferenceEngine") + print({k: v for k, v in download_statuses if v.downloaded_bytes > 0}) + +if __name__ == "__main__": + asyncio.run(test_new_shard_download()) + diff --git a/exo/inference/mlx/test_non_blocking.py b/exo/inference/mlx/test_non_blocking.py index 64eedfdde..4eb2e9fb3 100644 --- a/exo/inference/mlx/test_non_blocking.py +++ b/exo/inference/mlx/test_non_blocking.py @@ -2,7 +2,7 @@ import time import numpy as np from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine -from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.download.new_shard_download import NewShardDownloader from exo.inference.shard import Shard from exo.models import build_base_shard from collections import deque @@ -10,7 +10,7 @@ async def test_non_blocking(): # Setup - shard_downloader = HFShardDownloader() + shard_downloader = NewShardDownloader() engine = MLXDynamicShardInferenceEngine(shard_downloader) _shard = build_base_shard("llama-3.1-8b", "MLXDynamicShardInferenceEngine") shard = Shard(_shard.model_id, _shard.start_layer, _shard.n_layers - 1, _shard.n_layers) diff --git a/exo/inference/test_inference_engine.py b/exo/inference/test_inference_engine.py index 7a69eafce..956a11629 100644 --- a/exo/inference/test_inference_engine.py +++ b/exo/inference/test_inference_engine.py @@ -1,6 +1,6 @@ from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine -from exo.download.hf.hf_shard_download import HFShardDownloader from exo.inference.inference_engine import InferenceEngine +from exo.download.new_shard_download import NewShardDownloader from exo.inference.shard import Shard from exo.helpers import DEBUG import os @@ -44,13 +44,11 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e assert np.array_equal(next_resp_full, resp4) -asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(HFShardDownloader()), MLXDynamicShardInferenceEngine(HFShardDownloader()), "llama-3.2-1b", 16)) +asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(NewShardDownloader()), MLXDynamicShardInferenceEngine(NewShardDownloader()), "llama-3.2-1b", 16)) if os.getenv("RUN_TINYGRAD", default="0") == "1": import tinygrad import os from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0")) - asyncio.run( - test_inference_engine(TinygradDynamicShardInferenceEngine(HFShardDownloader()), TinygradDynamicShardInferenceEngine(HFShardDownloader()), "llama-3-8b", 32) - ) + asyncio.run(test_inference_engine(TinygradDynamicShardInferenceEngine(NewShardDownloader()), TinygradDynamicShardInferenceEngine(NewShardDownloader()), "llama-3.2-1b", 32)) diff --git a/exo/inference/tokenizers.py b/exo/inference/tokenizers.py index 4dccaf664..865765d5f 100644 --- a/exo/inference/tokenizers.py +++ b/exo/inference/tokenizers.py @@ -1,12 +1,11 @@ import traceback -from aiofiles import os as aios from os import PathLike -from pathlib import Path +from aiofiles import os as aios from typing import Union from transformers import AutoTokenizer, AutoProcessor import numpy as np -from exo.download.hf.hf_helpers import get_local_snapshot_dir from exo.helpers import DEBUG +from exo.download.new_shard_download import ensure_downloads_dir class DummyTokenizer: @@ -24,25 +23,25 @@ def decode(self, tokens): return "dummy" * len(tokens) -async def resolve_tokenizer(model_id: str): - if model_id == "dummy": +async def resolve_tokenizer(repo_id: Union[str, PathLike]): + if repo_id == "dummy": return DummyTokenizer() - local_path = await get_local_snapshot_dir(model_id) + local_path = await ensure_downloads_dir()/str(repo_id).replace("/", "--") if DEBUG >= 2: print(f"Checking if local path exists to load tokenizer from local {local_path=}") try: if local_path and await aios.path.exists(local_path): - if DEBUG >= 2: print(f"Resolving tokenizer for {model_id=} from {local_path=}") + if DEBUG >= 2: print(f"Resolving tokenizer for {repo_id=} from {local_path=}") return await _resolve_tokenizer(local_path) except: - if DEBUG >= 5: print(f"Local check for {local_path=} failed. Resolving tokenizer for {model_id=} normally...") + if DEBUG >= 5: print(f"Local check for {local_path=} failed. Resolving tokenizer for {repo_id=} normally...") if DEBUG >= 5: traceback.print_exc() - return await _resolve_tokenizer(model_id) + return await _resolve_tokenizer(repo_id) -async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]): +async def _resolve_tokenizer(repo_id_or_local_path: Union[str, PathLike]): try: - if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}") - processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=True if "Mistral-Large" in f"{model_id_or_local_path}" else False, trust_remote_code=True) + if DEBUG >= 4: print(f"Trying AutoProcessor for {repo_id_or_local_path}") + processor = AutoProcessor.from_pretrained(repo_id_or_local_path, use_fast=True if "Mistral-Large" in f"{repo_id_or_local_path}" else False, trust_remote_code=True) if not hasattr(processor, 'eos_token_id'): processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id if not hasattr(processor, 'encode'): @@ -51,14 +50,14 @@ async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]): processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode return processor except Exception as e: - if DEBUG >= 4: print(f"Failed to load processor for {model_id_or_local_path}. Error: {e}") + if DEBUG >= 4: print(f"Failed to load processor for {repo_id_or_local_path}. Error: {e}") if DEBUG >= 4: print(traceback.format_exc()) try: - if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id_or_local_path}") - return AutoTokenizer.from_pretrained(model_id_or_local_path, trust_remote_code=True) + if DEBUG >= 4: print(f"Trying AutoTokenizer for {repo_id_or_local_path}") + return AutoTokenizer.from_pretrained(repo_id_or_local_path, trust_remote_code=True) except Exception as e: - if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}") + if DEBUG >= 4: print(f"Failed to load tokenizer for {repo_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}") if DEBUG >= 4: print(traceback.format_exc()) - raise ValueError(f"[TODO] Unsupported model: {model_id_or_local_path}") + raise ValueError(f"[TODO] Unsupported model: {repo_id_or_local_path}") diff --git a/exo/main.py b/exo/main.py index d58224ca7..2dd75bf2e 100644 --- a/exo/main.py +++ b/exo/main.py @@ -23,19 +23,18 @@ from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy from exo.api import ChatGPTAPI -from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader -from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.download.shard_download import ShardDownloader, NoopShardDownloader +from exo.download.download_progress import RepoProgressEvent +from exo.download.new_shard_download import new_shard_downloader, has_exo_home_read_access, has_exo_home_write_access, exo_home, seed_models from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses_and_interfaces, terminal_link, shutdown from exo.inference.shard import Shard from exo.inference.inference_engine import get_inference_engine, InferenceEngine from exo.inference.tokenizers import resolve_tokenizer from exo.models import build_base_shard, get_repo from exo.viz.topology_viz import TopologyViz -from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf import uvloop from contextlib import asynccontextmanager import concurrent.futures -import socket import resource import psutil @@ -117,8 +116,7 @@ def configure_uvloop(): system_info = get_system_info() print(f"Detected system: {system_info}") -shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check, - max_parallel_downloads=args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader() +shard_downloader: ShardDownloader = new_shard_downloader() if args.inference_engine != "dummy" else NoopShardDownloader() inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad") print(f"Inference engine name after selection: {inference_engine_name}") @@ -175,10 +173,10 @@ def configure_uvloop(): None, inference_engine, discovery, + shard_downloader, partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), max_generate_tokens=args.max_generate_tokens, topology_viz=topology_viz, - shard_downloader=shard_downloader, default_sample_temperature=args.default_temp ) server = GRPCServer(node, args.node_host, args.node_port) @@ -223,6 +221,7 @@ def preemptively_start_download(request_id: str, opaque_status: str): def throttled_broadcast(shard: Shard, event: RepoProgressEvent): global last_broadcast_time current_time = time.time() + if event.status == "not_started": return if event.status == "complete" or current_time - last_broadcast_time >= 0.1: last_broadcast_time = current_time asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()}))) @@ -322,13 +321,13 @@ async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_n async def main(): loop = asyncio.get_running_loop() - # Check HuggingFace directory permissions - hf_home, has_read, has_write = get_hf_home(), await has_hf_home_read_access(), await has_hf_home_write_access() - if DEBUG >= 1: print(f"Model storage directory: {hf_home}") + # Check exo directory permissions + home, has_read, has_write = exo_home(), await has_exo_home_read_access(), await has_exo_home_write_access() + if DEBUG >= 1: print(f"exo home directory: {home}") print(f"{has_read=}, {has_write=}") if not has_read or not has_write: print(f""" - WARNING: Limited permissions for model storage directory: {hf_home}. + WARNING: Limited permissions for exo home directory: {home}. This may prevent model downloads from working correctly. {"❌ No read access" if not has_read else ""} {"❌ No write access" if not has_write else ""} @@ -337,9 +336,9 @@ async def main(): if not args.models_seed_dir is None: try: models_seed_dir = clean_path(args.models_seed_dir) - await move_models_to_hf(models_seed_dir) + await seed_models(models_seed_dir) except Exception as e: - print(f"Error moving models to .cache/huggingface: {e}") + print(f"Error seeding models: {e}") def restore_cursor(): if platform.system() != "Windows": diff --git a/exo/models.py b/exo/models.py index 29019e91b..74865aca3 100644 --- a/exo/models.py +++ b/exo/models.py @@ -175,8 +175,11 @@ "deepseek-coder-v2-lite": "Deepseek Coder V2 Lite", "deepseek-coder-v2.5": "Deepseek Coder V2.5", "deepseek-v3": "Deepseek V3", + "deepseek-v3-3bit": "Deepseek V3 (3-bit)", "deepseek-r1": "Deepseek R1", + "deepseek-r1-3bit": "Deepseek R1 (3-bit)", "llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)", + "qwen-2.5-0.5b": "Qwen 2.5 0.5B", "qwen-2.5-1.5b": "Qwen 2.5 1.5B", "qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B", "qwen-2.5-3b": "Qwen 2.5 3B", @@ -232,6 +235,9 @@ def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]: return model_cards.get(model_id, {}).get("repo", {}).get(inference_engine_classname, None) +def get_pretty_name(model_id: str) -> Optional[str]: + return pretty_name.get(model_id, None) + def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]: repo = get_repo(model_id, inference_engine_classname) n_layers = model_cards.get(model_id, {}).get("layers", 0) @@ -239,7 +245,12 @@ def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional return None return Shard(model_id, 0, 0, n_layers) -def get_supported_models(supported_inference_engine_lists: List[List[str]]) -> List[str]: +def build_full_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]: + base_shard = build_base_shard(model_id, inference_engine_classname) + if base_shard is None: return None + return Shard(base_shard.model_id, 0, base_shard.n_layers - 1, base_shard.n_layers) + +def get_supported_models(supported_inference_engine_lists: Optional[List[List[str]]] = None) -> List[str]: if not supported_inference_engine_lists: return list(model_cards.keys()) diff --git a/exo/orchestration/node.py b/exo/orchestration/node.py index 36c771b0b..cd60dd0fa 100644 --- a/exo/orchestration/node.py +++ b/exo/orchestration/node.py @@ -13,9 +13,9 @@ from exo import DEBUG from exo.helpers import AsyncCallbackSystem from exo.viz.topology_viz import TopologyViz -from exo.download.hf.hf_helpers import RepoProgressEvent +from exo.download.download_progress import RepoProgressEvent from exo.inference.inference_engine import get_inference_engine, InferenceEngine -from exo.download.hf.hf_shard_download import HFShardDownloader +from exo.download.shard_download import ShardDownloader class Node: def __init__( @@ -24,16 +24,17 @@ def __init__( server: Server, inference_engine: InferenceEngine, discovery: Discovery, + shard_downloader: ShardDownloader, partitioning_strategy: PartitioningStrategy = None, max_generate_tokens: int = 1024, default_sample_temperature: float = 0.0, topology_viz: Optional[TopologyViz] = None, - shard_downloader: Optional[HFShardDownloader] = None, ): self.id = _id self.inference_engine = inference_engine self.server = server self.discovery = discovery + self.shard_downloader = shard_downloader self.partitioning_strategy = partitioning_strategy self.peers: List[PeerHandle] = {} self.topology: Topology = Topology() @@ -52,7 +53,6 @@ def __init__( self._on_opaque_status.register("node_status").on_next(self.on_node_status) self.node_download_progress: Dict[str, RepoProgressEvent] = {} self.topology_inference_engines_pool: List[List[str]] = [] - self.shard_downloader = shard_downloader self.outstanding_requests = {} async def start(self, wait_for_peers: int = 0) -> None: diff --git a/exo/orchestration/test_node.py b/exo/orchestration/test_node.py index 8bee4b2ba..6f992bbeb 100644 --- a/exo/orchestration/test_node.py +++ b/exo/orchestration/test_node.py @@ -5,7 +5,7 @@ from .node import Node from exo.networking.peer_handle import PeerHandle - +from exo.download.shard_download import NoopShardDownloader class TestNode(unittest.IsolatedAsyncioTestCase): def setUp(self): @@ -22,7 +22,7 @@ def setUp(self): mock_peer2.id.return_value = "peer2" self.mock_discovery.discover_peers = AsyncMock(return_value=[mock_peer1, mock_peer2]) - self.node = Node("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery) + self.node = Node("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery, NoopShardDownloader()) async def asyncSetUp(self): await self.node.start() diff --git a/exo/tinychat/index.css b/exo/tinychat/index.css index d94a5a236..1971795af 100644 --- a/exo/tinychat/index.css +++ b/exo/tinychat/index.css @@ -843,4 +843,4 @@ main { font-size: 0.8em; color: var(--secondary-color-transparent); font-family: monospace; -} \ No newline at end of file +} diff --git a/exo/tinychat/index.js b/exo/tinychat/index.js index dff5fc97a..c38876673 100644 --- a/exo/tinychat/index.js +++ b/exo/tinychat/index.js @@ -75,12 +75,12 @@ document.addEventListener("alpine:init", () => { while (true) { try { await this.populateSelector(); - // Wait 5 seconds before next poll - await new Promise(resolve => setTimeout(resolve, 5000)); + // Wait 15 seconds before next poll + await new Promise(resolve => setTimeout(resolve, 15000)); } catch (error) { console.error('Model polling error:', error); // If there's an error, wait before retrying - await new Promise(resolve => setTimeout(resolve, 5000)); + await new Promise(resolve => setTimeout(resolve, 15000)); } } }, @@ -637,6 +637,9 @@ document.addEventListener("alpine:init", () => { const vizElement = this.$refs.topologyViz; vizElement.innerHTML = ''; // Clear existing visualization + // Helper function to truncate node ID + const truncateNodeId = (id) => id.substring(0, 8); + // Create nodes from object Object.entries(topologyData.nodes).forEach(([nodeId, node]) => { const nodeElement = document.createElement('div'); @@ -647,14 +650,14 @@ document.addEventListener("alpine:init", () => { const peerConnectionsHtml = peerConnections.map(peer => `
- To ${peer.to_id}: ${peer.description} + To ${truncateNodeId(peer.to_id)}: ${peer.description}
`).join(''); nodeElement.innerHTML = `
- ${node.model} + ${node.model} [${truncateNodeId(nodeId)}]
${node.chip} diff --git a/exo/viz/test_topology_viz.py b/exo/viz/test_topology_viz.py index e57de1ae3..a10ac43b7 100644 --- a/exo/viz/test_topology_viz.py +++ b/exo/viz/test_topology_viz.py @@ -5,7 +5,7 @@ from exo.topology.topology import Topology from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops from exo.topology.partitioning_strategy import Partition -from exo.download.hf.hf_helpers import RepoProgressEvent, RepoFileProgressEvent +from exo.download.download_progress import RepoProgressEvent def create_hf_repo_progress_event( diff --git a/exo/viz/topology_viz.py b/exo/viz/topology_viz.py index 734fe69dd..0a9f15971 100644 --- a/exo/viz/topology_viz.py +++ b/exo/viz/topology_viz.py @@ -4,7 +4,7 @@ from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second from exo.topology.topology import Topology from exo.topology.partitioning_strategy import Partition -from exo.download.hf.hf_helpers import RepoProgressEvent +from exo.download.download_progress import RepoProgressEvent from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES from rich.console import Console, Group from rich.text import Text diff --git a/extra/download_hf.py b/extra/download_hf.py deleted file mode 100644 index 8cf712e55..000000000 --- a/extra/download_hf.py +++ /dev/null @@ -1,50 +0,0 @@ -import argparse -import asyncio -from exo.download.hf.hf_helpers import download_all_files, RepoProgressEvent - -DEFAULT_ALLOW_PATTERNS = [ - "*.json", - "*.py", - "tokenizer.model", - "*.tiktoken", - "*.txt", - "*.safetensors", -] -# Always ignore `.git` and `.cache/huggingface` folders in commits -DEFAULT_IGNORE_PATTERNS = [ - ".git", - ".git/*", - "*/.git", - "**/.git/**", - ".cache/huggingface", - ".cache/huggingface/*", - "*/.cache/huggingface", - "**/.cache/huggingface/**", -] - - -async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None): - async def progress_callback(event: RepoProgressEvent): - print(f"Overall Progress: {event.completed_files}/{event.total_files} files, {event.downloaded_bytes}/{event.total_bytes} bytes") - print(f"Estimated time remaining: {event.overall_eta}") - print("File Progress:") - for file_path, progress in event.file_progress.items(): - status_icon = {'not_started': '⚪', 'in_progress': '🔵', 'complete': '✅'}[progress.status] - eta_str = str(progress.eta) - print(f"{status_icon} {file_path}: {progress.downloaded}/{progress.total} bytes, " - f"Speed: {progress.speed:.2f} B/s, ETA: {eta_str}") - print("\n") - - await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.") - parser.add_argument("--repo-id", required=True, help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')") - parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)") - parser.add_argument("--allow-patterns", nargs="*", default=None, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')") - parser.add_argument("--ignore-patterns", nargs="*", default=None, help="Patterns of files to ignore (e.g., '.*')") - - args = parser.parse_args() - - asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns)) diff --git a/test/test_hf.py b/test/test_hf.py deleted file mode 100644 index 0477d132b..000000000 --- a/test/test_hf.py +++ /dev/null @@ -1,26 +0,0 @@ -import os -import sys - -# Add the project root to the Python path -project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.insert(0, project_root) - -import asyncio -from exo.download.hf.hf_helpers import get_weight_map - -async def test_get_weight_map(): - repo_ids = [ - "mlx-community/quantized-gemma-2b", - "mlx-community/Meta-Llama-3.1-8B-4bit", - "mlx-community/Meta-Llama-3.1-70B-4bit", - "mlx-community/Meta-Llama-3.1-405B-4bit", - ] - for repo_id in repo_ids: - weight_map = await get_weight_map(repo_id) - assert weight_map is not None, "Weight map should not be None" - assert isinstance(weight_map, dict), "Weight map should be a dictionary" - assert len(weight_map) > 0, "Weight map should not be empty" - print(f"OK: {repo_id}") - -if __name__ == "__main__": - asyncio.run(test_get_weight_map())