Skip to content

Commit b051ae2

Browse files
authored
[Feature] Add tool calls to validate() method (#99)
1 parent bb2d31c commit b051ae2

File tree

5 files changed

+88
-5
lines changed

5 files changed

+88
-5
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ classifiers = [
2626
]
2727
dependencies = [
2828
"cleanlab-tlm~=1.1,>=1.1.14",
29-
"codex-sdk==0.1.0a23",
29+
"codex-sdk==0.1.0a24",
3030
"pydantic>=2.0.0, <3",
3131
]
3232

src/cleanlab_codex/project.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Dict, Optional, Union, cast
88

99
from codex import AuthenticationError
10-
from codex.types.project_validate_params import Response
10+
from codex.types.project_validate_params import Response, Tool
1111

1212
from cleanlab_codex.internal.analytics import _AnalyticsMetadata
1313
from cleanlab_codex.internal.sdk_client import client_from_access_key
@@ -18,7 +18,7 @@
1818

1919
from codex import Codex as _Codex
2020
from codex.types.project_validate_response import ProjectValidateResponse
21-
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
21+
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam, ChatCompletionToolParam
2222

2323

2424
_ERROR_CREATE_ACCESS_KEY = (
@@ -154,6 +154,7 @@ def validate(
154154
context: str,
155155
rewritten_query: Optional[str] = None,
156156
metadata: Optional[object] = None,
157+
tools: Optional[list[ChatCompletionToolParam]] = None,
157158
eval_scores: Optional[Dict[str, float]] = None,
158159
) -> ProjectValidateResponse:
159160
"""Evaluate the quality of an AI-generated `response` based on the same exact inputs that your LLM used to generate the response.
@@ -176,6 +177,7 @@ def validate(
176177
context (str): All retrieved context (e.g., from your RAG/retrieval/search system) that was supplied as part of `messages` for generating the LLM `response`. Specifying the `context` (as a part of the full `messages` object) enables Cleanlab to run certain Evals and display the retrieved context in the Web Inferface.
177178
rewritten_query (str, optional): An optional reformulation of `query` (e.g. to form a self-contained question out of a multi-turn conversation history) to improve retrieval quality. If you are using a query-rewriter in your RAG system, you can provide its output here. If not provided, Cleanlab may internally do its own query rewrite when necessary.
178179
metadata (object, optional): Arbitrary metadata to associate with this LLM `response` for logging/analytics inside the Project.
180+
tools (list[ChatCompletionToolParam], optional): Optional definitions of tools that were provided to the LLM in the response-generation call. Should match the `tools` argument in OpenAI's Chat Completions API. When provided to the LLM, its response might be to call one of these tools rather than natural language.
179181
eval_scores (dict[str, float], optional): Pre-computed evaluation scores to bypass automatic scoring. Providing `eval_scores` for specific evaluations bypasses automated scoring and uses the supplied scores instead. If you already have them pre-computed, this can reduce runtime.
180182
181183
Returns:
@@ -188,7 +190,6 @@ def validate(
188190
189191
When available, consider swapping your AI response with the expert answer before serving the response to your user.
190192
"""
191-
192193
return self._sdk_client.projects.validate(
193194
self._id,
194195
messages=messages,
@@ -197,6 +198,7 @@ def validate(
197198
query=query,
198199
rewritten_question=rewritten_query,
199200
custom_metadata=metadata,
201+
tools=[cast(Tool, tool) for tool in tools] if tools else None,
200202
eval_scores=eval_scores,
201203
)
202204

tests/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
openai_messages_bad_no_user,
55
openai_messages_conversational,
66
openai_messages_single_turn,
7+
openai_tools,
78
)
89

910
__all__ = [
@@ -14,4 +15,5 @@
1415
"openai_messages_conversational",
1516
"openai_messages_single_turn",
1617
"openai_messages_bad_no_user",
18+
"openai_tools",
1719
]

tests/fixtures/validate.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
ChatCompletionAssistantMessageParam,
77
ChatCompletionMessageParam,
88
ChatCompletionSystemMessageParam,
9+
ChatCompletionToolParam,
910
ChatCompletionUserMessageParam,
1011
)
1112

@@ -38,6 +39,25 @@ def openai_chat_completion() -> ChatCompletion:
3839
return ChatCompletion.model_validate(raw_response)
3940

4041

42+
@pytest.fixture
43+
def openai_tools() -> list[ChatCompletionToolParam]:
44+
"""Fixture that returns a list containing one static fake OpenAI Tool object."""
45+
raw_tool = {
46+
"type": "function",
47+
"function": {
48+
"name": "get_weather",
49+
"description": "Get the current weather in a given location.",
50+
"parameters": {
51+
"type": "object",
52+
"properties": {"location": {"type": "string", "description": "The location to get the weather for."}},
53+
"required": ["location"],
54+
},
55+
},
56+
}
57+
openai_tool = cast(ChatCompletionToolParam, raw_tool)
58+
return [openai_tool]
59+
60+
4161
@pytest.fixture
4262
def openai_messages_single_turn() -> list[ChatCompletionMessageParam]:
4363
"""Fixture that returns a single-turn message format."""

tests/test_project.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
)
1212

1313
if TYPE_CHECKING:
14-
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
14+
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam, ChatCompletionToolParam
1515

1616
from cleanlab_codex.project import MissingProjectError, Project
1717

@@ -75,6 +75,7 @@ def test_project_validate_with_dict_response(
7575
rewritten_question=None,
7676
custom_metadata=None,
7777
eval_scores=None,
78+
tools=None,
7879
)
7980

8081
# conversational
@@ -97,6 +98,7 @@ def test_project_validate_with_dict_response(
9798
rewritten_question=None,
9899
custom_metadata=None,
99100
eval_scores=None,
101+
tools=None,
100102
),
101103
call(
102104
FAKE_PROJECT_ID,
@@ -107,12 +109,69 @@ def test_project_validate_with_dict_response(
107109
rewritten_question=None,
108110
custom_metadata=None,
109111
eval_scores=None,
112+
tools=None,
110113
),
111114
]
112115
)
113116
assert mock_client_from_api_key.projects.validate.call_count == 2
114117

115118

119+
def test_project_validate_with_tools(
120+
mock_client_from_api_key: MagicMock,
121+
openai_chat_completion: "ChatCompletion",
122+
openai_messages_single_turn: list["ChatCompletionMessageParam"],
123+
openai_tools: list["ChatCompletionToolParam"],
124+
) -> None:
125+
expected_result = ProjectValidateResponse(
126+
is_bad_response=True,
127+
expert_answer=None,
128+
eval_scores={
129+
"response_helpfulness": EvalScores(
130+
score=0.8,
131+
triggered=True,
132+
triggered_escalation=False,
133+
triggered_guardrail=False,
134+
)
135+
},
136+
escalated_to_sme=True,
137+
should_guardrail=False,
138+
)
139+
mock_client_from_api_key.projects.validate.return_value = expected_result
140+
mock_client_from_api_key.projects.create.return_value.id = FAKE_PROJECT_ID
141+
mock_client_from_api_key.organization_id = FAKE_ORGANIZATION_ID
142+
project = Project.create(
143+
mock_client_from_api_key,
144+
FAKE_ORGANIZATION_ID,
145+
FAKE_PROJECT_NAME,
146+
FAKE_PROJECT_DESCRIPTION,
147+
)
148+
149+
context = "Cities in France: Paris, Lyon, Marseille"
150+
query = "What is the capitol of France?"
151+
152+
# single turn
153+
result = project.validate(
154+
messages=openai_messages_single_turn,
155+
response=openai_chat_completion,
156+
tools=openai_tools,
157+
context=context,
158+
query=query,
159+
)
160+
161+
assert result == expected_result
162+
mock_client_from_api_key.projects.validate.assert_called_once_with(
163+
FAKE_PROJECT_ID,
164+
messages=openai_messages_single_turn,
165+
response=openai_chat_completion,
166+
context=context,
167+
query=query,
168+
tools=openai_tools,
169+
rewritten_question=None,
170+
custom_metadata=None,
171+
eval_scores=None,
172+
)
173+
174+
116175
def test_from_access_key(mock_client_from_access_key: MagicMock) -> None:
117176
mock_client_from_access_key.projects.access_keys.retrieve_project_id.return_value = (
118177
AccessKeyRetrieveProjectIDResponse(

0 commit comments

Comments
 (0)