Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions llama_stack/testing/inference_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from __future__ import annotations # for forward references

import hashlib
import inspect
import json
import os
from collections.abc import Generator
Expand Down Expand Up @@ -243,11 +242,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
global _current_mode, _current_storage

if _current_mode == InferenceMode.LIVE or _current_storage is None:
# Normal operation
if inspect.iscoroutinefunction(original_method):
return await original_method(self, *args, **kwargs)
else:
if endpoint == "/v1/models":
return original_method(self, *args, **kwargs)
else:
return await original_method(self, *args, **kwargs)

# Get base URL based on client type
if client_type == "openai":
Expand Down Expand Up @@ -298,10 +296,10 @@ async def replay_stream():
)

elif _current_mode == InferenceMode.RECORD:
if inspect.iscoroutinefunction(original_method):
response = await original_method(self, *args, **kwargs)
else:
if endpoint == "/v1/models":
response = original_method(self, *args, **kwargs)
else:
response = await original_method(self, *args, **kwargs)

# we want to store the result of the iterator, not the iterator itself
if endpoint == "/v1/models":
Expand Down
216 changes: 111 additions & 105 deletions tests/unit/distribution/test_inference_recordings.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,71 +155,61 @@ def test_response_storage(self, temp_storage_dir):

async def test_recording_mode(self, temp_storage_dir, real_openai_chat_response):
"""Test that recording mode captures and stores responses."""

async def mock_create(*args, **kwargs):
return real_openai_chat_response

temp_storage_dir = temp_storage_dir / "test_recording_mode"
with patch(
"openai.resources.chat.completions.AsyncCompletions.create", new_callable=AsyncMock, side_effect=mock_create
):
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")

response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
user=NOT_GIVEN,
)
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)

response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
user=NOT_GIVEN,
)

# Verify the response was returned correctly
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
# Verify the response was returned correctly
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
client.chat.completions._post.assert_called_once()

# Verify recording was stored
storage = ResponseStorage(temp_storage_dir)
assert storage.responses_dir.exists()

async def test_replay_mode(self, temp_storage_dir, real_openai_chat_response):
"""Test that replay mode returns stored responses without making real calls."""

async def mock_create(*args, **kwargs):
return real_openai_chat_response

temp_storage_dir = temp_storage_dir / "test_replay_mode"
# First, record a response
with patch(
"openai.resources.chat.completions.AsyncCompletions.create", new_callable=AsyncMock, side_effect=mock_create
):
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")

response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
user=NOT_GIVEN,
)
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)

response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
user=NOT_GIVEN,
)
client.chat.completions._post.assert_called_once()

# Now test replay mode - should not call the original method
with patch("openai.resources.chat.completions.AsyncCompletions.create") as mock_create_patch:
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.chat.completions._post = AsyncMock(return_value=real_openai_chat_response)

response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
)
response = await client.chat.completions.create(
model="llama3.2:3b",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=0.7,
max_tokens=50,
)

# Verify we got the recorded response
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."
# Verify we got the recorded response
assert response.choices[0].message.content == "Hello! I'm doing well, thank you for asking."

# Verify the original method was NOT called
mock_create_patch.assert_not_called()
# Verify the original method was NOT called
client.chat.completions._post.assert_not_called()

async def test_replay_mode_models(self, temp_storage_dir):
"""Test that replay mode returns stored responses without making real model listing calls."""
Expand Down Expand Up @@ -272,43 +262,50 @@ async def test_replay_missing_recording(self, temp_storage_dir):
async def test_embeddings_recording(self, temp_storage_dir, real_embeddings_response):
"""Test recording and replay of embeddings calls."""

async def mock_create(*args, **kwargs):
return real_embeddings_response
# baseline - mock works without recording
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)
response = await client.embeddings.create(
model=real_embeddings_response.model,
input=["Hello world", "Test embedding"],
encoding_format=NOT_GIVEN,
)
assert len(response.data) == 2
assert response.data[0].embedding == [0.1, 0.2, 0.3]
client.embeddings._post.assert_called_once()

temp_storage_dir = temp_storage_dir / "test_embeddings_recording"
# Record
with patch(
"openai.resources.embeddings.AsyncEmbeddings.create", new_callable=AsyncMock, side_effect=mock_create
):
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")

response = await client.embeddings.create(
model=real_embeddings_response.model,
input=["Hello world", "Test embedding"],
encoding_format=NOT_GIVEN,
dimensions=NOT_GIVEN,
user=NOT_GIVEN,
)
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)

response = await client.embeddings.create(
model=real_embeddings_response.model,
input=["Hello world", "Test embedding"],
encoding_format=NOT_GIVEN,
dimensions=NOT_GIVEN,
user=NOT_GIVEN,
)

assert len(response.data) == 2
assert len(response.data) == 2

# Replay
with patch("openai.resources.embeddings.AsyncEmbeddings.create") as mock_create_patch:
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.embeddings._post = AsyncMock(return_value=real_embeddings_response)

response = await client.embeddings.create(
model=real_embeddings_response.model,
input=["Hello world", "Test embedding"],
)
response = await client.embeddings.create(
model=real_embeddings_response.model,
input=["Hello world", "Test embedding"],
)

# Verify we got the recorded response
assert len(response.data) == 2
assert response.data[0].embedding == [0.1, 0.2, 0.3]
# Verify we got the recorded response
assert len(response.data) == 2
assert response.data[0].embedding == [0.1, 0.2, 0.3]

# Verify original method was not called
mock_create_patch.assert_not_called()
# Verify original method was not called
client.embeddings._post.assert_not_called()

async def test_completions_recording(self, temp_storage_dir):
real_completions_response = OpenAICompletion(
Expand All @@ -326,40 +323,49 @@ async def test_completions_recording(self, temp_storage_dir):
],
)

async def mock_create(*args, **kwargs):
return real_completions_response

temp_storage_dir = temp_storage_dir / "test_completions_recording"

# Record
with patch(
"openai.resources.completions.AsyncCompletions.create", new_callable=AsyncMock, side_effect=mock_create
):
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
# baseline - mock works without recording
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.completions._post = AsyncMock(return_value=real_completions_response)
response = await client.completions.create(
model=real_completions_response.model,
prompt="Hello, how are you?",
temperature=0.7,
max_tokens=50,
user=NOT_GIVEN,
)
assert response.choices[0].text == real_completions_response.choices[0].text
client.completions._post.assert_called_once()

response = await client.completions.create(
model=real_completions_response.model,
prompt="Hello, how are you?",
temperature=0.7,
max_tokens=50,
user=NOT_GIVEN,
)
# Record
with inference_recording(mode=InferenceMode.RECORD, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.completions._post = AsyncMock(return_value=real_completions_response)

response = await client.completions.create(
model=real_completions_response.model,
prompt="Hello, how are you?",
temperature=0.7,
max_tokens=50,
user=NOT_GIVEN,
)

assert response.choices[0].text == real_completions_response.choices[0].text
assert response.choices[0].text == real_completions_response.choices[0].text
client.completions._post.assert_called_once()

# Replay
with patch("openai.resources.completions.AsyncCompletions.create") as mock_create_patch:
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
response = await client.completions.create(
model=real_completions_response.model,
prompt="Hello, how are you?",
temperature=0.7,
max_tokens=50,
)
assert response.choices[0].text == real_completions_response.choices[0].text
mock_create_patch.assert_not_called()
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=str(temp_storage_dir)):
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
client.completions._post = AsyncMock(return_value=real_completions_response)
response = await client.completions.create(
model=real_completions_response.model,
prompt="Hello, how are you?",
temperature=0.7,
max_tokens=50,
)
assert response.choices[0].text == real_completions_response.choices[0].text
client.completions._post.assert_not_called()

async def test_live_mode(self, real_openai_chat_response):
"""Test that live mode passes through to original methods."""
Expand Down
Loading