-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
rewrite ShardDownloader, simplify significantly
- Loading branch information
1 parent
a3766f5
commit b89495f
Showing
11 changed files
with
236 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
from exo.inference.shard import Shard | ||
from exo.models import get_repo | ||
from pathlib import Path | ||
from exo.download.hf.hf_helpers import get_hf_endpoint, get_auth_headers, filter_repo_objects, get_allow_patterns | ||
from exo.download.shard_download import ShardDownloader | ||
from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent | ||
from exo.helpers import AsyncCallbackSystem, DEBUG | ||
from exo.models import get_supported_models, build_base_shard | ||
import os | ||
import aiofiles.os as aios | ||
import aiohttp | ||
import aiofiles | ||
from urllib.parse import urljoin | ||
from typing import Optional, Callable, Union, Tuple, Dict | ||
import time | ||
from datetime import timedelta | ||
import asyncio | ||
import json | ||
|
||
def exo_home() -> Path: | ||
return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo")) | ||
|
||
async def ensure_downloads_dir() -> Path: | ||
downloads_dir = exo_home()/"downloads" | ||
await aios.makedirs(downloads_dir, exist_ok=True) | ||
return downloads_dir | ||
|
||
async def fetch_file_list(session, repo_id, revision, path=""): | ||
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}" | ||
url = f"{api_url}/{path}" if path else api_url | ||
|
||
headers = await get_auth_headers() | ||
async with session.get(url, headers=headers) as response: | ||
if response.status == 200: | ||
data = await response.json() | ||
files = [] | ||
for item in data: | ||
if item["type"] == "file": | ||
files.append({"path": item["path"], "size": item["size"]}) | ||
elif item["type"] == "directory": | ||
subfiles = await fetch_file_list(session, repo_id, revision, item["path"]) | ||
files.extend(subfiles) | ||
return files | ||
else: | ||
raise Exception(f"Failed to fetch file list: {response.status}") | ||
|
||
async def download_file(session: aiohttp.ClientSession, repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Optional[Callable[[int, int], None]] = None) -> Path: | ||
if (target_dir/path).exists(): return target_dir/path | ||
base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/" | ||
url = urljoin(base_url, path) | ||
headers = await get_auth_headers() | ||
async with session.get(url, headers=headers) as r: | ||
assert r.status == 200, r.status | ||
length = int(r.headers.get('content-length', 0)) | ||
n_read = 0 | ||
async with aiofiles.tempfile.NamedTemporaryFile(dir=target_dir, delete=False) as temp_file: | ||
while chunk := await r.content.read(16384): on_progress(n_read := n_read + await temp_file.write(chunk), length) | ||
await aios.rename(temp_file.name, target_dir/path) | ||
return target_dir/path | ||
|
||
def calculate_repo_progress(repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent: | ||
all_total_bytes = sum([p.total for p in file_progress.values()]) | ||
all_downloaded_bytes = sum([p.downloaded for p in file_progress.values()]) | ||
elapsed_time = time.time() - all_start_time | ||
all_speed = all_downloaded_bytes / elapsed_time if elapsed_time > 0 else 0 | ||
all_eta = timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed) if all_speed > 0 else timedelta(seconds=0) | ||
status = "not_started" if all_downloaded_bytes == 0 else "complete" if all_downloaded_bytes == all_total_bytes else "in_progress" | ||
return RepoProgressEvent(repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes, all_total_bytes, all_speed, all_eta, file_progress, status) | ||
|
||
async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]: | ||
target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--") | ||
async with aiohttp.ClientSession() as session: | ||
index_file = await download_file(session, repo_id, revision, "model.safetensors.index.json", target_dir) | ||
async with aiofiles.open(index_file, 'r') as f: | ||
index_data = json.loads(await f.read()) | ||
return index_data.get("weight_map") | ||
|
||
async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str) -> list[str]: | ||
weight_map = await get_weight_map(get_repo(shard.model_id, inference_engine_classname)) | ||
return get_allow_patterns(weight_map, shard) | ||
|
||
async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 6, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]: | ||
if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}") | ||
repo_id = get_repo(shard.model_id, inference_engine_classname) | ||
revision = "main" | ||
target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--") | ||
if not skip_download: await aios.makedirs(target_dir, exist_ok=True) | ||
|
||
if repo_id is None: | ||
raise ValueError(f"No repo found for {shard.model_id=} and inference engine {inference_engine_classname}") | ||
|
||
allow_patterns = await resolve_allow_patterns(shard, inference_engine_classname) if not skip_download else None | ||
if DEBUG >= 3: print(f"Downloading {shard.model_id=} with {allow_patterns=}") | ||
|
||
all_start_time = time.time() | ||
async with aiohttp.ClientSession() as session: | ||
file_list = await fetch_file_list(session, repo_id, revision) | ||
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, key=lambda x: x["path"])) | ||
file_progress: Dict[str, RepoFileProgressEvent] = {} | ||
def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int): | ||
start_time = file_progress[file["path"]].start_time if file["path"] in file_progress else time.time() | ||
speed = curr_bytes / (time.time() - start_time) | ||
eta = timedelta(seconds=(total_bytes - curr_bytes) / speed) | ||
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, curr_bytes, total_bytes, speed, eta, "in_progress", start_time) | ||
on_progress.trigger_all(shard, calculate_repo_progress(repo_id, revision, file_progress, all_start_time)) | ||
if DEBUG >= 2: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}") | ||
for file in filtered_file_list: | ||
downloaded_bytes = (await aios.stat(target_dir/file["path"])).st_size if await aios.path.exists(target_dir/file["path"]) else 0 | ||
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "not_started" if downloaded_bytes == 0 else "complete" if downloaded_bytes == file["size"] else "in_progress", time.time()) | ||
|
||
semaphore = asyncio.Semaphore(max_parallel_downloads) | ||
async def download_with_semaphore(file): | ||
async with semaphore: | ||
await download_file(session, repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes)) | ||
if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list]) | ||
final_repo_progress = calculate_repo_progress(repo_id, revision, file_progress, all_start_time) | ||
on_progress.trigger_all(shard, final_repo_progress) | ||
return target_dir, final_repo_progress | ||
|
||
def new_shard_downloader() -> ShardDownloader: | ||
return SingletonShardDownloader(CachedShardDownloader(NewShardDownloader())) | ||
|
||
class SingletonShardDownloader(ShardDownloader): | ||
def __init__(self, shard_downloader: ShardDownloader): | ||
self.shard_downloader = shard_downloader | ||
self.active_downloads: Dict[Shard, asyncio.Task] = {} | ||
|
||
@property | ||
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]: | ||
return self.shard_downloader.on_progress | ||
|
||
async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path: | ||
if shard not in self.active_downloads: self.active_downloads[shard] = asyncio.create_task(self.shard_downloader.ensure_shard(shard, inference_engine_name)) | ||
try: return await self.active_downloads[shard] | ||
finally: | ||
if shard in self.active_downloads and self.active_downloads[shard].done(): del self.active_downloads[shard] | ||
|
||
async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]: | ||
return await self.shard_downloader.get_shard_download_status(inference_engine_name) | ||
|
||
class CachedShardDownloader(ShardDownloader): | ||
def __init__(self, shard_downloader: ShardDownloader): | ||
self.shard_downloader = shard_downloader | ||
self.cache: Dict[tuple[str, Shard], Path] = {} | ||
|
||
@property | ||
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]: | ||
return self.shard_downloader.on_progress | ||
|
||
async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path: | ||
if (inference_engine_name, shard) in self.cache: | ||
if DEBUG >= 2: print(f"ensure_shard cache hit {shard=} for {inference_engine_name}") | ||
return self.cache[(inference_engine_name, shard)] | ||
if DEBUG >= 2: print(f"ensure_shard cache miss {shard=} for {inference_engine_name}") | ||
target_dir = await self.shard_downloader.ensure_shard(shard, inference_engine_name) | ||
self.cache[(inference_engine_name, shard)] = target_dir | ||
return target_dir | ||
|
||
async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]: | ||
return await self.shard_downloader.get_shard_download_status(inference_engine_name) | ||
|
||
class NewShardDownloader(ShardDownloader): | ||
def __init__(self): | ||
self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]() | ||
|
||
@property | ||
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]: | ||
return self._on_progress | ||
|
||
async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path: | ||
target_dir, _ = await download_shard(shard, inference_engine_name, self.on_progress) | ||
return target_dir | ||
|
||
async def get_shard_download_status(self, inference_engine_name: str) -> list[tuple[Path, RepoProgressEvent]]: | ||
if DEBUG >= 2: print("Getting shard download status for", inference_engine_name) | ||
downloads = await asyncio.gather(*[download_shard(build_base_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])], return_exceptions=True) | ||
if DEBUG >= 6: print("Downloaded shards:", downloads) | ||
if any(isinstance(d, Exception) for d in downloads) and DEBUG >= 1: print("Error downloading shards:", [d for d in downloads if isinstance(d, Exception)]) | ||
return [d for d in downloads if not isinstance(d, Exception)] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from exo.download.hf.new_shard_download import download_shard, NewShardDownloader | ||
from exo.inference.shard import Shard | ||
from exo.models import get_model_id | ||
from pathlib import Path | ||
import asyncio | ||
|
||
async def test_new_shard_download(): | ||
shard_downloader = NewShardDownloader() | ||
shard_downloader.on_progress.register("test").on_next(lambda shard, event: print(shard, event)) | ||
await shard_downloader.ensure_shard(Shard(model_id="llama-3.2-1b", start_layer=0, end_layer=0, n_layers=16), "MLXDynamicShardInferenceEngine") | ||
download_statuses = await shard_downloader.get_shard_download_status("MLXDynamicShardInferenceEngine") | ||
print({k: v for k, v in download_statuses if v.downloaded_bytes > 0}) | ||
|
||
async def test_helpers(): | ||
print(get_model_id("mlx-community/Llama-3.3-70B-Instruct-4bit")) | ||
print(get_model_id("fjsekljfd")) | ||
|
||
if __name__ == "__main__": | ||
asyncio.run(test_helpers()) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.