1
- import os
1
+ import json
2
2
from unittest .mock import patch
3
3
4
4
import groq
5
+ import pytest
5
6
from groq .types .audio .transcription import Transcription
6
7
from groq .types .audio .translation import Translation
8
+ from groq .types .chat import ChatCompletionMessageToolCall
7
9
from groq .types .chat .chat_completion import (
8
10
ChatCompletion ,
9
11
ChatCompletionMessage ,
22
24
"messages" : [{"role" : "user" , "content" : "test message" }],
23
25
}
24
26
27
+ DUMMY_COMPLETION_USAGE = CompletionUsage (
28
+ completion_tokens = 648 ,
29
+ prompt_tokens = 20 ,
30
+ total_tokens = 668 ,
31
+ completion_time = 0.54 ,
32
+ prompt_time = 0.000181289 ,
33
+ queue_time = 0.012770949 ,
34
+ total_time = 0.540181289 ,
35
+ )
36
+
25
37
DUMMY_CHAT_COMPLETION_RESPONSE = ChatCompletion (
26
38
id = "chatcmpl-test-id" ,
27
39
choices = [
42
54
model = "llama3-8b-8192" ,
43
55
object = "chat.completion" ,
44
56
system_fingerprint = "fp_test" ,
45
- usage = CompletionUsage (
46
- completion_tokens = 648 ,
47
- prompt_tokens = 20 ,
48
- total_tokens = 668 ,
49
- completion_time = 0.54 ,
50
- prompt_time = 0.000181289 ,
51
- queue_time = 0.012770949 ,
52
- total_time = 0.540181289 ,
53
- ),
57
+ usage = DUMMY_COMPLETION_USAGE ,
54
58
x_groq = {"id" : "req_test" },
55
59
)
56
60
57
61
58
- @patch .dict (os .environ , {"GROQ_API_KEY" : "test_key" })
59
- @patch ("groq._client.Groq.post" , return_value = DUMMY_CHAT_COMPLETION_RESPONSE )
60
- def test_chat_completion_autolog (mock_post ):
62
+ @pytest .fixture (autouse = True )
63
+ def init_state (monkeypatch ):
64
+ monkeypatch .setenv ("GROQ_API_KEY" , "test_key" )
65
+ yield
66
+ mlflow .groq .autolog (disable = True )
67
+
68
+
69
+ def test_chat_completion_autolog ():
61
70
mlflow .groq .autolog ()
62
71
client = groq .Groq ()
63
- client .chat .completions .create (** DUMMY_CHAT_COMPLETION_REQUEST )
72
+
73
+ with patch ("groq._client.Groq.post" , return_value = DUMMY_CHAT_COMPLETION_RESPONSE ):
74
+ client .chat .completions .create (** DUMMY_CHAT_COMPLETION_REQUEST )
64
75
65
76
traces = get_traces ()
66
77
assert len (traces ) == 1
@@ -74,13 +85,166 @@ def test_chat_completion_autolog(mock_post):
74
85
75
86
mlflow .groq .autolog (disable = True )
76
87
client = groq .Groq ()
77
- client .chat .completions .create (** DUMMY_CHAT_COMPLETION_REQUEST )
88
+
89
+ with patch ("groq._client.Groq.post" , return_value = DUMMY_CHAT_COMPLETION_RESPONSE ):
90
+ client .chat .completions .create (** DUMMY_CHAT_COMPLETION_REQUEST )
78
91
79
92
# No new trace should be created
80
93
traces = get_traces ()
81
94
assert len (traces ) == 1
82
95
83
96
97
+ TOOLS = [
98
+ {
99
+ "type" : "function" ,
100
+ "function" : {
101
+ "name" : "calculate" ,
102
+ "description" : "Evaluate a mathematical expression" ,
103
+ "parameters" : {
104
+ "type" : "object" ,
105
+ "properties" : {
106
+ "expression" : {
107
+ "type" : "string" ,
108
+ "description" : "The mathematical expression to evaluate" ,
109
+ }
110
+ },
111
+ "required" : ["expression" ],
112
+ },
113
+ },
114
+ }
115
+ ]
116
+ DUMMY_TOOL_CALL_REQUEST = {
117
+ "model" : "test_model" ,
118
+ "max_tokens" : 1024 ,
119
+ "messages" : [{"role" : "user" , "content" : "What is 25 * 4 + 10?" }],
120
+ "tools" : TOOLS ,
121
+ }
122
+ DUMMY_TOOL_CALL_RESPONSE = ChatCompletion (
123
+ id = "chatcmpl-test-id" ,
124
+ choices = [
125
+ Choice (
126
+ finish_reason = "stop" ,
127
+ index = 0 ,
128
+ logprobs = None ,
129
+ message = ChatCompletionMessage (
130
+ content = None ,
131
+ role = "assistant" ,
132
+ function_call = None ,
133
+ tool_calls = [
134
+ ChatCompletionMessageToolCall (
135
+ id = "tool call id" ,
136
+ function = {
137
+ "name" : "calculate" ,
138
+ "arguments" : json .dumps ({"expression" : "25 * 4 + 10" }),
139
+ },
140
+ type = "function" ,
141
+ )
142
+ ],
143
+ reasoning = None ,
144
+ ),
145
+ )
146
+ ],
147
+ created = 1733574047 ,
148
+ model = "llama3-8b-8192" ,
149
+ object = "chat.completion" ,
150
+ system_fingerprint = "fp_test" ,
151
+ usage = DUMMY_COMPLETION_USAGE ,
152
+ x_groq = {"id" : "req_test" },
153
+ )
154
+
155
+
156
+ def test_tool_calling_autolog ():
157
+ mlflow .groq .autolog ()
158
+ client = groq .Groq ()
159
+
160
+ with patch ("groq._client.Groq.post" , return_value = DUMMY_TOOL_CALL_RESPONSE ):
161
+ client .chat .completions .create (** DUMMY_TOOL_CALL_REQUEST )
162
+
163
+ traces = get_traces ()
164
+ assert len (traces ) == 1
165
+ assert traces [0 ].info .status == "OK"
166
+ assert len (traces [0 ].data .spans ) == 1
167
+ span = traces [0 ].data .spans [0 ]
168
+ assert span .name == "Completions"
169
+ assert span .span_type == SpanType .CHAT_MODEL
170
+ assert span .inputs == DUMMY_TOOL_CALL_REQUEST
171
+ assert span .outputs == DUMMY_TOOL_CALL_RESPONSE .to_dict ()
172
+ assert span .get_attribute ("mlflow.chat.tools" ) == TOOLS
173
+ assert span .get_attribute ("mlflow.chat.messages" ) == [
174
+ * DUMMY_TOOL_CALL_REQUEST ["messages" ],
175
+ DUMMY_TOOL_CALL_RESPONSE .choices [0 ].message .to_dict (),
176
+ ]
177
+
178
+
179
+ DUMMY_TOOL_RESPONSE_REQUEST = {
180
+ "model" : "test_model" ,
181
+ "max_tokens" : 1024 ,
182
+ "messages" : [
183
+ {"role" : "user" , "content" : "What is 25 * 4 + 10?" },
184
+ {
185
+ "role" : "assistant" ,
186
+ "tool_calls" : [
187
+ {
188
+ "id" : "tool call id" ,
189
+ "function" : {
190
+ "name" : "calculate" ,
191
+ "arguments" : json .dumps ({"expression" : "25 * 4 + 10" }),
192
+ },
193
+ "type" : "function" ,
194
+ }
195
+ ],
196
+ },
197
+ {"role" : "tool" , "name" : "calculate" , "content" : json .dumps ({"result" : 110 })},
198
+ ],
199
+ "tools" : TOOLS ,
200
+ }
201
+ DUMMY_TOOL_RESPONSE_RESPONSE = ChatCompletion (
202
+ id = "chatcmpl-test-id" ,
203
+ choices = [
204
+ Choice (
205
+ finish_reason = "stop" ,
206
+ index = 0 ,
207
+ logprobs = None ,
208
+ message = ChatCompletionMessage (
209
+ content = "The result of the calculation is 110" ,
210
+ role = "assistant" ,
211
+ function_call = None ,
212
+ reasoning = None ,
213
+ tool_calls = None ,
214
+ ),
215
+ )
216
+ ],
217
+ created = 1733574047 ,
218
+ model = "llama3-8b-8192" ,
219
+ object = "chat.completion" ,
220
+ system_fingerprint = "fp_test" ,
221
+ usage = DUMMY_COMPLETION_USAGE ,
222
+ x_groq = {"id" : "req_test" },
223
+ )
224
+
225
+
226
+ def test_tool_response_autolog ():
227
+ mlflow .groq .autolog ()
228
+ client = groq .Groq ()
229
+
230
+ with patch ("groq._client.Groq.post" , return_value = DUMMY_TOOL_RESPONSE_RESPONSE ):
231
+ client .chat .completions .create (** DUMMY_TOOL_RESPONSE_REQUEST )
232
+
233
+ traces = get_traces ()
234
+ assert len (traces ) == 1
235
+ assert traces [0 ].info .status == "OK"
236
+ assert len (traces [0 ].data .spans ) == 1
237
+ span = traces [0 ].data .spans [0 ]
238
+ assert span .name == "Completions"
239
+ assert span .span_type == SpanType .CHAT_MODEL
240
+ assert span .inputs == DUMMY_TOOL_RESPONSE_REQUEST
241
+ assert span .outputs == DUMMY_TOOL_RESPONSE_RESPONSE .to_dict ()
242
+ assert span .get_attribute ("mlflow.chat.messages" ) == [
243
+ * DUMMY_TOOL_RESPONSE_REQUEST ["messages" ],
244
+ DUMMY_TOOL_RESPONSE_RESPONSE .choices [0 ].message .to_dict (),
245
+ ]
246
+
247
+
84
248
BINARY_CONTENT = b"\x00 \x00 \x00 \x14 ftypM4A \x00 \x00 \x00 \x00 mdat\x00 \x01 \x02 \x03 "
85
249
86
250
DUMMY_AUDIO_TRANSCRIPTION_REQUEST = {
@@ -91,12 +255,12 @@ def test_chat_completion_autolog(mock_post):
91
255
DUMMY_AUDIO_TRANSCRIPTION_RESPONSE = Transcription (text = "Test audio" , x_groq = {"id" : "req_test" })
92
256
93
257
94
- @patch .dict (os .environ , {"GROQ_API_KEY" : "test_key" })
95
- @patch ("groq._client.Groq.post" , return_value = DUMMY_AUDIO_TRANSCRIPTION_RESPONSE )
96
- def test_audio_transcription_autolog (mock_post ):
258
+ def test_audio_transcription_autolog ():
97
259
mlflow .groq .autolog ()
98
260
client = groq .Groq ()
99
- client .audio .transcriptions .create (** DUMMY_AUDIO_TRANSCRIPTION_REQUEST )
261
+
262
+ with patch ("groq._client.Groq.post" , return_value = DUMMY_AUDIO_TRANSCRIPTION_RESPONSE ):
263
+ client .audio .transcriptions .create (** DUMMY_AUDIO_TRANSCRIPTION_REQUEST )
100
264
101
265
traces = get_traces ()
102
266
assert len (traces ) == 1
@@ -112,7 +276,9 @@ def test_audio_transcription_autolog(mock_post):
112
276
113
277
mlflow .groq .autolog (disable = True )
114
278
client = groq .Groq ()
115
- client .audio .transcriptions .create (** DUMMY_AUDIO_TRANSCRIPTION_REQUEST )
279
+
280
+ with patch ("groq._client.Groq.post" , return_value = DUMMY_AUDIO_TRANSCRIPTION_RESPONSE ):
281
+ client .audio .transcriptions .create (** DUMMY_AUDIO_TRANSCRIPTION_REQUEST )
116
282
117
283
# No new trace should be created
118
284
traces = get_traces ()
@@ -127,12 +293,12 @@ def test_audio_transcription_autolog(mock_post):
127
293
DUMMY_AUDIO_TRANSLATION_RESPONSE = Translation (text = "Test audio" , x_groq = {"id" : "req_test" })
128
294
129
295
130
- @patch .dict (os .environ , {"GROQ_API_KEY" : "test_key" })
131
- @patch ("groq._client.Groq.post" , return_value = DUMMY_AUDIO_TRANSLATION_RESPONSE )
132
- def test_audio_translation_autolog (mock_post ):
296
+ def test_audio_translation_autolog ():
133
297
mlflow .groq .autolog ()
134
298
client = groq .Groq ()
135
- client .audio .translations .create (** DUMMY_AUDIO_TRANSLATION_REQUEST )
299
+
300
+ with patch ("groq._client.Groq.post" , return_value = DUMMY_AUDIO_TRANSLATION_RESPONSE ):
301
+ client .audio .translations .create (** DUMMY_AUDIO_TRANSLATION_REQUEST )
136
302
137
303
traces = get_traces ()
138
304
assert len (traces ) == 1
@@ -148,7 +314,9 @@ def test_audio_translation_autolog(mock_post):
148
314
149
315
mlflow .groq .autolog (disable = True )
150
316
client = groq .Groq ()
151
- client .audio .translations .create (** DUMMY_AUDIO_TRANSLATION_REQUEST )
317
+
318
+ with patch ("groq._client.Groq.post" , return_value = DUMMY_AUDIO_TRANSLATION_RESPONSE ):
319
+ client .audio .translations .create (** DUMMY_AUDIO_TRANSLATION_REQUEST )
152
320
153
321
# No new trace should be created
154
322
traces = get_traces ()
0 commit comments