Skip to content

Commit

Permalink
fix visual bug where frontend would show the full hf repo size, but i…
Browse files Browse the repository at this point in the history
…n some cases that includes redundant files so we should use the model index in those cases too
  • Loading branch information
AlexCheema committed Jan 27, 2025
1 parent 2158606 commit b349e48
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 32 deletions.
13 changes: 4 additions & 9 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
from exo.inference.tokenizers import resolve_tokenizer
from exo.orchestration import Node
from exo.models import build_base_shard, model_cards, get_repo, get_model_id, get_supported_models, get_pretty_name
from exo.models import build_base_shard, build_full_shard, model_cards, get_repo, get_supported_models, get_pretty_name
from typing import Callable, Optional
from PIL import Image
import numpy as np
import base64
from io import BytesIO
import platform
from exo.download.download_progress import RepoProgressEvent
from exo.download.new_shard_download import ensure_downloads_dir, delete_model
from exo.download.new_shard_download import delete_model
import tempfile
from exo.apputil import create_animation_mp4
from collections import defaultdict
Expand Down Expand Up @@ -278,7 +278,7 @@ async def handle_model_support(self, request):
await response.prepare(request)
downloads = await self.node.shard_downloader.get_shard_download_status(self.inference_engine_classname)
for (path, d) in downloads:
model_data = { get_model_id(d.repo_id): { "downloaded": d.downloaded_bytes == d.total_bytes, "download_percentage": 100 if d.downloaded_bytes == d.total_bytes else 100 * float(d.downloaded_bytes) / float(d.total_bytes), "total_size": d.total_bytes, "total_downloaded": d.downloaded_bytes } }
model_data = { d.shard.model_id: { "downloaded": d.downloaded_bytes == d.total_bytes, "download_percentage": 100 if d.downloaded_bytes == d.total_bytes else 100 * float(d.downloaded_bytes) / float(d.total_bytes), "total_size": d.total_bytes, "total_downloaded": d.downloaded_bytes } }
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
await response.write(b"data: [DONE]\n\n")
return response
Expand Down Expand Up @@ -551,11 +551,6 @@ async def handle_delete_model(self, request):
if DEBUG >= 2: traceback.print_exc()
return web.json_response({"detail": f"Error deleting model: {str(e)}"}, status=500)

except Exception as e:
print(f"Error in handle_delete_model: {str(e)}")
traceback.print_exc()
return web.json_response({"detail": f"Server error: {str(e)}"}, status=500)

async def handle_get_initial_models(self, request):
model_data = {}
for model_id in get_supported_models([[self.inference_engine_classname]]):
Expand Down Expand Up @@ -606,7 +601,7 @@ async def handle_post_download(self, request):
model_name = data.get("model")
if not model_name: return web.json_response({"error": "model parameter is required"}, status=400)
if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
shard = build_base_shard(model_name, self.inference_engine_classname)
shard = build_full_shard(model_name, self.inference_engine_classname)
if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))

Expand Down
5 changes: 4 additions & 1 deletion exo/download/download_progress.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict, Callable, Coroutine, Any, Literal
from exo.inference.shard import Shard
from dataclasses import dataclass
from datetime import timedelta

Expand Down Expand Up @@ -30,6 +31,7 @@ def from_dict(cls, data):

@dataclass
class RepoProgressEvent:
shard: Shard
repo_id: str
repo_revision: str
completed_files: int
Expand All @@ -44,7 +46,7 @@ class RepoProgressEvent:

def to_dict(self):
return {
"repo_id": self.repo_id, "repo_revision": self.repo_revision, "completed_files": self.completed_files, "total_files": self.total_files, "downloaded_bytes": self.downloaded_bytes,
"shard": self.shard.to_dict(), "repo_id": self.repo_id, "repo_revision": self.repo_revision, "completed_files": self.completed_files, "total_files": self.total_files, "downloaded_bytes": self.downloaded_bytes,
"downloaded_bytes_this_session": self.downloaded_bytes_this_session, "total_bytes": self.total_bytes, "overall_speed": self.overall_speed, "overall_eta": self.overall_eta.total_seconds(),
"file_progress": {k: v.to_dict()
for k, v in self.file_progress.items()}, "status": self.status
Expand All @@ -54,6 +56,7 @@ def to_dict(self):
def from_dict(cls, data):
if 'overall_eta' in data: data['overall_eta'] = timedelta(seconds=data['overall_eta'])
if 'file_progress' in data: data['file_progress'] = {k: RepoFileProgressEvent.from_dict(v) for k, v in data['file_progress'].items()}
if 'shard' in data: data['shard'] = Shard.from_dict(data['shard'])

return cls(**data)

Expand Down
35 changes: 22 additions & 13 deletions exo/download/new_shard_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from exo.download.shard_download import ShardDownloader
from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent
from exo.helpers import AsyncCallbackSystem, DEBUG
from exo.models import get_supported_models, build_base_shard
from exo.models import get_supported_models, build_full_shard
import os
import aiofiles.os as aios
import aiohttp
Expand All @@ -18,10 +18,18 @@
import json
import traceback
import shutil
import tempfile

def exo_home() -> Path:
return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))

def exo_tmp() -> Path:
return Path(tempfile.gettempdir())/"exo"

async def ensure_exo_tmp() -> Path:
await aios.makedirs(exo_tmp(), exist_ok=True)
return exo_tmp()

async def has_exo_home_read_access() -> bool:
try: return await aios.access(exo_home(), os.R_OK)
except OSError: return False
Expand Down Expand Up @@ -90,17 +98,17 @@ async def download_file(session: aiohttp.ClientSession, repo_id: str, revision:
await aios.rename(temp_file.name, target_dir/path)
return target_dir/path

def calculate_repo_progress(repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent:
def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent:
all_total_bytes = sum([p.total for p in file_progress.values()])
all_downloaded_bytes = sum([p.downloaded for p in file_progress.values()])
elapsed_time = time.time() - all_start_time
all_speed = all_downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
all_eta = timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed) if all_speed > 0 else timedelta(seconds=0)
status = "not_started" if all_downloaded_bytes == 0 else "complete" if all_downloaded_bytes == all_total_bytes else "in_progress"
return RepoProgressEvent(repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes, all_total_bytes, all_speed, all_eta, file_progress, status)
return RepoProgressEvent(shard, repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes, all_total_bytes, all_speed, all_eta, file_progress, status)

async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
target_dir = await ensure_exo_tmp()/repo_id.replace("/", "--")
async with aiohttp.ClientSession() as session:
index_file = await download_file(session, repo_id, revision, "model.safetensors.index.json", target_dir)
async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
Expand All @@ -110,12 +118,13 @@ async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str)
try:
weight_map = await get_weight_map(get_repo(shard.model_id, inference_engine_classname))
return get_allow_patterns(weight_map, shard)
except Exception as e:
if DEBUG >= 1: print(f"Error getting weight map for {shard.model_id=} and inference engine {inference_engine_classname}: {e}")
except:
if DEBUG >= 1: print(f"Error getting weight map for {shard.model_id=} and inference engine {inference_engine_classname}")
if DEBUG >= 1: traceback.print_exc()
return ["*"]

async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 6, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]:
if DEBUG >= 6 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")
if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")
repo_id = get_repo(shard.model_id, inference_engine_classname)
revision = "main"
target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
Expand All @@ -124,8 +133,8 @@ async def download_shard(shard: Shard, inference_engine_classname: str, on_progr
if repo_id is None:
raise ValueError(f"No repo found for {shard.model_id=} and inference engine {inference_engine_classname}")

allow_patterns = await resolve_allow_patterns(shard, inference_engine_classname) if not skip_download else None
if DEBUG >= 3: print(f"Downloading {shard.model_id=} with {allow_patterns=}")
allow_patterns = await resolve_allow_patterns(shard, inference_engine_classname)
if DEBUG >= 2: print(f"Downloading {shard.model_id=} with {allow_patterns=}")

all_start_time = time.time()
async with aiohttp.ClientSession() as session:
Expand All @@ -137,8 +146,8 @@ def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int):
speed = curr_bytes / (time.time() - start_time)
eta = timedelta(seconds=(total_bytes - curr_bytes) / speed)
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, curr_bytes, total_bytes, speed, eta, "in_progress", start_time)
on_progress.trigger_all(shard, calculate_repo_progress(repo_id, revision, file_progress, all_start_time))
if DEBUG >= 2: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time))
if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
for file in filtered_file_list:
downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "not_started" if downloaded_bytes == 0 else "complete" if downloaded_bytes == file["size"] else "in_progress", time.time())
Expand All @@ -148,7 +157,7 @@ async def download_with_semaphore(file):
async with semaphore:
await download_file(session, repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list])
final_repo_progress = calculate_repo_progress(repo_id, revision, file_progress, all_start_time)
final_repo_progress = calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time)
on_progress.trigger_all(shard, final_repo_progress)
return target_dir, final_repo_progress

Expand Down Expand Up @@ -208,7 +217,7 @@ async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:

async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
if DEBUG >= 2: print("Getting shard download status for", inference_engine_name)
downloads = await asyncio.gather(*[download_shard(build_base_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])], return_exceptions=True)
downloads = await asyncio.gather(*[download_shard(build_full_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])], return_exceptions=True)
if DEBUG >= 6: print("Downloaded shards:", downloads)
if any(isinstance(d, Exception) for d in downloads) and DEBUG >= 1: print("Error downloading shards:", [d for d in downloads if isinstance(d, Exception)])
return [d for d in downloads if not isinstance(d, Exception)]
Expand Down
7 changes: 1 addition & 6 deletions exo/download/test_new_shard_download.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from exo.download.new_shard_download import download_shard, NewShardDownloader
from exo.inference.shard import Shard
from exo.models import get_model_id
from pathlib import Path
import asyncio

Expand All @@ -11,10 +10,6 @@ async def test_new_shard_download():
download_statuses = await shard_downloader.get_shard_download_status("MLXDynamicShardInferenceEngine")
print({k: v for k, v in download_statuses if v.downloaded_bytes > 0})

async def test_helpers():
print(get_model_id("mlx-community/Llama-3.3-70B-Instruct-4bit"))
print(get_model_id("fjsekljfd"))

if __name__ == "__main__":
asyncio.run(test_helpers())
asyncio.run(test_new_shard_download())

8 changes: 5 additions & 3 deletions exo/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,6 @@
def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:
return model_cards.get(model_id, {}).get("repo", {}).get(inference_engine_classname, None)

def get_model_id(repo_id: str) -> Optional[str]:
return next((model_id for model_id, card in model_cards.items() if repo_id in card["repo"].values()), None)

def get_pretty_name(model_id: str) -> Optional[str]:
return pretty_name.get(model_id, None)

Expand All @@ -248,6 +245,11 @@ def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional
return None
return Shard(model_id, 0, 0, n_layers)

def build_full_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]:
base_shard = build_base_shard(model_id, inference_engine_classname)
if base_shard is None: return None
return Shard(base_shard.model_id, 0, base_shard.n_layers - 1, base_shard.n_layers)

def get_supported_models(supported_inference_engine_lists: Optional[List[List[str]]] = None) -> List[str]:
if not supported_inference_engine_lists:
return list(model_cards.keys())
Expand Down

0 comments on commit b349e48

Please sign in to comment.