Skip to content

Commit

Permalink
rewrite ShardDownloader, simplify significantly
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCheema committed Jan 27, 2025
1 parent a3766f5 commit b89495f
Show file tree
Hide file tree
Showing 11 changed files with 236 additions and 57 deletions.
51 changes: 11 additions & 40 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,24 @@
import traceback
import signal
from exo import DEBUG, VERSION
from exo.download.download_progress import RepoProgressEvent
from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
from exo.inference.tokenizers import resolve_tokenizer
from exo.orchestration import Node
from exo.models import build_base_shard, model_cards, get_repo, pretty_name
from exo.models import build_base_shard, model_cards, get_repo, get_model_id, get_supported_models, get_pretty_name
from typing import Callable, Optional
from PIL import Image
import numpy as np
import base64
from io import BytesIO
import platform
from exo.download.shard_download import RepoProgressEvent

if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
import mlx.core as mx
else:
import numpy as mx

import tempfile
from exo.download.hf.hf_shard_download import HFShardDownloader
import shutil
from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
from exo.apputil import create_animation_mp4
Expand Down Expand Up @@ -277,41 +276,12 @@ async def handle_healthcheck(self, request):

async def handle_model_support(self, request):
try:
response = web.StreamResponse(status=200, reason='OK', headers={
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
})
response = web.StreamResponse(status=200, reason='OK', headers={ 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', 'Connection': 'keep-alive' })
await response.prepare(request)

async def process_model(model_name, pretty):
if model_name in model_cards:
model_info = model_cards[model_name]

if self.inference_engine_classname in model_info.get("repo", {}):
shard = build_base_shard(model_name, self.inference_engine_classname)
if shard:
downloader = HFShardDownloader(quick_check=True)
downloader.current_shard = shard
downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
status = await downloader.get_shard_download_status()

download_percentage = status.get("overall") if status else None
total_size = status.get("total_size") if status else None
total_downloaded = status.get("total_downloaded") if status else False

model_data = {
model_name: {
"name": pretty, "downloaded": download_percentage == 100 if download_percentage is not None else False, "download_percentage": download_percentage, "total_size": total_size,
"total_downloaded": total_downloaded
}
}

await response.write(f"data: {json.dumps(model_data)}\n\n".encode())

# Process all models in parallel
await asyncio.gather(*[process_model(model_name, pretty) for model_name, pretty in pretty_name.items()])

downloads = await self.node.shard_downloader.get_shard_download_status(self.inference_engine_classname)
for (path, d) in downloads:
model_data = { get_model_id(d.repo_id): { "downloaded": d.downloaded_bytes == d.total_bytes, "download_percentage": 100 if d.downloaded_bytes == d.total_bytes else 100 * float(d.downloaded_bytes) / float(d.total_bytes), "total_size": d.total_bytes, "total_downloaded": d.downloaded_bytes } }
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
await response.write(b"data: [DONE]\n\n")
return response

Expand Down Expand Up @@ -348,6 +318,7 @@ async def handle_get_download_progress(self, request):
progress_data = {}
for node_id, progress_event in self.node.node_download_progress.items():
if isinstance(progress_event, RepoProgressEvent):
if progress_event.status != "in_progress": continue
progress_data[node_id] = progress_event.to_dict()
else:
print(f"Unknown progress event type: {type(progress_event)}. {progress_event}")
Expand Down Expand Up @@ -611,9 +582,9 @@ async def handle_delete_model(self, request):

async def handle_get_initial_models(self, request):
model_data = {}
for model_name, pretty in pretty_name.items():
model_data[model_name] = {
"name": pretty,
for model_id in get_supported_models([[self.inference_engine_classname]]):
model_data[model_id] = {
"name": get_pretty_name(model_id),
"downloaded": None, # Initially unknown
"download_percentage": None, # Change from 0 to null
"total_size": None,
Expand Down
3 changes: 2 additions & 1 deletion exo/download/download_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ class RepoFileProgressEvent:
speed: int
eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
start_time: float

def to_dict(self):
return {
"repo_id": self.repo_id, "repo_revision": self.repo_revision, "file_path": self.file_path, "downloaded": self.downloaded, "downloaded_this_session": self.downloaded_this_session,
"total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status
"total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status, "start_time": self.start_time
}

@classmethod
Expand Down
1 change: 0 additions & 1 deletion exo/download/hf/hf_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ async def download_file(
raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")

if downloaded_size == total_size:
print(f"File already downloaded: {file_path}")
if progress_callback:
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
return
Expand Down
180 changes: 180 additions & 0 deletions exo/download/hf/new_shard_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from exo.inference.shard import Shard
from exo.models import get_repo
from pathlib import Path
from exo.download.hf.hf_helpers import get_hf_endpoint, get_auth_headers, filter_repo_objects, get_allow_patterns
from exo.download.shard_download import ShardDownloader
from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent
from exo.helpers import AsyncCallbackSystem, DEBUG
from exo.models import get_supported_models, build_base_shard
import os
import aiofiles.os as aios
import aiohttp
import aiofiles
from urllib.parse import urljoin
from typing import Optional, Callable, Union, Tuple, Dict
import time
from datetime import timedelta
import asyncio
import json

def exo_home() -> Path:
return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))

async def ensure_downloads_dir() -> Path:
downloads_dir = exo_home()/"downloads"
await aios.makedirs(downloads_dir, exist_ok=True)
return downloads_dir

async def fetch_file_list(session, repo_id, revision, path=""):
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
url = f"{api_url}/{path}" if path else api_url

headers = await get_auth_headers()
async with session.get(url, headers=headers) as response:
if response.status == 200:
data = await response.json()
files = []
for item in data:
if item["type"] == "file":
files.append({"path": item["path"], "size": item["size"]})
elif item["type"] == "directory":
subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
files.extend(subfiles)
return files
else:
raise Exception(f"Failed to fetch file list: {response.status}")

async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Optional[Callable[[int, int], None]] = None) -> Path:
if (target_dir/path).exists(): return target_dir/path
base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
url = urljoin(base_url, path)
headers = await get_auth_headers()
async with session.get(url, headers=headers) as r:
assert r.status == 200, r.status
length = int(r.headers.get('content-length', 0))
n_read = 0
async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file:
while chunk := await r.content.read(16384): on_progress(n_read := n_read + await temp_file.write(chunk), length)
await aios.rename(temp_file.name, target_dir/path)
return target_dir/path

def calculate_repo_progress(repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent:
all_total_bytes = sum([p.total for p in file_progress.values()])
all_downloaded_bytes = sum([p.downloaded for p in file_progress.values()])
elapsed_time = time.time() - all_start_time
all_speed = all_downloaded_bytes / elapsed_time if elapsed_time > 0 else 0
all_eta = timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed) if all_speed > 0 else timedelta(seconds=0)
status = "not_started" if all_downloaded_bytes == 0 else "complete" if all_downloaded_bytes == all_total_bytes else "in_progress"
return RepoProgressEvent(repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes, all_total_bytes, all_speed, all_eta, file_progress, status)

async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
async with aiohttp.ClientSession() as session:
index_file = await download_file(session, repo_id, revision, "model.safetensors.index.json", target_dir)
async with aiofiles.open(index_file, 'r') as f:
index_data = json.loads(await f.read())
return index_data.get("weight_map")

async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str) -> list[str]:
weight_map = await get_weight_map(get_repo(shard.model_id, inference_engine_classname))
return get_allow_patterns(weight_map, shard)

async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 6, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]:
if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")
repo_id = get_repo(shard.model_id, inference_engine_classname)
revision = "main"
target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
if not skip_download: await aios.makedirs(target_dir, exist_ok=True)

if repo_id is None:
raise ValueError(f"No repo found for {shard.model_id=} and inference engine {inference_engine_classname}")

allow_patterns = await resolve_allow_patterns(shard, inference_engine_classname) if not skip_download else None
if DEBUG >= 3: print(f"Downloading {shard.model_id=} with {allow_patterns=}")

all_start_time = time.time()
async with aiohttp.ClientSession() as session:
file_list = await fetch_file_list(session, repo_id, revision)
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, key=lambda x: x["path"]))
file_progress: Dict[str, RepoFileProgressEvent] = {}
def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int):
start_time = file_progress[file["path"]].start_time if file["path"] in file_progress else time.time()
speed = curr_bytes / (time.time() - start_time)
eta = timedelta(seconds=(total_bytes - curr_bytes) / speed)
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, curr_bytes, total_bytes, speed, eta, "in_progress", start_time)
on_progress.trigger_all(shard, calculate_repo_progress(repo_id, revision, file_progress, all_start_time))
if DEBUG >= 2: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
for file in filtered_file_list:
downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "not_started" if downloaded_bytes == 0 else "complete" if downloaded_bytes == file["size"] else "in_progress", time.time())

semaphore = asyncio.Semaphore(max_parallel_downloads)
async def download_with_semaphore(file):
async with semaphore:
await download_file(session, repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list])
final_repo_progress = calculate_repo_progress(repo_id, revision, file_progress, all_start_time)
on_progress.trigger_all(shard, final_repo_progress)
return target_dir, final_repo_progress

def new_shard_downloader() -> ShardDownloader:
return SingletonShardDownloader(CachedShardDownloader(NewShardDownloader()))

class SingletonShardDownloader(ShardDownloader):
def __init__(self, shard_downloader: ShardDownloader):
self.shard_downloader = shard_downloader
self.active_downloads: Dict[Shard, asyncio.Task] = {}

@property
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
return self.shard_downloader.on_progress

async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
if shard not in self.active_downloads: self.active_downloads[shard] = asyncio.create_task(self.shard_downloader.ensure_shard(shard, inference_engine_name))
try: return await self.active_downloads[shard]
finally:
if shard in self.active_downloads and self.active_downloads[shard].done(): del self.active_downloads[shard]

async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
return await self.shard_downloader.get_shard_download_status(inference_engine_name)

class CachedShardDownloader(ShardDownloader):
def __init__(self, shard_downloader: ShardDownloader):
self.shard_downloader = shard_downloader
self.cache: Dict[tuple[str, Shard], Path] = {}

@property
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
return self.shard_downloader.on_progress

async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
if (inference_engine_name, shard) in self.cache:
if DEBUG >= 2: print(f"ensure_shard cache hit {shard=} for {inference_engine_name}")
return self.cache[(inference_engine_name, shard)]
if DEBUG >= 2: print(f"ensure_shard cache miss {shard=} for {inference_engine_name}")
target_dir = await self.shard_downloader.ensure_shard(shard, inference_engine_name)
self.cache[(inference_engine_name, shard)] = target_dir
return target_dir

async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
return await self.shard_downloader.get_shard_download_status(inference_engine_name)

class NewShardDownloader(ShardDownloader):
def __init__(self):
self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()

@property
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
return self._on_progress

async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
target_dir, _ = await download_shard(shard, inference_engine_name, self.on_progress)
return target_dir

async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
if DEBUG >= 2: print("Getting shard download status for", inference_engine_name)
downloads = await asyncio.gather(*[download_shard(build_base_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])], return_exceptions=True)
if DEBUG >= 6: print("Downloaded shards:", downloads)
if any(isinstance(d, Exception) for d in downloads) and DEBUG >= 1: print("Error downloading shards:", [d for d in downloads if isinstance(d, Exception)])
return [d for d in downloads if not isinstance(d, Exception)]

20 changes: 20 additions & 0 deletions exo/download/hf/test_new_shard_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from exo.download.hf.new_shard_download import download_shard, NewShardDownloader
from exo.inference.shard import Shard
from exo.models import get_model_id
from pathlib import Path
import asyncio

async def test_new_shard_download():
shard_downloader = NewShardDownloader()
shard_downloader.on_progress.register("test").on_next(lambda shard, event: print(shard, event))
await shard_downloader.ensure_shard(Shard(model_id="llama-3.2-1b", start_layer=0, end_layer=0, n_layers=16), "MLXDynamicShardInferenceEngine")
download_statuses = await shard_downloader.get_shard_download_status("MLXDynamicShardInferenceEngine")
print({k: v for k, v in download_statuses if v.downloaded_bytes > 0})

async def test_helpers():
print(get_model_id("mlx-community/Llama-3.3-70B-Instruct-4bit"))
print(get_model_id("fjsekljfd"))

if __name__ == "__main__":
asyncio.run(test_helpers())

4 changes: 2 additions & 2 deletions exo/download/shard_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent
pass

@abstractmethod
async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]:
"""Get the download status of shards.
Returns:
Expand All @@ -45,5 +45,5 @@ async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
return AsyncCallbackSystem()

async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
async def get_shard_download_status(self, inference_engine_name: str) -> Optional[Dict[str, float]]:
return None
7 changes: 3 additions & 4 deletions exo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
from exo.api import ChatGPTAPI
from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
from exo.download.hf.hf_shard_download import HFShardDownloader
from exo.download.hf.new_shard_download import new_shard_downloader
from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses_and_interfaces, terminal_link, shutdown
from exo.inference.shard import Shard
from exo.inference.inference_engine import get_inference_engine, InferenceEngine
Expand Down Expand Up @@ -117,8 +117,7 @@ def configure_uvloop():
system_info = get_system_info()
print(f"Detected system: {system_info}")

shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check,
max_parallel_downloads=args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader()
shard_downloader: ShardDownloader = new_shard_downloader() if args.inference_engine != "dummy" else NoopShardDownloader()
inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
print(f"Inference engine name after selection: {inference_engine_name}")

Expand Down Expand Up @@ -175,10 +174,10 @@ def configure_uvloop():
None,
inference_engine,
discovery,
shard_downloader,
partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
max_generate_tokens=args.max_generate_tokens,
topology_viz=topology_viz,
shard_downloader=shard_downloader,
default_sample_temperature=args.default_temp
)
server = GRPCServer(node, args.node_host, args.node_port)
Expand Down
Loading

0 comments on commit b89495f

Please sign in to comment.