Skip to content

Commit c0dee34

Browse files
authored
Implement tool standard for Groq tracing (mlflow#14632)
Signed-off-by: Tomu Hirata <[email protected]>
1 parent d910c84 commit c0dee34

File tree

4 files changed

+215
-27
lines changed

4 files changed

+215
-27
lines changed

mlflow/groq/_groq_autolog.py

+20
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import mlflow
44
from mlflow.entities import SpanType
5+
from mlflow.tracing.utils import set_span_chat_messages, set_span_chat_tools
56
from mlflow.utils.autologging_utils.config import AutoLoggingConfig
67

78
_logger = logging.getLogger(__name__)
@@ -23,6 +24,8 @@ def _get_span_type(resource: type) -> str:
2324

2425

2526
def patched_call(original, self, *args, **kwargs):
27+
from groq.types.chat.chat_completion import ChatCompletion
28+
2629
config = AutoLoggingConfig.init(flavor_name=mlflow.groq.FLAVOR_NAME)
2730

2831
if config.log_traces:
@@ -31,6 +34,23 @@ def patched_call(original, self, *args, **kwargs):
3134
span_type=_get_span_type(self.__class__),
3235
) as span:
3336
span.set_inputs(kwargs)
37+
38+
if tools := kwargs.get("tools"):
39+
try:
40+
set_span_chat_tools(span, tools)
41+
except Exception:
42+
_logger.debug(f"Failed to set tools for {span}.", exc_info=True)
43+
3444
outputs = original(self, *args, **kwargs)
3545
span.set_outputs(outputs)
46+
47+
if isinstance(outputs, ChatCompletion):
48+
try:
49+
messages = kwargs.get("messages", [])
50+
set_span_chat_messages(
51+
span, [*messages, outputs.choices[0].message.model_dump()]
52+
)
53+
except Exception:
54+
_logger.debug(f"Failed to set chat messages for {span}.", exc_info=True)
55+
3656
return outputs

mlflow/ml-package-versions.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -967,7 +967,7 @@ groq:
967967
pip install git+https://github.com/groq/groq-python
968968
autologging:
969969
minimum: "0.13.0"
970-
maximum: "0.15.0"
970+
maximum: "0.18.0"
971971
requirements:
972972
run: pytest tests/groq
973973

mlflow/ml_package_versions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@
381381
},
382382
"autologging": {
383383
"minimum": "0.13.0",
384-
"maximum": "0.15.0"
384+
"maximum": "0.18.0"
385385
}
386386
},
387387
"bedrock": {

tests/groq/test_groq_autolog.py

+193-25
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
import os
1+
import json
22
from unittest.mock import patch
33

44
import groq
5+
import pytest
56
from groq.types.audio.transcription import Transcription
67
from groq.types.audio.translation import Translation
8+
from groq.types.chat import ChatCompletionMessageToolCall
79
from groq.types.chat.chat_completion import (
810
ChatCompletion,
911
ChatCompletionMessage,
@@ -22,6 +24,16 @@
2224
"messages": [{"role": "user", "content": "test message"}],
2325
}
2426

27+
DUMMY_COMPLETION_USAGE = CompletionUsage(
28+
completion_tokens=648,
29+
prompt_tokens=20,
30+
total_tokens=668,
31+
completion_time=0.54,
32+
prompt_time=0.000181289,
33+
queue_time=0.012770949,
34+
total_time=0.540181289,
35+
)
36+
2537
DUMMY_CHAT_COMPLETION_RESPONSE = ChatCompletion(
2638
id="chatcmpl-test-id",
2739
choices=[
@@ -42,25 +54,24 @@
4254
model="llama3-8b-8192",
4355
object="chat.completion",
4456
system_fingerprint="fp_test",
45-
usage=CompletionUsage(
46-
completion_tokens=648,
47-
prompt_tokens=20,
48-
total_tokens=668,
49-
completion_time=0.54,
50-
prompt_time=0.000181289,
51-
queue_time=0.012770949,
52-
total_time=0.540181289,
53-
),
57+
usage=DUMMY_COMPLETION_USAGE,
5458
x_groq={"id": "req_test"},
5559
)
5660

5761

58-
@patch.dict(os.environ, {"GROQ_API_KEY": "test_key"})
59-
@patch("groq._client.Groq.post", return_value=DUMMY_CHAT_COMPLETION_RESPONSE)
60-
def test_chat_completion_autolog(mock_post):
62+
@pytest.fixture(autouse=True)
63+
def init_state(monkeypatch):
64+
monkeypatch.setenv("GROQ_API_KEY", "test_key")
65+
yield
66+
mlflow.groq.autolog(disable=True)
67+
68+
69+
def test_chat_completion_autolog():
6170
mlflow.groq.autolog()
6271
client = groq.Groq()
63-
client.chat.completions.create(**DUMMY_CHAT_COMPLETION_REQUEST)
72+
73+
with patch("groq._client.Groq.post", return_value=DUMMY_CHAT_COMPLETION_RESPONSE):
74+
client.chat.completions.create(**DUMMY_CHAT_COMPLETION_REQUEST)
6475

6576
traces = get_traces()
6677
assert len(traces) == 1
@@ -74,13 +85,166 @@ def test_chat_completion_autolog(mock_post):
7485

7586
mlflow.groq.autolog(disable=True)
7687
client = groq.Groq()
77-
client.chat.completions.create(**DUMMY_CHAT_COMPLETION_REQUEST)
88+
89+
with patch("groq._client.Groq.post", return_value=DUMMY_CHAT_COMPLETION_RESPONSE):
90+
client.chat.completions.create(**DUMMY_CHAT_COMPLETION_REQUEST)
7891

7992
# No new trace should be created
8093
traces = get_traces()
8194
assert len(traces) == 1
8295

8396

97+
TOOLS = [
98+
{
99+
"type": "function",
100+
"function": {
101+
"name": "calculate",
102+
"description": "Evaluate a mathematical expression",
103+
"parameters": {
104+
"type": "object",
105+
"properties": {
106+
"expression": {
107+
"type": "string",
108+
"description": "The mathematical expression to evaluate",
109+
}
110+
},
111+
"required": ["expression"],
112+
},
113+
},
114+
}
115+
]
116+
DUMMY_TOOL_CALL_REQUEST = {
117+
"model": "test_model",
118+
"max_tokens": 1024,
119+
"messages": [{"role": "user", "content": "What is 25 * 4 + 10?"}],
120+
"tools": TOOLS,
121+
}
122+
DUMMY_TOOL_CALL_RESPONSE = ChatCompletion(
123+
id="chatcmpl-test-id",
124+
choices=[
125+
Choice(
126+
finish_reason="stop",
127+
index=0,
128+
logprobs=None,
129+
message=ChatCompletionMessage(
130+
content=None,
131+
role="assistant",
132+
function_call=None,
133+
tool_calls=[
134+
ChatCompletionMessageToolCall(
135+
id="tool call id",
136+
function={
137+
"name": "calculate",
138+
"arguments": json.dumps({"expression": "25 * 4 + 10"}),
139+
},
140+
type="function",
141+
)
142+
],
143+
reasoning=None,
144+
),
145+
)
146+
],
147+
created=1733574047,
148+
model="llama3-8b-8192",
149+
object="chat.completion",
150+
system_fingerprint="fp_test",
151+
usage=DUMMY_COMPLETION_USAGE,
152+
x_groq={"id": "req_test"},
153+
)
154+
155+
156+
def test_tool_calling_autolog():
157+
mlflow.groq.autolog()
158+
client = groq.Groq()
159+
160+
with patch("groq._client.Groq.post", return_value=DUMMY_TOOL_CALL_RESPONSE):
161+
client.chat.completions.create(**DUMMY_TOOL_CALL_REQUEST)
162+
163+
traces = get_traces()
164+
assert len(traces) == 1
165+
assert traces[0].info.status == "OK"
166+
assert len(traces[0].data.spans) == 1
167+
span = traces[0].data.spans[0]
168+
assert span.name == "Completions"
169+
assert span.span_type == SpanType.CHAT_MODEL
170+
assert span.inputs == DUMMY_TOOL_CALL_REQUEST
171+
assert span.outputs == DUMMY_TOOL_CALL_RESPONSE.to_dict()
172+
assert span.get_attribute("mlflow.chat.tools") == TOOLS
173+
assert span.get_attribute("mlflow.chat.messages") == [
174+
*DUMMY_TOOL_CALL_REQUEST["messages"],
175+
DUMMY_TOOL_CALL_RESPONSE.choices[0].message.to_dict(),
176+
]
177+
178+
179+
DUMMY_TOOL_RESPONSE_REQUEST = {
180+
"model": "test_model",
181+
"max_tokens": 1024,
182+
"messages": [
183+
{"role": "user", "content": "What is 25 * 4 + 10?"},
184+
{
185+
"role": "assistant",
186+
"tool_calls": [
187+
{
188+
"id": "tool call id",
189+
"function": {
190+
"name": "calculate",
191+
"arguments": json.dumps({"expression": "25 * 4 + 10"}),
192+
},
193+
"type": "function",
194+
}
195+
],
196+
},
197+
{"role": "tool", "name": "calculate", "content": json.dumps({"result": 110})},
198+
],
199+
"tools": TOOLS,
200+
}
201+
DUMMY_TOOL_RESPONSE_RESPONSE = ChatCompletion(
202+
id="chatcmpl-test-id",
203+
choices=[
204+
Choice(
205+
finish_reason="stop",
206+
index=0,
207+
logprobs=None,
208+
message=ChatCompletionMessage(
209+
content="The result of the calculation is 110",
210+
role="assistant",
211+
function_call=None,
212+
reasoning=None,
213+
tool_calls=None,
214+
),
215+
)
216+
],
217+
created=1733574047,
218+
model="llama3-8b-8192",
219+
object="chat.completion",
220+
system_fingerprint="fp_test",
221+
usage=DUMMY_COMPLETION_USAGE,
222+
x_groq={"id": "req_test"},
223+
)
224+
225+
226+
def test_tool_response_autolog():
227+
mlflow.groq.autolog()
228+
client = groq.Groq()
229+
230+
with patch("groq._client.Groq.post", return_value=DUMMY_TOOL_RESPONSE_RESPONSE):
231+
client.chat.completions.create(**DUMMY_TOOL_RESPONSE_REQUEST)
232+
233+
traces = get_traces()
234+
assert len(traces) == 1
235+
assert traces[0].info.status == "OK"
236+
assert len(traces[0].data.spans) == 1
237+
span = traces[0].data.spans[0]
238+
assert span.name == "Completions"
239+
assert span.span_type == SpanType.CHAT_MODEL
240+
assert span.inputs == DUMMY_TOOL_RESPONSE_REQUEST
241+
assert span.outputs == DUMMY_TOOL_RESPONSE_RESPONSE.to_dict()
242+
assert span.get_attribute("mlflow.chat.messages") == [
243+
*DUMMY_TOOL_RESPONSE_REQUEST["messages"],
244+
DUMMY_TOOL_RESPONSE_RESPONSE.choices[0].message.to_dict(),
245+
]
246+
247+
84248
BINARY_CONTENT = b"\x00\x00\x00\x14ftypM4A \x00\x00\x00\x00mdat\x00\x01\x02\x03"
85249

86250
DUMMY_AUDIO_TRANSCRIPTION_REQUEST = {
@@ -91,12 +255,12 @@ def test_chat_completion_autolog(mock_post):
91255
DUMMY_AUDIO_TRANSCRIPTION_RESPONSE = Transcription(text="Test audio", x_groq={"id": "req_test"})
92256

93257

94-
@patch.dict(os.environ, {"GROQ_API_KEY": "test_key"})
95-
@patch("groq._client.Groq.post", return_value=DUMMY_AUDIO_TRANSCRIPTION_RESPONSE)
96-
def test_audio_transcription_autolog(mock_post):
258+
def test_audio_transcription_autolog():
97259
mlflow.groq.autolog()
98260
client = groq.Groq()
99-
client.audio.transcriptions.create(**DUMMY_AUDIO_TRANSCRIPTION_REQUEST)
261+
262+
with patch("groq._client.Groq.post", return_value=DUMMY_AUDIO_TRANSCRIPTION_RESPONSE):
263+
client.audio.transcriptions.create(**DUMMY_AUDIO_TRANSCRIPTION_REQUEST)
100264

101265
traces = get_traces()
102266
assert len(traces) == 1
@@ -112,7 +276,9 @@ def test_audio_transcription_autolog(mock_post):
112276

113277
mlflow.groq.autolog(disable=True)
114278
client = groq.Groq()
115-
client.audio.transcriptions.create(**DUMMY_AUDIO_TRANSCRIPTION_REQUEST)
279+
280+
with patch("groq._client.Groq.post", return_value=DUMMY_AUDIO_TRANSCRIPTION_RESPONSE):
281+
client.audio.transcriptions.create(**DUMMY_AUDIO_TRANSCRIPTION_REQUEST)
116282

117283
# No new trace should be created
118284
traces = get_traces()
@@ -127,12 +293,12 @@ def test_audio_transcription_autolog(mock_post):
127293
DUMMY_AUDIO_TRANSLATION_RESPONSE = Translation(text="Test audio", x_groq={"id": "req_test"})
128294

129295

130-
@patch.dict(os.environ, {"GROQ_API_KEY": "test_key"})
131-
@patch("groq._client.Groq.post", return_value=DUMMY_AUDIO_TRANSLATION_RESPONSE)
132-
def test_audio_translation_autolog(mock_post):
296+
def test_audio_translation_autolog():
133297
mlflow.groq.autolog()
134298
client = groq.Groq()
135-
client.audio.translations.create(**DUMMY_AUDIO_TRANSLATION_REQUEST)
299+
300+
with patch("groq._client.Groq.post", return_value=DUMMY_AUDIO_TRANSLATION_RESPONSE):
301+
client.audio.translations.create(**DUMMY_AUDIO_TRANSLATION_REQUEST)
136302

137303
traces = get_traces()
138304
assert len(traces) == 1
@@ -148,7 +314,9 @@ def test_audio_translation_autolog(mock_post):
148314

149315
mlflow.groq.autolog(disable=True)
150316
client = groq.Groq()
151-
client.audio.translations.create(**DUMMY_AUDIO_TRANSLATION_REQUEST)
317+
318+
with patch("groq._client.Groq.post", return_value=DUMMY_AUDIO_TRANSLATION_RESPONSE):
319+
client.audio.translations.create(**DUMMY_AUDIO_TRANSLATION_REQUEST)
152320

153321
# No new trace should be created
154322
traces = get_traces()

0 commit comments

Comments
 (0)