Skip to content

Commit e432d41

Browse files
NathanHBaymeric-roucheralbertvillanova
authored
Adds VLLMModel (huggingface#337)
Co-authored-by: Aymeric <[email protected]> Co-authored-by: Albert Villanova del Moral <[email protected]>
1 parent 23eaf93 commit e432d41

File tree

6 files changed

+170
-89
lines changed

6 files changed

+170
-89
lines changed

docs/source/en/reference/models.mdx

+17
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,20 @@ print(model([{"role": "user", "content": "Ok!"}], stop_sequences=["great"]))
167167
> You must have `mlx-lm` installed on your machine. Please run `pip install smolagents[mlx-lm]` if it's not the case.
168168
169169
[[autodoc]] MLXModel
170+
171+
### VLLMModel
172+
173+
Model to use [vLLM](https://docs.vllm.ai/) for fast LLM inference and serving.
174+
175+
```python
176+
from smolagents import MLXModel
177+
178+
model = VLLMModel(model_id="HuggingFaceTB/SmolLM-135M-Instruct")
179+
180+
print(model([{"role": "user", "content": "Ok!"}], stop_sequences=["great"]))
181+
```
182+
183+
> [!TIP]
184+
> You must have `vllm` installed on your machine. Please run `pip install smolagents[vllm]` if it's not the case.
185+
186+
[[autodoc]] VLLMModel

pyproject.toml

+4
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ vision = [
7171
"helium",
7272
"selenium",
7373
]
74+
vllm = [
75+
"vllm",
76+
"torch"
77+
]
7478
all = [
7579
"smolagents[audio,docker,e2b,gradio,litellm,mcp,mlx-lm,openai,telemetry,transformers,vision]",
7680
]

src/smolagents/agents.py

-3
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@
6161
AgentParsingError,
6262
make_init_file,
6363
parse_code_blobs,
64-
parse_json_tool_call,
6564
truncate_content,
6665
)
6766

@@ -190,7 +189,6 @@ def __init__(
190189
model: Callable[[List[Dict[str, str]]], ChatMessage],
191190
prompt_templates: Optional[PromptTemplates] = None,
192191
max_steps: int = 20,
193-
tool_parser: Optional[Callable] = None,
194192
add_base_tools: bool = False,
195193
verbosity_level: LogLevel = LogLevel.INFO,
196194
grammar: Optional[Dict[str, str]] = None,
@@ -207,7 +205,6 @@ def __init__(
207205
self.prompt_templates = prompt_templates or EMPTY_PROMPT_TEMPLATES
208206
self.max_steps = max_steps
209207
self.step_number = 0
210-
self.tool_parser = tool_parser or parse_json_tool_call
211208
self.grammar = grammar
212209
self.planning_interval = planning_interval
213210
self.state = {}

src/smolagents/models.py

+144-56
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import json
1616
import logging
1717
import os
18-
import random
1918
import uuid
2019
import warnings
2120
from copy import deepcopy
@@ -26,7 +25,7 @@
2625
from huggingface_hub.utils import is_torch_available
2726

2827
from .tools import Tool
29-
from .utils import _is_package_available, encode_image_base64, make_image_url
28+
from .utils import _is_package_available, encode_image_base64, make_image_url, parse_json_blob
3029

3130

3231
if TYPE_CHECKING:
@@ -236,10 +235,34 @@ def get_clean_message_list(
236235
return output_message_list
237236

238237

238+
def get_tool_call_chat_message_from_text(text: str, tool_name_key: str, tool_arguments_key: str) -> ChatMessage:
239+
tool_call_dictionary, text = parse_json_blob(text)
240+
try:
241+
tool_name = tool_call_dictionary[tool_name_key]
242+
except Exception as e:
243+
raise ValueError(
244+
f"Key {tool_name_key=} not found in the generated tool call. Got keys: {list(tool_call_dictionary.keys())} instead"
245+
) from e
246+
tool_arguments = tool_call_dictionary.get(tool_arguments_key, None)
247+
return ChatMessage(
248+
role="assistant",
249+
content=text,
250+
tool_calls=[
251+
ChatMessageToolCall(
252+
id=uuid.uuid4(),
253+
type="function",
254+
function=ChatMessageToolCallDefinition(name=tool_name, arguments=tool_arguments),
255+
)
256+
],
257+
)
258+
259+
239260
class Model:
240-
def __init__(self, **kwargs):
261+
def __init__(self, tool_name_key: str = "name", tool_arguments_key: str = "arguments", **kwargs):
241262
self.last_input_token_count = None
242263
self.last_output_token_count = None
264+
self.tool_name_key = tool_name_key
265+
self.tool_arguments_key = tool_arguments_key
243266
self.kwargs = kwargs
244267

245268
def _prepare_completion_kwargs(
@@ -465,6 +488,104 @@ def __call__(
465488
return message
466489

467490

491+
class VLLMModel(Model):
492+
"""Model to use [vLLM](https://docs.vllm.ai/) for fast LLM inference and serving.
493+
494+
Parameters:
495+
model_id (`str`):
496+
The Hugging Face model ID to be used for inference.
497+
This can be a path or model identifier from the Hugging Face model hub.
498+
"""
499+
500+
def __init__(self, model_id, **kwargs):
501+
if not _is_package_available("vllm"):
502+
raise ModuleNotFoundError("Please install 'vllm' extra to use VLLMModel: `pip install 'smolagents[vllm]'`")
503+
504+
from vllm import LLM
505+
from vllm.transformers_utils.tokenizer import get_tokenizer
506+
507+
super().__init__(**kwargs)
508+
509+
self.model_id = model_id
510+
self.model = LLM(model=model_id)
511+
self.tokenizer = get_tokenizer(model_id)
512+
self._is_vlm = False # VLLMModel does not support vision models yet.
513+
514+
def cleanup(self):
515+
import gc
516+
517+
import torch
518+
from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel
519+
520+
destroy_model_parallel()
521+
if self.model is not None:
522+
# taken from https://github.com/vllm-project/vllm/issues/1908#issuecomment-2076870351
523+
del self.model.llm_engine.model_executor.driver_worker
524+
self.model = None
525+
gc.collect()
526+
destroy_distributed_environment()
527+
torch.cuda.empty_cache()
528+
529+
def __call__(
530+
self,
531+
messages: List[Dict[str, str]],
532+
stop_sequences: Optional[List[str]] = None,
533+
grammar: Optional[str] = None,
534+
tools_to_call_from: Optional[List[Tool]] = None,
535+
**kwargs,
536+
) -> ChatMessage:
537+
from vllm import SamplingParams
538+
539+
completion_kwargs = self._prepare_completion_kwargs(
540+
messages=messages,
541+
flatten_messages_as_text=(not self._is_vlm),
542+
stop_sequences=stop_sequences,
543+
grammar=grammar,
544+
tools_to_call_from=tools_to_call_from,
545+
**kwargs,
546+
)
547+
messages = completion_kwargs.pop("messages")
548+
prepared_stop_sequences = completion_kwargs.pop("stop", [])
549+
tools = completion_kwargs.pop("tools", None)
550+
completion_kwargs.pop("tool_choice", None)
551+
552+
if tools_to_call_from is not None:
553+
prompt = self.tokenizer.apply_chat_template(
554+
messages,
555+
tools=tools,
556+
add_generation_prompt=True,
557+
tokenize=False,
558+
)
559+
else:
560+
prompt = self.tokenizer.apply_chat_template(
561+
messages,
562+
tokenize=False,
563+
)
564+
565+
sampling_params = SamplingParams(
566+
n=kwargs.get("n", 1),
567+
temperature=kwargs.get("temperature", 0.0),
568+
max_tokens=kwargs.get("max_tokens", 2048),
569+
stop=prepared_stop_sequences,
570+
)
571+
572+
out = self.model.generate(
573+
prompt,
574+
sampling_params=sampling_params,
575+
)
576+
output = out[0].outputs[0].text
577+
self.last_input_token_count = len(out[0].prompt_token_ids)
578+
self.last_output_token_count = len(out[0].outputs[0].token_ids)
579+
if tools_to_call_from:
580+
chat_message = get_tool_call_chat_message_from_text(output, self.tool_name_key, self.tool_arguments_key)
581+
chat_message.raw = {"out": out, "completion_kwargs": completion_kwargs}
582+
return chat_message
583+
else:
584+
return ChatMessage(
585+
role="assistant", content=output, raw={"out": out, "completion_kwargs": completion_kwargs}
586+
)
587+
588+
468589
class MLXModel(Model):
469590
"""A class to interact with models loaded using MLX on Apple silicon.
470591
@@ -523,27 +644,7 @@ def __init__(
523644
self.stream_generate = mlx_lm.stream_generate
524645
self.tool_name_key = tool_name_key
525646
self.tool_arguments_key = tool_arguments_key
526-
527-
def _to_message(self, text, tools_to_call_from):
528-
if tools_to_call_from:
529-
# solution for extracting tool JSON without assuming a specific model output format
530-
maybe_json = "{" + text.split("{", 1)[-1][::-1].split("}", 1)[-1][::-1] + "}"
531-
parsed_text = json.loads(maybe_json)
532-
tool_name = parsed_text.get(self.tool_name_key, None)
533-
tool_arguments = parsed_text.get(self.tool_arguments_key, None)
534-
if tool_name:
535-
return ChatMessage(
536-
role="assistant",
537-
content="",
538-
tool_calls=[
539-
ChatMessageToolCall(
540-
id=uuid.uuid4(),
541-
type="function",
542-
function=ChatMessageToolCallDefinition(name=tool_name, arguments=tool_arguments),
543-
)
544-
],
545-
)
546-
return ChatMessage(role="assistant", content=text)
647+
self.is_vlm = False # mlx-lm doesn't support vision models
547648

548649
def __call__(
549650
self,
@@ -554,7 +655,7 @@ def __call__(
554655
**kwargs,
555656
) -> ChatMessage:
556657
completion_kwargs = self._prepare_completion_kwargs(
557-
flatten_messages_as_text=True, # mlx-lm doesn't support vision models
658+
flatten_messages_as_text=(not self._is_vlm),
558659
messages=messages,
559660
stop_sequences=stop_sequences,
560661
grammar=grammar,
@@ -583,9 +684,19 @@ def __call__(
583684
stop_sequence_start = text.rfind(stop_sequence)
584685
if stop_sequence_start != -1:
585686
text = text[:stop_sequence_start]
586-
return self._to_message(text, tools_to_call_from)
687+
found_stop_sequence = True
688+
break
689+
if found_stop_sequence:
690+
break
587691

588-
return self._to_message(text, tools_to_call_from)
692+
if tools_to_call_from:
693+
chat_message = get_tool_call_chat_message_from_text(text, self.tool_name_key, self.tool_arguments_key)
694+
chat_message.raw = {"out": text, "completion_kwargs": completion_kwargs}
695+
return chat_message
696+
else:
697+
return ChatMessage(
698+
role="assistant", content=text, raw={"out": text, "completion_kwargs": completion_kwargs}
699+
)
589700

590701

591702
class TransformersModel(Model):
@@ -779,38 +890,14 @@ def __call__(
779890
if stop_sequences is not None:
780891
output = remove_stop_sequences(output, stop_sequences)
781892

782-
if tools_to_call_from is None:
783-
return ChatMessage(
784-
role="assistant",
785-
content=output,
786-
raw={"out": out, "completion_kwargs": completion_kwargs},
787-
)
893+
if tools_to_call_from:
894+
chat_message = get_tool_call_chat_message_from_text(output, self.tool_name_key, self.tool_arguments_key)
895+
chat_message.raw = {"out": out, "completion_kwargs": completion_kwargs}
896+
return chat_message
788897
else:
789-
if "Action:" in output:
790-
output = output.split("Action:", 1)[1].strip()
791-
try:
792-
start_index = output.index("{")
793-
end_index = output.rindex("}")
794-
output = output[start_index : end_index + 1]
795-
except Exception as e:
796-
raise Exception("No json blob found in output!") from e
797-
798-
try:
799-
parsed_output = json.loads(output)
800-
except json.JSONDecodeError as e:
801-
raise ValueError(f"Tool call '{output}' has an invalid JSON structure: {e}")
802-
tool_name = parsed_output.get("name")
803-
tool_arguments = parsed_output.get("arguments")
804898
return ChatMessage(
805899
role="assistant",
806-
content="",
807-
tool_calls=[
808-
ChatMessageToolCall(
809-
id="".join(random.choices("0123456789", k=5)),
810-
type="function",
811-
function=ChatMessageToolCallDefinition(name=tool_name, arguments=tool_arguments),
812-
)
813-
],
900+
content=output,
814901
raw={"out": out, "completion_kwargs": completion_kwargs},
815902
)
816903

@@ -1051,6 +1138,7 @@ def create_client(self):
10511138
"HfApiModel",
10521139
"LiteLLMModel",
10531140
"OpenAIServerModel",
1141+
"VLLMModel",
10541142
"AzureOpenAIServerModel",
10551143
"ChatMessage",
10561144
]

src/smolagents/utils.py

+4-29
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from functools import lru_cache
2727
from io import BytesIO
2828
from textwrap import dedent
29-
from typing import TYPE_CHECKING, Any, Dict, Tuple, Union
29+
from typing import TYPE_CHECKING, Any, Dict, Tuple
3030

3131

3232
if TYPE_CHECKING:
@@ -140,13 +140,14 @@ def make_json_serializable(obj: Any) -> Any:
140140
return str(obj)
141141

142142

143-
def parse_json_blob(json_blob: str) -> Dict[str, str]:
143+
def parse_json_blob(json_blob: str) -> Tuple[Dict[str, str], str]:
144+
"Extracts the JSON blob from the input and returns the JSON data and the rest of the input."
144145
try:
145146
first_accolade_index = json_blob.find("{")
146147
last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1]
147148
json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace('\\"', "'")
148149
json_data = json.loads(json_blob, strict=False)
149-
return json_data
150+
return json_data, json_blob[:first_accolade_index]
150151
except json.JSONDecodeError as e:
151152
place = e.pos
152153
if json_blob[place - 1 : place + 2] == "},\n":
@@ -158,8 +159,6 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]:
158159
f"JSON blob was: {json_blob}, decoding failed on that specific part of the blob:\n"
159160
f"'{json_blob[place - 4 : place + 5]}'."
160161
)
161-
except Exception as e:
162-
raise ValueError(f"Error in parsing the JSON blob: {e}")
163162

164163

165164
def parse_code_blobs(text: str) -> str:
@@ -219,30 +218,6 @@ def parse_code_blobs(text: str) -> str:
219218
)
220219

221220

222-
def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]:
223-
json_blob = json_blob.replace("```json", "").replace("```", "")
224-
tool_call = parse_json_blob(json_blob)
225-
tool_name_key, tool_arguments_key = None, None
226-
for possible_tool_name_key in ["action", "tool_name", "tool", "name", "function"]:
227-
if possible_tool_name_key in tool_call:
228-
tool_name_key = possible_tool_name_key
229-
for possible_tool_arguments_key in [
230-
"action_input",
231-
"tool_arguments",
232-
"tool_args",
233-
"parameters",
234-
]:
235-
if possible_tool_arguments_key in tool_call:
236-
tool_arguments_key = possible_tool_arguments_key
237-
if tool_name_key is not None:
238-
if tool_arguments_key is not None:
239-
return tool_call[tool_name_key], tool_call[tool_arguments_key]
240-
else:
241-
return tool_call[tool_name_key], None
242-
error_msg = "No tool name key found in tool call!" + f" Tool call: {json_blob}"
243-
raise AgentParsingError(error_msg)
244-
245-
246221
MAX_LENGTH_TRUNCATE_CONTENT = 20000
247222

248223

0 commit comments

Comments
 (0)