Skip to content

Commit 1696db6

Browse files
committed
working worker bee
1 parent 24d354a commit 1696db6

File tree

5 files changed

+54
-62
lines changed

5 files changed

+54
-62
lines changed

ai_worker/main.py

+16-50
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,9 @@
22
import asyncio
33
import json
44
import multiprocessing
5-
import os
65
from typing import Optional
76
import logging as log
87

9-
import httpx
108
import psutil
119
import sseclient
1210
import websockets
@@ -17,9 +15,10 @@
1715
from fastapi.testclient import TestClient
1816
from starlette.responses import Response
1917

18+
from gguf_loader.main import get_size
19+
2020
APP_NAME= "gputopia"
21-
DEFAULT_COORDINATOR = "https://gputopia.ai/api/v1"
22-
DEFAULT_BASE_URL = "https://gputopia.ai/models"
21+
DEFAULT_COORDINATOR = "wss://gputopia.ai/api/v1"
2322

2423

2524
class Req(BaseModel):
@@ -31,8 +30,6 @@ class Config(BaseSettings):
3130
model_config = SettingsConfigDict(env_prefix=APP_NAME +'_worker', case_sensitive=False)
3231
auth_key: str = ""
3332
coordinator_url: str = DEFAULT_COORDINATOR
34-
model_base_url: str = DEFAULT_BASE_URL
35-
model_dir: str = os.path.expanduser('~/.ai-models')
3633

3734

3835
class WorkerMain:
@@ -53,13 +50,13 @@ async def run(self):
5350

5451
async def guess_layers(self, model_path):
5552
# todo: read model file and compare to gpu resources
56-
return 30
53+
return 20
5754

5855
async def load_model(self, name):
5956
if name == self.llama_model:
6057
return
6158
model_path = await self.get_model(name)
62-
settings = LlamaSettings(model=model_path, n_gpu_layers=self.guess_layers(model_path), seed=-1, embedding=True, cache=True, port=8181)
59+
settings = LlamaSettings(model=model_path, n_gpu_layers=await self.guess_layers(model_path), seed=-1, embedding=True, cache=True, port=8181)
6360
self.llama = create_llama_app(settings)
6461
self.llama_cli = TestClient(self.llama)
6562

@@ -112,56 +109,26 @@ async def run_ws(self, ws):
112109
ws.send(res.body.decode("urf-8"))
113110

114111
async def get_model(self, name):
115-
ret = self.get_local_model(name)
116-
if ret:
117-
return ret
118112
return await self.download_model(name)
119113

120-
def get_local_model(self, name):
121-
dest = self.model_file_for(name)
122-
if os.path.getsize(dest) > 0:
123-
return dest
124-
return None
125-
126-
def model_file_for(self, name):
127-
return self.conf.model_dir + "/" + name.replace("/", ".")
128-
129114
async def download_model(self, name):
130-
url = self.conf.model_base_url + "/" + name.replace("/", ".")
131-
132-
async with httpx.AsyncClient() as client:
133-
r = await client.head(url)
134-
size = r.headers.get('Content-Length')
135-
if not size:
136-
params = self.get_model_params(name)
137-
bits = self.get_model_bits(name)
138-
# 70b * 4 bit = 35gb (roughly)
139-
size = params * bits / 8
140-
141-
assert size, "unable to estimate model size, not downloading"
142-
143-
await self.free_up_space(size)
144-
145-
dest = self.model_file_for(name)
146-
147-
done = 0
148-
with open(dest + ".tmp", "wb") as f:
149-
async with client.stream("GET", url) as r:
150-
async for chunk in r.aiter_bytes():
151-
f.write(chunk)
152-
done += len(chunk)
153-
self.report_pct(name, done/size)
154-
os.replace(dest + ".tmp", dest)
155-
self.report_done(name)
156-
157-
return dest
115+
# uses hf cache, so no need to handle here
116+
from gguf_loader.main import download_gguf
117+
size = get_size(name)
118+
await self.free_up_space(size)
119+
loop = asyncio.get_running_loop()
120+
path = await loop.run_in_executor(None, lambda: download_gguf(name))
121+
return path
158122

159123
def report_done(self, name):
160124
print("\r", name, 100)
161125

162126
def report_pct(self, name, pct):
163127
print("\r", name, pct, end='')
164128

129+
async def free_up_space(self, size):
130+
pass
131+
165132

166133
def main():
167134
parser = argparse.ArgumentParser()
@@ -181,7 +148,6 @@ def main():
181148

182149
conf = Config(**{k: v for k, v in vars(args).items() if v is not None})
183150

184-
185151
wm = WorkerMain(conf)
186152

187-
asyncio.run(wm.main())
153+
asyncio.run(wm.run())

gguf_loader/main.py

+21-10
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@ def convert_to_gguf(file):
2222
return dest
2323

2424

25-
def download_gguf(name):
25+
def get_size(name):
26+
typ, hf, fil = pick_file(name)
27+
return fil["size"]
28+
29+
30+
def pick_file(name):
2631
parts = name.split(":", 1)
2732
if len(parts) == 1:
2833
hf, filt = parts[0], ""
@@ -50,26 +55,32 @@ def download_gguf(name):
5055
# this is all heuristics, but, imo it can be more than good enough
5156
raise ValueError("Need ggml or gguf")
5257

53-
base = os.path.basename(ggml[0]["name"])
58+
return "ggml", hf, ggml[0]
5459

55-
log.debug("downloading...")
60+
if len(gguf) > 1:
61+
raise ValueError("Multiple files match, please specify a better filter")
5662

57-
# use hf so we get a nice cache
58-
path = hf_hub_download(repo_id=hf, filename=base, resume_download=True)
63+
return "gguf", hf, gguf[0]
5964

60-
return convert_to_gguf(path)
6165

62-
if len(gguf) > 1:
63-
raise ValueError("Multiple files match, please specify a better filter")
66+
def download_gguf(name):
67+
typ, repo_id, fil = pick_file(name)
68+
if typ == "ggml":
69+
base = os.path.basename(fil["name"])
70+
log.debug("downloading...")
71+
# use hf so we get a nice cache
72+
path = hf_hub_download(repo_id=repo_id, filename=base, resume_download=True)
73+
return convert_to_gguf(path)
6474

65-
base = os.path.basename(gguf[0]["name"])
75+
base = os.path.basename(fil["name"])
6676
log.debug("downloading...")
67-
return hf_hub_download(repo_id=hf, filename=base)
77+
return hf_hub_download(repo_id=repo_id, filename=base)
6878

6979

7080
# Load environment variables from .env file
7181
load_dotenv()
7282

83+
7384
# Get AWS credentials from environment variables
7485

7586

pytest.ini

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[pytest]
2+
asyncio_mode = auto
Binary file not shown.

tests/test_conn.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22

33
from ai_worker.main import WorkerMain, Config
4-
from gguf_loader.main import download_gguf, main as loader_main
4+
from gguf_loader.main import download_gguf, main as loader_main, get_size
55

66
try:
77
from pynvml.smi import nvidia_smi
@@ -21,8 +21,21 @@ def test_conn_str():
2121
assert js["vram"]
2222

2323

24+
async def test_wm():
25+
wm = WorkerMain(Config())
26+
await wm.load_model("TheBloke/WizardLM-7B-uncensored-GGML:q4_K_M")
27+
res = wm.llama_cli.post("/v1/chat/completions", json=dict(
28+
model=wm.llama_model,
29+
messages=[
30+
{"role": "system", "content": "You are a helpful assistant"},
31+
{"role": "user", "content": "hello"},
32+
]
33+
))
34+
assert res
35+
2436
def test_download_model():
25-
download_gguf("TheBloke/WizardLM-7B-uncensored-GGML:q4_K_M")
37+
assert get_size("TheBloke/WizardLM-7B-uncensored-GGML:q4_K_M") > 0
38+
assert download_gguf("TheBloke/WizardLM-7B-uncensored-GGML:q4_K_M")
2639

2740

2841
def test_download_main(capsys):

0 commit comments

Comments
 (0)