diff --git a/.github/workflows/stackit.yml b/.github/workflows/stackit.yml index d62f998c4d..3c639912ce 100644 --- a/.github/workflows/stackit.yml +++ b/.github/workflows/stackit.yml @@ -22,7 +22,7 @@ concurrency: env: PYTHONUNBUFFERED: "1" FORCE_COLOR: "1" - STACKIT: ${{ secrets.STACKIT_API_KEY }} + STACKIT_API_KEY: ${{ secrets.STACKIT_API_KEY }} jobs: run: diff --git a/integrations/stackit/src/haystack_integrations/components/generators/stackit/chat/chat_generator.py b/integrations/stackit/src/haystack_integrations/components/generators/stackit/chat/chat_generator.py index e2ec1bd86d..9b93a03ae2 100644 --- a/integrations/stackit/src/haystack_integrations/components/generators/stackit/chat/chat_generator.py +++ b/integrations/stackit/src/haystack_integrations/components/generators/stackit/chat/chat_generator.py @@ -1,11 +1,12 @@ # SPDX-FileCopyrightText: 2025-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional, Union from haystack import component, default_to_dict from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import StreamingCallbackT +from haystack.tools import Tool, Toolset, serialize_tools_or_toolset from haystack.utils import serialize_callable from haystack.utils.auth import Secret @@ -44,6 +45,7 @@ def __init__( api_base_url: Optional[str] = "https://api.openai-compat.model-serving.eu01.onstackit.cloud/v1", generation_kwargs: Optional[Dict[str, Any]] = None, *, + tools: Optional[Union[List[Tool], Toolset]] = None, timeout: Optional[float] = None, max_retries: Optional[int] = None, http_client_kwargs: Optional[Dict[str, Any]] = None, @@ -74,6 +76,9 @@ def __init__( events as they become available, with the stream terminated by a data: [DONE] message. - `safe_prompt`: Whether to inject a safety prompt before all conversations. - `random_seed`: The seed to use for random sampling. + :param tools: + A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a + list of `Tool` objects or a `Toolset` instance. :param timeout: Timeout for STACKIT client calls. If not set, it defaults to either the `OPENAI_TIMEOUT` environment variable, or 30 seconds. @@ -93,6 +98,7 @@ def __init__( generation_kwargs=generation_kwargs, timeout=timeout, max_retries=max_retries, + tools=tools, http_client_kwargs=http_client_kwargs, ) @@ -108,7 +114,6 @@ def to_dict(self) -> Dict[str, Any]: # if we didn't implement the to_dict method here then the to_dict method of the superclass would be used # which would serialiaze some fields that we don't want to serialize (e.g. the ones we don't have in # the __init__) - # it would be hard to maintain the compatibility as superclass changes return default_to_dict( self, model=self.model, @@ -116,6 +121,7 @@ def to_dict(self) -> Dict[str, Any]: api_base_url=self.api_base_url, generation_kwargs=self.generation_kwargs, api_key=self.api_key.to_dict(), + tools=serialize_tools_or_toolset(self.tools), timeout=self.timeout, max_retries=self.max_retries, http_client_kwargs=self.http_client_kwargs, diff --git a/integrations/stackit/tests/test_stackit_chat_generator.py b/integrations/stackit/tests/test_stackit_chat_generator.py index b50c4f7c3c..0959b5bb3d 100644 --- a/integrations/stackit/tests/test_stackit_chat_generator.py +++ b/integrations/stackit/tests/test_stackit_chat_generator.py @@ -5,7 +5,8 @@ import pytest import pytz from haystack.components.generators.utils import print_streaming_chunk -from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk, ToolCall +from haystack.tools import Tool from haystack.utils.auth import Secret from openai import OpenAIError from openai.types import CompletionUsage @@ -23,6 +24,24 @@ def chat_messages(): ] +def weather(city: str): + """Get weather for a given city.""" + return f"The weather in {city} is sunny and 32°C" + + +@pytest.fixture +def tools(): + tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters=tool_parameters, + function=weather, + ) + + return [tool] + + @pytest.fixture def mock_chat_completion(): """ @@ -254,3 +273,59 @@ def __call__(self, chunk: StreamingChunk) -> None: assert callback.counter > 1 assert "Paris" in callback.responses + + @pytest.mark.skipif( + not os.environ.get("STACKIT_API_KEY", None), + reason="Export an env var called STACKIT_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + def test_live_run_with_tools_and_response(self, tools): + """ + Integration test that the STACKITChatGenerator component can run with tools and get a response. + """ + initial_messages = [ChatMessage.from_user("What's the weather like in Paris and Berlin?")] + component = STACKITChatGenerator( + # Only model that supports tool calls at the moment + # This one does indeed run, but for some reason the tool call is put into + # chat_completion.choices[0].message.content instead chat_completion.choices[0].message.tool_calls + # NOTE: If you only induce one tool call it works as expected, but with multiple tool calls + # it stores the result in the content field. + model="cortecs/Llama-3.3-70B-Instruct-FP8-Dynamic", + tools=tools + ) + results = component.run(messages=initial_messages, generation_kwargs={"tool_choice": "auto"}) + + assert len(results["replies"]) == 1 + + # Find the message with tool calls + tool_message = results["replies"][0] + + assert isinstance(tool_message, ChatMessage) + tool_calls = tool_message.tool_calls + assert len(tool_calls) == 2 + assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT) + + for tool_call in tool_calls: + assert tool_call.id is not None + assert isinstance(tool_call, ToolCall) + assert tool_call.tool_name == "weather" + + arguments = [tool_call.arguments for tool_call in tool_calls] + assert sorted(arguments, key=lambda x: x["city"]) == [{"city": "Berlin"}, {"city": "Paris"}] + assert tool_message.meta["finish_reason"] == "tool_calls" + + new_messages = [ + initial_messages[0], + tool_message, + ChatMessage.from_tool(tool_result="22° C and sunny", origin=tool_calls[0]), + ChatMessage.from_tool(tool_result="16° C and windy", origin=tool_calls[1]), + ] + # Pass the tool result to the model to get the final response + results = component.run(new_messages) + + assert len(results["replies"]) == 1 + final_message = results["replies"][0] + assert final_message.is_from(ChatRole.ASSISTANT) + assert len(final_message.text) > 0 + assert "paris" in final_message.text.lower() + assert "berlin" in final_message.text.lower()