2
2
import asyncio
3
3
import json
4
4
import multiprocessing
5
- import os
6
5
from typing import Optional
7
6
import logging as log
8
7
9
- import httpx
10
8
import psutil
11
9
import sseclient
12
10
import websockets
17
15
from fastapi .testclient import TestClient
18
16
from starlette .responses import Response
19
17
18
+ from gguf_loader .main import get_size
19
+
20
20
APP_NAME = "gputopia"
21
- DEFAULT_COORDINATOR = "https://gputopia.ai/api/v1"
22
- DEFAULT_BASE_URL = "https://gputopia.ai/models"
21
+ DEFAULT_COORDINATOR = "wss://gputopia.ai/api/v1"
23
22
24
23
25
24
class Req (BaseModel ):
@@ -31,8 +30,6 @@ class Config(BaseSettings):
31
30
model_config = SettingsConfigDict (env_prefix = APP_NAME + '_worker' , case_sensitive = False )
32
31
auth_key : str = ""
33
32
coordinator_url : str = DEFAULT_COORDINATOR
34
- model_base_url : str = DEFAULT_BASE_URL
35
- model_dir : str = os .path .expanduser ('~/.ai-models' )
36
33
37
34
38
35
class WorkerMain :
@@ -53,13 +50,13 @@ async def run(self):
53
50
54
51
async def guess_layers (self , model_path ):
55
52
# todo: read model file and compare to gpu resources
56
- return 30
53
+ return 20
57
54
58
55
async def load_model (self , name ):
59
56
if name == self .llama_model :
60
57
return
61
58
model_path = await self .get_model (name )
62
- settings = LlamaSettings (model = model_path , n_gpu_layers = self .guess_layers (model_path ), seed = - 1 , embedding = True , cache = True , port = 8181 )
59
+ settings = LlamaSettings (model = model_path , n_gpu_layers = await self .guess_layers (model_path ), seed = - 1 , embedding = True , cache = True , port = 8181 )
63
60
self .llama = create_llama_app (settings )
64
61
self .llama_cli = TestClient (self .llama )
65
62
@@ -112,56 +109,26 @@ async def run_ws(self, ws):
112
109
ws .send (res .body .decode ("urf-8" ))
113
110
114
111
async def get_model (self , name ):
115
- ret = self .get_local_model (name )
116
- if ret :
117
- return ret
118
112
return await self .download_model (name )
119
113
120
- def get_local_model (self , name ):
121
- dest = self .model_file_for (name )
122
- if os .path .getsize (dest ) > 0 :
123
- return dest
124
- return None
125
-
126
- def model_file_for (self , name ):
127
- return self .conf .model_dir + "/" + name .replace ("/" , "." )
128
-
129
114
async def download_model (self , name ):
130
- url = self .conf .model_base_url + "/" + name .replace ("/" , "." )
131
-
132
- async with httpx .AsyncClient () as client :
133
- r = await client .head (url )
134
- size = r .headers .get ('Content-Length' )
135
- if not size :
136
- params = self .get_model_params (name )
137
- bits = self .get_model_bits (name )
138
- # 70b * 4 bit = 35gb (roughly)
139
- size = params * bits / 8
140
-
141
- assert size , "unable to estimate model size, not downloading"
142
-
143
- await self .free_up_space (size )
144
-
145
- dest = self .model_file_for (name )
146
-
147
- done = 0
148
- with open (dest + ".tmp" , "wb" ) as f :
149
- async with client .stream ("GET" , url ) as r :
150
- async for chunk in r .aiter_bytes ():
151
- f .write (chunk )
152
- done += len (chunk )
153
- self .report_pct (name , done / size )
154
- os .replace (dest + ".tmp" , dest )
155
- self .report_done (name )
156
-
157
- return dest
115
+ # uses hf cache, so no need to handle here
116
+ from gguf_loader .main import download_gguf
117
+ size = get_size (name )
118
+ await self .free_up_space (size )
119
+ loop = asyncio .get_running_loop ()
120
+ path = await loop .run_in_executor (None , lambda : download_gguf (name ))
121
+ return path
158
122
159
123
def report_done (self , name ):
160
124
print ("\r " , name , 100 )
161
125
162
126
def report_pct (self , name , pct ):
163
127
print ("\r " , name , pct , end = '' )
164
128
129
+ async def free_up_space (self , size ):
130
+ pass
131
+
165
132
166
133
def main ():
167
134
parser = argparse .ArgumentParser ()
@@ -181,7 +148,6 @@ def main():
181
148
182
149
conf = Config (** {k : v for k , v in vars (args ).items () if v is not None })
183
150
184
-
185
151
wm = WorkerMain (conf )
186
152
187
- asyncio .run (wm .main ())
153
+ asyncio .run (wm .run ())
0 commit comments