27
27
#==================== Cache class definition =========================#
28
28
#=====================================================================#
29
29
30
- executor = ThreadPoolExecutor (max_workers = 2 )
31
-
32
- def response_text (cache_resp ):
33
- return cache_resp ['data' ]
34
-
35
- def response_hitquery (cache_resp ):
36
- return cache_resp ['hitQuery' ]
37
30
38
31
# noinspection PyMethodMayBeStatic
39
32
class Cache :
@@ -80,11 +73,16 @@ def close():
80
73
modelcache_log .error (e )
81
74
82
75
def save_query_resp (self , query_resp_dict , ** kwargs ):
83
- self .data_manager .save_query_resp (query_resp_dict , ** kwargs )
76
+ asyncio .create_task (asyncio .to_thread (
77
+ self .data_manager .save_query_resp ,
78
+ query_resp_dict , ** kwargs
79
+ ))
84
80
85
81
def save_query_info (self ,result , model , query , delta_time_log ):
86
- self .data_manager .save_query_resp (result , model = model , query = json .dumps (query , ensure_ascii = False ),
87
- delta_time = delta_time_log )
82
+ asyncio .create_task (asyncio .to_thread (
83
+ self .data_manager .save_query_resp ,
84
+ result , model = model , query = json .dumps (query , ensure_ascii = False ), delta_time = delta_time_log
85
+ ))
88
86
89
87
async def handle_request (self , param_dict : dict ):
90
88
# param parsing
@@ -103,7 +101,7 @@ async def handle_request(self, param_dict: dict):
103
101
result = {"errorCode" : 102 ,
104
102
"errorDesc" : "type exception, should one of ['query', 'insert', 'remove', 'register']" ,
105
103
"cacheHit" : False , "delta_time" : 0 , "hit_query" : '' , "answer" : '' }
106
- self .data_manager . save_query_resp (result , model = model , query = '' , delta_time = 0 )
104
+ self .save_query_resp (result , model = model , query = '' , delta_time = 0 )
107
105
return result
108
106
except Exception as e :
109
107
return {"errorCode" : 103 , "errorDesc" : str (e ), "cacheHit" : False , "delta_time" : 0 , "hit_query" : '' ,
@@ -120,14 +118,14 @@ async def handle_request(self, param_dict: dict):
120
118
elif request_type == 'insert' :
121
119
return await self .handle_insert (chat_info , model )
122
120
elif request_type == 'remove' :
123
- return self .handle_remove (model , param_dict )
121
+ return await self .handle_remove (model , param_dict )
124
122
elif request_type == 'register' :
125
- return self .handle_register (model )
123
+ return await self .handle_register (model )
126
124
else :
127
125
return {"errorCode" : 400 , "errorDesc" : "bad request" }
128
126
129
- def handle_register (self , model ):
130
- response = adapter .ChatCompletion .create_register (
127
+ async def handle_register (self , model ):
128
+ response = await adapter .ChatCompletion .create_register (
131
129
model = model ,
132
130
cache_obj = self
133
131
)
@@ -137,10 +135,10 @@ def handle_register(self, model):
137
135
result = {"errorCode" : 502 , "errorDesc" : "" , "response" : response , "writeStatus" : "exception" }
138
136
return result
139
137
140
- def handle_remove (self , model , param_dict ):
138
+ async def handle_remove (self , model , param_dict ):
141
139
remove_type = param_dict .get ("remove_type" )
142
140
id_list = param_dict .get ("id_list" , [])
143
- response = adapter .ChatCompletion .create_remove (
141
+ response = await adapter .ChatCompletion .create_remove (
144
142
model = model ,
145
143
remove_type = remove_type ,
146
144
id_list = id_list ,
@@ -191,12 +189,12 @@ async def handle_query(self, model, query):
191
189
result = {"errorCode" : 201 , "errorDesc" : response , "cacheHit" : False , "delta_time" : delta_time ,
192
190
"hit_query" : '' , "answer" : '' }
193
191
else :
194
- answer = response_text ( response )
195
- hit_query = response_hitquery ( response )
192
+ answer = response [ 'data' ]
193
+ hit_query = response [ 'hitQuery' ]
196
194
result = {"errorCode" : 0 , "errorDesc" : '' , "cacheHit" : True , "delta_time" : delta_time ,
197
195
"hit_query" : hit_query , "answer" : answer }
198
196
delta_time_log = round (time .time () - start_time , 2 )
199
- executor . submit ( self .save_query_info , result , model , query , delta_time_log )
197
+ self .save_query_info ( result , model , query , delta_time_log )
200
198
except Exception as e :
201
199
result = {"errorCode" : 202 , "errorDesc" : str (e ), "cacheHit" : False , "delta_time" : 0 ,
202
200
"hit_query" : '' , "answer" : '' }
@@ -265,7 +263,9 @@ async def init(
265
263
#==================================================#
266
264
267
265
# switching based on embedding_model
268
- if embedding_model == EmbeddingModel .HUGGINGFACE_ALL_MPNET_BASE_V2 :
266
+ if (embedding_model == EmbeddingModel .HUGGINGFACE_ALL_MPNET_BASE_V2
267
+ or embedding_model == EmbeddingModel .HUGGINGFACE_ALL_MINILM_L6_V2
268
+ or embedding_model == EmbeddingModel .HUGGINGFACE_ALL_MINILM_L12_V2 ):
269
269
query_pre_embedding_func = query_with_role
270
270
insert_pre_embedding_func = query_with_role
271
271
post_process_messages_func = first
@@ -287,8 +287,8 @@ async def init(
287
287
288
288
# add more configurations for other embedding models as needed
289
289
else :
290
- modelcache_log .error (f"Please add configuration for { embedding_model } in modelcache/__init__ .py." )
291
- raise CacheError (f"Please add configuration for { embedding_model } in modelcache/__init__ .py." )
290
+ modelcache_log .error (f"Please add configuration for { embedding_model } in modelcache/cache .py." )
291
+ raise CacheError (f"Please add configuration for { embedding_model } in modelcache/cache .py." )
292
292
293
293
# ====================== Data manager ==============================#
294
294
@@ -300,7 +300,7 @@ async def init(
300
300
config = vector_config ,
301
301
metric_type = similarity_metric_type ,
302
302
),
303
- eviction = 'ARC' ,
303
+ memory_cache_policy = 'ARC' ,
304
304
max_size = 10000 ,
305
305
normalize = normalize ,
306
306
)
0 commit comments