diff --git a/documentation/examples/anagrams_with_genai.py b/documentation/examples/anagrams_with_genai.py index 5f15c92..91d19fb 100644 --- a/documentation/examples/anagrams_with_genai.py +++ b/documentation/examples/anagrams_with_genai.py @@ -17,17 +17,11 @@ # title: Anagrams Task with Genai/ OpenAI Api # --- # %% -from kaggle_benchmarks import assertions, chats, task -from kaggle_benchmarks.kaggle import model_proxy -llm_with_openai_api = model_proxy.ModelProxy( - model="google/gemini-2.5-flash", - api="openai", -) -llm_with_genai_api = model_proxy.ModelProxy( - model="google/gemini-2.5-pro", - api="genai", -) +from kaggle_benchmarks import assertions, chats, kaggle, task + +llm_with_openai_api = kaggle.load_model("google/gemini-2.5-flash", api="openai") +llm_with_genai_api = kaggle.load_model("google/gemini-2.5-flash", api="genai") def is_anagram(x: str, y: str) -> bool: @@ -54,23 +48,11 @@ def write_anagrams(llm, word: str) -> int: for msg in reversed(non_streaming_result.chat.messages) if msg.sender is llm_with_genai_api ) -metadata = llm_response_message._meta - -assert "input_tokens" in metadata, "Metadata is missing 'input_tokens' key" -assert "output_tokens" in metadata, "Metadata is missing 'output_tokens' key" -# %% -llm_with_genai_api.stream_responses = True +usage = llm_response_message.usage -streaming_result = write_anagrams.run(llm_with_genai_api, "creative") - -llm_response_message_stream = next( - msg - for msg in reversed(streaming_result.chat.messages) - if msg.sender is llm_with_genai_api -) -metadata_stream = llm_response_message_stream._meta -assert "input_tokens" in metadata_stream -assert "output_tokens" in metadata_stream +assert usage, "Metadata is missing 'usage' attribute" +assert usage.input_tokens > 0, "usage is missing 'input_tokens' key" +assert usage.output_tokens > 0, "usage is missing 'output_tokens' key" # %% diff --git a/documentation/examples/guess_the_number.py b/documentation/examples/guess_the_number.py new file mode 100644 index 0000000..f42a0a7 --- /dev/null +++ b/documentation/examples/guess_the_number.py @@ -0,0 +1,74 @@ +# Copyright 2026 Kaggle Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% [markdown] +# --- +# title: Example of a game that requires tool use. +# --- + +# %% +import random + +import kaggle_benchmarks as kbench +from kaggle_benchmarks.kaggle import models + +SECRET_NUMBER = random.randint(1, 10) + + +def guess_number(guess: int) -> str: + """Make a guess in the number guessing game.""" + if guess < SECRET_NUMBER: + return "Higher" + elif guess > SECRET_NUMBER: + return "Lower" + else: + return "Correct!" + + +@kbench.task(name="guess-the-number-game") +def play_game(llm): + prompt = "I'm thinking of a number between 1 and 10. Can you guess it?" + response = llm.prompt(prompt, schema=int, tools=[guess_number]) + + for _ in range(4): + if response == SECRET_NUMBER: + break + response = llm.prompt(response, schema=int, tools=[guess_number]) + + kbench.assertions.assert_equal( + SECRET_NUMBER, + response, + expectation=f"LLM should have guessed the secret number. The secret number was {SECRET_NUMBER}", + ) + + +# %% + +llm_with_genai_api = models.load_model( + model_name=kbench.llm.name, + api="genai", +) + +play_game.run(llm=llm_with_genai_api) + +# %% + +llm_with_openai_api = models.load_model( + model_name=kbench.llm.name, + api="openai", +) + +play_game.run(llm_with_openai_api) + +# %% diff --git a/documentation/examples/prompt_with_tools.py b/documentation/examples/prompt_with_tools.py deleted file mode 100644 index 9b6110f..0000000 --- a/documentation/examples/prompt_with_tools.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2025 Kaggle Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# %% [markdown] -# --- -# title: Example of using `prompt` with `tools` parameter. -# --- -# - **Automatic tool calling is currently only supported via the `Gemini`` API** -# - **For manual tool calling with with `OpenAI` API, please refer to the example in `use_calculator_tool.py`.** - -# %% -import kaggle_benchmarks as kbench -from kaggle_benchmarks import tools -from kaggle_benchmarks.kaggle import models - - -def lookup_wikipedia(query: str): - """Searches Wikipedia for a given query and returns the article content. - - This function acts as a tool that can be used by an LLM. It takes a search - query, uses the `SearchEngine` to find a relevant Wikipedia article, and - returns its content. It includes basic error handling. - - Args: - query: The search term to look up on Wikipedia. - - Returns: - A dictionary containing a "success" status and, if successful, the - "article" content. Returns `{"success": False}` on failure. - """ - try: - wikipedia = tools.search.SearchEngine("wikipedia") - article = wikipedia.search(query) - return { - "success": True, - "article": article, - } - - except Exception: - return {"success": False} - - -# %% -@kbench.task(name="llm auto call wiki tool") -def llm_auto_call_wiki_tool(llm, question: str, correct_answer: str): - prompt_message = f""" - Answer the following question by the following steps: - 1) Generate a useful query based on the question. - 2) Use the provided tool `lookup_wikipedia` to look up wiki articles. - 3) Answer the question based on the wiki article. Just generate the answer and nothing else. - - - Question: {question} - Answer: - """ - response = llm.prompt(prompt_message, tools=[lookup_wikipedia]) - kbench.assertions.assert_contains_regex( - f"(?i){correct_answer}", - response, - expectation=f"LLM should answer the question correctly. Expected answer: {correct_answer}, Got: {response}", - ) - - -# NOTE: Automatic tool calling requires the `genai` API. -# For `openai` API, tools must be called manually (see `use_calculator_tool.py`). -llm_with_genai_api = models.load_model( - model_name=kbench.llm.name, - api="genai", -) - -llm_auto_call_wiki_tool.run( - llm_with_genai_api, "What gymnasium did Stefan Vrtel-Wierczynski attend?", "Stryj" -) - -# %% diff --git a/documentation/examples/use_calculator_tool.py b/documentation/examples/use_calculator_tool.py index 7049521..e7a5396 100644 --- a/documentation/examples/use_calculator_tool.py +++ b/documentation/examples/use_calculator_tool.py @@ -14,17 +14,16 @@ # %% [markdown] # --- -# title: Manual Calculator Tool Calling +# title: Calculator Tool # --- # %% -import json - -from kaggle_benchmarks import actors, assertions, llm, messages, task +from kaggle_benchmarks import actors, assertions, llm, task tool = actors.Actor(name="Tool", role="tool", avatar="🛠️") def run_simple_calculator(a: float, b: float, operator: str) -> float: + """Calculates the result of an arithmetic operation like +, -, *, or /.""" if operator == "+": return a + b if operator == "-": @@ -37,72 +36,23 @@ def run_simple_calculator(a: float, b: float, operator: str) -> float: @task("Calculator Tool Use") -def use_calculator( - llm, problem: str, expected_answer: float, stream_mode: bool = False -) -> None: - calculator_tool = { - "type": "function", - "function": { - "name": "simple_calculator", - "description": "Calculates the result of an arithmetic operation.", - "parameters": { - "type": "object", - "properties": { - "a": {"type": "number", "description": "The first number."}, - "b": {"type": "number", "description": "The second number."}, - "operator": { - "type": "string", - "description": "The operator (+, -, *, /).", - }, - }, - "required": ["a", "b", "operator"], - }, - }, - } - llm.stream_responses = stream_mode - - actors.user.send(problem) - - tool_call_msg = llm.respond(tools=[calculator_tool]) - tool_calls = tool_call_msg.tool_calls - assertions.assert_true( - bool(tool_calls), "LLM was expected to call a tool, but it did not." - ) - - tool_call = tool_calls[0] - function_args = json.loads(tool_call["function"]["arguments"]) - # Removes 'signature' parameter in thinking mode. - function_args.pop("signature", None) - tool_result = "" - try: - tool_result = run_simple_calculator(**function_args) - except Exception as e: - tool_result = f"Error executing tool: {type(e).__name__} - {e}" - - tool.send( - messages.Message( - sender=tool, - content=str(tool_result), - _meta={"tool_call_id": tool_call["id"]}, - ) +def use_calculator(llm, problem: str, expected_answer: float) -> None: + final_answer = llm.prompt(problem, tools=[run_simple_calculator]) + assertions.assert_tool_was_invoked( + run_simple_calculator, "LLM was expected to call a tool, but it did not." ) - final_answer_msg = llm.respond() - final_answer = final_answer_msg.content - assertions.assert_true( - str(expected_answer) in final_answer, - f"Expected '{expected_answer}' to be in the final answer, but got '{final_answer}'.", + str(expected_answer) in answer, + f"Expected '{expected_answer}' to be in the final answer, but got '{answer}'.", ) +# %% + problem = "What is 485 multiplied by 12?" expected = 485 * 12 -# %% -use_calculator.run(llm, problem=problem, expected_answer=expected, stream_mode=True) - -# %% -use_calculator.run(llm, problem=problem, expected_answer=expected, stream_mode=False) +use_calculator.run(llm, problem=problem, expected_answer=expected) # %% diff --git a/golden_tests/conftest.py b/golden_tests/conftest.py index a98ab74..c815095 100644 --- a/golden_tests/conftest.py +++ b/golden_tests/conftest.py @@ -54,7 +54,11 @@ def module_report_fixture(request): model_report = report.setdefault( f"{api}://{llm.name}", { - "config": {"structured_output": llm.support_structured_outputs}, + "config": { + "structured_output": llm.support_structured_outputs, + "tools": llm.support_tool_calling, + "vision": llm.support_vision, + }, "tests": {}, }, ) diff --git a/golden_tests/test_api_integration.py b/golden_tests/test_api_integration.py new file mode 100644 index 0000000..bad88da --- /dev/null +++ b/golden_tests/test_api_integration.py @@ -0,0 +1,434 @@ +# Copyright 2026 Kaggle Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Generic, TypeVar + +import openai +import pydantic +import pytest +from google import genai + +from kaggle_benchmarks import chats, llm_messages, prompting, providers, utils +from kaggle_benchmarks import tools as tool_utils +from kaggle_benchmarks.actors import llms +from kaggle_benchmarks.content_types import images + +http_client = utils.build_httpx_client("test_cache") + + +def create_openai_client(cls=providers.openai.OpenAIResponsesAPI, **kwargs): + if "OPENAI_API_KEY" not in os.environ: + pytest.skip("Missing OPENAI_API_KEY environment variable.") + return cls( + client=openai.OpenAI( + api_key=os.environ["OPENAI_API_KEY"], + http_client=http_client, + ), + **kwargs, + ) + + +def create_google_client(cls=providers.genai.GoogleGenAI, **kwargs): + if "GEMINI_API_KEY" not in os.environ: + pytest.skip("Missing GEMINI_API_KEY environment variable.") + return cls( + client=genai.Client(api_key=os.environ["GEMINI_API_KEY"]), + **kwargs, + ) + + +def create_model_proxy_openai_client(**kwargs): + if "MODEL_PROXY_API_KEY" not in os.environ: + pytest.skip("Missing MODEL_PROXY_API_KEY environment variable.") + return providers.openai.ModelProxyOpenAI( + client=openai.OpenAI( + api_key=os.environ["MODEL_PROXY_API_KEY"], + base_url=os.environ["MODEL_PROXY_URL"], + http_client=http_client, + ), + **kwargs, + ) + + +def create_model_proxy_genai_client(**kwargs): + if "MODEL_PROXY_API_KEY" not in os.environ: + pytest.skip("Missing MODEL_PROXY_API_KEY environment variable.") + return providers.genai.ModelProxyGenAI( + client=genai.Client( + api_key=os.environ["MODEL_PROXY_API_KEY"], + http_options={ + "api_version": "v1", + "base_url": os.environ["MODEL_PROXY_URL"].replace("/openapi", "/genai"), + }, + ), + **kwargs, + ) + + +PARAMS = [ + pytest.param( + ( + create_openai_client, + dict( + model="gpt-4o", + support_structured_outputs=True, + support_tool_calling=True, + ), + ), + id="openai[+s+t]", + ), + pytest.param( + ( + create_openai_client, + dict( + model="gpt-4o", + support_structured_outputs=True, + support_tool_calling=False, + ), + ), + id="openai[+s-t]", + ), + pytest.param( + ( + create_openai_client, + dict( + model="gpt-4o", + support_structured_outputs=False, + support_tool_calling=False, + ), + ), + id="openai[-s-t]", + ), + pytest.param( + ( + create_openai_client, + dict(model="o4-mini", cls=providers.openai.StreamingOpenAIResponsesAPI), + ), + id="openai-o4-mini-streaming", + ), + pytest.param( + (create_google_client, dict(model="gemini-2.5-flash")), + id="google-gemini-2.5-flash", + ), + pytest.param( + ( + create_google_client, + dict(model="gemini-2.5-flash", cls=providers.genai.StreamingGoogleGenAI), + ), + id="google-gemini-2.5-flash-streaming", + ), +] + +PROXY_MODELS = [ + "google/gemini-2.0-flash", + "google/gemini-2.5-flash", + "google/gemini-2.5-pro", + "google/gemini-3-flash-preview", + "google/gemma-3-12b", + "qwen/qwen3-235b-a22b-instruct-2507", + "qwen/qwen3-next-80b-a3b-instruct", + "anthropic/claude-haiku-4-5@20251001", + "anthropic/claude-opus-4-5@20251101", + "anthropic/claude-sonnet-4-5@20250929", + "deepseek-ai/deepseek-r1-0528", + "deepseek-ai/deepseek-v3.2", + "zai/glm-5", + # "google/gemini-3.1-flash-lite-preview", +] +for name in PROXY_MODELS: + PARAMS.append( + pytest.param( + (create_model_proxy_openai_client, dict(model=name)), + id=f"model-proxy-openai-{name}", + ) + ) + +for name in PROXY_MODELS: + PARAMS.append( + pytest.param( + (create_model_proxy_genai_client, dict(model=name)), + id=f"model-proxy-genai-{name}", + ) + ) + + +@pytest.fixture +def llm(request): + model_factory, params = request.param + return model_factory(**params) + + +@pytest.mark.parametrize("llm", PARAMS, indirect=True) +def test_text_generation(llm): + """Tests basic text generation to ensure the model responds.""" + with chats.new(): + response = llm.prompt("Say 'hello world' and nothing else.") + assert "hello world" in response.lower() + + +class UserInfo(pydantic.BaseModel): + name: str + age: int + + +@pytest.mark.parametrize("llm", PARAMS, indirect=True) +def test_structured_output(llm): + """Tests the model's ability to generate a simple Pydantic model.""" + with chats.new(): + response = llm.prompt( + "Generate a user named Alice who is 30 years old.", schema=UserInfo + ) + assert isinstance(response, UserInfo) + assert response.name == "Alice" + assert response.age == 30 + + +class UserDetails(pydantic.BaseModel): + user: UserInfo + address: str + + +@pytest.mark.parametrize("llm", PARAMS, indirect=True) +def test_nested_structured_output(llm): + """Tests the model's ability to generate a nested Pydantic model.""" + with chats.new(): + try: + response = llm.prompt( + "Generate a user named Alice who is 30 years old and lives at 123 Kaggle Street.", + schema=UserDetails, + ) + assert isinstance(response, UserDetails) + assert response.user.name == "Alice" + assert response.user.age == 30 + assert "123 Kaggle Street" in response.address + except prompting.ResponseParsingError as e: + pytest.xfail( + f"Model {llm.model} may not support nested structured output: {e}" + ) + + +T = TypeVar("T") + + +class User(prompting.RenderablePydanticModel, Generic[T]): + # class User(pydantic.BaseModel, Generic[T]): + user: T + address: str + + +@pytest.mark.parametrize("llm", PARAMS, indirect=True) +def test_generic_structured_output(llm): + """Tests the model's ability to generate a nested Pydantic model.""" + with chats.new(): + try: + response = llm.prompt( + "Generate a user named Alice who is 30 years old and lives at 123 Kaggle Street.", + schema=User[UserInfo], + ) + assert isinstance(response, User) + assert response.user.name == "Alice" + assert response.user.age == 30 + assert "123 Kaggle Street" in response.address + except prompting.ResponseParsingError as e: + pytest.xfail( + f"Model {llm.model} may not support nested structured output: {e}" + ) + + +@pytest.mark.parametrize("llm", PARAMS, indirect=True) +def test_vision_input(llm): + """Tests the model's ability to process image input.""" + image = images.from_url( + "https://storage.googleapis.com/kaggle-organizations/5154/thumbnail.png" + ) + + with chats.new("Vision Test Chat"): + if llm.support_vision: + response = llm.prompt("What is in this image?", image=image) + assert "goose" in response.lower() or "bird" in response.lower() + else: + with pytest.raises(ValueError, match="Vision not supported"): + llm.prompt("What is in this image?", image=image) + + +class StockPrice(pydantic.BaseModel): + symbol: str + price: float + + model_config = pydantic.ConfigDict( + title="StockPrice", + extra="forbid", + ) + + +def get_stock_price(symbol: str) -> float: + """Gets the current stock price for a given symbol.""" + if symbol == "KGL": + return 120.5 + elif symbol == "BNCH": + return 210.3 + else: + return 0.0 + + +@pytest.mark.parametrize( + "schema", + [ + pytest.param(str, id="output_str"), + pytest.param(float, id="output_primitive"), + pytest.param(StockPrice, id="output_pydantic"), + ], +) +@pytest.mark.parametrize("llm", PARAMS, indirect=True) +def test_tool_calling(llm, schema): + """Tests the full tool-calling loop with various output schemas.""" + with chats.new() as chat: + try: + response = llm.prompt( + "What is the price of KGL?", + schema=schema, + tools=[get_stock_price], + max_tool_calls=2, + ) + except prompting.ResponseParsingError as e: + # Not all models reliably produce structured output, so we fail gracefully. + pytest.xfail(str(e)) + except llms.ToolInvocationLimitExhausted as e: + # TODO: some model will not see tool invocation and will call the same tool over and over + pytest.xfail(str(e)) + + assert isinstance(response, schema) + # models may not respond correctly but should respond in proper format + # if schema is float: + # assert response == 120.5 + # elif schema is StockPrice: + # assert response.price == 120.5 + # assert "KGL" in response.symbol.upper() + + assert len(chat.messages) == 2 + llm_message = chat.messages[1] + assert isinstance(llm_message, llm_messages.LLMMessage) + assert llm_message.tool_calls is not None + # This should be exactly one, but some models may generate more. + assert len(llm_message.tool_calls) >= 1 + + # The sub-chat containing the tool invocation should be preserved. + assert llm_message.chat + assert llm_message.chat.messages + + tool_call = llm_message.tool_calls[0] + assert isinstance( + tool_call, (tool_utils.ToolInvocation, tool_utils.ToolInvocationResult) + ) + assert tool_call.name == "get_stock_price" + assert "symbol" in tool_call.arguments + assert "KGL" in tool_call.arguments["symbol"].upper() + + +@pytest.mark.parametrize("llm", PARAMS, indirect=True) +def test_parallel_tool_calling(llm): + """Tests the LLM's ability to make parallel tool calls.""" + prompt = "What are the stock prices for 'KGL' and 'BNCH'?" + with chats.new() as chat: + response = llm.prompt(prompt, tools=[get_stock_price]) + llm_message = chat.messages[-1] + + assert isinstance(llm_message, llm_messages.LLMMessage) + assert isinstance(response, str) + assert llm_message.tool_calls is not None + assert len(llm_message.tool_calls) >= 2 + + tool_names = {call.name for call in llm_message.tool_calls} + assert tool_names == {"get_stock_price"} + + symbols = {call.arguments["symbol"] for call in llm_message.tool_calls} + assert symbols == {"KGL", "BNCH"} + + +@pytest.mark.parametrize("llm", PARAMS, indirect=True) +def test_tool_calling_structured_args(llm): + """Tests tool calling where the tool argument is a Pydantic model.""" + + class Point(pydantic.BaseModel): + x: int + y: int + + def draw_point(point: Point) -> str: + """Draws a point on a canvas.""" + if isinstance(point, dict): + point = Point.model_validate(point) + return f"Drawing point at ({point.x}, {point.y})" + + with chats.new() as chat: + try: + response = llm.prompt( + "Draw a point at (10, 20)", + tools=[draw_point], + ) + except prompting.ResponseParsingError as e: + pytest.fail(str(e)) + except llms.ToolInvocationLimitExhausted as e: + pytest.fail(str(e)) + + assert isinstance(response, str) + + llm_message = chat.messages[1] + assert llm_message.tool_calls + tool_call = llm_message.tool_calls[0] + assert tool_call.name == "draw_point" + assert "point" in tool_call.arguments + point = tool_call.arguments["point"] + if isinstance(point, dict): + assert point == {"x": 10, "y": 20} + else: + assert point.x == 10 + assert point.y == 20 + + +def get_user_id(username: str) -> int: + """Gets the user ID for a given username.""" + if username == "test_user": + return 123 + else: + return -1 + + +def get_user_posts(user_id: int) -> list[str]: + """Gets the posts for a given user ID.""" + if user_id == 123: + return ["Post 1", "Post 2"] + else: + return [] + + +@pytest.mark.parametrize("llm", PARAMS, indirect=True) +def test_dependent_tool_calling(llm): + """Tests the LLM's ability to make dependent tool calls.""" + + prompt = "What are the posts for user 'test_user'?" + with chats.new() as chat: + response = llm.prompt(prompt, tools=[get_user_id, get_user_posts]) + + assert "Post 1" in response + assert "Post 2" in response + assert len(chat.messages) == 2 + llm_message = chat.messages[-1] + assert isinstance(llm_message, llm_messages.LLMMessage) + assert isinstance(response, str) + assert llm_message.tool_calls is not None + assert len(llm_message.tool_calls) >= 2 + + tool_names = {call.name for call in llm_message.tool_calls} + assert tool_names == {"get_user_id", "get_user_posts"} diff --git a/golden_tests/test_cookbook_examples.py b/golden_tests/test_cookbook_examples.py index a7616c0..170527f 100644 --- a/golden_tests/test_cookbook_examples.py +++ b/golden_tests/test_cookbook_examples.py @@ -173,9 +173,10 @@ def assess_with_judge_task(llm, judge_llm) -> None: # We fix the test LLM to one reliable model to focus on testing the judges. @pytest.mark.parametrize("llm_name", ["google/gemini-2.5-flash"]) @pytest.mark.parametrize("judge_llm_name", JUDGE_LLM_NAMES) -def test_assess_with_judge(llm_name, judge_llm_name): - llm = kbench.llms[llm_name] - judge_llm = kbench.llms[judge_llm_name] +@pytest.mark.parametrize("api", ["openai", "genai"]) +def test_assess_with_judge(llm_name, judge_llm_name, api): + llm = kbench.kaggle.load_model(llm_name, api=api) + judge_llm = kbench.kaggle.load_model(judge_llm_name, api=api) run = assess_with_judge_task.run(llm, judge_llm) assert run.passed @@ -381,13 +382,12 @@ def assert_multi_qa_result(run): @benchmark_test(df=df, verify_fn=assert_multi_qa_result) @kbench.task() def test_dataset_eval(llm, df) -> tuple[float, float]: - with kbench.client.enable_cache(): - runs = single_qa_task.evaluate( - llm=[llm], - evaluation_data=df, - n_jobs=2, - remove_run_files=True, - ) + runs = single_qa_task.evaluate( + llm=[llm], + evaluation_data=df, + n_jobs=2, + remove_run_files=True, + ) eval_df = runs.as_dataframe() @@ -401,18 +401,12 @@ def test_dataset_eval(llm, df) -> tuple[float, float]: # --- Test Case: Image inputs (URL) --- -@benchmark_test( - exclude={ - "deepseek-ai/deepseek-r1-0528", - "deepseek-ai/deepseek-v3.2", - "qwen/qwen3-235b-a22b-instruct-2507", - "qwen/qwen3-next-80b-a3b-instruct", - "zai/glm-5", - } -) +@benchmark_test() @kbench.task() def test_image_url(llm): """Sends an image URL directly to the model.""" + if not llm.support_vision: + pytest.skip("Model does not support vision") # Kaggle logo image_url = "https://www.kaggle.com/static/images/site-logo.png" @@ -431,19 +425,12 @@ def test_image_url(llm): # --- Test Case: Image inputs (Base64) --- -@benchmark_test( - exclude={ - "deepseek-ai/deepseek-r1-0528", - "deepseek-ai/deepseek-v3.2", - "qwen/qwen3-235b-a22b-instruct-2507", - "qwen/qwen3-next-80b-a3b-instruct", - "anthropic/claude-sonnet-4-5@20250929", - "zai/glm-5", - } -) +@benchmark_test() @kbench.task() def test_image_base64(llm): """Sends a base64 encoded image with explicit format specification.""" + if not llm.support_vision: + pytest.skip("Model does not support vision") # Example: A small red dot (PNG) # This is a 1x1 red pixel in PNG format red_dot_b64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==" @@ -463,17 +450,11 @@ def test_image_base64(llm): # --- Test Case: Image inputs (local file) --- -@benchmark_test( - exclude={ - "deepseek-ai/deepseek-r1-0528", - "deepseek-ai/deepseek-v3.2", - "qwen/qwen3-235b-a22b-instruct-2507", - "qwen/qwen3-next-80b-a3b-instruct", - "zai/glm-5", - } -) +@benchmark_test() @kbench.task() def test_image_local_file(llm): + if not llm.support_vision: + pytest.skip("Model does not support vision") # Kaggle logo image_url = "https://www.kaggle.com/static/images/site-logo.png" @@ -592,9 +573,6 @@ def test_audio_url(llm): # %% # --- Test Case: Tool Use --- -# This doesn't work with "genai" API for now -# So test it with `-k "openai"` only. -# TODO: Rewrite this test after tool refactoring. def run_simple_calculator(a: float, b: float, operator: str) -> float: @@ -612,7 +590,7 @@ def run_simple_calculator(a: float, b: float, operator: str) -> float: @benchmark_test() @kbench.task() -def test_simple_tool_use(llm): +def test_tool_use(llm): problem = "What is 50 plus 25?" expected_answer = 75.0 diff --git a/golden_tests/test_cookbook_examples_report.yaml b/golden_tests/test_cookbook_examples_report.yaml index fdbc660..07b13c6 100644 --- a/golden_tests/test_cookbook_examples_report.yaml +++ b/golden_tests/test_cookbook_examples_report.yaml @@ -1,6 +1,8 @@ genai://anthropic/claude-haiku-4-5@20251001: config: - structured_output: true + structured_output: false + tools: false + vision: true tests: test_dataset_eval: passed test_extract_bool: passed @@ -12,9 +14,12 @@ genai://anthropic/claude-haiku-4-5@20251001: test_image_base64: passed test_image_local_file: passed test_image_url: passed + test_tool_use: passed genai://anthropic/claude-opus-4-5@20251101: config: - structured_output: true + structured_output: false + tools: false + vision: true tests: test_dataset_eval: passed test_extract_bool: passed @@ -26,9 +31,12 @@ genai://anthropic/claude-opus-4-5@20251101: test_image_base64: passed test_image_local_file: passed test_image_url: passed + test_tool_use: passed genai://anthropic/claude-sonnet-4-5@20250929: config: - structured_output: true + structured_output: false + tools: false + vision: true tests: test_dataset_eval: passed test_extract_bool: passed @@ -37,11 +45,15 @@ genai://anthropic/claude-sonnet-4-5@20250929: test_extract_dict: passed test_extract_int: passed test_extract_pydantic: passed + test_image_base64: passed test_image_local_file: passed test_image_url: passed + test_tool_use: passed genai://deepseek-ai/deepseek-r1-0528: config: structured_output: false + tools: false + vision: false tests: test_dataset_eval: passed test_extract_bool: passed @@ -50,9 +62,15 @@ genai://deepseek-ai/deepseek-r1-0528: test_extract_dict: passed test_extract_int: passed test_extract_pydantic: passed + test_image_base64: skipped + test_image_local_file: skipped + test_image_url: skipped + test_tool_use: passed genai://deepseek-ai/deepseek-v3.2: config: structured_output: false + tools: false + vision: true tests: test_dataset_eval: passed test_extract_bool: passed @@ -61,9 +79,15 @@ genai://deepseek-ai/deepseek-v3.2: test_extract_dict: passed test_extract_int: passed test_extract_pydantic: passed + test_image_base64: failed + test_image_local_file: failed + test_image_url: failed + test_tool_use: passed genai://google/gemini-2.0-flash: config: structured_output: true + tools: true + vision: true tests: test_audio_base64: passed test_audio_local_file: passed @@ -78,10 +102,12 @@ genai://google/gemini-2.0-flash: test_image_base64: passed test_image_local_file: passed test_image_url: passed - test_manual_tool_use: failed + test_tool_use: passed genai://google/gemini-2.5-flash: config: structured_output: true + tools: false + vision: true tests: test_audio_base64: passed test_audio_local_file: passed @@ -101,6 +127,8 @@ genai://google/gemini-2.5-flash: genai://google/gemini-2.5-pro: config: structured_output: true + tools: true + vision: true tests: test_audio_base64: passed test_audio_local_file: passed @@ -120,6 +148,8 @@ genai://google/gemini-2.5-pro: genai://google/gemini-3-flash-preview: config: structured_output: true + tools: true + vision: true tests: test_audio_base64: passed test_audio_local_file: passed @@ -139,6 +169,8 @@ genai://google/gemini-3-flash-preview: genai://google/gemini-3.1-flash-lite-preview: config: structured_output: true + tools: true + vision: true tests: test_audio_base64: passed test_audio_local_file: passed @@ -153,17 +185,23 @@ genai://google/gemini-3.1-flash-lite-preview: test_image_base64: passed test_image_local_file: passed test_image_url: passed + test_tool_use: failed genai://google/gemma-3-12b: config: structured_output: true + tools: false + vision: false tests: test_dataset_eval: passed - test_image_base64: passed - test_image_local_file: passed - test_image_url: passed + test_image_base64: skipped + test_image_local_file: skipped + test_image_url: skipped + test_tool_use: failed genai://qwen/qwen3-235b-a22b-instruct-2507: config: structured_output: false + tools: false + vision: false tests: test_dataset_eval: passed test_extract_bool: passed @@ -172,10 +210,15 @@ genai://qwen/qwen3-235b-a22b-instruct-2507: test_extract_dict: passed test_extract_int: passed test_extract_pydantic: passed - test_manual_tool_use: failed + test_image_base64: skipped + test_image_local_file: skipped + test_image_url: skipped + test_tool_use: passed genai://qwen/qwen3-next-80b-a3b-instruct: config: structured_output: false + tools: false + vision: false tests: test_dataset_eval: passed test_extract_bool: passed @@ -184,10 +227,15 @@ genai://qwen/qwen3-next-80b-a3b-instruct: test_extract_dict: passed test_extract_int: passed test_extract_pydantic: passed - test_manual_tool_use: failed + test_image_base64: skipped + test_image_local_file: skipped + test_image_url: skipped + test_tool_use: passed genai://zai/glm-5: config: structured_output: true + tools: false + vision: false tests: test_dataset_eval: passed test_extract_bool: failed @@ -196,12 +244,17 @@ genai://zai/glm-5: test_extract_dict: failed test_extract_int: failed test_extract_pydantic: failed - test_manual_tool_use: failed + test_image_base64: skipped + test_image_local_file: skipped + test_image_url: skipped + test_tool_use: failed openai://anthropic/claude-haiku-4-5@20251001: config: structured_output: true + tools: false + vision: true tests: - test_dataset_eval: passed + test_dataset_eval: failed test_extract_bool: passed test_extract_composite_pydantic: passed test_extract_dataclass: passed @@ -211,11 +264,14 @@ openai://anthropic/claude-haiku-4-5@20251001: test_image_base64: passed test_image_local_file: passed test_image_url: passed + test_tool_use: passed openai://anthropic/claude-opus-4-5@20251101: config: structured_output: true + tools: false + vision: true tests: - test_dataset_eval: passed + test_dataset_eval: failed test_extract_bool: passed test_extract_composite_pydantic: passed test_extract_dataclass: passed @@ -225,49 +281,65 @@ openai://anthropic/claude-opus-4-5@20251101: test_image_base64: passed test_image_local_file: passed test_image_url: passed + test_tool_use: passed openai://anthropic/claude-sonnet-4-5@20250929: config: structured_output: true + tools: false + vision: true tests: - test_dataset_eval: passed + test_dataset_eval: failed test_extract_bool: passed test_extract_composite_pydantic: passed test_extract_dataclass: passed test_extract_dict: passed test_extract_int: passed test_extract_pydantic: passed + test_image_base64: failed test_image_local_file: passed test_image_url: passed + test_tool_use: passed openai://deepseek-ai/deepseek-r1-0528: config: - structured_output: false + structured_output: true + tools: false + vision: false tests: - test_dataset_eval: passed + test_dataset_eval: failed test_extract_bool: passed test_extract_composite_pydantic: passed test_extract_dataclass: passed test_extract_dict: passed test_extract_int: passed test_extract_pydantic: passed + test_image_base64: skipped + test_image_local_file: skipped + test_image_url: skipped + test_tool_use: passed openai://deepseek-ai/deepseek-v3.2: config: - structured_output: false + structured_output: true + tools: false + vision: false tests: - test_dataset_eval: passed + test_dataset_eval: failed test_extract_bool: passed - test_extract_composite_pydantic: passed + test_extract_composite_pydantic: failed test_extract_dataclass: passed test_extract_dict: passed test_extract_int: passed test_extract_pydantic: passed + test_image_base64: skipped + test_image_local_file: skipped + test_image_url: skipped + test_tool_use: passed openai://google/gemini-2.0-flash: config: structured_output: true + tools: false + vision: true tests: - test_audio_base64: passed - test_audio_local_file: passed - test_audio_url: passed - test_dataset_eval: passed + test_dataset_eval: failed test_extract_bool: passed test_extract_composite_pydantic: passed test_extract_dataclass: passed @@ -277,17 +349,17 @@ openai://google/gemini-2.0-flash: test_image_base64: passed test_image_local_file: passed test_image_url: passed - test_manual_tool_use: passed + test_tool_use: passed openai://google/gemini-2.5-flash: config: structured_output: true + tools: false + vision: true tests: - test_audio_base64: passed - test_audio_local_file: passed - test_audio_url: passed - test_dataset_eval: passed + test_dataset_eval: failed test_extract_bool: passed test_extract_composite_pydantic: passed + test_extract_composite_pydantic: passed test_extract_dataclass: passed test_extract_dict: passed test_extract_int: passed @@ -300,11 +372,10 @@ openai://google/gemini-2.5-flash: openai://google/gemini-2.5-pro: config: structured_output: true + tools: false + vision: true tests: - test_audio_base64: passed - test_audio_local_file: passed - test_audio_url: passed - test_dataset_eval: passed + test_dataset_eval: failed test_extract_bool: passed test_extract_composite_pydantic: passed test_extract_dataclass: passed @@ -319,11 +390,10 @@ openai://google/gemini-2.5-pro: openai://google/gemini-3-flash-preview: config: structured_output: true + tools: false + vision: true tests: - test_audio_base64: passed - test_audio_local_file: passed - test_audio_url: passed - test_dataset_eval: passed + test_dataset_eval: failed test_extract_bool: passed test_extract_composite_pydantic: passed test_extract_dataclass: passed @@ -338,11 +408,10 @@ openai://google/gemini-3-flash-preview: openai://google/gemini-3.1-flash-lite-preview: config: structured_output: true + tools: false + vision: true tests: - test_audio_base64: passed - test_audio_local_file: passed - test_audio_url: passed - test_dataset_eval: passed + test_dataset_eval: failed test_extract_bool: passed test_extract_composite_pydantic: passed test_extract_dataclass: passed @@ -352,47 +421,66 @@ openai://google/gemini-3.1-flash-lite-preview: test_image_base64: passed test_image_local_file: passed test_image_url: passed + test_tool_use: passed openai://google/gemma-3-12b: config: - structured_output: true + structured_output: false + tools: false + vision: false tests: - test_dataset_eval: passed - test_image_base64: passed - test_image_local_file: passed - test_image_url: passed + test_dataset_eval: failed + test_image_base64: skipped + test_image_local_file: skipped + test_image_url: skipped + test_tool_use: passed openai://qwen/qwen3-235b-a22b-instruct-2507: config: structured_output: false + tools: false + vision: false tests: - test_dataset_eval: passed + test_dataset_eval: failed test_extract_bool: passed test_extract_composite_pydantic: passed test_extract_dataclass: passed test_extract_dict: passed test_extract_int: passed test_extract_pydantic: passed - test_manual_tool_use: passed + test_image_base64: skipped + test_image_local_file: skipped + test_image_url: skipped + test_tool_use: passed openai://qwen/qwen3-next-80b-a3b-instruct: config: structured_output: false + tools: false + vision: false tests: - test_dataset_eval: passed + test_dataset_eval: failed test_extract_bool: passed test_extract_composite_pydantic: passed test_extract_dataclass: passed test_extract_dict: passed test_extract_int: passed test_extract_pydantic: passed - test_manual_tool_use: passed + test_image_base64: skipped + test_image_local_file: skipped + test_image_url: skipped + test_tool_use: passed openai://zai/glm-5: config: structured_output: true + tools: false + vision: true tests: - test_dataset_eval: passed + test_dataset_eval: failed test_extract_bool: passed test_extract_composite_pydantic: passed test_extract_dataclass: passed test_extract_dict: passed test_extract_int: passed test_extract_pydantic: passed - test_manual_tool_use: passed + test_image_base64: failed + test_image_local_file: failed + test_image_url: failed + test_tool_use: passed diff --git a/src/kaggle_benchmarks/actors/llms.py b/src/kaggle_benchmarks/actors/llms.py index fd319a1..cc31193 100644 --- a/src/kaggle_benchmarks/actors/llms.py +++ b/src/kaggle_benchmarks/actors/llms.py @@ -11,95 +11,62 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Defines a chat agent that interacts with a Large Language Model (LLM). -""" -Defines a chat agent that interacts with a Large Language Model (LLM). +The core class is `LLMChat`, which provides a unified interface for sending +messages, handling structured outputs, managing tool calls, and processing +multimodal inputs. + +The primary entry point for interaction is the `prompt` method, which handles the +conversation loop, including: +1. Sending user input (text and optional images). +2. Invoking the LLM. +3. Executing requested tools and feeding results back to the LLM. +4. Parsing the final response into a requested schema (str, Pydantic model, etc.). Design Note: - LLMChat is stateless. No system instructions or temperature settings are - managed within the class itself. All state, including instructions and - chat history, is maintained within the `chats.Chat` object. This allows - for clean separation of concerns and enables nested threads to encapsulate - inner chat history, preventing it from being visible or sent to the LLM - in outer threads. - - -Example: - -class Goose(LLMChat): - def __init__(self, sound): - super().__init__(name="goose", avatar="🪿") - self.sound = sound - - def invoke(self, messages, system: str = ""): - return LLMResponse(content=self.sound if not system else system) - -goose = Goose('honk') - - -print(goose.send("Hi")) -# 🪿 [goose]: Hi - - -with chats.new(system_instructions="quack") as t: - goose.send("Hi!") - goose.respond() - goose.send("What's up?") - goose.respond() - -print(t) -# 🧵Chat: -# ⚙️ [System]: quack -# 🪿 [goose]: Hi! -# 🪿 [goose]: honk -# 🪿 [goose]: What's up? -# 🪿 [goose]: quack quack - -# system message is separately managed by chats module, so goose doesn't use system message of the Chat obj unless explicitly passed in to respond() - -with chats.new(name="Outer") as outer_t: - goose.send("Outer message 1") - goose.respond() - with chats.new(name="Inner", system_instructions="inner") as inner_t: - goose.send("Inner message 1") - goose.respond() - goose.send("Inner message 2") - goose.respond() - goose.send("Outer message 2") - goose.respond() - - -# Inner messages are not part of the outer chat's history. -print(outer_t) -# 🧵Outer: -# 🪿 [goose]: Outer message 1 -# 🪿 [goose]: honk -# 🧵Inner: -# ⚙️ [System]: inner -# 🪿 [goose]: Inner message 1 -# 🪿 [goose]: honk -# 🪿 [goose]: Inner message 2 -# 🪿 [goose]: honk -# 🪿 [goose]: Outer message 2 -# 🪿 [goose]: honk + The `LLMChat` class is designed to be stateless. It does not hold any + conversation history or configuration like temperature settings internally. + Instead, all state, including system instructions and the sequence of + messages, is managed within the current `chats.Chat` context. -""" + Methods like `prompt()` are stateful in their interaction with this context. + They append messages to the current chat history and trigger LLM responses, + effectively advancing the conversational state. This design allows for clean + separation of concerns and enables features like nested conversation threads. + + +Examples: + + # 1. Basic Text Interaction + >>> llm.prompt("What is the capital of France?") + 'Paris' + + # 2. Structured Output + >>> class Sentiment(pydantic.BaseModel): + ... score: float + ... label: str + >>> llm.prompt("I love this library!", schema=Sentiment, system="...") + Sentiment(score=0.9, label='positive') -import dataclasses -import enum -import json -import typing -from typing import TYPE_CHECKING, Any, Iterator, TypeVar + # 3. Tool Calling + >>> def roll_dice(sides: int) -> int: + ... return 4 # chosen by fair dice roll + >>> llm.prompt("Roll a dice", tools=[roll_dice]) + 'You rolled a 4.' -import openai -from google import genai -from google.genai import types + # 4. Multimodal Input + >>> image = images.from_url("https://example.com/cat.jpg") + >>> llm.prompt("What animal is this?", image=image) + 'It is a cat.' + +""" + +from typing import Any, TypeVar from kaggle_benchmarks import actors, chats, messages, prompting, utils -from kaggle_benchmarks._config import config from kaggle_benchmarks.content_types import audios, images, videos -from kaggle_benchmarks.serializers import genai as genai_serializer -from kaggle_benchmarks.serializers import openai as openai_serializer +from kaggle_benchmarks.llm_messages import LLMMessage, Usage if TYPE_CHECKING: from kaggle_benchmarks import llm_messages @@ -107,32 +74,27 @@ def invoke(self, messages, system: str = ""): T = TypeVar("T") -# TODO: Figure out a more robust way to handle extra fields. -def _extract_extra_usage_metadata(usage: Any) -> dict[str, Any]: - """Extracts cost metadata from a usage object augmented by Model Proxy.""" - cost = getattr(usage, "cost", None) or {} - return { - "input_tokens_cost_nanodollars": cost.get("input_tokens_cost_nanodollars"), - "output_tokens_cost_nanodollars": cost.get("output_tokens_cost_nanodollars"), - "total_backend_latency_ms": getattr(usage, "total_backend_latency_ms", None), - } +class APIError(Exception): + pass -@dataclasses.dataclass(frozen=True) -class LLMResponse: - content: str - tool_calls: list[Any] | None = None - meta: dict[str, Any] = dataclasses.field(default_factory=dict) +class ToolInvocationLimitExhausted(Exception): + pass class LLMChat(actors.Actor): - """A chat agent that interacts with a Large Language Model (LLM).""" + """Base class for chat actors that interact with a Large Language Model API.""" + + roles_mapping = {} def __init__( self, *, support_structured_outputs: bool = False, support_temperature: bool = False, + support_tool_calling: bool = True, + support_vision: bool = True, + postprocessor=lambda x: x, **kwargs, ): kwargs.setdefault("role", "assistant") @@ -140,118 +102,157 @@ def __init__( super().__init__(**kwargs) self.support_structured_outputs = support_structured_outputs self.support_temperature = support_temperature - self.stream_responses = config.interactive_mode - - def invoke( - self, messages: list[messages.Message], system: str | None, **kwargs - ) -> LLMResponse | Iterator[LLMResponse] | "llm_messages.LLMMessage[str]": - """Invokes the LLM with the given messages and system instructions.""" - raise NotImplementedError + self.support_tool_calling = support_tool_calling + self.support_vision = support_vision + self.postprocessor = postprocessor def prompt( self, message: str, schema: type[T] = str, - seed: int = 0, - temperature: float = 0, + seed: int | None = None, + temperature: float | None = 0, tools: list[Any] | None = None, image: images.ImageContent | None = None, video: videos.VideoContent | None = None, audio: audios.AudioContent | None = None, + max_tool_calls: int = 5, ) -> T: - if image is not None: - match image: - case images.ImageURL(): - image_to_send = images.from_image_url(image) - case images.ImageBase64(): - image_to_send = image - case _: - raise ValueError(f"Unsupported image type: {type(image)}") + """Sends a user message to the LLM and returns the structured response. + + This convenience method handles the entire conversation loop, including sending + the initial message, managing tool calls, and parsing the final response into + the desired schema. + + Args: + message: The user's message. + schema: The expected Pydantic model or type of the response. + seed: A random seed for the LLM. + temperature: The sampling temperature for the LLM. + tools: A list of tools available to the LLM. + image: An optional image to include with the message. - actors.user.send(image_to_send) + Returns: + The processed and validated response from the LLM, matching the `schema`. + """ + from kaggle_benchmarks import tools as tool_utils - if video is not None: + if image is not None: + if not isinstance(image, images.ImageContent): + raise TypeError(f"Unsupported image type: {type(image)}") + if not self.support_vision: + raise ValueError(f"Vision not supported by {self.name}") + image.caption = message + actors.user.send(image) + + elif video is not None: if not isinstance(video, videos.VideoContent): raise ValueError(f"Unsupported video type: {video!r}") actors.user.send(video) + actors.user.send(message) if audio is not None: if not isinstance(audio, audios.AudioContent): raise ValueError(f"Unsupported audio type: {audio!r}") + audio.caption = message actors.user.send(audio) + else: + actors.user.send(message) - actors.user.send(message) - return self.respond( - schema=schema, - seed=seed, - temperature=temperature if self.support_temperature else None, - tools=tools if tools is not None else [], - ).content + final_response = LLMMessage( + sender=self, content=None, usage=Usage(0, 0), tool_calls=[] + ) + + try: + # Fork the chat to isolate the tool-calling loop from the main + # conversation. This prevents format instructions and tool invocations + # from appearing in the primary chat history. + with chats.fork() as subchat: + final_response.chat = subchat + + for _ in range(max_tool_calls): + response = self.respond( + schema=schema, + seed=seed, + temperature=temperature, + tools=tools if tools is not None else [], + ) + + # final_response.tool_calls.extend(response.tool_calls or []) + final_response.content = response.content + final_response.usage += response.usage + + if tools and response.tool_calls: + for call in response.tool_calls: + result = tool_utils.invoke_tool(call, tools) + final_response.tool_calls.append(result) + actors.Tool(name=call.name).send(result) + else: + break + else: + raise ToolInvocationLimitExhausted() + finally: + chats.send(final_response) + + return final_response.content @chats.emits_message def respond( self, + *, system: str | None = None, schema: type[T] = str, - **kwargs, - ) -> messages.Message[T]: - from kaggle_benchmarks import contexts, llm_messages + temperature: float | None = 0, + seed: int | None = None, + tools: list[Any] | None = None, + ) -> LLMMessage[T]: + """Generates a response from the LLM, handling schema processing and tool calls.""" + from kaggle_benchmarks import contexts + + if tools and not self.support_tool_calling: + return self._simulate_tool_calling( + tools=tools, + schema=schema, + system=system, + temperature=temperature, + seed=seed, + ) ctx = contexts.get_current() chat = ctx.chat h = prompting.process_schema(schema) - - temp_messages = [] - schema_instructions = next(h) - match schema_instructions: - case [msg, schema]: - if self.support_structured_outputs: - kwargs["response_format"] = schema - else: - temp_messages.append( - messages.Message(sender=actors.system, content=msg) - ) - case None: - pass - case _: - temp_messages.append( - messages.Message(sender=actors.system, content=schema_instructions) - ) + if isinstance(schema_instructions, tuple): + schema_instructions, schema = schema_instructions - response = messages.Message( - sender=self, - content="", - _status=utils.Status.RUNNING, + response = self.invoke( + messages=chat.messages, + schema_instructions=schema_instructions, + system=system, + schema=schema, + temperature=temperature, + seed=seed, + tools=tools, ) - raw_messages = [ - msg for msg in chat.messages if msg.is_visible_to_llm - ] + temp_messages - - invoke_response = self.invoke( - raw_messages, - system=system, - **kwargs, + response._meta.update( + chat=chat, + schema=schema, + raw_content=response.content, + temperature=temperature, + seed=seed, + tools=tools, ) - if isinstance(invoke_response, LLMResponse): - # A response can have either content, tool_calls, or both in some cases. - response.content = invoke_response.content or "" - response._meta["tool_calls"] = invoke_response.tool_calls - response._meta.update(invoke_response.meta) - elif isinstance(invoke_response, Iterator): - response.stream(invoke_response) - elif isinstance(invoke_response, llm_messages.LLMMessage): - response = invoke_response - else: - raise TypeError("Unknown response type from LLM.") - answer = response.content - response._meta.update(chat=chat, schema=schema, raw_content=answer, **kwargs) + if not response.content: + # e.g., waiting for tool invocation or an error occurred. + return response try: - h.send(answer) # must raise StopIteration by returning the parsed value + h.send( + response.content + ) # must raise StopIteration by returning the parsed value raise prompting.SchemaError( f"Generator for {schema!r} yielded multiple values, expected only one." ) @@ -263,203 +264,106 @@ def respond( ) ) response.status = utils.Status.FAILED - raise e - + raise except StopIteration as e: - # StopIteration is expected as this is how you get returned value from a generator + # StopIteration is the expected signal for a successful parse. response.content = e.value response.status = utils.Status.SUCCESS chat.append(response) - return response - - def __repr__(self): - name = self.name - return f"{type(self).__name__}({name=})" - - -class OpenAI(LLMChat): - def __init__(self, client: openai.OpenAI, model: str, **kwargs): - kwargs.setdefault("name", model) - super().__init__(**kwargs) - self.model = model - self.client = client - self.serializer = openai_serializer.ModelProxyOpenAISerializer( - roles_mapping={"tool": "system"} - ) - - def _get_usage_meta( - self, usage: openai.types.CompletionUsage | None - ) -> dict[str, Any]: - """Extracts token usage metadata from an OpenAI response object.""" - if usage is None: - return {} - return { - "input_tokens": usage.prompt_tokens, - "output_tokens": usage.completion_tokens, - **_extract_extra_usage_metadata(usage), - } - - def _should_remove_seed(self) -> bool: - unsupported_prefixes = ("google/", "openai/gpt-5.4-pro") - return any(self.model.startswith(prefix) for prefix in unsupported_prefixes) + return response # type: ignore def invoke( - self, messages: list[messages.Message], system: str | None, **kwargs - ) -> LLMResponse | Iterator[LLMResponse]: - if system: - from kaggle_benchmarks.messages import Message - - messages = [Message(sender=actors.system, content=system)] + messages - - raw_messages = list(self.serializer.dump_messages(messages)) - - if self._should_remove_seed(): - # TODO(b/430112500): Remove once model proxy supports it for AIS backends. - # Temporarily do not send "seed" parameter for models not supporting it in Model Proxy. - kwargs.pop("seed", None) - - return self._call_api(raw_messages, **kwargs) - - def _get_stream_response( - self, response_stream: openai.Stream - ) -> Iterator[LLMResponse]: - """Yields LLMResponse objects from a streaming response.""" - for chunk in response_stream: - if not chunk.choices: - continue - - delta = chunk.choices[0].delta - - # Guard against chunks where 'delta' is None - if not delta: - continue - - yield LLMResponse( - content=delta.content or "", - tool_calls=delta.tool_calls, - meta=self._get_usage_meta(chunk.usage), + self, + messages: list[messages.Message], + *, + schema_instructions: str | None = None, + schema: type[T] = str, + system: str | None = None, + temperature: float | None = 0, + seed: int | None = None, + tools: list[Any] | None = None, + ) -> LLMMessage[str]: + """Invokes the LLM with a given context, handling structured output simulation.""" + if schema is not str and not self.support_structured_outputs: + result = self._simulate_structured_response( + messages=messages, + system=system, + temperature=temperature, + seed=seed, + tools=tools, + schema_instructions=schema_instructions, ) - - def _call_api( - self, messages: list[dict[str, str]], **kwargs - ) -> LLMResponse | Iterator[LLMResponse]: - if self.support_structured_outputs and "response_format" in kwargs: - # quickfix for nested models in ModelProxy API - if utils.has_nested_models(kwargs["response_format"]): - method = self.client.chat.completions.create - response_format = kwargs.pop("response_format") - json_schema = json.dumps(response_format.model_json_schema()) - messages.append( - { - "role": "user", - "content": ( - "The output must be a valid JSON object that strictly adheres to the following JSON schema:\n" - f"{json_schema}" - ), - } - ) - else: - method = self.client.beta.chat.completions.parse else: - if self.stream_responses: - kwargs["stream"] = True - - method = self.client.chat.completions.create - - response = method( - model=self.model, - messages=messages, - **kwargs, - ) - - if isinstance(response, openai.Stream): - return self._get_stream_response(response) - else: - message = response.choices[0].message - tool_calls = message.tool_calls - return LLMResponse( - content=message.content or "", - tool_calls=[t.model_dump() for t in tool_calls] if tool_calls else None, - meta=self._get_usage_meta(response.usage), + result = self._invoke( + messages=messages, + system=system, + schema=schema, + temperature=temperature, + seed=seed, + tools=tools, ) + return self.postprocessor(result) + def _invoke( + self, + messages: list[messages.Message], + *, + schema: type[T | str] = str, + system: str | None = None, + temperature: float | None = 0, + seed: int | None = None, + tools: list[Any] | None = None, + ) -> LLMMessage[str]: + """Abstract method for native LLM invocation.""" + raise NotImplementedError -class GoogleGenAI(LLMChat): - def __init__(self, client: genai.Client, model: str, **kwargs): - kwargs.setdefault("name", model) - super().__init__(**kwargs) - self.model = model - self.client = client - self.serializer = genai_serializer.GenAISerializer( - roles_mapping={"assistant": "model", "system": "user", "tool": "user"} + def __repr__(self): + name = self.name + arguments = ", ".join( + f"{k}={v!r}" for k, v in self.__dict__.items() if k.startswith("support") ) + return f"{type(self).__name__}({name=}, {arguments})" - def _get_usage_meta(self, usage: types.UsageMetadata | None) -> dict[str, Any]: - if usage is None: - return {} - return { - "input_tokens": usage.prompt_token_count, - "output_tokens": usage.candidates_token_count, - **_extract_extra_usage_metadata(usage), - } - - def _get_stream_response( - self, response_stream: Iterator[types.GenerateContentResponse] - ) -> Iterator[LLMResponse]: - # We currently only support text outputs - for chunk in response_stream: - yield LLMResponse( - content=chunk.text or "", - meta=self._get_usage_meta(chunk.usage_metadata), - ) + def _simulate_tool_calling( + self, + tools: list[Any], + schema: type[T], + system: str | None = None, + temperature: float | None = None, + seed: int | None = None, + ) -> LLMMessage[T]: + """Simulates tool calling for models that do not support it natively.""" + from kaggle_benchmarks.tools import simulate + + return simulate.simulate_respond_with_tools( + self, + tools=tools, + output_schema=schema, + system=system, + temperature=temperature, + seed=seed, + ) - def invoke( - self, messages: list[messages.Message], system: str | None, **kwargs - ) -> LLMResponse | Iterator[LLMResponse]: - raw_messages = list(self.serializer.dump_messages(messages)) - - config_params = {} - if system: - config_params["system_instruction"] = system - if "response_format" in kwargs: - schema = kwargs.pop("response_format") - config_params["response_schema"] = schema - - # Determine the correct MIME type based on the schema's type - is_enum = isinstance(schema, type) and issubclass(schema, enum.Enum) - is_literal = typing.get_origin(schema) is typing.Literal - - if is_enum or is_literal: - config_params["response_mime_type"] = "text/x.enum" - else: - # Assume any other schema (like a Pydantic model) is for JSON - config_params["response_mime_type"] = "application/json" - - config = types.GenerateContentConfig(**kwargs, **config_params) - - return self._call_api(contents=raw_messages, config=config) - - def _call_api( - self, contents: list[types.Content], config: types.GenerateContentConfig - ) -> LLMResponse | Iterator[LLMResponse]: - if self.stream_responses: - response_stream = self.client.models.generate_content_stream( - model=self.model, contents=contents, config=config - ) - return self._get_stream_response(response_stream) - else: - response = self.client.models.generate_content( - model=self.model, contents=contents, config=config - ) - # Handle cases where the model refuses to respond - if not response.candidates: - return LLMResponse( - content="", - meta=self._get_usage_meta(response.usage_metadata), - ) + def _simulate_structured_response( + self, + messages: list[messages.Message], + *, + system: str | None = None, + schema_instructions: str | None = None, + temperature: float | None = 0, + seed: int | None = None, + tools: list[Any] | None = None, + ) -> LLMMessage[str]: + """Simulates structured output generation for text-only models.""" + if schema_instructions: + messages.append(actors.system.send(schema_instructions)) - return LLMResponse( - content=response.text, - meta=self._get_usage_meta(response.usage_metadata), - ) + return self.invoke( + messages=messages, + system=system, + schema=str, + temperature=temperature, + seed=seed, + tools=tools, + ) diff --git a/src/kaggle_benchmarks/assertions.py b/src/kaggle_benchmarks/assertions.py index 4eb9b97..e524364 100644 --- a/src/kaggle_benchmarks/assertions.py +++ b/src/kaggle_benchmarks/assertions.py @@ -25,7 +25,7 @@ import panel as pn import pydantic -from kaggle_benchmarks import chats +from kaggle_benchmarks import chats, llm_messages, tools @dataclasses.dataclass @@ -508,3 +508,31 @@ def assess_response_with_judge( assess_report = None return assess_report + + +@assertion_handler() +def assert_tool_was_invoked( + tool: str | Callable, expectation: str | None = None +) -> AssertionResult: + if not isinstance(tool, str): + tool = tool.__name__ + chat = chats.get_current_chat() + passed = False + for msg in chat.messages: + if ( + isinstance(msg.content, tools.ToolInvocationResult) + and msg.content.name == tool + ): + passed = True + break + + elif isinstance(msg, llm_messages.LLMMessage) and any( + t.name == tool for t in msg.tool_calls + ): + passed = True + break + + return AssertionResult( + passed=passed, + expectation=expectation or "Expected to call `{tool}`", + ) diff --git a/src/kaggle_benchmarks/kaggle/model_proxy.py b/src/kaggle_benchmarks/kaggle/model_proxy.py index 1f2155e..fbcf6c0 100644 --- a/src/kaggle_benchmarks/kaggle/model_proxy.py +++ b/src/kaggle_benchmarks/kaggle/model_proxy.py @@ -19,8 +19,8 @@ from google import genai from google.genai import types -from kaggle_benchmarks import utils -from kaggle_benchmarks.actors.llms import GoogleGenAI, LLMChat, OpenAI +from kaggle_benchmarks import providers, utils +from kaggle_benchmarks.actors import LLMChat class ModelProxy: @@ -54,7 +54,11 @@ def __new__( base_url=resolved_base_url, ), ) - llm_instance = GoogleGenAI(client, model, **kwargs) + llm_instance = providers.genai.ModelProxyGenAI( + client, + model, + **kwargs, + ) elif api == "openai": client = openai.OpenAI( @@ -65,11 +69,9 @@ def __new__( # TODO (b/439876083): Disable temperature parameter till this is resolved. kwargs.setdefault("support_temperature", False) - llm_instance = OpenAI(client, model, **kwargs) + llm_instance = providers.openai.ModelProxyOpenAI(client, model, **kwargs) else: raise ValueError(f"Unsupported API: '{api}'. Must be 'openai' or 'genai'.") - if llm_instance: - llm_instance.stream_responses = False return llm_instance diff --git a/src/kaggle_benchmarks/messages.py b/src/kaggle_benchmarks/messages.py index 5446dda..f8be39d 100644 --- a/src/kaggle_benchmarks/messages.py +++ b/src/kaggle_benchmarks/messages.py @@ -56,6 +56,10 @@ def text(self): def tool_calls(self): return self._meta.get("tool_calls") + @tool_calls.setter + def tool_calls(self, value): + self._meta["tool_calls"] = value + @property def usage(self): """Token usage and cost metadata for this message.""" diff --git a/src/kaggle_benchmarks/prompting.py b/src/kaggle_benchmarks/prompting.py index 8f9fb79..b052408 100644 --- a/src/kaggle_benchmarks/prompting.py +++ b/src/kaggle_benchmarks/prompting.py @@ -163,7 +163,6 @@ def root_model_handler(cls): __base__=(RenderablePydanticModel, cls), **{field.name: (field.type, ...) for field in dataclasses.fields(cls)}, ) - response = yield ( f"Output JSON using this schema: {json.dumps(schema)}", model_cls, diff --git a/src/kaggle_benchmarks/providers/__init__.py b/src/kaggle_benchmarks/providers/__init__.py new file mode 100644 index 0000000..d6cbaae --- /dev/null +++ b/src/kaggle_benchmarks/providers/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2026 Kaggle Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from kaggle_benchmarks.providers import genai, openai diff --git a/src/kaggle_benchmarks/providers/genai.py b/src/kaggle_benchmarks/providers/genai.py new file mode 100644 index 0000000..b111570 --- /dev/null +++ b/src/kaggle_benchmarks/providers/genai.py @@ -0,0 +1,335 @@ +# Copyright 2026 Kaggle Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import enum +import logging +import mimetypes +import typing +from typing import Any, TypeVar + +import pydantic +from google import genai +from google.genai import types + +from kaggle_benchmarks import actors, chats, messages, utils +from kaggle_benchmarks import tools as tool_utils +from kaggle_benchmarks.actors import llms +from kaggle_benchmarks.content_types import images + +# Define a generic type variable for the output schema +T = TypeVar("T", bound=pydantic.BaseModel) + + +def _get_usage_meta( + usage: types.GenerateContentResponseUsageMetadata | None, +) -> llms.Usage | None: + if usage is None: + return None + return llms.Usage( + input_tokens=usage.prompt_token_count, + output_tokens=usage.candidates_token_count, + ) + + +class GoogleGenAI(llms.LLMChat): + """An actor that interacts with the Google GenAI API (e.g., Gemini).""" + + def __init__(self, client: genai.Client, model: str, **kwargs): + kwargs.setdefault("name", model) + + super().__init__(**kwargs) + self.model = model + self.client = client + + def _convert_to_genai_types( + self, messages: list[messages.Message] + ) -> list[types.Content]: + """Converts internal messages to Google GenAI's `Content` format.""" + raw_messages = [] + for message in messages: + role = "model" if message.sender.role == "assistant" else "user" + + parts = [] + if isinstance(message.content, str): + parts.append(types.Part(text=message.content)) + elif isinstance(message.content, images.ImageContent): + image = message.content + if image.caption: + parts.append(types.Part.from_text(text=image.caption)) + parts.append( + types.Part( + inline_data=types.Blob( + # The API expects the raw base64 string, not bytes. + data=image.b64_string, + mime_type=image.mime_type, + ) + ) + ) + + # Note: The Gemini API is smart enough to process image data URLs even when they are passed as part of a plain text string. + elif ( + isinstance(message.content, list) + and message.content + and isinstance(message.content[0], dict) + ): + for item in message.content: + if item.get("type") == "image_url": + url = item["image_url"]["url"] + + image_bytes = None + mime_type = "image/jpeg" + if url.startswith("data:"): + # Handle base64 data URLs + header, b64_string = url.split(",", 1) + mime_type = header.split(";")[0].split(":")[1] + image_bytes = base64.b64decode(b64_string) + else: + # Handle remote http/https URLs + b64_string = images.image_url_to_base64(url) + image_bytes = base64.b64decode(b64_string) + mime_type = mimetypes.guess_type(url)[0] or "image/jpeg" + + if image_bytes: + parts.append( + types.Part.from_bytes( + data=image_bytes, mime_type=mime_type + ) + ) + else: + # Fallback for any other unexpected payload types + parts.append(types.Part(text=message.text)) + + raw_messages.append(types.Content(role=role, parts=parts)) + + return raw_messages + + def respond( + self, + system: str | None = None, + schema: type[T | str] = str, + temperature: float | None = 0, + seed: int | None = None, + tools: list[Any] | None = None, + ) -> llms.LLMMessage[T]: + if tools and schema is not str: + # GenAI doesn't support both tools and response_schema simultaneously. + # As a workaround, we ask model to generate a json and parse it manually. + if not isinstance(schema, pydantic.BaseModel): + schema = pydantic.create_model("Response", value=(schema, ...)) + + # Temporarily disable structured output support to force tool emulation. + flag = self.support_structured_outputs + try: + self.support_structured_outputs = False + response = super().respond( + system=system, + schema=schema, + temperature=temperature, + seed=seed, + tools=tools, + ) + finally: + self.support_structured_outputs = flag + + if response.content: + response.content = response.content.value + return response + + return super().respond( + system=system, + schema=schema, + temperature=temperature, + seed=seed, + tools=tools, + ) + + def _invoke( + self, + messages: list[messages.Message], + *, + schema: type[T | str] = str, + system: str | None = None, + temperature: float | None = 0, + seed: int | None = None, + tools: list[Any] | None = None, + ) -> llms.LLMMessage[str]: + """Prepares and executes the GenAI API call.""" + raw_messages = self._convert_to_genai_types(messages) + + config_params = {} + if tools and schema is not str: + return self.invoke( + messages=messages, + schema_instructions=None, + schema=str, + system=system, + temperature=temperature, + seed=seed, + tools=tools, + ) + + if system: + config_params["system_instruction"] = system + + if schema is not str and self.support_structured_outputs: + config_params["response_json_schema"] = schema.model_json_schema() + + # Determine the correct MIME type based on the schema's type + is_enum = isinstance(schema, type) and issubclass(schema, enum.Enum) + is_literal = typing.get_origin(schema) is typing.Literal + + if is_enum or is_literal: + config_params["response_mime_type"] = "text/x.enum" + else: + # Assume any other schema (like a Pydantic model) is for JSON + config_params["response_mime_type"] = "application/json" + + tools_declaration = None + if tools and self.support_tool_calling: + tools_declaration = tools + + config = types.GenerateContentConfig( + temperature=temperature, + seed=seed, + tools=tools_declaration, + automatic_function_calling=types.AutomaticFunctionCallingConfig( + disable=False + ), + **config_params, + ) + + return self._call_api(contents=raw_messages, config=config) + + def _call_api( + self, contents: list[types.Content], config: types.GenerateContentConfig + ) -> llms.LLMMessage[str]: + response = self.client.models.generate_content( + model=self.model, contents=contents, config=config + ) + # Handle cases where the model refuses to respond + if not response.candidates or not response.candidates[0].content.parts: + logging.warning( + "API failed to produce a response for the following request:\n" + f"model: {self.model}\ncontents: {contents}\nconfig: {config}" + ) + raise llms.APIError( + "API failed to produce a response for the following request:\n" + f"model: {self.model}\ncontents: {contents}\nconfig: {config}" + ) + + tool_calls = self.extract_tool_calls(response) + + for tool_invocation in self._iter_tool_calls(response): + chats.send(messages.Message(sender=actors.Tool(), content=tool_invocation)) + + return llms.LLMMessage( + sender=self, + content=response.text, + tool_calls=tool_calls if tool_calls else None, + usage=_get_usage_meta(response.usage_metadata), + ) + + def extract_tool_calls(self, response): + tool_calls = list(self._iter_tool_calls(response)) + for part in response.candidates[0].content.parts: + if part.function_call: + tool_calls.append( + tool_utils.ToolInvocation( + name=part.function_call.name, + call_id=f"call_{part.function_call.name}", + arguments=part.function_call.args, + ) + ) + return tool_calls + + def _iter_tool_calls(self, response): + # TODO: review this function for potentiall issues + calls = [] + if response.automatic_function_calling_history: + for item in response.automatic_function_calling_history: + for part in item.parts: + if part.function_call: + calls.append(part.function_call) + if part.function_response: + yield tool_utils.ToolInvocationResult( + name=part.function_response.name, + call_id=f"call_{part.function_response.name}", + arguments=calls.pop(0).args, + output=part.function_response.response["result"], + ) + + if not part.function_call and not part.function_response: + logging.warning(f"Unknown part {part}") + + +class StreamingGoogleGenAI(GoogleGenAI): + """A `GoogleGenAI` actor that handles streaming responses.""" + + def _call_api( + self, contents: list[types.Content], config: types.GenerateContentConfig + ) -> llms.LLMMessage: + response_stream = self.client.models.generate_content_stream( + model=self.model, contents=contents, config=config + ) + msg = llms.LLMMessage(sender=self, content="") + usage = None + tool_calls = None + for chunk in response_stream: + if isinstance(chunk.text, str): + msg.add_chunk(chunk.text) + usage = _get_usage_meta(chunk.usage_metadata) + + if isinstance(chunk, types.GenerateContentResponse): + tool_calls = self.extract_tool_calls(chunk) + + msg.usage = usage + msg.tool_calls = tool_calls + return msg + + +class ModelProxyGenAI(GoogleGenAI): + """A `GoogleGenAI` actor variant for use with a model proxy. + + This class may include workarounds for specific proxy behaviors. + """ + + def __init__(self, client: genai.Client, model: str, **kwargs): + if "gemini" in model: + kwargs.setdefault("support_structured_outputs", True) + + if "gemini-2.5-flash" in model: + # The proxy returns a 400 error if tools are set with this model. + kwargs["support_tool_calling"] = False + + elif "deepseek" in model: + kwargs["postprocessor"] = utils.extract_thinking_tag + kwargs.setdefault("support_structured_outputs", False) + kwargs.setdefault("support_tool_calling", False) + kwargs.setdefault("support_vision", "r1" not in model) + + elif "anthropic" in model: + # kwargs.setdefault("support_structured_outputs", False) + kwargs["support_structured_outputs"] = False + kwargs.setdefault("support_tool_calling", False) + kwargs["postprocessor"] = utils.extract_json_tag + + elif "gemma" in model: + kwargs.setdefault("support_vision", True) + else: + kwargs.setdefault("support_structured_outputs", False) + kwargs.setdefault("support_tool_calling", False) + kwargs.setdefault("support_vision", False) + + super().__init__(client, model, **kwargs) diff --git a/src/kaggle_benchmarks/providers/openai.py b/src/kaggle_benchmarks/providers/openai.py new file mode 100644 index 0000000..aa110e5 --- /dev/null +++ b/src/kaggle_benchmarks/providers/openai.py @@ -0,0 +1,329 @@ +# Copyright 2026 Kaggle Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Any, TypeVar + +import openai +import pydantic +from openai.types import responses as responses_types + +from kaggle_benchmarks import messages, utils +from kaggle_benchmarks.actors import llms +from kaggle_benchmarks.serializers import openai as openai_serializer + +T = TypeVar("T") + + +def parse_usage( + usage: openai.types.CompletionUsage | responses_types.ResponseUsage, +) -> llms.Usage: + """Converts an OpenAI usage object to the internal `llms.Usage` format.""" + if isinstance(usage, openai.types.CompletionUsage): + return llms.Usage( + input_tokens=usage.prompt_tokens, + output_tokens=usage.completion_tokens, + input_tokens_cost_nanodollars=usage.cost.get( + "input_tokens_cost_nanodollars" + ), + output_tokens_cost_nanodollars=usage.cost.get( + "output_tokens_cost_nanodollars" + ), + total_backend_latency_ms=usage.total_backend_latency_ms, + ) + return llms.Usage( + input_tokens=usage.input_tokens, + output_tokens=usage.output_tokens, + ) + + +class OpenAIResponsesAPI(llms.LLMChat): + """An actor that interacts with an OpenAI-compatible API.""" + + def __init__(self, client: openai.OpenAI, model: str, **kwargs): + kwargs.setdefault("name", model) + if model in ["gpt-3.5-turbo"]: + kwargs["support_structured_outputs"] = False + kwargs["support_vision"] = False + kwargs["support_temperature"] = False + super().__init__(**kwargs) + self.model = model + self.client = client + self.serializer = openai_serializer.OpenAIResponsesSerializer() + + def _invoke( + self, + messages: list[messages.Message], + *, + schema: type[T | str] = str, + system: str | None = None, + temperature: float | None = 0, + seed: int | None = None, + tools: list[Any] | None = None, + ) -> llms.LLMMessage[str]: + raw_messages = list(self.serializer.dump_messages(messages)) + + api_kwargs = {"tools": tools, "response_format": schema} + if self.support_temperature: + api_kwargs["temperature"] = temperature + # if seed: + # api_kwargs["seed"] = seed + if system: + api_kwargs["instructions"] = system + + return self._call_api(raw_messages, **api_kwargs) + + def dump_tools(self, tools: list[Any]) -> list[dict]: + """Converts a list of functions to the OpenAI tool specification.""" + from kaggle_benchmarks import tools as tool_utils + + return [tool_utils.functions.function_to_openai_tool(tool) for tool in tools] + + def _call_api( + self, + messages: list, + tools: list[Any] | None = None, + response_format: Any = str, + **kwargs, + ) -> llms.LLMMessage[str]: + """Makes the API call to the OpenAI-compatible endpoint.""" + + if tools: + kwargs["tools"] = self.dump_tools(tools) + + try: + if response_format is str: + response = self.client.responses.create( + model=self.model, + input=messages, + **kwargs, + ) + else: + response_format.__name__ = "Response" + response = self.client.responses.parse( + model=self.model, + input=messages, + text_format=response_format, + **kwargs, + ) + except openai.BadRequestError as e: + # logging.warning( + # f"encounter {e}. Trying out disabling structured output." + # ) + raise llms.APIError( + f"{self!r} encountered an API invocation error. " + f"input: {messages!r}" + f"arguments: {kwargs!r}" + f"error: {e}" + ) + + return self.process_response(response, tools=tools) + + def process_response( + self, response, message: llms.LLMMessage | None = None, tools=() + ) -> llms.LLMMessage: + """Processes the API response to extract content and tool calls.""" + from kaggle_benchmarks import tools as tool_utils + + tool_calls = [] + content = "" + for item in response.output: + if item.type == "function_call": + tool_calls.append( + tool_utils.invoke_tool( + tool_utils.ToolInvocation( + name=item.name, + call_id=item.call_id, + arguments=json.loads(item.arguments), + ), + tools, + ) + ) + elif item.type == "message": + content += "".join(x.text for x in item.content) + + if message is None: + return llms.LLMMessage( + sender=self, + content=content, + tool_calls=tool_calls, + usage=parse_usage(response.usage), + ) + + message.content = content + message.tool_calls = tool_calls + message.usage = parse_usage(response.usage) + return message + + +class StreamingOpenAIResponsesAPI(OpenAIResponsesAPI): + """An actor that handles streaming responses.""" + + def _call_api( + self, + messages: list[dict[str, str]], + tools: list[Any] | None = None, + response_format: Any = str, + **kwargs, + ) -> llms.LLMMessage: + """Makes a streaming API call.""" + from kaggle_benchmarks import tools as tool_utils + + tools_definition = self.dump_tools(tools) if tools else [] + result = llms.LLMMessage(sender=self, content="", tool_calls=[], usage=None) + + stream_kwargs = { + "model": self.model, + "input": messages, + "tools": tools_definition, + **kwargs, + } + if response_format and response_format is not str: + stream_kwargs["text_format"] = response_format + + with self.client.responses.stream(**stream_kwargs) as response: + for chunk in response: + if hasattr(chunk, "usage") and chunk.usage: + result.usage = parse_usage(chunk.usage) + if isinstance(chunk, responses_types.ResponseTextDeltaEvent): + result.add_chunk(chunk.delta) + elif isinstance( + chunk, responses_types.ResponseFunctionCallArgumentsDeltaEvent + ): + result.add_chunk(chunk.delta) + elif isinstance(chunk, responses_types.ResponseOutputItemDoneEvent): + if isinstance(chunk.item, responses_types.ResponseFunctionToolCall): + result.tool_calls.append( + tool_utils.ToolInvocation( + name=chunk.item.name, + call_id=chunk.item.call_id, + arguments=json.loads(chunk.item.arguments), + ) + ) + elif isinstance(chunk, responses_types.ResponseCompletedEvent): + return self.process_response(chunk.response, result) + return result + + +class ModelProxyOpenAI(OpenAIResponsesAPI): + """An OpenAI-compatible actor for routing requests through a model proxy. + + This class includes workarounds for inconsistencies observed with various + proxied models (e.g., Gemini, Meta, Gemma, DeepSeek). + """ + + def __init__(self, client: openai.OpenAI, model: str, **kwargs): + if "gemini" in model: + kwargs["support_structured_outputs"] = True + elif "meta" in model: + kwargs["support_structured_outputs"] = False + kwargs["support_tool_calling"] = False + elif "gemma" in model: + kwargs["support_structured_outputs"] = False + kwargs["support_vision"] = False + self.roles_mapping = { + "system": "user", + } + elif "deepseek" in model: + kwargs["support_vision"] = False + kwargs["support_structured_outputs"] = True + kwargs["postprocessor"] = utils.extract_thinking_tag + elif "qwen" in model: + kwargs["support_vision"] = False + # kwargs["support_structured_outputs"] = True + elif "anthropic" in model: + kwargs["postprocessor"] = utils.extract_json_tag + + kwargs.setdefault("support_tool_calling", False) + super().__init__(client, model, **kwargs) + self.serializer = openai_serializer.ModelProxyOpenAISerializer( + roles_mapping=self.roles_mapping, + ) + + def _invoke( + self, + messages: list[messages.Message], + *, + schema: type[T] | type[str] = str, + system: str | None = None, + temperature: float | None = 0, + seed: int | None = None, + tools: list[Any] | None = None, + ) -> llms.LLMMessage[str]: + """Invokes the model, with a fallback for complex nested schemas. + + The model proxy can struggle with deeply nested Pydantic models. This + method detects that and falls back to providing the schema via a system + prompt instead of using the native structured output feature. + """ + if issubclass(schema, pydantic.BaseModel) and has_nested_models(schema): + return self._simulate_structured_response( + messages=messages, + schema_instructions=json.dumps(schema.model_json_schema()), + temperature=temperature, + seed=seed, + tools=tools, + ) + return super()._invoke( + messages, + schema=schema, + system=system, + temperature=temperature, + seed=seed, + tools=tools, + ) + + def _call_api( + self, + messages: list[dict[str, str]], + tools: list[Any] | None = None, + response_format: Any = str, + **kwargs, + ) -> llms.LLMMessage: + """Calls the proxy using the `chat.completions` endpoint. + + The proxy does not handle the `responses` API correctly, so this method + uses the deprecated `chat.completions` endpoint instead. + """ + if self.support_structured_outputs and response_format is not str: + method = self.client.chat.completions.parse + kwargs["response_format"] = response_format + else: + method = self.client.chat.completions.create + + try: + response = method( + model=self.model, + messages=messages, + tools=tools or [], + **kwargs, + ) + except TypeError as e: + # This can happen due to API quota or other proxy-side issues. + raise RuntimeError( + "API call failed, possibly due to an exhausted quota." + ) from e + + message = response.choices[0].message + return llms.LLMMessage( + sender=self, + content=message.content or "", + usage=parse_usage(response.usage), + ) + + +def has_nested_models(model: type[pydantic.BaseModel]) -> bool: + """Checks if a Pydantic model's schema contains nested definitions.""" + schema = model.model_json_schema() + return bool(schema.get("$defs")) diff --git a/src/kaggle_benchmarks/serializers/openai.py b/src/kaggle_benchmarks/serializers/openai.py index 5a1d9b7..2ee1808 100644 --- a/src/kaggle_benchmarks/serializers/openai.py +++ b/src/kaggle_benchmarks/serializers/openai.py @@ -107,6 +107,23 @@ def _dump_invocation( } +class OpenAIResponsesSerializer(OpenAICompletionSerializer): + """Serializer mapping generic messages to the OpenAI Responses API format.""" + + def dump_image(self, message: messages.Message[images.ImageContent]): + """Serializes an image content object into the API's input_image format.""" + image = message.content + yield { + "role": self.get_role(message.sender), + "content": [{"type": "input_text", "text": image.caption}] + if image.caption + else [] + + [ + {"type": "input_image", "image_url": image.url}, + ], + } + + class ModelProxyOpenAISerializer(OpenAICompletionSerializer): """Specialized OpenAI serializer that maps constructs like videos and images specifically for the Kaggle Model Proxy format. diff --git a/src/kaggle_benchmarks/tools/__init__.py b/src/kaggle_benchmarks/tools/__init__.py index 2a8ba40..2084929 100644 --- a/src/kaggle_benchmarks/tools/__init__.py +++ b/src/kaggle_benchmarks/tools/__init__.py @@ -14,6 +14,7 @@ from kaggle_benchmarks.tools import container, functions, python, search, web from kaggle_benchmarks.tools.base import ( + UNKNOWN, ModelResponse, ToolCallModel, ToolInvocation, diff --git a/src/kaggle_benchmarks/tools/base.py b/src/kaggle_benchmarks/tools/base.py index c883473..04a62b4 100644 --- a/src/kaggle_benchmarks/tools/base.py +++ b/src/kaggle_benchmarks/tools/base.py @@ -19,6 +19,8 @@ import pydantic T = TypeVar("T") +F = TypeVar("F") +Name = TypeVar("Name", bound=str) @dataclasses.dataclass @@ -29,6 +31,13 @@ class ToolInvocation: arguments: dict[str, Any] call_id: str | None = None + # Allows serialization as content message + def get_payload(self): + return dataclasses.asdict(self) + + +UNKNOWN = object() + @dataclasses.dataclass class ToolInvocationResult: @@ -37,25 +46,41 @@ class ToolInvocationResult: name: str arguments: dict[str, Any] call_id: str | None = None - output: Any = None + output: Any = UNKNOWN + error: str | None = None def describe(self): + if self.error: + return f"{self.name}({self.arguments}): Error: {self.error}" return f"{self.name}({self.arguments}) -> {self.output}" + def get_payload(self): + return dataclasses.asdict(self) -class ToolCallModel(pydantic.BaseModel): - """Represents a tool call in a structured response.""" + @property + def text(self): + return str(self.output if self.output is not UNKNOWN else self.error) - name: str - arguments: dict[str, Any] + +class ToolCallModel(pydantic.BaseModel, Generic[Name, T]): + # Represents a tool call in a structured response. + # Generic Name allows for overwriting str with Literal['tool_name'] + name: Name + arguments: T -class ModelResponse(pydantic.BaseModel, Generic[T]): - """A structured response from the LLM that may contain tool calls or a message.""" +class ModelResponse(pydantic.BaseModel, Generic[T, F]): + # A structured response from the LLM that may contain tool calls or a message. - tools: list[ToolCallModel] | None = None + tools: list[F] | None = None message: T | None = None + model_config = pydantic.ConfigDict( + title="Response", + extra="forbid", + arbitrary_types_allowed=False, + ) + def describe_tools(tools: list[Callable]) -> str: """Generates a plain English description of the available tools.""" @@ -107,7 +132,7 @@ def invoke_tool(call: ToolInvocation, tools: list[Callable]) -> ToolInvocationRe call_id=call.call_id, ) try: - output = tool(**call.arguments) + output = tool(**(call.arguments or {})) return ToolInvocationResult( name=call.name, arguments=call.arguments, @@ -121,6 +146,16 @@ def invoke_tool(call: ToolInvocation, tools: list[Callable]) -> ToolInvocationRe return ToolInvocationResult( name=call.name, arguments=call.arguments, - output=error_message, + error=error_message, call_id=call.call_id, ) + + +def iter_invocations(chat): + from kaggle_benchmarks import llm_messages + + for item in chat.messages: + if isinstance(item.content, ToolInvocationResult): + yield item.content + elif isinstance(item, llm_messages.LLMMessage): + yield from item.tool_calls or [] diff --git a/src/kaggle_benchmarks/tools/functions.py b/src/kaggle_benchmarks/tools/functions.py index 13a921f..84aeba8 100644 --- a/src/kaggle_benchmarks/tools/functions.py +++ b/src/kaggle_benchmarks/tools/functions.py @@ -23,8 +23,8 @@ class ToolSchemaError(Exception): """Raised when a function schema cannot be generated.""" -def get_function_schema(func: Callable) -> dict: - """Generates a JSON schema for a function's parameters using Pydantic.""" +def _get_function_arguments(func: Callable) -> dict[str, tuple[type, Any]]: + """Generates a Pydantic model for a function's arguments.""" sig = inspect.signature(func) fields = {} @@ -36,12 +36,17 @@ def get_function_schema(func: Callable) -> dict: fields[name] = (annotation, default) + return fields + + +def get_function_schema(func: Callable) -> dict: + """Generates a JSON schema for a function's parameters using Pydantic.""" try: - DynamicModel = pydantic.create_model(f"{func.__name__}", **fields) - return DynamicModel.model_json_schema() + model = pydantic.create_model(func.__name__, **_get_function_arguments(func)) + return model.model_json_schema() except pydantic.PydanticSchemaGenerationError as e: raise ToolSchemaError( - "Unable to generate json schema for function {func.__name__} arguments", e + f"Unable to generate json schema for function {func.__name__} arguments", e ) diff --git a/src/kaggle_benchmarks/tools/simulate.py b/src/kaggle_benchmarks/tools/simulate.py new file mode 100644 index 0000000..de0ce70 --- /dev/null +++ b/src/kaggle_benchmarks/tools/simulate.py @@ -0,0 +1,181 @@ +# Copyright 2026 Kaggle Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Literal, TypeVar, Union + +import pydantic +from typing_extensions import TypedDict + +from kaggle_benchmarks import actors, chats, usage +from kaggle_benchmarks.llm_messages import LLMMessage +from kaggle_benchmarks.tools import base, functions + +T = TypeVar("T") + + +class ToolInvocationLimitExhausted(Exception): + pass + + +def build_response_model(tools: list[Callable], output_schema: type): + """Creates a pydantic model that can be used as response format for LLM that provides option for llm to invoke tools.""" + return base.ModelResponse[ + output_schema, + Union[ + *( + base.ToolCallModel[ + Literal[tool.__name__], + TypedDict( + tool.__name__, + { + field: annotation + for field, ( + annotation, + _, + ) in functions._get_function_arguments(tool).items() + }, + ), + ] + for tool in tools + ) + ], + ] + + +def simulate_respond_with_tools( + llm: actors.LLMChat, + tools: list[Callable], + output_schema: type[T], + system: str | None = None, + temperature: float | None = None, + seed: int | None = None, +) -> LLMMessage[T]: + """Simulates tool calling for models that do not support it natively.""" + if not tools: + return llm.respond( + system=system, + schema=output_schema, + temperature=temperature, + seed=seed, + tools=None, + ) + + chat = chats.get_current_chat() + + previous_invocations = list(base.iter_invocations(chat)) + + if previous_invocations: + invocation_history = "\n".join(i.describe() for i in previous_invocations) + + history_prompt = f"""You have already invocated the following tools: +{invocation_history}""" + else: + history_prompt = "" + + instructions = f"""You can invoke the following tools: +{base.describe_tools(tools)} + +{history_prompt} + +If you decided to invoke a tool, fill the `tools` attribute with the invocation details, like `[{{"name": "function_name", "arguments": {{...}}}}]`. +If you have enough information from previous tool calls or you decide not to use any tools, leave the `tools` field blank and fill the `message` field with your response. +Only one of `tools` or `message`, should be filled with a value. +""" + + with chats.fork(orphan=True) as subchat: + actors.user.send(instructions) + + try: + wrapped_schema = build_response_model(tools, output_schema) + except pydantic.PydanticSchemaGenerationError as e: + raise ValueError( + f"Unable to generate JSON schema for response format {output_schema}." + ) from e + + response = llm.respond( + system=system, + schema=wrapped_schema, + temperature=temperature, + seed=seed, + tools=None, + ) + + value = response.content + if value.tools: + response.tool_calls = [ + base.invoke_tool( + base.ToolInvocation( + name=call.name, + arguments=call.arguments, + call_id=f"call_{call.name}", + ), + tools, + ) + for call in value.tools + ] + response.content = value.message + elif value.message is not None: + response.content = value.message + else: + # some models will not produce anything + # so we ask them once more without tools + response = llm.respond( + system=system, + schema=output_schema, + temperature=temperature, + seed=seed, + tools=None, + ) + response.chat = subchat + return response + + +def simulate_agent( + llm: actors.LLMChat, + tools: list[Callable], + output_schema: type[T] = str, + max_iterations: int = 10, + system: str | None = None, + temperature: float | None = None, + seed: int | None = None, +) -> LLMMessage[T | None]: + """Simulates an agent using tools over multiple iterations.""" + final_response = LLMMessage( + sender=llm, content=None, usage=usage.Usage(0, 0), tool_calls=[] + ) + + for _ in range(max_iterations): + response = simulate_respond_with_tools( + llm, + tools, + output_schema=output_schema, + system=system, + temperature=temperature, + seed=seed, + ) + + final_response.content = response.content + final_response.usage += response.usage + + if tools and response.tool_calls: + for call in response.tool_calls: + result = base.invoke_tool(call=call, tools=tools) + final_response.tool_calls.append(result) + actors.Tool(name=call.name).send(result) + else: + break + else: + raise ToolInvocationLimitExhausted() + + return final_response diff --git a/src/kaggle_benchmarks/utils.py b/src/kaggle_benchmarks/utils.py index 42adc5f..4e0a637 100644 --- a/src/kaggle_benchmarks/utils.py +++ b/src/kaggle_benchmarks/utils.py @@ -225,3 +225,19 @@ def task_autopilot( generated_code = model.prompt(prompt, **kwargs) return extract_code_block(generated_code, name="python", greedy=False) + + +def extract_thinking_tag(response): + if "" in response.content: + thinking, content = response.content.split("", 1) + response.thinking = thinking[6:] + response.content = content + return response + + +def extract_json_tag(response): + if "" in response.content: + thinking, content = response.content.split("", 1) + response.thinking = thinking + response.content = content.split("")[0] + return response diff --git a/tests/mocks.py b/tests/mocks.py index 06780dc..9c23fc3 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -48,7 +48,7 @@ def from_contents_data(cls, contents: list[dict], cycle=False, **kwargs): **kwargs, ) - def invoke(self, messages, **kwargs): + def _invoke(self, messages, **kwargs): self.invocations.append((messages, kwargs)) try: response = next(self.response) diff --git a/tests/test_assertions.py b/tests/test_assertions.py index 95b3048..d775d80 100644 --- a/tests/test_assertions.py +++ b/tests/test_assertions.py @@ -249,7 +249,7 @@ def test_a_task_with_nested_assertion(duck): }, ) - assert len(run.chat.history) == 4 + assert len(run.chat.messages) == 4 # Check payload of the first (inner) assertion result inner_payload_dict = json.loads(run.chat.history[-2].payload) diff --git a/tests/test_genai_client.py b/tests/test_genai_client.py deleted file mode 100644 index d986acd..0000000 --- a/tests/test_genai_client.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright 2025 Kaggle Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -from google.genai import types -from pydantic import BaseModel - -from kaggle_benchmarks import actors, chats -from kaggle_benchmarks.actors.llms import GoogleGenAI, LLMResponse -from kaggle_benchmarks.content_types.images import ImageBase64 - - -class MockedGoogleGenAI(GoogleGenAI): - """A mock of the GoogleGenAI class that records inputs and returns fixed outputs.""" - - def __init__(self, **kwargs): - super().__init__(client=None, model="mocked-gemini", **kwargs) - self.support_temperature = True - self.support_structured_outputs = True - - def _call_api( - self, contents: list[types.Content], config: types.GenerateContentConfig - ): - self.contents = contents - self.config = config - - if config.response_schema: - mock_json_response = ( - '{"name": "Mock Recipe", "ingredients": ["water", "flour"]}' - ) - return LLMResponse( - content=mock_json_response, - meta={"input_tokens": 10, "output_tokens": 10}, - ) - - if self.stream_responses: - - def stream_generator(): - yield LLMResponse(content="Streaming", meta={"input_tokens": 15}) - yield LLMResponse( - content=" response", - meta={"input_tokens": 15, "output_tokens": 8}, - ) - - return stream_generator() - else: - return LLMResponse( - content="Non-streaming response", - meta={"input_tokens": 10, "output_tokens": 4}, - ) - - -def test_invoke_basic(): - """Tests that a simple user prompt is formatted correctly.""" - llm = MockedGoogleGenAI() - llm.prompt("Hello") - - assert len(llm.contents) == 1 - assert llm.contents[0].role == "user" - assert llm.contents[0].parts[0].text == "Hello" - # Ensure no system instruction is passed by default - assert llm.config.system_instruction is None - - -def test_invoke_with_system_instruction(): - """Tests that system instructions are placed correctly in the config.""" - llm = MockedGoogleGenAI() - llm.respond(system="You are a helpful assistant.") - - assert llm.config.system_instruction == "You are a helpful assistant." - - -def test_invoke_with_config_params(): - """Tests that temperature and seed are passed correctly into the config.""" - llm = MockedGoogleGenAI() - llm.prompt("Be creative", temperature=0.9, seed=42) - - assert llm.config.temperature == 0.9 - - -@pytest.mark.parametrize("streaming", [True, False]) -def test_streaming_and_non_streaming_responses(streaming): - """Tests both streaming and non-streaming modes and checks metadata.""" - llm = MockedGoogleGenAI() - llm.stream_responses = streaming - - with chats.new("Test GenAI Tokens") as t: - response_content = llm.prompt("Tell me a story.") - - last_message = t.messages[-1] - assert last_message.sender is llm - - if streaming: - assert response_content == "Streaming response" - assert last_message._meta["input_tokens"] == 15 - assert last_message._meta["output_tokens"] == 8 - else: - assert response_content == "Non-streaming response" - assert last_message._meta["input_tokens"] == 10 - assert last_message._meta["output_tokens"] == 4 - - -def test_invoke_with_tools(): - """Tests that tools are correctly passed into the config.""" - llm = MockedGoogleGenAI() - - def multiply(a: int, b: int) -> int: - return a * b - - llm.respond(tools=[multiply]) - - assert llm.config.tools is not None - assert len(llm.config.tools) == 1 - assert llm.config.tools[0] == multiply - - -def test_invoke_with_structured_output(): - """Tests that a schema correctly configures the response format.""" - llm = MockedGoogleGenAI() - - class Recipe(BaseModel): - name: str - ingredients: list[str] - - response = llm.prompt("Give me a recipe.", schema=Recipe) - - assert isinstance(response, Recipe) - assert response.name == "Mock Recipe" - assert llm.config.response_schema == Recipe - - -def test_invoke_with_image_input(): - """Tests that image payloads are correctly formatted as inline data.""" - llm = MockedGoogleGenAI() - - mock_image = ImageBase64( - b64_string="R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7", - mime_type="image/jpeg", - ) - - with chats.new("Image Test Chat"): - actors.user.send(mock_image) - llm.prompt("What is in this image?") - - assert len(llm.contents) == 1 - assert len(llm.contents[0].parts) == 2 - - image_part = llm.contents[0].parts[0] - assert image_part.text is None or image_part.text == "" - assert image_part.inline_data is not None - assert image_part.inline_data.mime_type == "image/jpeg" - assert llm.contents[0].parts[1].text == "What is in this image?" diff --git a/tests/test_llm_chats.py b/tests/test_llm_chats.py index a2d9ab1..b6e192d 100644 --- a/tests/test_llm_chats.py +++ b/tests/test_llm_chats.py @@ -14,97 +14,302 @@ import json +import pydantic import pytest -from kaggle_benchmarks import actors, chats, contexts, prompting, utils -from kaggle_benchmarks.actors.llms import LLMResponse +from kaggle_benchmarks import actors, chats, prompting, usage, utils +from kaggle_benchmarks import tools as tool_utils from kaggle_benchmarks.content_types import images, videos from kaggle_benchmarks.llm_messages import LLMMessage -from kaggle_benchmarks.prompting import handler from tests.mocks import MockedChat -class Ferret(actors.LLMChat): - def __init__(self): - super().__init__(name="Ferret") - self.stream_responses = False - - def invoke(self, messages, system=None, **kwargs): - if not self.stream_responses: - return LLMResponse( - content=json.dumps( - dict( - messages=[[m.sender.name.lower(), m.content] for m in messages], - system=system, - ) - ) - ) - - def stream_generator(): - yield LLMResponse(content="stream", meta={"input_tokens": 10}) - yield LLMResponse( - content="ing", meta={"input_tokens": 10, "output_tokens": 1} - ) - yield LLMResponse( - content="...", meta={"input_tokens": 10, "output_tokens": 2} - ) - - return stream_generator() - - def test_prompt_without_context(): - llm = Ferret() + llm = MockedChat.from_contents(["response content"]) + r = llm.prompt("A") - assert {"messages": [["user", "A"]], "system": None} == json.loads(r) + assert r == "response content" + assert len(llm.invocations) == 1 + invoked_messages, kwargs = llm.invocations[0] + assert len(invoked_messages) == 1 + assert invoked_messages[0].content == "A" + assert invoked_messages[0].sender is actors.user + + assert kwargs["system"] is None def test_respond(): - llm = Ferret() + llm = MockedChat.from_contents(["response content"]) - with chats.new("Test") as t: + with chats.new() as t: actors.user.send("A") assert len(t.messages) == 1 r = llm.respond() assert len(t.messages) == 2 - assert {"messages": [["user", "A"]], "system": None} == json.loads(r.text) + assert r.content == "response content" + assert len(llm.invocations) == 1 + invoked_messages, kwargs = llm.invocations[0] + assert len(invoked_messages) == 1 + assert invoked_messages[0].content == "A" + assert invoked_messages[0].sender is actors.user def test_chat_context(): - llm = Ferret() - llm.prompt("") + llm = MockedChat.from_contents(["response A", "response B"]) + # This message should not be visible in the context of the next chat. + actors.user.send("") with chats.new(system_instructions="S") as t: assert t.status == utils.Status.RUNNING + assert len(t.messages) == 1 + assert t.messages[0].content == "S" + assert t.messages[0].sender is actors.system + r = llm.prompt("A") - assert { - "messages": [["system", "S"], ["user", "A"]], - "system": None, - } == json.loads(r) + assert r == "response A" + + assert len(t.messages) == 3 + assert t.messages[1].content == "A" + assert t.messages[1].sender is actors.user + assert t.messages[2].content == "response A" + assert t.messages[2].sender is llm + + assert len(llm.invocations) == 1 + + invoked_messages, kwargs = llm.invocations[0] + assert len(invoked_messages) == 2 + assert invoked_messages == t.messages[:2] + assert kwargs["system"] is None + assert kwargs["schema"] is str r = llm.prompt("B") - response = json.loads(r) + assert r == "response B" + assert len(t.messages) == 5 + assert len(llm.invocations) == 2 - assert response["system"] is None - assert 4 == len(response["messages"]) - assert ["system", "S"] == response["messages"][0] - assert ["user", "A"] == response["messages"][1] - assert llm.name.lower() == response["messages"][2][0] - assert ["user", "B"] == response["messages"][3] + invoked_messages, kwargs = llm.invocations[1] + assert len(invoked_messages) == 4 + assert invoked_messages == t.messages[:4] + assert kwargs["system"] is None assert t.status == utils.Status.SUCCESS -def test_structured(): - llm = Ferret() +@pytest.mark.parametrize( + "support_structured_outputs", + [ + pytest.param(True, id="with_schema_support"), + pytest.param(False, id="without_schema_support"), + ], +) +def test_structured_output(support_structured_outputs): + llm = MockedChat.from_contents( + ['{"field1": 1, "field2": "two"}'], + support_structured_outputs=support_structured_outputs, + ) + + class Response(pydantic.BaseModel): + field1: int + field2: str + + with chats.new("test") as t: + response = llm.prompt("test", schema=Response) + assert response == Response(field1=1, field2="two") + assert len(t.messages) == 2 + + invoked_messages, kwargs = llm.invocations[0] + if support_structured_outputs: + assert len(invoked_messages) == 1 + assert kwargs["schema"] is Response + else: + # extra message for schema instructions + assert len(invoked_messages) == 2 + assert kwargs["schema"] is str + + +def get_weather(location: str) -> str: + """Get current weather""" + if "london" in location.lower(): + return "Rainy" + return "Sunny" + + +class WeatherReport(pydantic.BaseModel): + text: str + temperature: int + + +@pytest.mark.parametrize( + "support_tools", + [ + pytest.param(True, id="with_tool_support"), + pytest.param(False, id="without_tool_support"), + ], +) +def test_tool_calling_with_structured_output(support_tools): + value = WeatherReport(text="Rainy", temperature=15) + + if support_tools: + responses = [ + LLMMessage( + sender=None, + content=None, + tool_calls=[ + tool_utils.ToolInvocation( + name="get_weather", arguments={"location": "London"} + ) + ], + ), + LLMMessage( + sender=None, + content=value.model_dump_json(), + ), + ] + else: + responses = [ + LLMMessage( + sender=None, + content=json.dumps( + { + "tools": [ + dict(name="get_weather", arguments={"location": "London"}) + ], + "message": None, + } + ), + ), + LLMMessage( + sender=None, + content=json.dumps( + { + "tools": None, + "message": value.model_dump(), + } + ), + ), + ] + + llm = MockedChat( + responses=responses, + support_tool_calling=support_tools, + support_structured_outputs=True, + ) + + tools = [get_weather] + + with chats.new() as t: + response = llm.prompt( + "What is the weather in London?", tools=tools, schema=WeatherReport + ) + assert isinstance(response, WeatherReport) + assert response == value + assert len(t.messages) == 2 + + # one with tool invocation + # second one with result + assert len(llm.invocations) == 2 + messages1, kwargs1 = llm.invocations[0] + + if support_tools: + assert len(messages1) == 1 + assert kwargs1["schema"] == WeatherReport + assert kwargs1["tools"] == tools + + else: + # extra message to describe tools + assert len(messages1) == 2 + assert issubclass(kwargs1["schema"], tool_utils.ModelResponse) + + assert not kwargs1["tools"] + + # assert len(messages1) == 1 + assert messages1[0].content == "What is the weather in London?" + assert messages1[0].sender is actors.user + # assert kwargs1["schema"] == WeatherReport + + messages2, kwargs2 = llm.invocations[1] + if support_tools: + assert len(messages2) == 3 + assert kwargs2["tools"] == tools + assert kwargs2["schema"] == WeatherReport + else: + # extra message to describe tools + assert len(messages2) == 4 + assert issubclass(kwargs2["schema"], tool_utils.ModelResponse) + assert not kwargs2["tools"] + + assert messages2[0].sender is actors.user + assert messages2[1].sender is llm + assert isinstance(messages2[2].content, tool_utils.ToolInvocationResult) + assert messages2[2].content.output == "Rainy" + + +def test_tool_calling_with_typed_output_no_tools(): + responses = [ + LLMMessage( + sender=None, + content=json.dumps( + { + "tools": [ + dict(name="get_weather", arguments={"location": "London"}) + ], + "message": None, + } + ), + ), + LLMMessage( + sender=None, + content=json.dumps( + { + "tools": [], + "message": "12", + } + ), + ), + ] + + llm = MockedChat( + responses=responses, + support_tool_calling=False, + support_structured_outputs=True, + ) + + tools = [get_weather] + + with chats.new(): + response = llm.prompt("What is the weather in London?", tools=tools, schema=int) + assert isinstance(response, int) + assert response == 12 + assert len(llm.invocations) == 2 + # Check that the schema was wrapped for emulated tool calling + messages1, kwargs1 = llm.invocations[0] + # user message + instructions + assert len(messages1) == 2 + assert issubclass(kwargs1["schema"], tool_utils.ModelResponse) + messages2, kwargs2 = llm.invocations[1] + # user message + first response + call result + instructions + assert len(messages2) == 4 + assert issubclass(kwargs2["schema"], tool_utils.ModelResponse) + + +def test_custom_types(): + llm = MockedChat( + responses=[ + LLMMessage(sender=None, content="any content"), + LLMMessage(sender=None, content="any content"), + LLMMessage(sender=None, content="any content"), + ], + support_structured_outputs=True, + ) class F: pass value = F() - @handler(types=F) + @prompting.handler(types=F) def _(cls): yield "" return value @@ -113,7 +318,7 @@ def _(cls): assert isinstance(response, F) assert value is response - @handler(types=F) + @prompting.handler(types=F) def _(cls): value = yield "" raise prompting.ResponseParsingError( @@ -123,11 +328,16 @@ def _(cls): with chats.new() as t: with pytest.raises(prompting.ResponseParsingError): llm.prompt("test_value", schema=F) - assert "Bad response" in t.messages[-1].text - assert "test_value" in t.messages[-1].text - assert "F" in t.messages[-1].text - @handler(types=F) + assert len(t.messages) == 2 + llm_message = t.messages[-1] + assert isinstance(llm_message, LLMMessage) + # the error goes to the subchat used for helper prompt + error_text = llm_message.chat.messages[-1].text + assert "Bad response" in error_text + assert "F" in error_text + + @prompting.handler(types=F) def _(cls): yield "" yield "nonsense" @@ -137,43 +347,22 @@ def _(cls): llm.prompt("Test", schema=F) -def test_streaming_prompt(): - llm = Ferret() - # Explicitly set stream mode. - llm.stream_responses = True - - with chats.new("Test Streaming") as t: - response_content = llm.prompt("stream this") - assert response_content == "streaming..." - - # The last message in the chat is the one from the LLM. - last_message = t.messages[-1] - assert last_message.content == "streaming..." - assert last_message.sender is llm - assert last_message._meta["input_tokens"] == 10 - assert last_message._meta["output_tokens"] == 2 - - -def test_nested_chat_id(): - llm = Ferret() - with chats.new("root") as root: - sub = chats.Chat(name="sub") - chats.get_current_chat().append(sub) - with contexts.enter(chat=sub): - llm.prompt("Hi") - - sub.name += " - analysis" - - assert root.history[0] is sub - assert sub.id.startswith("sub - analysis-") - assert len(sub.history) == 2 - - def test_chat_usage_aggregation(): """Test that chat usage properties aggregate token usage from all assistant messages.""" - llm = Ferret() - llm.stream_responses = True - + llm = MockedChat( + responses=[ + LLMMessage( + sender=None, + content="first", + usage=usage.Usage(input_tokens=5, output_tokens=3), + ), + LLMMessage( + sender=None, + content="second", + usage=usage.Usage(input_tokens=15, output_tokens=1), + ), + ], + ) with chats.new("Test Usage") as t: llm.prompt("first") llm.prompt("second") @@ -263,4 +452,4 @@ def test_invoke_llmmessage(): assert response.sender is mocked_chat assert len(mocked_chat.invocations) == 1 assert mocked_chat.invocations[0][0] == messages - assert mocked_chat.invocations[0][1] == {"temperature": 0.5} + assert mocked_chat.invocations[0][1].get("temperature") == 0.5 diff --git a/tests/test_messages.py b/tests/test_messages.py index 9b6b874..5d1af7c 100644 --- a/tests/test_messages.py +++ b/tests/test_messages.py @@ -19,7 +19,6 @@ import pytest from kaggle_benchmarks import chats, messages, user -from kaggle_benchmarks.actors.llms import LLMResponse from tests.mocks import MockedChat @@ -29,7 +28,7 @@ def test_raw_payload(): with chats.new() as chat: m = p.prompt(float_response, schema=float) assert m == 0.01 - assert chat.messages[-1].payload == float_response + assert "0.01" in chat.messages[-1].payload r = p.prompt('{"value": true}', schema=bool) assert r @@ -64,51 +63,3 @@ def __init__(self, x): payload = msg.payload assert json.loads(payload) == {"x": 1} - - -def test_streaming(): - text = "a b c d" - - def g(): - for chunk in text: - yield chunk - assert msg.content.endswith(chunk) - - msg = messages.Message(content="", sender=None) - msg.stream(g()) - assert msg.content == text - - -def test_streaming_with_token_counts(): - """Tests that streaming correctly updates metadata like token counts.""" - - def chunk_generator(): - yield LLMResponse( - content="Hello ", meta={"input_tokens": 10, "output_tokens": 1} - ) - yield LLMResponse( - content="world", meta={"input_tokens": 10, "output_tokens": 2} - ) - yield LLMResponse(content="!", meta={"input_tokens": 10, "output_tokens": 3}) - - msg = messages.Message(content="", sender=None) - msg.stream(chunk_generator()) - - assert msg.content == "Hello world!" - assert msg._meta["input_tokens"] == 10 - assert msg._meta["output_tokens"] == 3 - - -def test_tool_calls_property(): - """Tests that the tool_calls property correctly retrieves data from _meta.""" - mock_tool_calls = [{"id": "call_abc", "type": "function"}] - - # Message with tool calls - msg_with_tools = messages.Message( - content="", sender=None, _meta={"tool_calls": mock_tool_calls} - ) - - msg_without_tools = messages.Message(content="Hello", sender=None) - - assert msg_with_tools.tool_calls == mock_tool_calls - assert msg_without_tools.tool_calls is None diff --git a/tests/test_openai_client.py b/tests/test_openai_client.py deleted file mode 100644 index 32426a7..0000000 --- a/tests/test_openai_client.py +++ /dev/null @@ -1,319 +0,0 @@ -# Copyright 2025 Kaggle Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -from dataclasses import dataclass - -import pytest -from pydantic import BaseModel - -from kaggle_benchmarks import actors, chats -from kaggle_benchmarks.actors.llms import LLMResponse, OpenAI -from kaggle_benchmarks.prompting import handler - - -@dataclass -class MockFunction: - name: str - arguments: str - - -@dataclass -class MockToolCall: - id: str - function: MockFunction - type: str = "function" - - -class MockedOpenAI(OpenAI): - def __init__(self, model: str, **kwargs): - super().__init__(client=None, model=model, **kwargs) - self.support_temperature = False - - def _call_api(self, messages, **kwargs): - self.messages = messages - self.kwargs = kwargs - return LLMResponse(content="{}") - - -@dataclass -class MockFunctionDelta: - name: str | None = None - arguments: str | None = None - - -@dataclass -class MockToolCallDelta: - index: int - id: str | None = None - function: MockFunctionDelta | None = None - type: str | None = "function" - - -class MockedOpenAIWithTokens(OpenAI): - def __init__(self, **kwargs): - # We pass a dummy client, as it's not used in the mocked _call_api - super().__init__(client=None, model="mock_with_tokens", **kwargs) - - def _call_api(self, messages, **kwargs): - self.messages = messages - self.kwargs = kwargs - - if self.stream_responses: - - def stream_generator(): - yield LLMResponse(content="stream", meta={"input_tokens": 10}) - yield LLMResponse( - content="ing", meta={"input_tokens": 10, "output_tokens": 5} - ) - - return stream_generator() - return LLMResponse( - content="non-streaming", meta={"input_tokens": 20, "output_tokens": 2} - ) - - -class MockedOpenAIWithToolCall(OpenAI): - def __init__(self, **kwargs): - super().__init__(client=None, model="mock-tool-caller", **kwargs) - - def _call_api(self, messages, **kwargs): - tool_call = MockToolCall( - id="call_123", - function=MockFunction(name="calculator", arguments='{"a": 1, "b": 2}'), - ) - return LLMResponse(content="", tool_calls=[tool_call]) - - -class MockedOpenAIWithStreamingToolCall(OpenAI): - def __init__(self, **kwargs): - super().__init__(client=None, model="mock-streaming-tool-caller", **kwargs) - - def _call_api(self, messages, **kwargs): - def stream_generator(): - yield LLMResponse( - content="", - tool_calls=[ - MockToolCallDelta( - index=0, - id="call_123", - function=MockFunctionDelta(name="calculator"), - ) - ], - ) - yield LLMResponse( - content="", - tool_calls=[ - MockToolCallDelta( - index=0, function=MockFunctionDelta(arguments='{"a": 5,') - ) - ], - ) - yield LLMResponse( - content="Okay, ", - tool_calls=[ - MockToolCallDelta( - index=0, function=MockFunctionDelta(arguments=' "b": 10}') - ) - ], - ) - yield LLMResponse(content="calculating...") - - return stream_generator() - - -class MockedOpenAIWithMultipleStreamingToolCalls(OpenAI): - def __init__(self, **kwargs): - super().__init__( - client=None, model="mock-multi-streaming-tool-caller", **kwargs - ) - - def _call_api(self, messages, **kwargs): - def stream_generator(): - yield LLMResponse( - content="", - tool_calls=[ - MockToolCallDelta( - index=0, - id="call_calc_123", - function=MockFunctionDelta(name="calculator"), - ) - ], - ) - yield LLMResponse( - content="", - tool_calls=[ - MockToolCallDelta( - index=1, - id="call_weather_456", - function=MockFunctionDelta(name="get_weather"), - ) - ], - ) - yield LLMResponse( - content="", - tool_calls=[ - MockToolCallDelta( - index=0, function=MockFunctionDelta(arguments='{"a": 100,') - ) - ], - ) - yield LLMResponse( - content="", - tool_calls=[ - MockToolCallDelta( - index=1, function=MockFunctionDelta(arguments='{"city": "NYC"}') - ) - ], - ) - yield LLMResponse( - content="Okay, ", - tool_calls=[ - MockToolCallDelta( - index=0, function=MockFunctionDelta(arguments=' "b": 200}') - ) - ], - ) - yield LLMResponse(content="processing requests...") - - return stream_generator() - - -def test_invoke(): - llm = MockedOpenAI(model="test-model") - llm.prompt("Hi") - assert llm.messages == [{"role": "user", "content": "Hi"}] - assert llm.kwargs.get("response_format") is None - - -def test_invoke_prompt(): - llm = MockedOpenAI(model="test-model") - llm.support_structured_outputs = False - - class A: - pass - - @handler(types=A) - def type_a(_): - v = yield "A" - return v - - llm.prompt("Hi", schema=A) - - assert llm.messages == [ - {"role": "user", "content": "Hi"}, - {"role": "system", "content": "A"}, - ] - assert llm.kwargs.get("response_format") is None - - -def test_pydantic_models(): - class Model(BaseModel): - a: str = "a" - b: int = 0 - - llm = MockedOpenAI(model="test-model") - llm.support_structured_outputs = True - llm.prompt("Hi", schema=Model) - assert llm.messages == [ - {"role": "user", "content": "Hi"}, - ] - assert llm.kwargs.get("response_format") is Model - - llm.support_structured_outputs = False - llm.prompt("Hi", schema=Model) - assert llm.messages[0] == {"role": "user", "content": "Hi"} - assert json.dumps(Model.model_json_schema()) in llm.messages[1]["content"] - assert llm.kwargs.get("response_format") is None - - -@pytest.mark.parametrize("streaming", [True, False]) -def test_invoke_with_token_counts(streaming): - llm = MockedOpenAIWithTokens() - llm.stream_responses = streaming - - with chats.new("Test Tokens") as t: - response_content = llm.prompt("count my tokens") - - last_message = t.messages[-1] - assert last_message.sender is llm - if streaming: - assert response_content == "streaming" - assert last_message._meta["input_tokens"] == 10 - assert last_message._meta["output_tokens"] == 5 - else: - assert response_content == "non-streaming" - assert last_message._meta["input_tokens"] == 20 - assert last_message._meta["output_tokens"] == 2 - - -def test_llm_extracts_tool_calls(): - llm = MockedOpenAIWithToolCall() - - with chats.new("test tools"): - actors.user.send("call a tool") - response_msg = llm.respond() - - assert response_msg.tool_calls is not None - assert len(response_msg.tool_calls) == 1 - assert response_msg.tool_calls[0].function.name == "calculator" - assert response_msg.tool_calls[0].function.arguments == '{"a": 1, "b": 2}' - - -def test_streaming_accumulates_tool_calls(): - llm = MockedOpenAIWithStreamingToolCall() - llm.stream_responses = True - - with chats.new("test streaming tools"): - actors.user.send("What is 5 + 10?") - response_msg = llm.respond() - - assert response_msg.content == "Okay, calculating..." - - final_tool_calls = response_msg.tool_calls - assert final_tool_calls is not None - assert len(final_tool_calls) == 1 - - final_call_obj = MockToolCallDelta(index=0, **final_tool_calls[0]) - final_call_obj.function = MockFunctionDelta(**final_call_obj.function) - - assert final_call_obj.id == "call_123" - assert final_call_obj.function.name == "calculator" - assert final_call_obj.function.arguments == '{"a": 5, "b": 10}' - - -def test_streaming_accumulates_multiple_tool_calls(): - llm = MockedOpenAIWithMultipleStreamingToolCalls() - llm.stream_responses = True - - with chats.new("test multi-streaming tools"): - actors.user.send("What is 100 + 200 and the weather in NYC?") - response_msg = llm.respond() - - assert response_msg.content == "Okay, processing requests..." - - final_tool_calls = response_msg.tool_calls - - assert final_tool_calls is not None - assert len(final_tool_calls) == 2 - - calculator_call = final_tool_calls[0] - assert calculator_call["id"] == "call_calc_123" - assert calculator_call["function"]["name"] == "calculator" - assert calculator_call["function"]["arguments"] == '{"a": 100, "b": 200}' - - weather_call = final_tool_calls[1] - assert weather_call["id"] == "call_weather_456" - assert weather_call["function"]["name"] == "get_weather" - assert weather_call["function"]["arguments"] == '{"city": "NYC"}' diff --git a/tests/tools/test_base.py b/tests/tools/test_base.py index b81efa1..1d5ccf4 100644 --- a/tests/tools/test_base.py +++ b/tests/tools/test_base.py @@ -51,10 +51,12 @@ def test_invoke_tool_success(): def test_invoke_tool_not_found(): call = ToolInvocation(name="non_existent_tool", arguments={}) result = invoke_tool(call, [simple_tool]) - assert "Error: Tool 'non_existent_tool' not found." in result.output + assert "Error: Tool 'non_existent_tool' not found." in result.describe() def test_invoke_tool_exception(): call = ToolInvocation(name="tool_that_raises", arguments={}) result = invoke_tool(call, [tool_that_raises]) - assert "Error invoking tool 'tool_that_raises': This tool failed." in result.output + assert ( + "Error invoking tool 'tool_that_raises': This tool failed." in result.describe() + ) diff --git a/tests/tools/test_simulate.py b/tests/tools/test_simulate.py new file mode 100644 index 0000000..b06bf9d --- /dev/null +++ b/tests/tools/test_simulate.py @@ -0,0 +1,219 @@ +# Copyright 2026 Kaggle Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pydantic +import pytest + +from kaggle_benchmarks import chats +from kaggle_benchmarks.tools import simulate +from tests.mocks import MockedChat + + +def dummy_tool(x: int, y: str = "default") -> str: + """A dummy tool.""" + return f"{x}-{y}" + + +def dummy_tool_2(z: float) -> float: + return z + + +def tool_no_arguments(): + pass + + +def error_tool() -> str: + raise ValueError("Simulated tool failure") + + +def test_build_response_model(): + """Tests the creation of a Pydantic response model for tool invocation.""" + model = simulate.build_response_model( + [dummy_tool, dummy_tool_2, tool_no_arguments], str + ) + assert issubclass(model, pydantic.BaseModel) + + +@pytest.mark.parametrize( + "tools_payload", + [ + [{"name": "dummy_tool", "arguments": {"x": 1, "y": "test"}}], + [{"name": "dummy_tool_2", "arguments": {"z": 3.14}}], + [ + { + "name": "dummy_tool", + "arguments": {"x": 1, "y": "test", "extra_arg": "ignored"}, + } + ], + ], +) +def test_build_response_model_valid(tools_payload): + """Tests the creation of a Pydantic response model for tool invocation.""" + model = simulate.build_response_model([dummy_tool, dummy_tool_2], str) + assert issubclass(model, pydantic.BaseModel) + + instance = model(tools=tools_payload, message=None) + assert instance.tools is not None + assert len(instance.tools) == 1 + assert instance.tools[0].name == tools_payload[0]["name"] + for k, v in instance.tools[0].arguments.items(): + assert v == tools_payload[0]["arguments"][k] + + +@pytest.mark.parametrize( + "tools_payload,expected_error", + [ + ( + [{"name": "non_existent_tool", "arguments": {"x": 1, "y": "test"}}], + "Input should be", + ), + ( + [{"name": "dummy_tool", "arguments": {"x": "not-an-int", "y": "test"}}], + "Input should be a valid integer", + ), + ( + [{"name": "dummy_tool", "arguments": {}}], + "Field required", + ), + ], +) +def test_build_response_model_invalid(tools_payload, expected_error): + """Tests that the response model rejects invalid tool calls.""" + model = simulate.build_response_model([dummy_tool, dummy_tool_2], str) + + with pytest.raises(pydantic.ValidationError) as exc_info: + model(tools=tools_payload, message=None) + + assert expected_error in str(exc_info.value) + + +def test_simulate_tool_calling_with_tools(): + llm = MockedChat.from_contents_data( + [ + dict( + tools=[ + { + "name": "dummy_tool", + "arguments": {"x": 42, "y": "default"}, + } + ], + message=None, + ) + ] + ) + + response = simulate.simulate_respond_with_tools( + llm=llm, tools=[dummy_tool, dummy_tool_2], output_schema=str + ) + + tool_calls = response.tool_calls or [] + assert len(tool_calls) == 1 + assert tool_calls[0].name == "dummy_tool" + assert tool_calls[0].arguments == {"x": 42, "y": "default"} + assert response.content is None + + +def test_simulate_tool_calling_with_message(): + llm = MockedChat.from_contents_data( + [{"tools": None, "message": "Here is the final answer."}] + ) + + response = simulate.simulate_respond_with_tools( + llm=llm, tools=[dummy_tool, dummy_tool_2], output_schema=str + ) + tool_calls = response.tool_calls or [] + assert not tool_calls + assert response.content == "Here is the final answer." + + +def test_simulate_agent_success(): + llm = MockedChat.from_contents_data( + [ + { + "tools": [ + { + "name": "dummy_tool", + "arguments": {"x": 10, "y": "test"}, + } + ], + "message": None, + }, + {"tools": None, "message": "Done!"}, + ] + ) + + response = simulate.simulate_agent( + llm=llm, tools=[dummy_tool, dummy_tool_2], output_schema=str + ) + + assert response.content == "Done!" + assert len(response.tool_calls) == 1 + assert response.tool_calls[0].name == "dummy_tool" + assert response.tool_calls[0].output == "10-test" + + +def test_simulate_agent_limit_exhausted(): + llm = MockedChat.from_contents_data( + [ + { + "tools": [ + { + "name": "dummy_tool", + "arguments": {"x": i, "y": "test"}, + } + ], + "message": None, + } + for i in range(5) + ] + ) + + with pytest.raises(simulate.ToolInvocationLimitExhausted): + simulate.simulate_agent( + llm=llm, tools=[dummy_tool], output_schema=str, max_iterations=2 + ) + + +def test_simulate_agent_without_tools(): + llm = MockedChat.from_contents(["No tools used!"]) + with chats.new("test_chat"): + response = simulate.simulate_agent(llm=llm, tools=[], output_schema=str) + assert response.content == "No tools used!" + assert not response.tool_calls + + +def test_simulate_agent_tool_error_recovery(): + # 1st turn: calls the tool that will fail + # 2nd turn: acknowledges the error and provides a final answer + llm = MockedChat.from_contents_data( + [ + { + "tools": [{"name": "error_tool", "arguments": {}}], + "message": None, + }, + {"tools": None, "message": "I recovered from the error!"}, + ] + ) + + with chats.new("test_chat"): + response = simulate.simulate_agent( + llm=llm, tools=[error_tool], output_schema=str + ) + + assert response.content == "I recovered from the error!" + assert len(response.tool_calls) == 1 + assert response.tool_calls[0].name == "error_tool" + assert response.tool_calls[0].error is not None + assert "Simulated tool failure" in response.tool_calls[0].error