diff --git a/verifiers/clients/openai_chat_completions_client.py b/verifiers/clients/openai_chat_completions_client.py index 3cd4acf58..7fc287397 100644 --- a/verifiers/clients/openai_chat_completions_client.py +++ b/verifiers/clients/openai_chat_completions_client.py @@ -33,7 +33,6 @@ from openai.types.shared_params import FunctionDefinition from verifiers.clients.client import Client -from verifiers.clients.routed_experts import parse_routed_experts from verifiers.errors import ( EmptyModelResponseError, InvalidModelResponseError, @@ -58,6 +57,7 @@ UserMessage, ) from verifiers.utils.client_utils import setup_openai_client +from verifiers.utils.response_utils import parse_routed_experts def handle_openai_overlong_prompt(func): diff --git a/verifiers/clients/openai_completions_client.py b/verifiers/clients/openai_completions_client.py index 34e015b49..a170a5872 100644 --- a/verifiers/clients/openai_completions_client.py +++ b/verifiers/clients/openai_completions_client.py @@ -9,7 +9,6 @@ get_usage_field, handle_openai_overlong_prompt, ) -from verifiers.clients.routed_experts import parse_routed_experts from verifiers.errors import ( EmptyModelResponseError, InvalidModelResponseError, @@ -26,6 +25,7 @@ Usage, ) from verifiers.utils.client_utils import setup_openai_client +from verifiers.utils.response_utils import parse_routed_experts OpenAITextMessages = str OpenAITextResponse = Completion diff --git a/verifiers/clients/routed_experts.py b/verifiers/clients/routed_experts.py deleted file mode 100644 index fb92ffa04..000000000 --- a/verifiers/clients/routed_experts.py +++ /dev/null @@ -1,24 +0,0 @@ -import base64 -from io import BytesIO -from typing import Any, cast - -import numpy as np - - -def parse_routed_experts(raw: Any) -> str | None: - if raw is None: - return None - return cast(str, raw) - - -def truncate_routed_experts(routed_experts: str | None, seq_len: int) -> str | None: - if routed_experts is None: - return None - - array = np.load(BytesIO(base64.b64decode(routed_experts)), allow_pickle=False) - assert array.ndim == 3 - assert 0 <= seq_len <= array.shape[0] - - buffer = BytesIO() - np.save(buffer, np.ascontiguousarray(array[:seq_len]), allow_pickle=False) - return base64.b64encode(buffer.getvalue()).decode("ascii") diff --git a/verifiers/utils/response_utils.py b/verifiers/utils/response_utils.py index 2ac00a83d..a6c40f1a0 100644 --- a/verifiers/utils/response_utils.py +++ b/verifiers/utils/response_utils.py @@ -1,4 +1,9 @@ -from verifiers.clients.routed_experts import truncate_routed_experts +import base64 +from io import BytesIO +from typing import Any, cast + +import numpy as np + from verifiers.types import ( AssistantMessage, Messages, @@ -7,6 +12,25 @@ ) +def parse_routed_experts(raw: Any) -> str | None: + if raw is None: + return None + return cast(str, raw) + + +def truncate_routed_experts(routed_experts: str | None, seq_len: int) -> str | None: + if routed_experts is None: + return None + + array = np.load(BytesIO(base64.b64decode(routed_experts)), allow_pickle=False) + assert array.ndim == 3 + assert 0 <= seq_len <= array.shape[0] + + buffer = BytesIO() + np.save(buffer, np.ascontiguousarray(array[:seq_len]), allow_pickle=False) + return base64.b64encode(buffer.getvalue()).decode("ascii") + + async def parse_response_message(response: Response) -> Messages: """Parse a vf.Response into a vf.Messages list (single vf.AssistantMessage).""" response_message = response.message