Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions src/smolagents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import logging
import os
Expand Down Expand Up @@ -589,6 +590,9 @@ class VLLMModel(Model):
This can be a path or model identifier from the Hugging Face model hub.
model_kwargs (`dict[str, Any]`, *optional*):
Additional keyword arguments to forward to the vLLM LLM instantiation, such as `revision`, `max_model_len`, etc.
sampling_params (`dict[str, Any]`, *optional*):
Default sampling parameters (e.g., max_tokens, top_p) to be used for generation.
These can be overridden at runtime by passing kwargs to `generate()`
**kwargs:
Additional keyword arguments to forward to the underlying vLLM model generate call.
"""
Expand All @@ -597,15 +601,22 @@ def __init__(
self,
model_id,
model_kwargs: dict[str, Any] | None = None,
sampling_params: dict[str, Any] | None = None,
**kwargs,
):
if not _is_package_available("vllm"):
raise ModuleNotFoundError("Please install 'vllm' extra to use VLLMModel: `pip install 'smolagents[vllm]'`")

from vllm import LLM # type: ignore
from vllm import (
LLM, # type: ignore
SamplingParams, # type: ignore
)
from vllm.transformers_utils.tokenizer import get_tokenizer # type: ignore

self._valid_sampling_keys = set(inspect.signature(SamplingParams).parameters.keys())

self.model_kwargs = model_kwargs or {}
self.sampling_params = sampling_params or {}
super().__init__(**kwargs)
self.model_id = model_id
self.model = LLM(model=model_id, **self.model_kwargs)
Expand Down Expand Up @@ -665,17 +676,21 @@ def generate(
tokenize=False,
)

sampling_params = SamplingParams(
n=kwargs.get("n", 1),
temperature=kwargs.get("temperature", 0.0),
max_tokens=kwargs.get("max_tokens", 2048),
stop=prepared_stop_sequences,
structured_outputs=structured_outputs,
)
sampling_kwargs = {
"n": kwargs.get("n", 1),
"temperature": kwargs.get("temperature", 0.0),
"max_tokens": kwargs.get("max_tokens", 2048),
"stop": prepared_stop_sequences,
"structured_outputs": structured_outputs,
**self.sampling_params,
**kwargs,
}

sampling_params = {key: value for key, value in sampling_kwargs.items() if key in self._valid_sampling_keys}

out = self.model.generate(
prompt,
sampling_params=sampling_params,
sampling_params=SamplingParams(**sampling_params),
**completion_kwargs,
)

Expand Down
68 changes: 68 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
Model,
OpenAIModel,
TransformersModel,
VLLMModel,
get_clean_message_list,
get_tool_call_from_text,
get_tool_json_schema,
Expand Down Expand Up @@ -667,6 +668,73 @@ def test_init(self, patching):
assert mocks["transformers.AutoProcessor.from_pretrained"].call_args.kwargs == {"trust_remote_code": True}


class TestVLLMModel:
@pytest.mark.parametrize(
"sampling_params, generate_kwargs, expected_call_params",
[
(
{"temperature": 0.5, "top_p": 0.9},
{},
{"temperature": 0.5, "top_p": 0.9, "n": 1, "max_tokens": 2048},
),
(
{"temperature": 0.5, "top_p": 0.9},
{"temperature": 0.8, "frequency_penalty": 1},
{"temperature": 0.8, "top_p": 0.9, "frequency_penalty": 1, "n": 1, "max_tokens": 2048},
),
(
{},
{},
{"temperature": 0.0, "n": 1, "max_tokens": 2048},
),
(
{"invalid_key": "foo"},
{"another_invalid": "bar", "temperature": 0.7},
{"temperature": 0.7, "n": 1, "max_tokens": 2048},
),
],
)
def test_sampling_params_precedence(self, sampling_params, generate_kwargs, expected_call_params):
with (
patch("smolagents.models._is_package_available", return_value=True),
patch("vllm.LLM") as MockLLM,
patch("vllm.transformers_utils.tokenizer.get_tokenizer") as MockTokenizer,
patch("vllm.SamplingParams") as MockSamplingParams,
patch("inspect.signature") as MockSignature,
):
MockSignature.return_value.parameters.keys.return_value = {
"n",
"temperature",
"max_tokens",
"stop",
"structured_outputs",
"top_p",
"frequency_penalty",
}

model = VLLMModel(model_id="test-model", sampling_params=sampling_params)

model.model = MockLLM.return_value
model.tokenizer = MockTokenizer.return_value
model.tokenizer.apply_chat_template.return_value = "Test prompt"

mock_out = MagicMock()
mock_out[0].outputs[0].text = "Test response"
mock_out[0].prompt_token_ids = [1, 2, 3]
mock_out[0].outputs[0].token_ids = [4, 5]
model.model.generate.return_value = mock_out

expected_call_params["stop"] = []
expected_call_params["structured_outputs"] = None

messages = [ChatMessage(role=MessageRole.USER, content=[{"type": "text", "text": "Hello"}])]
model.generate(messages, **generate_kwargs)

MockSamplingParams.assert_called_once_with(
**expected_call_params,
)


def test_get_clean_message_list_basic():
messages = [
ChatMessage(role=MessageRole.USER, content=[{"type": "text", "text": "Hello!"}]),
Expand Down