Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into runners2
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCheema committed Jan 22, 2025
2 parents 6b8cd05 + 07ceb19 commit 461e4f3
Show file tree
Hide file tree
Showing 17 changed files with 402 additions and 286 deletions.
300 changes: 148 additions & 152 deletions exo/api/chatgpt_api.py

Large diffs are not rendered by default.

17 changes: 12 additions & 5 deletions exo/apputil/anim.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import numpy as np
import cv2
import sys

def draw_rounded_rectangle(draw, coords, radius, fill):
left, top, right, bottom = coords
Expand Down Expand Up @@ -80,14 +81,20 @@ def create_animation_mp4(
font = ImageFont.load_default()
promptfont = ImageFont.load_default()

# Get the base directory for images when running as a bundled app
if hasattr(sys, '_MEIPASS'):
base_dir = os.path.join(sys._MEIPASS, "exo", "apputil", "baseimages")
else:
base_dir = os.path.join(os.path.dirname(__file__), "baseimages")

# Process first frame
base_img = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image1.png"))
base_img = Image.open(os.path.join(base_dir, "image1.png"))
draw = ImageDraw.Draw(base_img)
draw_centered_text_rounded(draw, device_name, font, device_coords)
frames.extend([crop_image(base_img)] * 30) # 1 second at 30fps

# Process second frame with typing animation
base_img2 = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image2.png"))
base_img2 = Image.open(os.path.join(base_dir, "image2.png"))
for i in range(len(prompt_text) + 1):
current_frame = base_img2.copy()
draw = ImageDraw.Draw(current_frame)
Expand All @@ -101,7 +108,7 @@ def create_animation_mp4(

# Create blur sequence
replacement_img = Image.open(replacement_image_path)
base_img = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image3.png"))
base_img = Image.open(os.path.join(base_dir, "image3.png"))
blur_steps = [int(80 * (1 - i/8)) for i in range(9)]

for i, blur_amount in enumerate(blur_steps):
Expand All @@ -123,7 +130,7 @@ def create_animation_mp4(
frames.extend([crop_image(new_frame)] * 15) # 0.5 seconds at 30fps

# Create and add final frame (image4)
final_base = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image4.png"))
final_base = Image.open(os.path.join(base_dir, "image4.png"))
draw = ImageDraw.Draw(final_base)

draw_centered_text_rounded(draw, device_name, font, device_coords)
Expand Down Expand Up @@ -158,4 +165,4 @@ def create_animation_mp4(
out.write(frame_array)

out.release()
print(f"Video saved successfully to {output_path}")
print(f"Video saved successfully to {output_path}")
50 changes: 25 additions & 25 deletions exo/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import platform
import psutil
import uuid
import netifaces
from scapy.all import get_if_addr, get_if_list
import re
import subprocess
from pathlib import Path
import tempfile
Expand Down Expand Up @@ -231,26 +232,26 @@ def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
def get_all_ip_addresses_and_interfaces():
try:
ip_addresses = []
for interface in netifaces.interfaces():
ifaddresses = netifaces.ifaddresses(interface)
if netifaces.AF_INET in ifaddresses:
for link in ifaddresses[netifaces.AF_INET]:
ip = link['addr']
ip_addresses.append((ip, interface))
for interface in get_if_list():
ip = get_if_addr(interface)
# Include all addresses, including loopback
# Filter out link-local addresses
if not ip.startswith('169.254.') and not ip.startswith('0.0.'):
# Remove "\\Device\\NPF_" prefix from interface name
simplified_interface = re.sub(r'^\\Device\\NPF_', '', interface)
ip_addresses.append((ip, simplified_interface))
return list(set(ip_addresses))
except:
if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
return [("localhost", "lo")]


async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:
try:
# Use the shared subprocess_pool
output = await asyncio.get_running_loop().run_in_executor(subprocess_pool, lambda: subprocess.run(
['system_profiler', 'SPNetworkDataType', '-json'],
capture_output=True,
text=True,
close_fds=True
).stdout)
output = await asyncio.get_running_loop().run_in_executor(
subprocess_pool, lambda: subprocess.run(['system_profiler', 'SPNetworkDataType', '-json'], capture_output=True, text=True, close_fds=True).stdout
)

data = json.loads(output)

Expand All @@ -276,15 +277,15 @@ async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:

return None


async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
# On macOS, try to get interface type using networksetup
if psutil.MACOS:
macos_type = await get_macos_interface_type(ifname)
if macos_type is not None: return macos_type

# Local container/virtual interfaces
if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or
'bridge' in ifname):
if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or 'bridge' in ifname):
return (7, "Container Virtual")

# Loopback interface
Expand All @@ -310,6 +311,7 @@ async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
# Other physical interfaces
return (2, "Other")


async def shutdown(signal, loop, server):
"""Gracefully shutdown the server and close the asyncio loop."""
print(f"Received exit signal {signal.name}...")
Expand Down Expand Up @@ -353,18 +355,16 @@ async def get_mac_system_info() -> Tuple[str, str, int]:
return "Unknown Model", "Unknown Chip", 0

def get_exo_home() -> Path:
if os.name == "nt": # Check if the OS is Windows
docs_folder = Path(os.environ["USERPROFILE"]) / "Documents"
else:
docs_folder = Path.home() / "Documents"
exo_folder = docs_folder / "Exo"
if not exo_folder.exists():
exo_folder.mkdir()
if psutil.WINDOWS: docs_folder = Path(os.environ["USERPROFILE"])/"Documents"
else: docs_folder = Path.home()/"Documents"
if not docs_folder.exists(): docs_folder.mkdir(exist_ok=True)
exo_folder = docs_folder/"Exo"
if not exo_folder.exists(): exo_folder.mkdir(exist_ok=True)
return exo_folder


def get_exo_images_dir() -> Path:
exo_home = get_exo_home()
images_dir = exo_home / "Images"
if not images_dir.exists():
images_dir.mkdir()
images_dir = exo_home/"Images"
if not images_dir.exists(): images_dir.mkdir(exist_ok=True)
return images_dir
10 changes: 5 additions & 5 deletions exo/inference/debug_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,25 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
token_full = await inference_engine_1.sample(resp_full)

next_resp_full = await inference_engine_1.infer_tensor(
next_resp_full, _ = await inference_engine_1.infer_tensor(
"A",
shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
input_data=token_full,
)

resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
resp2 = await inference_engine_2.infer_tensor(
resp1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
resp2, _ = await inference_engine_2.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
input_data=resp1,
)
token2 = await inference_engine_2.sample(resp2)
resp3 = await inference_engine_1.infer_tensor(
resp3, _ = await inference_engine_1.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32),
input_data=token2,
)
resp4 = await inference_engine_2.infer_tensor(
resp4, _ = await inference_engine_2.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
input_data=resp3,
Expand Down
4 changes: 2 additions & 2 deletions exo/inference/dummy_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ async def sample(self, x: np.ndarray, temp: float = 0.0, top_p: float = 1.0) ->
async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
return self.tokenizer.decode(tokens)

async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
await self.ensure_shard(shard)
return input_data + 1 if self.shard.is_last_layer() else input_data
return input_data + 1 if self.shard.is_last_layer() else input_data, None

async def ensure_shard(self, shard: Shard):
if self.shard == shard: return
Expand Down
17 changes: 10 additions & 7 deletions exo/inference/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Tuple, Optional
from abc import ABC, abstractmethod
from .shard import Shard
from exo.download.shard_download import ShardDownloader


class InferenceEngine(ABC):
Expand All @@ -13,7 +14,7 @@ class InferenceEngine(ABC):
@abstractmethod
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
pass

@abstractmethod
async def sample(self, x: np.ndarray) -> np.ndarray:
pass
Expand All @@ -23,7 +24,7 @@ async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
pass

@abstractmethod
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
pass

@abstractmethod
Expand All @@ -32,14 +33,14 @@ async def load_checkpoint(self, shard: Shard, path: str):

async def save_checkpoint(self, shard: Shard, path: str):
pass

async def save_session(self, key, value):
self.session[key] = value

async def clear_session(self):
self.session.empty()
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> np.ndarray:

async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
tokens = await self.encode(shard, prompt)
if shard.model_id != 'stable-diffusion-2-1-base':
x = tokens.reshape(1, -1)
Expand All @@ -49,13 +50,15 @@ async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inferen

return output_data, inference_state


inference_engine_classes = {
"mlx": "MLXDynamicShardInferenceEngine",
"tinygrad": "TinygradDynamicShardInferenceEngine",
"dummy": "DummyInferenceEngine",
}

def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):

def get_inference_engine(inference_engine_name: str, shard_downloader: ShardDownloader):
if DEBUG >= 2:
print(f"get_inference_engine called with: {inference_engine_name}")
if inference_engine_name == "mlx":
Expand Down
6 changes: 3 additions & 3 deletions exo/inference/mlx/sharded_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,15 @@ async def load_checkpoint(self, shard: Shard, path: str):
await self.ensure_shard(shard)
self.model.load_weights(path)

async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
await self.ensure_shard(shard)
loop = asyncio.get_running_loop()
state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
x = mx.array(input_data)
if self.model.model_type != 'StableDiffusionPipeline':
output_data = self.model(x, **state, **inference_state)
output_data = self.model(x, **state, **(inference_state or {}))
else:
output_data, inference_state = self.model(x, **state, **inference_state)
output_data, inference_state = self.model(x, **state, **(inference_state or {}))
output_data = np.array(output_data, copy=False)
return output_data, inference_state

Expand Down
16 changes: 5 additions & 11 deletions exo/inference/test_dummy_inference_engine.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
import pytest
import json
import numpy as np
from exo.inference.dummy_inference_engine import DummyInferenceEngine
from exo.inference.shard import Shard


class MockShardDownloader:
async def ensure_shard(self, shard):
pass


@pytest.mark.asyncio
async def test_dummy_inference_specific():
engine = DummyInferenceEngine(MockShardDownloader())
engine = DummyInferenceEngine()
test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
test_prompt = "This is a test prompt"

result = await engine.infer_prompt("test_request", test_shard, test_prompt)
result, _ = await engine.infer_prompt("test_request", test_shard, test_prompt)

print(f"Inference result shape: {result.shape}")

Expand All @@ -26,20 +20,20 @@ async def test_dummy_inference_specific():
@pytest.mark.asyncio
async def test_dummy_inference_engine():
# Initialize the DummyInferenceEngine
engine = DummyInferenceEngine(MockShardDownloader())
engine = DummyInferenceEngine()

# Create a test shard
shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)

# Test infer_prompt
output = await engine.infer_prompt("test_id", shard, "Test prompt")
output, _ = await engine.infer_prompt("test_id", shard, "Test prompt")

assert isinstance(output, np.ndarray), "Output should be a numpy array"
assert output.ndim == 2, "Output should be 2-dimensional"

# Test infer_tensor
input_tensor = np.array([[1, 2, 3]])
output = await engine.infer_tensor("test_id", shard, input_tensor)
output, _ = await engine.infer_tensor("test_id", shard, input_tensor)

assert isinstance(output, np.ndarray), "Output should be a numpy array"
assert output.ndim == 2, "Output should be 2-dimensional"
Expand Down
12 changes: 6 additions & 6 deletions exo/inference/test_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,30 @@
# An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int):
prompt = "In a single word only, what is the last name of the current president of the USA?"
resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
resp_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
token_full = await inference_engine_1.sample(resp_full)
token_full = token_full.reshape(1, -1)
next_resp_full = await inference_engine_1.infer_tensor(
next_resp_full, _ = await inference_engine_1.infer_tensor(
"A",
shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers),
input_data=token_full,
)

pp = n_layers // 2
resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
resp2 = await inference_engine_2.infer_tensor(
resp1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
resp2, _ = await inference_engine_2.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
input_data=resp1,
)
tokens2 = await inference_engine_1.sample(resp2)
tokens2 = tokens2.reshape(1, -1)
resp3 = await inference_engine_1.infer_tensor(
resp3, _ = await inference_engine_1.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers),
input_data=tokens2,
)
resp4 = await inference_engine_2.infer_tensor(
resp4, _ = await inference_engine_2.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
input_data=resp3,
Expand Down
2 changes: 1 addition & 1 deletion exo/inference/tinygrad/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ async def save_checkpoint(self, shard: Shard, path: str):
state_dict = await asyncio.get_running_loop().run_in_executor(self.executor, get_state_dict, self.model)
safe_save(state_dict, path)

async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> np.ndarray:
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
await self.ensure_shard(shard)
def wrap_infer():
x = Tensor(input_data)
Expand Down
Loading

0 comments on commit 461e4f3

Please sign in to comment.