diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index e05ccde06..c2ceea399 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -21,7 +21,13 @@ import numpy as np import base64 from io import BytesIO -import mlx.core as mx +import platform + +if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64": + import mlx.core as mx +else: + import numpy as mx + import tempfile from exo.download.hf.hf_shard_download import HFShardDownloader import shutil @@ -29,6 +35,7 @@ from exo.apputil import create_animation_mp4 from collections import defaultdict + class Message: def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None): self.role = role @@ -42,7 +49,6 @@ def to_dict(self): return data - class ChatCompletionRequest: def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None): self.model = model @@ -133,16 +139,24 @@ def remap_messages(messages: List[Message]) -> List[Message]: def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None): messages = remap_messages(_messages) - chat_template_args = { - "conversation": [m.to_dict() for m in messages], - "tokenize": False, - "add_generation_prompt": True - } - if tools: chat_template_args["tools"] = tools - - prompt = tokenizer.apply_chat_template(**chat_template_args) - print(f"!!! Prompt: {prompt}") - return prompt + chat_template_args = {"conversation": [m.to_dict() for m in messages], "tokenize": False, "add_generation_prompt": True} + if tools: + chat_template_args["tools"] = tools + + try: + prompt = tokenizer.apply_chat_template(**chat_template_args) + if DEBUG >= 3: print(f"!!! Prompt: {prompt}") + return prompt + except UnicodeEncodeError: + # Handle Unicode encoding by ensuring everything is UTF-8 + chat_template_args["conversation"] = [ + {k: v.encode('utf-8').decode('utf-8') if isinstance(v, str) else v + for k, v in m.to_dict().items()} + for m in messages + ] + prompt = tokenizer.apply_chat_template(**chat_template_args) + if DEBUG >= 3: print(f"!!! Prompt (UTF-8 encoded): {prompt}") + return prompt def parse_message(data: dict): @@ -166,8 +180,17 @@ def __init__(self, request_id: str, timestamp: int, prompt: str): self.timestamp = timestamp self.prompt = prompt + class ChatGPTAPI: - def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None, system_prompt: Optional[str] = None): + def __init__( + self, + node: Node, + inference_engine_classname: str, + response_timeout: int = 90, + on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, + default_model: Optional[str] = None, + system_prompt: Optional[str] = None + ): self.node = node self.inference_engine_classname = inference_engine_classname self.response_timeout = response_timeout @@ -209,18 +232,22 @@ def __init__(self, node: Node, inference_engine_classname: str, response_timeout cors.add(self.app.router.add_get("/v1/topology", self.handle_get_topology), {"*": cors_options}) cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options}) - + # Add static routes if "__compiled__" not in globals(): self.static_dir = Path(__file__).parent.parent/"tinychat" self.app.router.add_get("/", self.handle_root) self.app.router.add_static("/", self.static_dir, name="static") - self.app.router.add_static('/images/', get_exo_images_dir(), name='static_images') + + # Always add images route, regardless of compilation status + self.images_dir = get_exo_images_dir() + self.images_dir.mkdir(parents=True, exist_ok=True) + self.app.router.add_static('/images/', self.images_dir, name='static_images') self.app.middlewares.append(self.timeout_middleware) self.app.middlewares.append(self.log_request) async def handle_quit(self, request): - if DEBUG>=1: print("Received quit signal") + if DEBUG >= 1: print("Received quit signal") response = web.json_response({"detail": "Quit signal received"}, status=200) await response.prepare(request) await response.write_eof() @@ -250,61 +277,48 @@ async def handle_healthcheck(self, request): async def handle_model_support(self, request): try: - response = web.StreamResponse( - status=200, - reason='OK', - headers={ - 'Content-Type': 'text/event-stream', - 'Cache-Control': 'no-cache', - 'Connection': 'keep-alive', - } - ) - await response.prepare(request) + response = web.StreamResponse(status=200, reason='OK', headers={ + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache', + 'Connection': 'keep-alive', + }) + await response.prepare(request) + + async def process_model(model_name, pretty): + if model_name in model_cards: + model_info = model_cards[model_name] + + if self.inference_engine_classname in model_info.get("repo", {}): + shard = build_base_shard(model_name, self.inference_engine_classname) + if shard: + downloader = HFShardDownloader(quick_check=True) + downloader.current_shard = shard + downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname) + status = await downloader.get_shard_download_status() + + download_percentage = status.get("overall") if status else None + total_size = status.get("total_size") if status else None + total_downloaded = status.get("total_downloaded") if status else False - async def process_model(model_name, pretty): - if model_name in model_cards: - model_info = model_cards[model_name] - - if self.inference_engine_classname in model_info.get("repo", {}): - shard = build_base_shard(model_name, self.inference_engine_classname) - if shard: - downloader = HFShardDownloader(quick_check=True) - downloader.current_shard = shard - downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname) - status = await downloader.get_shard_download_status() - - download_percentage = status.get("overall") if status else None - total_size = status.get("total_size") if status else None - total_downloaded = status.get("total_downloaded") if status else False - - model_data = { - model_name: { - "name": pretty, - "downloaded": download_percentage == 100 if download_percentage is not None else False, - "download_percentage": download_percentage, - "total_size": total_size, - "total_downloaded": total_downloaded - } - } - - await response.write(f"data: {json.dumps(model_data)}\n\n".encode()) - - # Process all models in parallel - await asyncio.gather(*[ - process_model(model_name, pretty) - for model_name, pretty in pretty_name.items() - ]) - - await response.write(b"data: [DONE]\n\n") - return response + model_data = { + model_name: { + "name": pretty, "downloaded": download_percentage == 100 if download_percentage is not None else False, "download_percentage": download_percentage, "total_size": total_size, + "total_downloaded": total_downloaded + } + } + + await response.write(f"data: {json.dumps(model_data)}\n\n".encode()) + + # Process all models in parallel + await asyncio.gather(*[process_model(model_name, pretty) for model_name, pretty in pretty_name.items()]) + + await response.write(b"data: [DONE]\n\n") + return response except Exception as e: - print(f"Error in handle_model_support: {str(e)}") - traceback.print_exc() - return web.json_response( - {"detail": f"Server error: {str(e)}"}, - status=500 - ) + print(f"Error in handle_model_support: {str(e)}") + traceback.print_exc() + return web.json_response({"detail": f"Server error: {str(e)}"}, status=500) async def handle_get_models(self, request): models_list = [{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()] @@ -466,7 +480,6 @@ async def handle_post_chat_completions(self, request): if DEBUG >= 2: traceback.print_exc() return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500) - async def handle_post_image_generations(self, request): data = await request.json() @@ -479,7 +492,7 @@ async def handle_post_image_generations(self, request): shard = build_base_shard(model, self.inference_engine_classname) if DEBUG >= 2: print(f"shard: {shard}") if not shard: - return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400) + return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400) request_id = str(uuid.uuid4()) callback_id = f"chatgpt-api-wait-response-{request_id}" @@ -491,77 +504,85 @@ async def handle_post_image_generations(self, request): img = None await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout) - - response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'application/octet-stream',"Cache-Control": "no-cache",}) + response = web.StreamResponse(status=200, reason='OK', headers={ + 'Content-Type': 'application/octet-stream', + "Cache-Control": "no-cache", + }) await response.prepare(request) def get_progress_bar(current_step, total_steps, bar_length=50): # Calculate the percentage of completion - percent = float(current_step) / total_steps + percent = float(current_step)/total_steps # Calculate the number of hashes to display - arrow = '-' * int(round(percent * bar_length) - 1) + '>' - spaces = ' ' * (bar_length - len(arrow)) - + arrow = '-'*int(round(percent*bar_length) - 1) + '>' + spaces = ' '*(bar_length - len(arrow)) + # Create the progress bar string progress_bar = f'Progress: [{arrow}{spaces}] {int(percent * 100)}% ({current_step}/{total_steps})' return progress_bar async def stream_image(_request_id: str, result, is_finished: bool): - if isinstance(result, list): - await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n') + if isinstance(result, list): + await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n') - elif isinstance(result, np.ndarray): + elif isinstance(result, np.ndarray): + try: im = Image.fromarray(np.array(result)) - images_folder = get_exo_images_dir() # Save the image to a file image_filename = f"{_request_id}.png" - image_path = images_folder / image_filename + image_path = self.images_dir/image_filename im.save(image_path) - image_url = request.app.router['static_images'].url_for(filename=image_filename) - base_url = f"{request.scheme}://{request.host}" - # Construct the full URL correctly - full_image_url = base_url + str(image_url) - await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n') + # Get URL for the saved image + try: + image_url = request.app.router['static_images'].url_for(filename=image_filename) + base_url = f"{request.scheme}://{request.host}" + full_image_url = base_url + str(image_url) + + await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n') + except KeyError as e: + if DEBUG >= 2: print(f"Error getting image URL: {e}") + # Fallback to direct file path if URL generation fails + await response.write(json.dumps({'images': [{'url': str(image_path), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n') + if is_finished: await response.write_eof() - + + except Exception as e: + if DEBUG >= 2: print(f"Error processing image: {e}") + if DEBUG >= 2: traceback.print_exc() + await response.write(json.dumps({'error': str(e)}).encode('utf-8') + b'\n') stream_task = None + def on_result(_request_id: str, result, is_finished: bool): - nonlocal stream_task - stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished)) - return _request_id == request_id and is_finished + nonlocal stream_task + stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished)) + return _request_id == request_id and is_finished await callback.wait(on_result, timeout=self.response_timeout*10) - + if stream_task: - # Wait for the stream task to complete before returning - await stream_task + # Wait for the stream task to complete before returning + await stream_task return response except Exception as e: - if DEBUG >= 2: traceback.print_exc() - return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500) - + if DEBUG >= 2: traceback.print_exc() + return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500) + async def handle_delete_model(self, request): try: model_name = request.match_info.get('model_name') if DEBUG >= 2: print(f"Attempting to delete model: {model_name}") if not model_name or model_name not in model_cards: - return web.json_response( - {"detail": f"Invalid model name: {model_name}"}, - status=400 - ) + return web.json_response({"detail": f"Invalid model name: {model_name}"}, status=400) shard = build_base_shard(model_name, self.inference_engine_classname) if not shard: - return web.json_response( - {"detail": "Could not build shard for model"}, - status=400 - ) + return web.json_response({"detail": "Could not build shard for model"}, status=400) repo_id = get_repo(shard.model_id, self.inference_engine_classname) if DEBUG >= 2: print(f"Repo ID for model: {repo_id}") @@ -576,38 +597,28 @@ async def handle_delete_model(self, request): if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...") try: shutil.rmtree(cache_dir) - return web.json_response({ - "status": "success", - "message": f"Model {model_name} deleted successfully", - "path": str(cache_dir) - }) + return web.json_response({"status": "success", "message": f"Model {model_name} deleted successfully", "path": str(cache_dir)}) except Exception as e: - return web.json_response({ - "detail": f"Failed to delete model files: {str(e)}" - }, status=500) + return web.json_response({"detail": f"Failed to delete model files: {str(e)}"}, status=500) else: - return web.json_response({ - "detail": f"Model files not found at {cache_dir}" - }, status=404) + return web.json_response({"detail": f"Model files not found at {cache_dir}"}, status=404) except Exception as e: - print(f"Error in handle_delete_model: {str(e)}") - traceback.print_exc() - return web.json_response({ - "detail": f"Server error: {str(e)}" - }, status=500) + print(f"Error in handle_delete_model: {str(e)}") + traceback.print_exc() + return web.json_response({"detail": f"Server error: {str(e)}"}, status=500) async def handle_get_initial_models(self, request): model_data = {} for model_name, pretty in pretty_name.items(): - model_data[model_name] = { - "name": pretty, - "downloaded": None, # Initially unknown - "download_percentage": None, # Change from 0 to null - "total_size": None, - "total_downloaded": None, - "loading": True # Add loading state - } + model_data[model_name] = { + "name": pretty, + "downloaded": None, # Initially unknown + "download_percentage": None, # Change from 0 to null + "total_size": None, + "total_downloaded": None, + "loading": True # Add loading state + } return web.json_response(model_data) async def handle_create_animation(self, request): @@ -633,17 +644,9 @@ async def handle_create_animation(self, request): if DEBUG >= 2: print(f"Animation temp directory: {tmp_dir}, output file: {output_path}, directory exists: {tmp_dir.exists()}, directory permissions: {oct(tmp_dir.stat().st_mode)[-3:]}") # Create the animation - create_animation_mp4( - replacement_image_path, - output_path, - device_name, - prompt_text - ) + create_animation_mp4(replacement_image_path, output_path, device_name, prompt_text) - return web.json_response({ - "status": "success", - "output_path": output_path - }) + return web.json_response({"status": "success", "output_path": output_path}) except Exception as e: if DEBUG >= 2: traceback.print_exc() @@ -659,10 +662,7 @@ async def handle_post_download(self, request): if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400) asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname)) - return web.json_response({ - "status": "success", - "message": f"Download started for model: {model_name}" - }) + return web.json_response({"status": "success", "message": f"Download started for model: {model_name}"}) except Exception as e: if DEBUG >= 2: traceback.print_exc() return web.json_response({"error": str(e)}, status=500) @@ -676,10 +676,7 @@ async def handle_get_topology(self, request): return web.json_response({}) except Exception as e: if DEBUG >= 2: traceback.print_exc() - return web.json_response( - {"detail": f"Error getting topology: {str(e)}"}, - status=500 - ) + return web.json_response({"detail": f"Error getting topology: {str(e)}"}, status=500) async def handle_token(self, request_id: str, token: int, is_finished: bool): await self.token_queues[request_id].put((token, is_finished)) @@ -693,15 +690,14 @@ async def run(self, host: str = "0.0.0.0", port: int = 52415): def base64_decode(self, base64_string): #decode and reshape image if base64_string.startswith('data:image'): - base64_string = base64_string.split(',')[1] + base64_string = base64_string.split(',')[1] image_data = base64.b64decode(base64_string) img = Image.open(BytesIO(image_data)) - W, H = (dim - dim % 64 for dim in (img.width, img.height)) + W, H = (dim - dim%64 for dim in (img.width, img.height)) if W != img.width or H != img.height: - if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}") - img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter + if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}") + img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter img = mx.array(np.array(img)) - img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1 + img = (img[:, :, :3].astype(mx.float32)/255)*2 - 1 img = img[None] return img - diff --git a/exo/apputil/anim.py b/exo/apputil/anim.py index 654ca4703..b64aace3c 100644 --- a/exo/apputil/anim.py +++ b/exo/apputil/anim.py @@ -2,6 +2,7 @@ import os import numpy as np import cv2 +import sys def draw_rounded_rectangle(draw, coords, radius, fill): left, top, right, bottom = coords @@ -80,14 +81,20 @@ def create_animation_mp4( font = ImageFont.load_default() promptfont = ImageFont.load_default() + # Get the base directory for images when running as a bundled app + if hasattr(sys, '_MEIPASS'): + base_dir = os.path.join(sys._MEIPASS, "exo", "apputil", "baseimages") + else: + base_dir = os.path.join(os.path.dirname(__file__), "baseimages") + # Process first frame - base_img = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image1.png")) + base_img = Image.open(os.path.join(base_dir, "image1.png")) draw = ImageDraw.Draw(base_img) draw_centered_text_rounded(draw, device_name, font, device_coords) frames.extend([crop_image(base_img)] * 30) # 1 second at 30fps # Process second frame with typing animation - base_img2 = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image2.png")) + base_img2 = Image.open(os.path.join(base_dir, "image2.png")) for i in range(len(prompt_text) + 1): current_frame = base_img2.copy() draw = ImageDraw.Draw(current_frame) @@ -101,7 +108,7 @@ def create_animation_mp4( # Create blur sequence replacement_img = Image.open(replacement_image_path) - base_img = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image3.png")) + base_img = Image.open(os.path.join(base_dir, "image3.png")) blur_steps = [int(80 * (1 - i/8)) for i in range(9)] for i, blur_amount in enumerate(blur_steps): @@ -123,7 +130,7 @@ def create_animation_mp4( frames.extend([crop_image(new_frame)] * 15) # 0.5 seconds at 30fps # Create and add final frame (image4) - final_base = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image4.png")) + final_base = Image.open(os.path.join(base_dir, "image4.png")) draw = ImageDraw.Draw(final_base) draw_centered_text_rounded(draw, device_name, font, device_coords) @@ -158,4 +165,4 @@ def create_animation_mp4( out.write(frame_array) out.release() - print(f"Video saved successfully to {output_path}") \ No newline at end of file + print(f"Video saved successfully to {output_path}") diff --git a/exo/helpers.py b/exo/helpers.py index da286bbea..5f3554f90 100644 --- a/exo/helpers.py +++ b/exo/helpers.py @@ -7,7 +7,8 @@ import platform import psutil import uuid -import netifaces +from scapy.all import get_if_addr, get_if_list +import re import subprocess from pathlib import Path import tempfile @@ -231,26 +232,26 @@ def pretty_print_bytes_per_second(bytes_per_second: int) -> str: def get_all_ip_addresses_and_interfaces(): try: ip_addresses = [] - for interface in netifaces.interfaces(): - ifaddresses = netifaces.ifaddresses(interface) - if netifaces.AF_INET in ifaddresses: - for link in ifaddresses[netifaces.AF_INET]: - ip = link['addr'] - ip_addresses.append((ip, interface)) + for interface in get_if_list(): + ip = get_if_addr(interface) + # Include all addresses, including loopback + # Filter out link-local addresses + if not ip.startswith('169.254.') and not ip.startswith('0.0.'): + # Remove "\\Device\\NPF_" prefix from interface name + simplified_interface = re.sub(r'^\\Device\\NPF_', '', interface) + ip_addresses.append((ip, simplified_interface)) return list(set(ip_addresses)) except: if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.") return [("localhost", "lo")] + async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]: try: # Use the shared subprocess_pool - output = await asyncio.get_running_loop().run_in_executor(subprocess_pool, lambda: subprocess.run( - ['system_profiler', 'SPNetworkDataType', '-json'], - capture_output=True, - text=True, - close_fds=True - ).stdout) + output = await asyncio.get_running_loop().run_in_executor( + subprocess_pool, lambda: subprocess.run(['system_profiler', 'SPNetworkDataType', '-json'], capture_output=True, text=True, close_fds=True).stdout + ) data = json.loads(output) @@ -276,6 +277,7 @@ async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]: return None + async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]: # On macOS, try to get interface type using networksetup if psutil.MACOS: @@ -283,8 +285,7 @@ async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]: if macos_type is not None: return macos_type # Local container/virtual interfaces - if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or - 'bridge' in ifname): + if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or 'bridge' in ifname): return (7, "Container Virtual") # Loopback interface @@ -310,6 +311,7 @@ async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]: # Other physical interfaces return (2, "Other") + async def shutdown(signal, loop, server): """Gracefully shutdown the server and close the asyncio loop.""" print(f"Received exit signal {signal.name}...") @@ -353,18 +355,16 @@ async def get_mac_system_info() -> Tuple[str, str, int]: return "Unknown Model", "Unknown Chip", 0 def get_exo_home() -> Path: - if os.name == "nt": # Check if the OS is Windows - docs_folder = Path(os.environ["USERPROFILE"]) / "Documents" - else: - docs_folder = Path.home() / "Documents" - exo_folder = docs_folder / "Exo" - if not exo_folder.exists(): - exo_folder.mkdir() + if psutil.WINDOWS: docs_folder = Path(os.environ["USERPROFILE"])/"Documents" + else: docs_folder = Path.home()/"Documents" + if not docs_folder.exists(): docs_folder.mkdir(exist_ok=True) + exo_folder = docs_folder/"Exo" + if not exo_folder.exists(): exo_folder.mkdir(exist_ok=True) return exo_folder + def get_exo_images_dir() -> Path: exo_home = get_exo_home() - images_dir = exo_home / "Images" - if not images_dir.exists(): - images_dir.mkdir() + images_dir = exo_home/"Images" + if not images_dir.exists(): images_dir.mkdir(exist_ok=True) return images_dir diff --git a/exo/inference/debug_inference_engine.py b/exo/inference/debug_inference_engine.py index c1ae4b40c..d81ac0237 100644 --- a/exo/inference/debug_inference_engine.py +++ b/exo/inference/debug_inference_engine.py @@ -16,25 +16,25 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt) token_full = await inference_engine_1.sample(resp_full) - next_resp_full = await inference_engine_1.infer_tensor( + next_resp_full, _ = await inference_engine_1.infer_tensor( "A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), input_data=token_full, ) - resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt) - resp2 = await inference_engine_2.infer_tensor( + resp1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt) + resp2, _ = await inference_engine_2.infer_tensor( "B", shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32), input_data=resp1, ) token2 = await inference_engine_2.sample(resp2) - resp3 = await inference_engine_1.infer_tensor( + resp3, _ = await inference_engine_1.infer_tensor( "B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), input_data=token2, ) - resp4 = await inference_engine_2.infer_tensor( + resp4, _ = await inference_engine_2.infer_tensor( "B", shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32), input_data=resp3, diff --git a/exo/inference/dummy_inference_engine.py b/exo/inference/dummy_inference_engine.py index daf4b677e..98cf88945 100644 --- a/exo/inference/dummy_inference_engine.py +++ b/exo/inference/dummy_inference_engine.py @@ -25,9 +25,9 @@ async def sample(self, x: np.ndarray, temp: float = 0.0, top_p: float = 1.0) -> async def decode(self, shard: Shard, tokens: np.ndarray) -> str: return self.tokenizer.decode(tokens) - async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray: + async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]: await self.ensure_shard(shard) - return input_data + 1 if self.shard.is_last_layer() else input_data + return input_data + 1 if self.shard.is_last_layer() else input_data, None async def ensure_shard(self, shard: Shard): if self.shard == shard: return diff --git a/exo/inference/inference_engine.py b/exo/inference/inference_engine.py index 97cd6aa57..b62000371 100644 --- a/exo/inference/inference_engine.py +++ b/exo/inference/inference_engine.py @@ -5,6 +5,7 @@ from typing import Tuple, Optional from abc import ABC, abstractmethod from .shard import Shard +from exo.download.shard_download import ShardDownloader class InferenceEngine(ABC): @@ -13,7 +14,7 @@ class InferenceEngine(ABC): @abstractmethod async def encode(self, shard: Shard, prompt: str) -> np.ndarray: pass - + @abstractmethod async def sample(self, x: np.ndarray) -> np.ndarray: pass @@ -23,7 +24,7 @@ async def decode(self, shard: Shard, tokens: np.ndarray) -> str: pass @abstractmethod - async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray: + async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]: pass @abstractmethod @@ -32,14 +33,14 @@ async def load_checkpoint(self, shard: Shard, path: str): async def save_checkpoint(self, shard: Shard, path: str): pass - + async def save_session(self, key, value): self.session[key] = value - + async def clear_session(self): self.session.empty() - - async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> np.ndarray: + + async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]: tokens = await self.encode(shard, prompt) if shard.model_id != 'stable-diffusion-2-1-base': x = tokens.reshape(1, -1) @@ -49,13 +50,15 @@ async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inferen return output_data, inference_state + inference_engine_classes = { "mlx": "MLXDynamicShardInferenceEngine", "tinygrad": "TinygradDynamicShardInferenceEngine", "dummy": "DummyInferenceEngine", } -def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'): + +def get_inference_engine(inference_engine_name: str, shard_downloader: ShardDownloader): if DEBUG >= 2: print(f"get_inference_engine called with: {inference_engine_name}") if inference_engine_name == "mlx": diff --git a/exo/inference/mlx/sharded_inference_engine.py b/exo/inference/mlx/sharded_inference_engine.py index 51bde44a4..3bcd5d3ca 100644 --- a/exo/inference/mlx/sharded_inference_engine.py +++ b/exo/inference/mlx/sharded_inference_engine.py @@ -67,15 +67,15 @@ async def load_checkpoint(self, shard: Shard, path: str): await self.ensure_shard(shard) self.model.load_weights(path) - async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray: + async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]: await self.ensure_shard(shard) loop = asyncio.get_running_loop() state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {} x = mx.array(input_data) if self.model.model_type != 'StableDiffusionPipeline': - output_data = self.model(x, **state, **inference_state) + output_data = self.model(x, **state, **(inference_state or {})) else: - output_data, inference_state = self.model(x, **state, **inference_state) + output_data, inference_state = self.model(x, **state, **(inference_state or {})) output_data = np.array(output_data, copy=False) return output_data, inference_state diff --git a/exo/inference/test_dummy_inference_engine.py b/exo/inference/test_dummy_inference_engine.py index cfd33df6e..fad5178e1 100644 --- a/exo/inference/test_dummy_inference_engine.py +++ b/exo/inference/test_dummy_inference_engine.py @@ -1,22 +1,16 @@ import pytest -import json import numpy as np from exo.inference.dummy_inference_engine import DummyInferenceEngine from exo.inference.shard import Shard -class MockShardDownloader: - async def ensure_shard(self, shard): - pass - - @pytest.mark.asyncio async def test_dummy_inference_specific(): - engine = DummyInferenceEngine(MockShardDownloader()) + engine = DummyInferenceEngine() test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1) test_prompt = "This is a test prompt" - result = await engine.infer_prompt("test_request", test_shard, test_prompt) + result, _ = await engine.infer_prompt("test_request", test_shard, test_prompt) print(f"Inference result shape: {result.shape}") @@ -26,20 +20,20 @@ async def test_dummy_inference_specific(): @pytest.mark.asyncio async def test_dummy_inference_engine(): # Initialize the DummyInferenceEngine - engine = DummyInferenceEngine(MockShardDownloader()) + engine = DummyInferenceEngine() # Create a test shard shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1) # Test infer_prompt - output = await engine.infer_prompt("test_id", shard, "Test prompt") + output, _ = await engine.infer_prompt("test_id", shard, "Test prompt") assert isinstance(output, np.ndarray), "Output should be a numpy array" assert output.ndim == 2, "Output should be 2-dimensional" # Test infer_tensor input_tensor = np.array([[1, 2, 3]]) - output = await engine.infer_tensor("test_id", shard, input_tensor) + output, _ = await engine.infer_tensor("test_id", shard, input_tensor) assert isinstance(output, np.ndarray), "Output should be a numpy array" assert output.ndim == 2, "Output should be 2-dimensional" diff --git a/exo/inference/test_inference_engine.py b/exo/inference/test_inference_engine.py index db69aebf4..7a69eafce 100644 --- a/exo/inference/test_inference_engine.py +++ b/exo/inference/test_inference_engine.py @@ -11,30 +11,30 @@ # An inference engine should work the same for any number of Shards, as long as the Shards are continuous. async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int): prompt = "In a single word only, what is the last name of the current president of the USA?" - resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt) + resp_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt) token_full = await inference_engine_1.sample(resp_full) token_full = token_full.reshape(1, -1) - next_resp_full = await inference_engine_1.infer_tensor( + next_resp_full, _ = await inference_engine_1.infer_tensor( "A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), input_data=token_full, ) pp = n_layers // 2 - resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt) - resp2 = await inference_engine_2.infer_tensor( + resp1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt) + resp2, _ = await inference_engine_2.infer_tensor( "B", shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers), input_data=resp1, ) tokens2 = await inference_engine_1.sample(resp2) tokens2 = tokens2.reshape(1, -1) - resp3 = await inference_engine_1.infer_tensor( + resp3, _ = await inference_engine_1.infer_tensor( "B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), input_data=tokens2, ) - resp4 = await inference_engine_2.infer_tensor( + resp4, _ = await inference_engine_2.infer_tensor( "B", shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers), input_data=resp3, diff --git a/exo/inference/tinygrad/inference.py b/exo/inference/tinygrad/inference.py index 214cfd3d4..6543f0b80 100644 --- a/exo/inference/tinygrad/inference.py +++ b/exo/inference/tinygrad/inference.py @@ -104,7 +104,7 @@ async def save_checkpoint(self, shard: Shard, path: str): state_dict = await asyncio.get_running_loop().run_in_executor(self.executor, get_state_dict, self.model) safe_save(state_dict, path) - async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray: + async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]: await self.ensure_shard(shard) def wrap_infer(): x = Tensor(input_data) diff --git a/exo/networking/grpc/grpc_peer_handle.py b/exo/networking/grpc/grpc_peer_handle.py index eea315b8f..96b476e38 100644 --- a/exo/networking/grpc/grpc_peer_handle.py +++ b/exo/networking/grpc/grpc_peer_handle.py @@ -12,7 +12,13 @@ from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops from exo.helpers import DEBUG import json -import mlx.core as mx +import platform + +if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64": + import mlx.core as mx +else: + import numpy as mx + class GRPCPeerHandle(PeerHandle): def __init__(self, _id: str, address: str, desc: str, device_capabilities: DeviceCapabilities): @@ -101,7 +107,7 @@ async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional n_layers=shard.n_layers, ), request_id=request_id, - inference_state=self.serialize_inference_state(inference_state) + inference_state=None if inference_state is None else self.serialize_inference_state(inference_state) ) await self.stub.SendPrompt(request) @@ -115,7 +121,7 @@ async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: O ), tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)), request_id=request_id, - inference_state=self.serialize_inference_state(inference_state) + inference_state=None if inference_state is None else self.serialize_inference_state(inference_state) ) response =await self.stub.SendTensor(request) @@ -123,7 +129,7 @@ async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: O return None return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape) - + async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]: request = node_service_pb2.ExampleRequest( shard=node_service_pb2.Shard( @@ -145,7 +151,7 @@ async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarr return loss, grads else: return loss - + async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]: request = node_service_pb2.TensorRequest( shard=node_service_pb2.Shard( @@ -170,10 +176,7 @@ async def collect_topology(self, visited: set[str], max_depth: int) -> Topology: topology = Topology() for node_id, capabilities in response.nodes.items(): device_capabilities = DeviceCapabilities( - model=capabilities.model, - chip=capabilities.chip, - memory=capabilities.memory, - flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8) + model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8) ) topology.update_node(node_id, device_capabilities) for node_id, peer_connections in response.peer_graph.items(): @@ -197,28 +200,20 @@ def serialize_inference_state(self, inference_state: dict) -> node_service_pb2.I proto_inference_state = node_service_pb2.InferenceState() other_data = {} for k, v in inference_state.items(): - if isinstance(v, mx.array): - np_array = np.array(v) - tensor_data = node_service_pb2.Tensor( - tensor_data=np_array.tobytes(), - shape=list(np_array.shape), - dtype=str(np_array.dtype) - ) - proto_inference_state.tensor_data[k].CopyFrom(tensor_data) - elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v): - tensor_list = node_service_pb2.TensorList() - for tensor in v: - np_array = np.array(tensor) - tensor_data = node_service_pb2.Tensor( - tensor_data=np_array.tobytes(), - shape=list(np_array.shape), - dtype=str(np_array.dtype) - ) - tensor_list.tensors.append(tensor_data) - proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list) - else: - # For non-tensor data, we'll still use JSON - other_data[k] = v + if isinstance(v, mx.array): + np_array = np.array(v) + tensor_data = node_service_pb2.Tensor(tensor_data=np_array.tobytes(), shape=list(np_array.shape), dtype=str(np_array.dtype)) + proto_inference_state.tensor_data[k].CopyFrom(tensor_data) + elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v): + tensor_list = node_service_pb2.TensorList() + for tensor in v: + np_array = np.array(tensor) + tensor_data = node_service_pb2.Tensor(tensor_data=np_array.tobytes(), shape=list(np_array.shape), dtype=str(np_array.dtype)) + tensor_list.tensors.append(tensor_data) + proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list) + else: + # For non-tensor data, we'll still use JSON + other_data[k] = v if other_data: proto_inference_state.other_data_json = json.dumps(other_data) return proto_inference_state diff --git a/exo/networking/grpc/grpc_server.py b/exo/networking/grpc/grpc_server.py index 83cc0f01f..fbde1dc38 100644 --- a/exo/networking/grpc/grpc_server.py +++ b/exo/networking/grpc/grpc_server.py @@ -3,13 +3,19 @@ import numpy as np from asyncio import CancelledError +import platform + from . import node_service_pb2 from . import node_service_pb2_grpc from exo import DEBUG from exo.inference.shard import Shard from exo.orchestration import Node import json -import mlx.core as mx + +if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64": + import mlx.core as mx +else: + import numpy as mx class GRPCServer(node_service_pb2_grpc.NodeServiceServicer): @@ -60,7 +66,7 @@ async def SendPrompt(self, request, context): ) prompt = request.prompt request_id = request.request_id - inference_state = self.deserialize_inference_state(request.inference_state) + inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state) result = await self.node.process_prompt(shard, prompt, request_id, inference_state) if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}") tensor_data = result.tobytes() if result is not None else None @@ -76,13 +82,13 @@ async def SendTensor(self, request, context): tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape) request_id = request.request_id - inference_state = self.deserialize_inference_state(request.inference_state) + inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state) result = await self.node.process_tensor(shard, tensor, request_id, inference_state) if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}") tensor_data = result.tobytes() if result is not None else None return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor() - + async def SendExample(self, request, context): shard = Shard( model_id=request.shard.model_id, @@ -104,7 +110,7 @@ async def SendExample(self, request, context): else: loss = await self.node.process_example(shard, example, target, length, train, request_id) return node_service_pb2.Loss(loss=loss, grads=None) - + async def CollectTopology(self, request, context): max_depth = request.max_depth visited = set(request.visited) @@ -120,12 +126,7 @@ async def CollectTopology(self, request, context): for node_id, cap in topology.nodes.items() } peer_graph = { - node_id: node_service_pb2.PeerConnections( - connections=[ - node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description) - for conn in connections - ] - ) + node_id: node_service_pb2.PeerConnections(connections=[node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description) for conn in connections]) for node_id, connections in topology.peer_graph.items() } if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}") @@ -139,7 +140,7 @@ async def SendResult(self, request, context): if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}") result = list(result) if len(img.tensor_data) > 0: - result=np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape) + result = np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape) self.node.on_token.trigger_all(request_id, result, is_finished) return node_service_pb2.Empty() @@ -153,21 +154,18 @@ async def SendOpaqueStatus(self, request, context): async def HealthCheck(self, request, context): return node_service_pb2.HealthCheckResponse(is_healthy=True) - def deserialize_inference_state(self,inference_state_proto: node_service_pb2.InferenceState) -> dict: + def deserialize_inference_state(self, inference_state_proto: node_service_pb2.InferenceState) -> dict: inference_state = {} - + for k, tensor_data in inference_state_proto.tensor_data.items(): - np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape) - inference_state[k] = mx.array(np_array) - + np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape) + inference_state[k] = mx.array(np_array) + for k, tensor_list in inference_state_proto.tensor_list_data.items(): - inference_state[k] = [ - mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape)) - for tensor in tensor_list.tensors - ] - + inference_state[k] = [mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape)) for tensor in tensor_list.tensors] + if inference_state_proto.other_data_json: - other_data = json.loads(inference_state_proto.other_data_json) - inference_state.update(other_data) - + other_data = json.loads(inference_state_proto.other_data_json) + inference_state.update(other_data) + return inference_state diff --git a/exo/orchestration/node.py b/exo/orchestration/node.py index 9a10c126e..14780410c 100644 --- a/exo/orchestration/node.py +++ b/exo/orchestration/node.py @@ -326,7 +326,7 @@ async def _process_example( loss, grad = await self.inference_engine.train(request_id, shard, example, target, length) else: self.outstanding_requests[request_id] = "preprocessing" - step = await self.inference_engine.infer_tensor(request_id, shard, example) + step, _ = await self.inference_engine.infer_tensor(request_id, shard, example) self.outstanding_requests[request_id] = "waiting" loss, backgrad = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1)) self.outstanding_requests[request_id] = "training" @@ -342,7 +342,7 @@ async def _process_example( loss = await self.inference_engine.evaluate(request_id, shard, example, target, length) else: self.outstanding_requests[request_id] = "preprocessing" - step = await self.inference_engine.infer_tensor(request_id, shard, example) + step, _ = await self.inference_engine.infer_tensor(request_id, shard, example) self.outstanding_requests[request_id] = "waiting" loss = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1)) self.outstanding_requests.pop(request_id) diff --git a/exo/topology/device_capabilities.py b/exo/topology/device_capabilities.py index e962ecdf1..1358923a4 100644 --- a/exo/topology/device_capabilities.py +++ b/exo/topology/device_capabilities.py @@ -151,6 +151,8 @@ async def device_capabilities() -> DeviceCapabilities: return await mac_device_capabilities() elif psutil.LINUX: return await linux_device_capabilities() + elif psutil.WINDOWS: + return await windows_device_capabilities() else: return DeviceCapabilities( model="Unknown Device", @@ -187,6 +189,8 @@ async def linux_device_capabilities() -> DeviceCapabilities: if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}") + pynvml.nvmlShutdown() + return DeviceCapabilities( model=f"Linux Box ({gpu_name})", chip=gpu_name, @@ -194,13 +198,24 @@ async def linux_device_capabilities() -> DeviceCapabilities: flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)), ) elif Device.DEFAULT == "AMD": - # TODO AMD support + # For AMD GPUs, pyrsmi is the way (Official python package for rocm-smi) + from pyrsmi import rocml + + rocml.smi_initialize() + gpu_name = rocml.smi_get_device_name(0).upper() + gpu_memory_info = rocml.smi_get_device_memory_total(0) + + if DEBUG >= 2: print(f"AMD device {gpu_name=} {gpu_memory_info=}") + + rocml.smi_shutdown() + return DeviceCapabilities( - model="Linux Box (AMD)", - chip="Unknown AMD", - memory=psutil.virtual_memory().total // 2**20, + model="Linux Box ({gpu_name})", + chip={gpu_name}, + memory=gpu_memory_info.total // 2**20, flops=DeviceFlops(fp32=0, fp16=0, int8=0), ) + else: return DeviceCapabilities( model=f"Linux Box (Device: {Device.DEFAULT})", @@ -208,3 +223,74 @@ async def linux_device_capabilities() -> DeviceCapabilities: memory=psutil.virtual_memory().total // 2**20, flops=DeviceFlops(fp32=0, fp16=0, int8=0), ) + + +def windows_device_capabilities() -> DeviceCapabilities: + import psutil + + def get_gpu_info(): + import win32com.client # install pywin32 + + wmiObj = win32com.client.GetObject("winmgmts:\\\\.\\root\\cimv2") + gpus = wmiObj.ExecQuery("SELECT * FROM Win32_VideoController") + + gpu_info = [] + for gpu in gpus: + info = { + "Name": gpu.Name, + "AdapterRAM": gpu.AdapterRAM, # Bug in this property, returns -ve for VRAM > 4GB (uint32 overflow) + "DriverVersion": gpu.DriverVersion, + "VideoProcessor": gpu.VideoProcessor + } + gpu_info.append(info) + + return gpu_info + + gpus_info = get_gpu_info() + gpu_names = [gpu['Name'] for gpu in gpus_info] + + contains_nvidia = any('nvidia' in gpu_name.lower() for gpu_name in gpu_names) + contains_amd = any('amd' in gpu_name.lower() for gpu_name in gpu_names) + + if contains_nvidia: + import pynvml + + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + gpu_raw_name = pynvml.nvmlDeviceGetName(handle).upper() + gpu_name = gpu_raw_name.rsplit(" ", 1)[0] if gpu_raw_name.endswith("GB") else gpu_raw_name + gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle) + + if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}") + + return DeviceCapabilities( + model=f"Windows Box ({gpu_name})", + chip=gpu_name, + memory=gpu_memory_info.total // 2**20, + flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)), + ) + elif contains_amd: + # For AMD GPUs, pyrsmi is the way (Official python package for rocm-smi) + from pyrsmi import rocml + + rocml.smi_initialize() + gpu_name = rocml.smi_get_device_name(0).upper() + gpu_memory_info = rocml.smi_get_device_memory_total(0) + + if DEBUG >= 2: print(f"AMD device {gpu_name=} {gpu_memory_info=}") + + rocml.smi_shutdown() + + return DeviceCapabilities( + model="Windows Box ({gpu_name})", + chip={gpu_name}, + memory=gpu_memory_info.total // 2**20, + flops=DeviceFlops(fp32=0, fp16=0, int8=0), + ) + else: + return DeviceCapabilities( + model=f"Windows Box (Device: Unknown)", + chip=f"Unknown Chip (Device(s): {gpu_names})", + memory=psutil.virtual_memory().total // 2**20, + flops=DeviceFlops(fp32=0, fp16=0, int8=0), + ) diff --git a/scripts/build_exo.py b/scripts/build_exo.py index 78a612496..618dca48c 100644 --- a/scripts/build_exo.py +++ b/scripts/build_exo.py @@ -6,6 +6,9 @@ def run(): site_packages = site.getsitepackages()[0] + base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + baseimages_dir = os.path.join(base_dir, "exo", "apputil", "baseimages") + command = [ f"{sys.executable}", "-m", "nuitka", "exo/main.py", "--company-name=exolabs", @@ -15,7 +18,8 @@ def run(): "--standalone", "--output-filename=exo", "--python-flag=no_site", - "--onefile" + "--onefile", + f"--include-data-dir={baseimages_dir}=exo/apputil/baseimages" ] if sys.platform == "darwin": @@ -23,7 +27,7 @@ def run(): "--macos-app-name=exo", "--macos-app-mode=gui", "--macos-app-version=0.0.1", - "--macos-signed-app-name=com.exolabs.exo", + "--macos-signed-app-name=net.exolabs.exo", "--include-distribution-meta=mlx", "--include-module=mlx._reprlib_fix", "--include-module=mlx._os_warning", diff --git a/setup.py b/setup.py index 4b3720a28..93028202b 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,6 @@ import sys import platform +import subprocess from setuptools import find_packages, setup @@ -11,7 +12,6 @@ "grpcio==1.68.0", "grpcio-tools==1.68.0", "Jinja2==3.1.4", - "netifaces==0.11.0", "numpy==2.0.0", "nuitka==2.5.1", "nvidia-ml-py==12.560.30", @@ -23,6 +23,7 @@ "pydantic==2.9.2", "requests==2.32.3", "rich==13.7.1", + "scapy==2.6.1", "tenacity==9.0.0", "tqdm==4.66.4", "transformers==4.46.3", @@ -32,19 +33,51 @@ ] extras_require = { - "formatting": [ - "yapf==0.40.2", - ], + "formatting": ["yapf==0.40.2",], "apple_silicon": [ "mlx==0.21.1", "mlx-lm==0.20.4", ], + "windows": ["pywin32==308",], + "nvidia-gpu": ["nvidia-ml-py==12.560.30",], + "amd-gpu": ["pyrsmi==0.2.0"], } # Check if running on macOS with Apple Silicon if sys.platform.startswith("darwin") and platform.machine() == "arm64": install_requires.extend(extras_require["apple_silicon"]) +# Check if running Windows +if sys.platform.startswith("win32"): + install_requires.extend(extras_require["windows"]) + + +def _add_gpu_requires(): + global install_requires + # Add Nvidia-GPU + try: + out = subprocess.run(['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'], shell=True, text=True, capture_output=True, check=False) + if out.returncode == 0: + install_requires.extend(extras_require["nvidia-gpu"]) + except subprocess.CalledProcessError: + pass + + # Add AMD-GPU + # This will mostly work only on Linux, amd/rocm-smi is not yet supported on Windows + try: + out = subprocess.run(['amd-smi', 'list', '--csv'], shell=True, text=True, capture_output=True, check=False) + if out.returncode == 0: + install_requires.extend(extras_require["amd-gpu"]) + except: + out = subprocess.run(['rocm-smi', 'list', '--csv'], shell=True, text=True, capture_output=True, check=False) + if out.returncode == 0: + install_requires.extend(extras_require["amd-gpu"]) + finally: + pass + + +_add_gpu_requires() + setup( name="exo", version="0.0.1", diff --git a/test/test_tokenizers.py b/test/test_tokenizers.py index 3635357f4..7c12902be 100644 --- a/test/test_tokenizers.py +++ b/test/test_tokenizers.py @@ -24,7 +24,7 @@ def test_tokenizer(name, tokenizer, verbose=False): strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id])) assert text == strip_tokens(decoded) == strip_tokens(reconstructed) -ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", "mlx-community/Phi-3.5-mini-instruct-4bit", "mlx-community/phi-4-4bit"] +ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", "mlx-community/Phi-3.5-mini-instruct-4bit", "mlx-community/phi-4-4bit", "stabilityai/stable-diffusion-2-1-base"] ignore_pattern = re.compile(r"^(" + "|".join(model.replace("*", ".*") for model in ignore) + r")") models = [] for model_id in model_cards: