Skip to content

Commit f4ab154

Browse files
authored
feat: add dynamic model registration support to TGI inference (#3417)
# What does this PR do? adds dynamic model support to TGI add new overwrite_completion_id feature to OpenAIMixin to deal with TGI always returning id="" ## Test Plan tgi: `docker run --gpus all --shm-size 1g -p 8080:80 -v /data:/data ghcr.io/huggingface/text-generation-inference --model-id Qwen/Qwen3-0.6B` stack: `TGI_URL=http://localhost:8080 uv run llama stack build --image-type venv --distro ci-tests --run` test: `./scripts/integration-tests.sh --stack-config http://localhost:8321 --setup tgi --subdirs inference --pattern openai`
1 parent ab32173 commit f4ab154

File tree

14 files changed

+12218
-20
lines changed

14 files changed

+12218
-20
lines changed

llama_stack/providers/remote/inference/tgi/tgi.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from collections.abc import AsyncGenerator
99

1010
from huggingface_hub import AsyncInferenceClient, HfApi
11+
from pydantic import SecretStr
1112

1213
from llama_stack.apis.common.content_types import (
1314
InterleavedContent,
@@ -33,6 +34,7 @@
3334
ToolPromptFormat,
3435
)
3536
from llama_stack.apis.models import Model
37+
from llama_stack.apis.models.models import ModelType
3638
from llama_stack.log import get_logger
3739
from llama_stack.models.llama.sku_list import all_registered_models
3840
from llama_stack.providers.datatypes import ModelsProtocolPrivate
@@ -41,16 +43,15 @@
4143
build_hf_repo_model_entry,
4244
)
4345
from llama_stack.providers.utils.inference.openai_compat import (
44-
OpenAIChatCompletionToLlamaStackMixin,
4546
OpenAICompatCompletionChoice,
4647
OpenAICompatCompletionResponse,
47-
OpenAICompletionToLlamaStackMixin,
4848
get_sampling_options,
4949
process_chat_completion_response,
5050
process_chat_completion_stream_response,
5151
process_completion_response,
5252
process_completion_stream_response,
5353
)
54+
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
5455
from llama_stack.providers.utils.inference.prompt_adapter import (
5556
chat_completion_request_to_model_input_info,
5657
completion_request_to_prompt_model_input_info,
@@ -73,26 +74,49 @@ def build_hf_repo_model_entries():
7374

7475

7576
class _HfAdapter(
77+
OpenAIMixin,
7678
Inference,
77-
OpenAIChatCompletionToLlamaStackMixin,
78-
OpenAICompletionToLlamaStackMixin,
7979
ModelsProtocolPrivate,
8080
):
81-
client: AsyncInferenceClient
81+
url: str
82+
api_key: SecretStr
83+
84+
hf_client: AsyncInferenceClient
8285
max_tokens: int
8386
model_id: str
8487

88+
overwrite_completion_id = True # TGI always returns id=""
89+
8590
def __init__(self) -> None:
8691
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
8792
self.huggingface_repo_to_llama_model_id = {
8893
model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo
8994
}
9095

96+
def get_api_key(self):
97+
return self.api_key.get_secret_value()
98+
99+
def get_base_url(self):
100+
return self.url
101+
91102
async def shutdown(self) -> None:
92103
pass
93104

105+
async def list_models(self) -> list[Model] | None:
106+
models = []
107+
async for model in self.client.models.list():
108+
models.append(
109+
Model(
110+
identifier=model.id,
111+
provider_resource_id=model.id,
112+
provider_id=self.__provider_id__,
113+
metadata={},
114+
model_type=ModelType.llm,
115+
)
116+
)
117+
return models
118+
94119
async def register_model(self, model: Model) -> Model:
95-
model = await self.register_helper.register_model(model)
96120
if model.provider_resource_id != self.model_id:
97121
raise ValueError(
98122
f"Model {model.provider_resource_id} does not match the model {self.model_id} served by TGI."
@@ -176,7 +200,7 @@ async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator
176200
params = await self._get_params_for_completion(request)
177201

178202
async def _generate_and_convert_to_openai_compat():
179-
s = await self.client.text_generation(**params)
203+
s = await self.hf_client.text_generation(**params)
180204
async for chunk in s:
181205
token_result = chunk.token
182206
finish_reason = None
@@ -194,7 +218,7 @@ async def _generate_and_convert_to_openai_compat():
194218

195219
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
196220
params = await self._get_params_for_completion(request)
197-
r = await self.client.text_generation(**params)
221+
r = await self.hf_client.text_generation(**params)
198222

199223
choice = OpenAICompatCompletionChoice(
200224
finish_reason=r.details.finish_reason,
@@ -241,7 +265,7 @@ async def chat_completion(
241265

242266
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
243267
params = await self._get_params(request)
244-
r = await self.client.text_generation(**params)
268+
r = await self.hf_client.text_generation(**params)
245269

246270
choice = OpenAICompatCompletionChoice(
247271
finish_reason=r.details.finish_reason,
@@ -256,7 +280,7 @@ async def _stream_chat_completion(self, request: ChatCompletionRequest) -> Async
256280
params = await self._get_params(request)
257281

258282
async def _generate_and_convert_to_openai_compat():
259-
s = await self.client.text_generation(**params)
283+
s = await self.hf_client.text_generation(**params)
260284
async for chunk in s:
261285
token_result = chunk.token
262286

@@ -308,18 +332,21 @@ async def initialize(self, config: TGIImplConfig) -> None:
308332
if not config.url:
309333
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
310334
log.info(f"Initializing TGI client with url={config.url}")
311-
self.client = AsyncInferenceClient(model=config.url, provider="hf-inference")
312-
endpoint_info = await self.client.get_endpoint_info()
335+
self.hf_client = AsyncInferenceClient(model=config.url, provider="hf-inference")
336+
endpoint_info = await self.hf_client.get_endpoint_info()
313337
self.max_tokens = endpoint_info["max_total_tokens"]
314338
self.model_id = endpoint_info["model_id"]
339+
self.url = f"{config.url.rstrip('/')}/v1"
340+
self.api_key = SecretStr("NO_KEY")
315341

316342

317343
class InferenceAPIAdapter(_HfAdapter):
318344
async def initialize(self, config: InferenceAPIImplConfig) -> None:
319-
self.client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value())
320-
endpoint_info = await self.client.get_endpoint_info()
345+
self.hf_client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value())
346+
endpoint_info = await self.hf_client.get_endpoint_info()
321347
self.max_tokens = endpoint_info["max_total_tokens"]
322348
self.model_id = endpoint_info["model_id"]
349+
# TODO: how do we set url for this?
323350

324351

325352
class InferenceEndpointAdapter(_HfAdapter):
@@ -331,6 +358,7 @@ async def initialize(self, config: InferenceEndpointImplConfig) -> None:
331358
endpoint.wait(timeout=60)
332359

333360
# Initialize the adapter
334-
self.client = endpoint.async_client
361+
self.hf_client = endpoint.async_client
335362
self.model_id = endpoint.repository
336363
self.max_tokens = int(endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"])
364+
# TODO: how do we set url for this?

llama_stack/providers/utils/inference/openai_mixin.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7+
import uuid
78
from abc import ABC, abstractmethod
89
from collections.abc import AsyncIterator
910
from typing import Any
@@ -43,6 +44,12 @@ class OpenAIMixin(ABC):
4344
The model_store is set in routing_tables/common.py during provider initialization.
4445
"""
4546

47+
# Allow subclasses to control whether to overwrite the 'id' field in OpenAI responses
48+
# is overwritten with a client-side generated id.
49+
#
50+
# This is useful for providers that do not return a unique id in the response.
51+
overwrite_completion_id: bool = False
52+
4653
@abstractmethod
4754
def get_api_key(self) -> str:
4855
"""
@@ -110,6 +117,23 @@ async def _get_provider_model_id(self, model: str) -> str:
110117
raise ValueError(f"Model {model} has no provider_resource_id")
111118
return model_obj.provider_resource_id
112119

120+
async def _maybe_overwrite_id(self, resp: Any, stream: bool | None) -> Any:
121+
if not self.overwrite_completion_id:
122+
return resp
123+
124+
new_id = f"cltsd-{uuid.uuid4()}"
125+
if stream:
126+
127+
async def _gen():
128+
async for chunk in resp:
129+
chunk.id = new_id
130+
yield chunk
131+
132+
return _gen()
133+
else:
134+
resp.id = new_id
135+
return resp
136+
113137
async def openai_completion(
114138
self,
115139
model: str,
@@ -147,7 +171,7 @@ async def openai_completion(
147171
extra_body["guided_choice"] = guided_choice
148172

149173
# TODO: fix openai_completion to return type compatible with OpenAI's API response
150-
return await self.client.completions.create( # type: ignore[no-any-return]
174+
resp = await self.client.completions.create(
151175
**await prepare_openai_completion_params(
152176
model=await self._get_provider_model_id(model),
153177
prompt=prompt,
@@ -171,6 +195,8 @@ async def openai_completion(
171195
extra_body=extra_body,
172196
)
173197

198+
return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return]
199+
174200
async def openai_chat_completion(
175201
self,
176202
model: str,
@@ -200,8 +226,7 @@ async def openai_chat_completion(
200226
"""
201227
Direct OpenAI chat completion API call.
202228
"""
203-
# Type ignore because return types are compatible
204-
return await self.client.chat.completions.create( # type: ignore[no-any-return]
229+
resp = await self.client.chat.completions.create(
205230
**await prepare_openai_completion_params(
206231
model=await self._get_provider_model_id(model),
207232
messages=messages,
@@ -229,6 +254,8 @@ async def openai_chat_completion(
229254
)
230255
)
231256

257+
return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return]
258+
232259
async def openai_embeddings(
233260
self,
234261
model: str,

tests/integration/inference/test_openai_completion.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
4848
"remote::nvidia",
4949
"remote::runpod",
5050
"remote::sambanova",
51-
"remote::tgi",
5251
"remote::vertexai",
5352
# {"error":{"message":"Unknown request URL: GET /openai/v1/completions. Please check the URL for typos,
5453
# or see the docs at https://console.groq.com/docs/","type":"invalid_request_error","code":"unknown_url"}}
@@ -96,6 +95,7 @@ def skip_if_doesnt_support_n(client_with_models, model_id):
9695
"remote::vertexai",
9796
# Error code: 400 - [{'error': {'code': 400, 'message': 'Unable to submit request because candidateCount must be 1 but
9897
# the entered value was 2. Update the candidateCount value and try again.', 'status': 'INVALID_ARGUMENT'}
98+
"remote::tgi", # TGI ignores n param silently
9999
):
100100
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.")
101101

@@ -110,7 +110,6 @@ def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, mode
110110
"remote::cerebras",
111111
"remote::databricks",
112112
"remote::runpod",
113-
"remote::tgi",
114113
):
115114
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI chat completions.")
116115

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
{
2+
"request": {
3+
"method": "POST",
4+
"url": "http://localhost:8080/v1/v1/chat/completions",
5+
"headers": {},
6+
"body": {
7+
"model": "Qwen/Qwen3-0.6B",
8+
"messages": [
9+
{
10+
"role": "user",
11+
"content": "Hello, world!"
12+
}
13+
],
14+
"stream": false
15+
},
16+
"endpoint": "/v1/chat/completions",
17+
"model": "Qwen/Qwen3-0.6B"
18+
},
19+
"response": {
20+
"body": {
21+
"__type__": "openai.types.chat.chat_completion.ChatCompletion",
22+
"__data__": {
23+
"id": "",
24+
"choices": [
25+
{
26+
"finish_reason": "stop",
27+
"index": 0,
28+
"logprobs": null,
29+
"message": {
30+
"content": "<think>\nOkay, the user just said \"Hello, world!\" so I need to respond in a friendly way. My prompt says to respond in the same style, so I should start with \"Hello, world!\" but maybe add some helpful information. Let me think. Since the user is probably testing or just sharing, a simple \"Hello, world!\" with a question would be best for user interaction. I'll make sure to keep it positive and open-ended.\n</think>\n\nHello, world! \ud83d\ude0a What do you need today?",
31+
"refusal": null,
32+
"role": "assistant",
33+
"annotations": null,
34+
"audio": null,
35+
"function_call": null,
36+
"tool_calls": null
37+
}
38+
}
39+
],
40+
"created": 1757550395,
41+
"model": "Qwen/Qwen3-0.6B",
42+
"object": "chat.completion",
43+
"service_tier": null,
44+
"system_fingerprint": "3.3.5-dev0-sha-1b90c50",
45+
"usage": {
46+
"completion_tokens": 108,
47+
"prompt_tokens": 12,
48+
"total_tokens": 120,
49+
"completion_tokens_details": null,
50+
"prompt_tokens_details": null
51+
}
52+
}
53+
},
54+
"is_streaming": false
55+
}
56+
}

0 commit comments

Comments
 (0)