22import json
33import copy
44import asyncio
5- import aiohttp
65import aiofiles
76import termcolor
87import os
1312from fastapi import APIRouter , HTTPException , Query , Header
1413from fastapi .responses import Response , StreamingResponse
1514
15+ from itertools import chain
16+
1617from refact_utils .scripts import env
1718from refact_utils .finetune .utils import running_models_and_loras
1819from refact_utils .third_party .utils .models import available_third_party_models
@@ -250,8 +251,9 @@ def _select_default_model(models: List[str]) -> str:
250251 # completion models
251252 completion_models = {}
252253 for model_name in running_models .get ("completion" , []):
253- if model_info := self ._model_assigner .models_db .get (_get_base_model_info (model_name )):
254- completion_models [model_name ] = self ._model_assigner .to_completion_model_record (model_info )
254+ base_model_name = _get_base_model_info (model_name )
255+ if model_info := self ._model_assigner .models_db .get (base_model_name ):
256+ completion_models [model_name ] = self ._model_assigner .to_completion_model_record (base_model_name , model_info )
255257 elif model := available_third_party_models ().get (model_name ):
256258 completion_models [model_name ] = model .to_completion_model_record ()
257259 else :
@@ -261,8 +263,9 @@ def _select_default_model(models: List[str]) -> str:
261263 # chat models
262264 chat_models = {}
263265 for model_name in running_models .get ("chat" , []):
264- if model_info := self ._model_assigner .models_db .get (_get_base_model_info (model_name )):
265- chat_models [model_name ] = self ._model_assigner .to_chat_model_record (model_info )
266+ base_model_name = _get_base_model_info (model_name )
267+ if model_info := self ._model_assigner .models_db .get (base_model_name ):
268+ chat_models [model_name ] = self ._model_assigner .to_chat_model_record (base_model_name , model_info )
266269 elif model := available_third_party_models ().get (model_name ):
267270 chat_models [model_name ] = model .to_chat_model_record ()
268271 else :
@@ -337,6 +340,18 @@ def _parse_client_version(user_agent: str = Header(None)) -> Optional[Tuple[int,
337340
338341 @staticmethod
339342 def _to_deprecated_caps_format (data : Dict [str , Any ]):
343+ models_dict_patch = {}
344+ for model_name , model_record in chain (
345+ data ["completion" ]["models" ].items (),
346+ data ["completion" ]["models" ].items (),
347+ ):
348+ dict_patch = {}
349+ if n_ctx := model_record .get ("n_ctx" ):
350+ dict_patch ["n_ctx" ] = n_ctx
351+ if supports_tools := model_record .get ("supports_tools" ):
352+ dict_patch ["supports_tools" ] = supports_tools
353+ if dict_patch :
354+ models_dict_patch [model_name ] = dict_patch
340355 return {
341356 "cloud_name" : data ["cloud_name" ],
342357 "endpoint_template" : data ["completion" ]["endpoint" ],
@@ -349,7 +364,7 @@ def _to_deprecated_caps_format(data: Dict[str, Any]):
349364 "code_completion_default_model" : data ["completion" ]["default_model" ],
350365 "multiline_code_completion_default_model" : data ["completion" ]["default_multiline_model" ],
351366 "code_chat_default_model" : data ["chat" ]["default_model" ],
352- "models_dict_patch" : {}, # NOTE: this actually should have n_ctx, but we're skiping it
367+ "models_dict_patch" : models_dict_patch ,
353368 "default_embeddings_model" : data ["embedding" ]["default_model" ],
354369 "endpoint_embeddings_template" : "v1/embeddings" ,
355370 "endpoint_embeddings_style" : "openai" ,
0 commit comments