88from collections .abc import AsyncGenerator
99
1010from huggingface_hub import AsyncInferenceClient , HfApi
11+ from pydantic import SecretStr
1112
1213from llama_stack .apis .common .content_types import (
1314 InterleavedContent ,
3334 ToolPromptFormat ,
3435)
3536from llama_stack .apis .models import Model
37+ from llama_stack .apis .models .models import ModelType
3638from llama_stack .log import get_logger
3739from llama_stack .models .llama .sku_list import all_registered_models
3840from llama_stack .providers .datatypes import ModelsProtocolPrivate
4143 build_hf_repo_model_entry ,
4244)
4345from llama_stack .providers .utils .inference .openai_compat import (
44- OpenAIChatCompletionToLlamaStackMixin ,
4546 OpenAICompatCompletionChoice ,
4647 OpenAICompatCompletionResponse ,
47- OpenAICompletionToLlamaStackMixin ,
4848 get_sampling_options ,
4949 process_chat_completion_response ,
5050 process_chat_completion_stream_response ,
5151 process_completion_response ,
5252 process_completion_stream_response ,
5353)
54+ from llama_stack .providers .utils .inference .openai_mixin import OpenAIMixin
5455from llama_stack .providers .utils .inference .prompt_adapter import (
5556 chat_completion_request_to_model_input_info ,
5657 completion_request_to_prompt_model_input_info ,
@@ -73,26 +74,49 @@ def build_hf_repo_model_entries():
7374
7475
7576class _HfAdapter (
77+ OpenAIMixin ,
7678 Inference ,
77- OpenAIChatCompletionToLlamaStackMixin ,
78- OpenAICompletionToLlamaStackMixin ,
7979 ModelsProtocolPrivate ,
8080):
81- client : AsyncInferenceClient
81+ url : str
82+ api_key : SecretStr
83+
84+ hf_client : AsyncInferenceClient
8285 max_tokens : int
8386 model_id : str
8487
88+ overwrite_completion_id = True # TGI always returns id=""
89+
8590 def __init__ (self ) -> None :
8691 self .register_helper = ModelRegistryHelper (build_hf_repo_model_entries ())
8792 self .huggingface_repo_to_llama_model_id = {
8893 model .huggingface_repo : model .descriptor () for model in all_registered_models () if model .huggingface_repo
8994 }
9095
96+ def get_api_key (self ):
97+ return self .api_key .get_secret_value ()
98+
99+ def get_base_url (self ):
100+ return self .url
101+
91102 async def shutdown (self ) -> None :
92103 pass
93104
105+ async def list_models (self ) -> list [Model ] | None :
106+ models = []
107+ async for model in self .client .models .list ():
108+ models .append (
109+ Model (
110+ identifier = model .id ,
111+ provider_resource_id = model .id ,
112+ provider_id = self .__provider_id__ ,
113+ metadata = {},
114+ model_type = ModelType .llm ,
115+ )
116+ )
117+ return models
118+
94119 async def register_model (self , model : Model ) -> Model :
95- model = await self .register_helper .register_model (model )
96120 if model .provider_resource_id != self .model_id :
97121 raise ValueError (
98122 f"Model { model .provider_resource_id } does not match the model { self .model_id } served by TGI."
@@ -176,7 +200,7 @@ async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator
176200 params = await self ._get_params_for_completion (request )
177201
178202 async def _generate_and_convert_to_openai_compat ():
179- s = await self .client .text_generation (** params )
203+ s = await self .hf_client .text_generation (** params )
180204 async for chunk in s :
181205 token_result = chunk .token
182206 finish_reason = None
@@ -194,7 +218,7 @@ async def _generate_and_convert_to_openai_compat():
194218
195219 async def _nonstream_completion (self , request : CompletionRequest ) -> AsyncGenerator :
196220 params = await self ._get_params_for_completion (request )
197- r = await self .client .text_generation (** params )
221+ r = await self .hf_client .text_generation (** params )
198222
199223 choice = OpenAICompatCompletionChoice (
200224 finish_reason = r .details .finish_reason ,
@@ -241,7 +265,7 @@ async def chat_completion(
241265
242266 async def _nonstream_chat_completion (self , request : ChatCompletionRequest ) -> ChatCompletionResponse :
243267 params = await self ._get_params (request )
244- r = await self .client .text_generation (** params )
268+ r = await self .hf_client .text_generation (** params )
245269
246270 choice = OpenAICompatCompletionChoice (
247271 finish_reason = r .details .finish_reason ,
@@ -256,7 +280,7 @@ async def _stream_chat_completion(self, request: ChatCompletionRequest) -> Async
256280 params = await self ._get_params (request )
257281
258282 async def _generate_and_convert_to_openai_compat ():
259- s = await self .client .text_generation (** params )
283+ s = await self .hf_client .text_generation (** params )
260284 async for chunk in s :
261285 token_result = chunk .token
262286
@@ -308,18 +332,21 @@ async def initialize(self, config: TGIImplConfig) -> None:
308332 if not config .url :
309333 raise ValueError ("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI." )
310334 log .info (f"Initializing TGI client with url={ config .url } " )
311- self .client = AsyncInferenceClient (model = config .url , provider = "hf-inference" )
312- endpoint_info = await self .client .get_endpoint_info ()
335+ self .hf_client = AsyncInferenceClient (model = config .url , provider = "hf-inference" )
336+ endpoint_info = await self .hf_client .get_endpoint_info ()
313337 self .max_tokens = endpoint_info ["max_total_tokens" ]
314338 self .model_id = endpoint_info ["model_id" ]
339+ self .url = f"{ config .url .rstrip ('/' )} /v1"
340+ self .api_key = SecretStr ("NO_KEY" )
315341
316342
317343class InferenceAPIAdapter (_HfAdapter ):
318344 async def initialize (self , config : InferenceAPIImplConfig ) -> None :
319- self .client = AsyncInferenceClient (model = config .huggingface_repo , token = config .api_token .get_secret_value ())
320- endpoint_info = await self .client .get_endpoint_info ()
345+ self .hf_client = AsyncInferenceClient (model = config .huggingface_repo , token = config .api_token .get_secret_value ())
346+ endpoint_info = await self .hf_client .get_endpoint_info ()
321347 self .max_tokens = endpoint_info ["max_total_tokens" ]
322348 self .model_id = endpoint_info ["model_id" ]
349+ # TODO: how do we set url for this?
323350
324351
325352class InferenceEndpointAdapter (_HfAdapter ):
@@ -331,6 +358,7 @@ async def initialize(self, config: InferenceEndpointImplConfig) -> None:
331358 endpoint .wait (timeout = 60 )
332359
333360 # Initialize the adapter
334- self .client = endpoint .async_client
361+ self .hf_client = endpoint .async_client
335362 self .model_id = endpoint .repository
336363 self .max_tokens = int (endpoint .raw ["model" ]["image" ]["custom" ]["env" ]["MAX_TOTAL_TOKENS" ])
364+ # TODO: how do we set url for this?
0 commit comments