diff --git a/py/packages/genkit/src/genkit/_ai/_aio.py b/py/packages/genkit/src/genkit/_ai/_aio.py index cba93b17e8..c30a82523c 100644 --- a/py/packages/genkit/src/genkit/_ai/_aio.py +++ b/py/packages/genkit/src/genkit/_ai/_aio.py @@ -21,6 +21,7 @@ import asyncio import inspect import json +import os import signal import socket import threading @@ -31,8 +32,6 @@ import anyio import uvicorn -from opentelemetry import trace as trace_api -from opentelemetry.sdk.trace import TracerProvider from pydantic import BaseModel from genkit._ai._embedding import EmbedderFn, EmbedderOptions, EmbedderRef, define_embedder @@ -94,6 +93,7 @@ from genkit._core._model import Document from genkit._core._plugin import Plugin from genkit._core._reflection import ReflectionServer, ServerSpec, create_reflection_asgi_app +from genkit._core._reflection_v2 import ReflectionServerV2 from genkit._core._registry import Registry from genkit._core._tracing import run_in_new_span from genkit._core._typing import ( @@ -628,9 +628,22 @@ def define_resource( # ------------------------------------------------------------------------- def _start_reflection_background(self) -> None: - """Start the Dev UI reflection server in a background daemon thread.""" + """Start the Dev UI reflection server in a background daemon thread. + + If GENKIT_REFLECTION_V2_SERVER is set (the CLI launches the runtime in + v2 mode and provides a WebSocket URL), run the v2 JSON-RPC client. + Otherwise start the v1 HTTP server. + """ async def _run_server() -> None: + v2_url = os.environ.get('GENKIT_REFLECTION_V2_SERVER') + if v2_url: + await logger.ainfo(f'Genkit Dev UI reflection v2 client connecting to {v2_url}') + server_v2 = ReflectionServerV2(self.registry, v2_url) + self._reflection_ready.set() + await server_v2.run_forever() + return + sockets: list[socket.socket] | None = None spec = self._reflection_server_spec if spec is None: @@ -668,7 +681,7 @@ def _initialize_registry(self, model: str | None, plugins: list[Plugin] | None) """Initialize the registry with default model and plugins.""" self.registry.default_model = model if model: - self.registry.register_value('defaultModel', model, model) + self.registry.register_value('defaultModel', 'defaultModel', model) for fmt in built_in_formats: self.define_format(fmt) @@ -1058,12 +1071,6 @@ def current_context() -> dict[str, Any] | None: """Get the current execution context, or None if not in an action.""" return ActionRunContext._current_context() # pyright: ignore[reportPrivateUsage] - async def flush_tracing(self) -> None: - """Flush all pending trace spans to exporters.""" - provider = trace_api.get_tracer_provider() - if isinstance(provider, TracerProvider): - await asyncio.to_thread(provider.force_flush) - async def run( self, *, diff --git a/py/packages/genkit/src/genkit/_core/_action.py b/py/packages/genkit/src/genkit/_core/_action.py index 23057c2f01..9cb8524db3 100644 --- a/py/packages/genkit/src/genkit/_core/_action.py +++ b/py/packages/genkit/src/genkit/_core/_action.py @@ -409,7 +409,7 @@ async def run( input: InputT | None = None, on_chunk: Callable[[ChunkT], None] | None = None, context: dict[str, object] | None = None, - on_trace_start: Callable[[str, str], None] | None = None, + on_trace_start: Callable[[str, str], Awaitable[None]] | None = None, telemetry_labels: dict[str, object] | None = None, ) -> ActionResponse[OutputT]: """Execute the action with optional input validation. @@ -543,7 +543,7 @@ def _make_tracing_wrapper( object | None, ActionRunContext, StreamingCallback | None, - Callable[[str, str], None] | None, + Callable[[str, str], Any] | None, dict[str, object] | None, ], Awaitable[ActionResponse[Any]], @@ -565,7 +565,7 @@ async def tracing_wrapper( input: object | None, ctx: ActionRunContext, on_chunk: StreamingCallback | None, - on_trace_start: Callable[[str, str], None] | None, + on_trace_start: Callable[[str, str], Awaitable[None]] | None, telemetry_labels: dict[str, object] | None, ) -> ActionResponse[Any]: start_time = time.perf_counter() @@ -579,7 +579,7 @@ async def tracing_wrapper( trace_id = format(span.get_span_context().trace_id, '032x') span_id = format(span.get_span_context().span_id, '016x') if on_trace_start: - on_trace_start(trace_id, span_id) + await on_trace_start(trace_id, span_id) # Set telemetry labels as direct span attributes (matches JS/Go behavior) if telemetry_labels: diff --git a/py/packages/genkit/src/genkit/_core/_reflection.py b/py/packages/genkit/src/genkit/_core/_reflection.py index 130f3ad8f8..76b68aefd7 100644 --- a/py/packages/genkit/src/genkit/_core/_reflection.py +++ b/py/packages/genkit/src/genkit/_core/_reflection.py @@ -72,7 +72,7 @@ class ActionRunner: trace_id: str | None = None span_id: str | None = None - def on_trace_start(self, tid: str, sid: str) -> None: + async def on_trace_start(self, tid: str, sid: str) -> None: self.trace_id, self.span_id = tid, sid if task := asyncio.current_task(): self.active_actions[tid] = task diff --git a/py/packages/genkit/src/genkit/_core/_reflection_v2.py b/py/packages/genkit/src/genkit/_core/_reflection_v2.py new file mode 100644 index 0000000000..f6310c341f --- /dev/null +++ b/py/packages/genkit/src/genkit/_core/_reflection_v2.py @@ -0,0 +1,543 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Reflection API v2 (WebSocket JSON-RPC client) for Genkit Dev UI / CLI. + +``runAction`` with ``stream: true`` emits ``streamChunk`` notifications (output streaming). +Bidirectional input streaming (``sendInputStreamChunk`` / ``endInputStream``) is not +implemented yet. Requests with an ``id`` receive JSON-RPC ``-32000`` with message +``Not implemented`` and ``error.data.stack`` (same pattern as JS ``throw`` in the handler). +Notifications without ``id`` are ignored except for a debug log. +""" + +from __future__ import annotations + +import asyncio +import json +import os +import traceback +from typing import Any + +import websockets +from opentelemetry import trace as trace_api +from opentelemetry.sdk.trace import TracerProvider +from pydantic import BaseModel, JsonValue, ValidationError +from websockets.exceptions import ConnectionClosed + +from genkit._core._constants import GENKIT_VERSION +from genkit._core._error import ReflectionError, ReflectionErrorDetails, StatusCodes, get_reflection_json +from genkit._core._logger import get_logger +from genkit._core._registry import Registry +from genkit._core._trace._default_exporter import TraceServerExporter +from genkit._core._tracing import add_custom_exporter +from genkit._core._typing import ( + ReflectionCancelActionParams, + ReflectionCancelActionResponse, + ReflectionConfigureParams, + ReflectionListValuesParams, + ReflectionRegisterParams, + ReflectionRunActionParams, + ReflectionRunActionStateParams, + ReflectionStreamChunkParams, + State, +) + +logger = get_logger(__name__) + +GENKIT_REFLECTION_API_SPEC_VERSION = 1 + +JSON_RPC_METHOD_NOT_FOUND = -32601 +JSON_RPC_INVALID_PARAMS = -32602 +JSON_RPC_SERVER_ERROR = -32000 + +RECONNECT_BASE_DELAY_S = 0.5 +RECONNECT_MAX_DELAY_S = 5.0 + +WRITE_TIMEOUT_S = 5.0 + + +def _coerce_json_rpc_message(message: object) -> str: + """JSON-RPC and RuntimeManagerV2 require ``error.message`` to be a string.""" + if isinstance(message, str): + return message + if message is None: + return 'Unknown error' + try: + return json.dumps(message, default=str) + except TypeError: + return str(message) + + +class JsonRpcCallError(Exception): + """Error returned in a JSON-RPC response for a request we originated.""" + + def __init__(self, code: int, message: str, data: object | None = None) -> None: + self.code = code + self.message = message + self.data = data + super().__init__(f'JSON-RPC error {code}: {message}') + + +def _chunk_for_json(chunk: object) -> object: + if isinstance(chunk, BaseModel): + return json.loads(chunk.model_dump_json()) + return chunk + + +def _omit_none(payload: dict[str, Any]) -> dict[str, Any]: + return {k: v for k, v in payload.items() if v is not None} + + +class ReflectionServerV2: + """WebSocket client that connects to the CLI reflection manager (RuntimeManagerV2). + + See module docstring for streaming support scope. + """ + + def __init__( + self, + registry: Registry, + ws_url: str, + *, + app_name: str | None = None, + ) -> None: + self._registry = registry + self._ws_url = ws_url + self._app_name = app_name + self._ws: Any = None + self._write_lock = asyncio.Lock() + self._pending: dict[str, asyncio.Future[JsonValue]] = {} + self._request_seq = 0 + self._active_actions: dict[str, asyncio.Task[Any]] = {} + self._stop = False + self._reflection_handshake_telemetry_applied = False + + def _apply_handshake_telemetry(self, url: str | None) -> None: + """Use the Dev UI trace server URL from the reflection handshake. + + The CLI manager returns ``telemetryServerUrl`` on ``register`` and may send it + again on ``configure``. We need that base URL so OpenTelemetry spans can be + POSTed to ``{url}/api/traces`` (see ``TraceServerExporter``). + """ + if not url or os.environ.get('GENKIT_TELEMETRY_SERVER'): + return + if self._reflection_handshake_telemetry_applied: + return + self._reflection_handshake_telemetry_applied = True + # Register HTTP export to this URL on the global OTel provider. + add_custom_exporter(TraceServerExporter(telemetry_server_url=url), 'reflection_v2_telemetry') + logger.debug('reflection V2: connected to telemetry server', url=url) + + async def run_forever(self) -> None: + """Connect, handle requests, reconnect with backoff until stop() or process exit.""" + attempt = 0 + while not self._stop: + try: + async with websockets.connect( + self._ws_url, + ping_interval=20, + ping_timeout=20, + ) as ws: + self._ws = ws + attempt = 0 + _ = asyncio.create_task(self._register()) + await self._read_loop() + except ConnectionClosed as e: + logger.debug('reflection V2: connection closed', code=e.code, reason=e.reason) + except OSError as e: + logger.debug('reflection V2: connection error', err=e) + finally: + self._ws = None + self._drain_pending(ConnectionError('connection closed')) + + if self._stop: + return + + delay = min(RECONNECT_BASE_DELAY_S * (2**attempt), RECONNECT_MAX_DELAY_S) + attempt += 1 + logger.debug('reflection V2: reconnect scheduled', delay_s=delay, attempt=attempt) + await asyncio.sleep(delay) + + def stop(self) -> None: + self._stop = True + + def _drain_pending(self, exc: BaseException) -> None: + for _rid, fut in list(self._pending.items()): + if not fut.done(): + fut.set_exception(exc) + self._pending.clear() + + async def _send_message(self, message: dict[str, Any]) -> None: + if self._ws is None: + raise ConnectionError('websocket not connected') + raw = json.dumps(message, default=str) + async with self._write_lock: + await asyncio.wait_for(self._ws.send(raw), timeout=WRITE_TIMEOUT_S) + + async def _send_response(self, req_id: str, result: object) -> None: + await self._send_message({'jsonrpc': '2.0', 'result': result, 'id': req_id}) + + async def _send_error( + self, + req_id: str, + code: int, + message: object, + data: object | None = None, + ) -> None: + """Emit a JSON-RPC error.""" + err: dict[str, Any] = {'code': code, 'message': _coerce_json_rpc_message(message)} + if data is not None: + err['data'] = data + await self._send_message({'jsonrpc': '2.0', 'error': err, 'id': req_id}) + + async def _send_notification(self, method: str, params: object) -> None: + await self._send_message({'jsonrpc': '2.0', 'method': method, 'params': params}) + + async def _send_request(self, method: str, params: object) -> JsonValue: + self._request_seq += 1 + req_id = str(self._request_seq) + loop = asyncio.get_running_loop() + fut: asyncio.Future[JsonValue] = loop.create_future() + self._pending[req_id] = fut + try: + await self._send_message({'jsonrpc': '2.0', 'id': req_id, 'method': method, 'params': params}) + return await fut + finally: + self._pending.pop(req_id, None) + + async def _register(self) -> None: + runtime_id = os.environ.get('GENKIT_RUNTIME_ID') or str(os.getpid()) + name = self._app_name or runtime_id + params = ReflectionRegisterParams( + id=runtime_id, + pid=float(os.getpid()), + name=name, + genkit_version='py/' + GENKIT_VERSION, + reflection_api_spec_version=float(GENKIT_REFLECTION_API_SPEC_VERSION), + envs=['dev'], + ).model_dump(by_alias=True, exclude_none=True) + try: + result = await self._send_request('register', params) + if isinstance(result, dict) and (telemetry_url := result.get('telemetryServerUrl')): + self._apply_handshake_telemetry(str(telemetry_url)) + except JsonRpcCallError as e: + logger.error('reflection V2: register failed', code=e.code, message=e.message) + except Exception as e: + logger.error('reflection V2: register failed', err=e) + + async def _read_loop(self) -> None: + assert self._ws is not None + async for raw in self._ws: + try: + msg = json.loads(raw) + except json.JSONDecodeError: + logger.debug('reflection V2: invalid JSON from manager') + continue + if not isinstance(msg, dict): + logger.debug('reflection V2: ignoring JSON value that is not an object', type=type(msg).__name__) + continue + if msg.get('jsonrpc') != '2.0': + logger.debug( + 'reflection V2: ignoring frame without jsonrpc 2.0', + jsonrpc=msg.get('jsonrpc'), + ) + continue + if 'method' in msg: + _ = asyncio.create_task(self._dispatch_incoming(msg)) + elif msg.get('id') is not None: + self._deliver_response(msg) + else: + logger.debug( + 'reflection V2: ignoring JSON-RPC 2.0 object without method or id', + keys=list(msg.keys()), + ) + + def _deliver_response(self, msg: dict[str, Any]) -> None: + req_id = msg.get('id') + if req_id is None: + return + sid = str(req_id) + fut = self._pending.pop(sid, None) + if fut is None: + logger.debug('reflection V2: response for unknown id', id=sid) + return + if err := msg.get('error'): + fut.set_exception( + JsonRpcCallError( + int(err.get('code', JSON_RPC_SERVER_ERROR)), + str(err.get('message', '')), + err.get('data'), + ) + ) + else: + fut.set_result(msg.get('result')) + + async def _dispatch_incoming(self, msg: dict[str, Any]) -> None: + method = msg.get('method') + req_id = msg.get('id') + params = msg.get('params') or {} + if not isinstance(params, dict): + if req_id is not None: + await self._send_error( + str(req_id), + JSON_RPC_INVALID_PARAMS, + 'params must be a JSON object', + ) + return + try: + if method == 'listActions': + await self._handle_list_actions(req_id, params) + elif method == 'listValues': + await self._handle_list_values(req_id, params) + elif method == 'runAction': + await self._handle_run_action(req_id, params) + elif method == 'cancelAction': + await self._handle_cancel_action(req_id, params) + elif method == 'configure': + self._handle_configure(params) + elif method in ('sendInputStreamChunk', 'endInputStream'): + await self._handle_input_stream_unimplemented(req_id, method) + else: + if req_id is not None: + await self._send_error( + str(req_id), + JSON_RPC_METHOD_NOT_FOUND, + f'method not found: {method}', + ) + else: + logger.debug('reflection V2: unknown notification', method=method) + except Exception: + logger.exception('reflection V2: handler error', method=method) + if req_id is not None: + await self._send_error(str(req_id), JSON_RPC_SERVER_ERROR, 'internal error') + + async def _handle_input_stream_unimplemented(self, req_id: str | int | None, method: str) -> None: + if req_id is None: + logger.debug('reflection V2: input stream method not implemented (notification)', method=method) + return + try: + raise NotImplementedError('Not implemented') + except NotImplementedError as e: + stack = ''.join(traceback.format_exception(type(e), e, e.__traceback__)) + await self._send_error( + str(req_id), + JSON_RPC_SERVER_ERROR, + str(e) or 'Not implemented', + {'stack': stack}, + ) + + async def _handle_list_actions(self, req_id: str | int | None, _: dict[str, Any]) -> None: + if req_id is None: + return + sid = str(req_id) + catalog = await self._registry.list_actions() + actions = { + key: _omit_none({ + 'key': key, + 'name': meta.name, + 'actionType': meta.action_type, + 'description': meta.description, + 'metadata': meta.metadata, + 'inputSchema': meta.input_schema or meta.input_json_schema, + 'outputSchema': meta.output_schema or meta.output_json_schema, + }) + for key, meta in catalog.items() + } + await self._send_response(sid, {'actions': actions}) + + async def _handle_list_values(self, req_id: str | int | None, params: dict[str, Any]) -> None: + if req_id is None: + return + sid = str(req_id) + try: + p = ReflectionListValuesParams.model_validate(params) + except ValidationError as e: + await self._send_error(sid, JSON_RPC_INVALID_PARAMS, f'invalid params: {e}') + return + if p.type not in ('defaultModel', 'middleware'): + await self._send_error( + sid, + JSON_RPC_INVALID_PARAMS, + f"'type' {p.type} is not supported. Only 'defaultModel' and 'middleware' are supported", + ) + return + mapped: dict[str, Any] = {} + for name in self._registry.list_values(p.type): + value = self._registry.lookup_value(p.type, name) + to_json_fn = getattr(value, 'to_json', None) if value is not None else None + if callable(to_json_fn): + mapped[name] = to_json_fn() + else: + mapped[name] = value + await self._send_response(sid, {'values': mapped}) + + def _handle_configure(self, params: dict[str, Any]) -> None: + try: + p = ReflectionConfigureParams.model_validate(params) + except ValidationError as e: + logger.error('reflection V2: invalid configure params', err=e) + return + if p.telemetry_server_url: + self._apply_handshake_telemetry(p.telemetry_server_url) + + async def _handle_cancel_action(self, req_id: str | int | None, params: dict[str, Any]) -> None: + if req_id is None: + return + sid = str(req_id) + try: + p = ReflectionCancelActionParams.model_validate(params) + except ValidationError as e: + await self._send_error(sid, JSON_RPC_INVALID_PARAMS, f'invalid params: {e}') + return + if not p.trace_id: + await self._send_error(sid, JSON_RPC_INVALID_PARAMS, 'traceId is required') + return + task = self._active_actions.get(p.trace_id) + if task: + task.cancel() + self._active_actions.pop(p.trace_id, None) + body = ReflectionCancelActionResponse(message='Action cancelled').model_dump(by_alias=True) + await self._send_response(sid, body) + else: + await self._send_error( + sid, + JSON_RPC_INVALID_PARAMS, + 'Action not found or already completed', + ) + + async def _flush_tracing(self) -> None: + provider = trace_api.get_tracer_provider() + if isinstance(provider, TracerProvider): + await asyncio.to_thread(provider.force_flush) + + async def _handle_run_action(self, req_id: str | int | None, params: dict[str, Any]) -> None: + if req_id is None: + return + sid = str(req_id) + try: + p = ReflectionRunActionParams.model_validate(params) + except ValidationError as e: + await self._send_error(sid, JSON_RPC_INVALID_PARAMS, f'invalid params: {e}') + return + + action = await self._registry.resolve_action_by_key(p.key) + if not action: + await self._send_error(sid, JSON_RPC_INVALID_PARAMS, f'action {p.key} not found') + return + + if p.context is not None and not isinstance(p.context, dict): + await self._send_error( + sid, + JSON_RPC_INVALID_PARAMS, + 'context must be a JSON object when provided', + ) + return + + stream = bool(p.stream) + trace_holder: list[str | None] = [None] + stream_chunk_tasks: list[asyncio.Task[Any]] = [] + + async def on_trace_start(tid: str, span_id: str) -> None: + trace_holder[0] = tid + if t := asyncio.current_task(): + self._active_actions[tid] = t + st = ReflectionRunActionStateParams( + request_id=sid, + state=State(trace_id=tid), + ).model_dump(by_alias=True, exclude_none=True) + await self._send_notification('runActionState', st) + + on_chunk = None + if stream: + + def on_chunk_fn(chunk: object) -> None: + chunk_payload = ReflectionStreamChunkParams( + request_id=sid, + chunk=_chunk_for_json(chunk), + ).model_dump(by_alias=True, exclude_none=True) + stream_chunk_tasks.append(asyncio.create_task(self._send_notification('streamChunk', chunk_payload))) + + on_chunk = on_chunk_fn + + ctx: dict[str, object] = {} if p.context is None else {str(k): v for k, v in p.context.items()} + + labels: dict[str, object] | None = None + if p.telemetry_labels is not None: + labels = {str(k): v for k, v in p.telemetry_labels.items()} + + try: + output = await action.run( + input=p.input, + on_chunk=on_chunk, + context=ctx or None, + on_trace_start=on_trace_start, + telemetry_labels=labels, + ) + if stream_chunk_tasks: + await asyncio.gather(*stream_chunk_tasks) + await self._flush_tracing() + result_body: object + if isinstance(output.response, BaseModel): + result_body = output.response.model_dump(by_alias=True, exclude_none=True) + else: + result_body = output.response + # Omit telemetry or traceId when absent — Dev UI parses with Zod; null traceId fails + # z.string().optional() and would surface as HTTP 500 with an empty error body. + success_body: dict[str, Any] = {'result': result_body} + if output.trace_id: + success_body['telemetry'] = {'traceId': output.trace_id} + await self._send_response(sid, success_body) + except asyncio.CancelledError: + err_details: dict[str, Any] = {} + if trace_holder[0]: + err_details['traceId'] = trace_holder[0] + err_data: dict[str, Any] = { + 'code': StatusCodes.CANCELLED.value, + 'message': 'Action was cancelled', + } + if err_details: + err_data['details'] = err_details + await self._send_error(sid, JSON_RPC_SERVER_ERROR, 'Action was cancelled', err_data) + return + except Exception as e: + logger.exception('reflection V2: runAction error') + # Wire contract requires ``details`` to carry only ``stack`` and ``traceId`` + # (see ``GenkitErrorSchema.data.genkitErrorDetails`` in genkit-tools); anything + # else in ``GenkitError.details`` is runtime-internal and gets dropped. + # + # ``stack``: prefer the value the error already carries (set by ``GenkitError`` + # and copied through by ``get_reflection_json``); fall back to formatting the + # live traceback so plain Python exceptions still surface a useful frame. + ref = get_reflection_json(e) + stack = ref.details.stack if ref.details else None + if not stack and e.__traceback__: + stack = ''.join(traceback.format_exception(type(e), e, e.__traceback__)) + tid = trace_holder[0] or (ref.details.trace_id if ref.details else None) + status = ReflectionError( + code=ref.code, + message=_coerce_json_rpc_message(ref.message), + details=ReflectionErrorDetails(stack=stack, trace_id=tid) if (stack or tid) else None, + ) + await self._send_error( + sid, + JSON_RPC_SERVER_ERROR, + status.message, + status.model_dump(by_alias=True, exclude_none=True), + ) + finally: + tid = trace_holder[0] + if tid: + self._active_actions.pop(tid, None) diff --git a/py/packages/genkit/src/genkit/_core/_typing.py b/py/packages/genkit/src/genkit/_core/_typing.py index 16cafd092e..8195652f71 100644 --- a/py/packages/genkit/src/genkit/_core/_typing.py +++ b/py/packages/genkit/src/genkit/_core/_typing.py @@ -851,10 +851,7 @@ class Values(GenkitModel): model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) -class TelemetryLabels(GenkitModel): - """Model for telemetrylabels data.""" - - model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) +TelemetryLabels = dict[str, str] # type alias for telemetrylabels (typed string map) class State(GenkitModel): diff --git a/py/packages/genkit/tests/genkit/ai/genkit_api_test.py b/py/packages/genkit/tests/genkit/ai/genkit_api_test.py index 6b6cbe0474..ee3ae2d601 100644 --- a/py/packages/genkit/tests/genkit/ai/genkit_api_test.py +++ b/py/packages/genkit/tests/genkit/ai/genkit_api_test.py @@ -9,8 +9,6 @@ from unittest.mock import AsyncMock, MagicMock import pytest -from opentelemetry import trace as trace_api -from opentelemetry.sdk.trace import TracerProvider from genkit import Genkit from genkit._core._action import _action_context @@ -100,16 +98,3 @@ async def test_current_context() -> None: _action_context.reset(token) assert Genkit.current_context() is None - - -@pytest.mark.asyncio -async def test_flush_tracing() -> None: - """Test Genkit.flush_tracing method.""" - ai = Genkit() - - mock_provider = MagicMock(spec=TracerProvider) - mock_provider.force_flush = MagicMock() - - with mock.patch.object(trace_api, 'get_tracer_provider', return_value=mock_provider): - await ai.flush_tracing() - mock_provider.force_flush.assert_called_once() diff --git a/py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py b/py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py index 060d05f13f..116fec1b21 100644 --- a/py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py +++ b/py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py @@ -145,11 +145,11 @@ async def side_effect( input: object = None, on_chunk: object | None = None, context: object | None = None, - on_trace_start: Callable[[str, str], None] | None = None, + on_trace_start: Callable[[str, str], Awaitable[None]] | None = None, **kwargs: Any, # noqa: ANN401 ) -> MagicMock: if on_trace_start: - on_trace_start('test_trace_id', 'test_span_id') + await on_trace_start('test_trace_id', 'test_span_id') return mock_output mock_action.run.side_effect = side_effect @@ -224,11 +224,11 @@ async def mock_streaming( input: object = None, on_chunk: object | None = None, context: object | None = None, - on_trace_start: Callable[[str, str], None] | None = None, + on_trace_start: Callable[[str, str], Awaitable[None]] | None = None, **kwargs: Any, # noqa: ANN401 ) -> MagicMock: if on_trace_start: - on_trace_start('stream_trace_id', 'stream_span_id') + await on_trace_start('stream_trace_id', 'stream_span_id') if on_chunk: on_chunk_fn = cast(Callable[[object], Awaitable[None]], on_chunk) await on_chunk_fn({'chunk': 1}) @@ -277,11 +277,11 @@ async def mock_streaming( input: object = None, on_chunk: object | None = None, context: object | None = None, - on_trace_start: Callable[[str, str], None] | None = None, + on_trace_start: Callable[[str, str], Awaitable[None]] | None = None, **kwargs: Any, # noqa: ANN401 ) -> MagicMock: if on_trace_start: - on_trace_start('stream_trace_id', 'stream_span_id') + await on_trace_start('stream_trace_id', 'stream_span_id') if on_chunk: on_chunk_fn = cast(Callable[[object], None], on_chunk) for chunk in chunks: diff --git a/py/packages/genkit/tests/genkit/core/reflection_v2_test.py b/py/packages/genkit/tests/genkit/core/reflection_v2_test.py new file mode 100644 index 0000000000..68e2cf4184 --- /dev/null +++ b/py/packages/genkit/tests/genkit/core/reflection_v2_test.py @@ -0,0 +1,532 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for Reflection API v2 (WebSocket JSON-RPC client). + +Design notes: + +- **fakeManager pattern**: A minimal in-process WebSocket *server* stands in for + the CLI ``RuntimeManagerV2``. The runtime under test is the *client*. This + isolates protocol handling without the full tools server or Dev UI. +- **Explicit JSON-RPC sequencing**: Tests ``read`` the next frame, assert + ``method`` / ``id`` / ``params``, then ``write`` responses. This catches + wrong ordering (e.g. ``register`` vs first ``listActions``) deterministically. +- **ackRegister helper**: The runtime sends ``register`` and awaits a result; + most tests must reply with a minimal ``result`` so the client does not stall. +- **Draining notifications**: ``runAction`` may emit ``runActionState`` frames + before the final ``result`` or ``error``; tests loop until they see the + response shape they need rather than asserting on the very next frame. +- **Parallel failure modes**: ``cancelAction`` tests assert on *two* correlated + replies (cancel ack + runAction error) without assuming order. +""" + +from __future__ import annotations + +import asyncio +import json +from typing import Any + +import pytest +import pytest_asyncio +from websockets.asyncio.server import serve + +from genkit._core._action import Action, ActionKind, ActionRunContext +from genkit._core._reflection_v2 import ( + JSON_RPC_INVALID_PARAMS, + JSON_RPC_METHOD_NOT_FOUND, + JSON_RPC_SERVER_ERROR, + ReflectionServerV2, +) +from genkit._core._registry import Registry + + +class FakeReflectionManager: + """Minimal WebSocket server that accepts one runtime client (CLI stand-in).""" + + def __init__(self) -> None: + self._stop = asyncio.Event() + self._client_ws: Any = None + self._server: Any = None + self._serve_ctx: Any = None + self._host = '127.0.0.1' + self._port = 0 + self._ready: asyncio.Future[None] | None = None + + @property + def url(self) -> str: + return f'ws://{self._host}:{self._port}' + + async def _handler(self, ws: Any) -> None: + self._client_ws = ws + if self._ready is not None and not self._ready.done(): + self._ready.set_result(None) + await self._stop.wait() + + async def start(self) -> None: + self._ready = asyncio.get_running_loop().create_future() + self._serve_ctx = serve(self._handler, self._host, 0) + self._server = await self._serve_ctx.__aenter__() + first_socket = next(iter(self._server.sockets)) + self._port = first_socket.getsockname()[1] + + async def aclose(self) -> None: + self._stop.set() + if self._client_ws is not None: + await self._client_ws.close() + if self._serve_ctx is not None: + await self._serve_ctx.__aexit__(None, None, None) + + async def wait_connected(self, timeout: float = 2.0) -> None: + assert self._ready is not None + await asyncio.wait_for(self._ready, timeout=timeout) + + async def read_rpc(self, timeout: float = 2.0) -> dict[str, Any]: + assert self._client_ws is not None + raw = await asyncio.wait_for(self._client_ws.recv(), timeout=timeout) + return json.loads(raw) + + async def write_rpc(self, msg: dict[str, Any]) -> None: + assert self._client_ws is not None + await self._client_ws.send(json.dumps(msg)) + + +async def ack_register(fm: FakeReflectionManager) -> dict[str, Any]: + msg = await fm.read_rpc() + assert msg.get('method') == 'register' + req_id = msg['id'] + assert isinstance(req_id, str) and req_id != '' + await fm.write_rpc({'jsonrpc': '2.0', 'result': {}, 'id': req_id}) + return msg + + +@pytest_asyncio.fixture(loop_scope='function') +async def fake_manager() -> Any: + fm = FakeReflectionManager() + await fm.start() + try: + yield fm + finally: + await fm.aclose() + + +async def _run_client_lifecycle( + registry: Registry, + fm: FakeReflectionManager, + *, + app_name: str = 'test-app', +) -> tuple[ReflectionServerV2, asyncio.Task[None]]: + client = ReflectionServerV2(registry, fm.url, app_name=app_name) + task = asyncio.create_task(client.run_forever()) + await fm.wait_connected() + await asyncio.sleep(0) # let register task schedule + return client, task + + +async def _stop_client(client: ReflectionServerV2, task: asyncio.Task[None]) -> None: + client.stop() + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_reflection_server_v2_register(fake_manager: FakeReflectionManager) -> None: + registry = Registry() + client, task = await _run_client_lifecycle(registry, fake_manager) + try: + msg = await fake_manager.read_rpc() + assert msg.get('method') == 'register' + assert isinstance(msg.get('id'), str) + params = msg.get('params') + assert isinstance(params, dict) + assert params.get('name') == 'test-app' + assert params.get('id') + assert isinstance(params.get('pid'), (int, float)) + assert str(params.get('genkitVersion', '')).startswith('py/') + assert isinstance(params.get('reflectionApiSpecVersion'), (int, float)) + envs = params.get('envs') + assert isinstance(envs, list) and envs == ['dev'] + finally: + await _stop_client(client, task) + + +@pytest.mark.asyncio +async def test_reflection_server_v2_register_handshake_telemetry(fake_manager: FakeReflectionManager) -> None: + registry = Registry() + client, task = await _run_client_lifecycle(registry, fake_manager) + try: + msg = await fake_manager.read_rpc() + assert msg.get('method') == 'register' + req_id = msg['id'] + await fake_manager.write_rpc({ + 'jsonrpc': '2.0', + 'result': {'telemetryServerUrl': 'http://127.0.0.1:9999'}, + 'id': req_id, + }) + finally: + await _stop_client(client, task) + + +@pytest.mark.asyncio +async def test_reflection_server_v2_list_actions(fake_manager: FakeReflectionManager) -> None: + """listActions returns the same action map as HTTP reflection (:func:`_get_actions_payload`).""" + registry = Registry() + + async def inc(x: int) -> int: + return x + 1 + + registry.register_action_from_instance(Action(ActionKind.CUSTOM, 'test/inc', inc)) + + client, task = await _run_client_lifecycle(registry, fake_manager) + try: + await ack_register(fake_manager) + await fake_manager.write_rpc({ + 'jsonrpc': '2.0', + 'method': 'listActions', + 'id': '1', + }) + resp = await fake_manager.read_rpc() + assert resp.get('id') == '1' + result = resp.get('result') + assert isinstance(result, dict) + actions = result.get('actions') + assert isinstance(actions, dict) + assert actions == { + '/custom/test/inc': { + 'key': '/custom/test/inc', + 'name': 'test/inc', + 'actionType': 'custom', + 'inputSchema': {'type': 'integer'}, + 'outputSchema': {'type': 'integer'}, + 'metadata': { + 'inputSchema': {'type': 'integer'}, + 'outputSchema': {'type': 'integer'}, + }, + } + } + finally: + await _stop_client(client, task) + + +@pytest.mark.asyncio +async def test_reflection_server_v2_list_values(fake_manager: FakeReflectionManager) -> None: + registry = Registry() + registry.register_value('defaultModel', 'defaultModel', 'my-model') + + client, task = await _run_client_lifecycle(registry, fake_manager) + try: + await ack_register(fake_manager) + await fake_manager.write_rpc({ + 'jsonrpc': '2.0', + 'method': 'listValues', + 'params': {'type': 'defaultModel'}, + 'id': '2', + }) + resp = await fake_manager.read_rpc() + assert resp.get('id') == '2' + result = resp.get('result') + assert isinstance(result, dict) + values = result.get('values') + assert isinstance(values, dict) + assert values.get('defaultModel') == 'my-model' + finally: + await _stop_client(client, task) + + +@pytest.mark.asyncio +async def test_reflection_server_v2_list_values_rejects_unsupported_type( + fake_manager: FakeReflectionManager, +) -> None: + registry = Registry() + client, task = await _run_client_lifecycle(registry, fake_manager) + try: + await ack_register(fake_manager) + await fake_manager.write_rpc({ + 'jsonrpc': '2.0', + 'method': 'listValues', + 'params': {'type': 'prompt'}, + 'id': '2a', + }) + resp = await fake_manager.read_rpc() + err = resp.get('error') + assert isinstance(err, dict) + assert err.get('code') == JSON_RPC_INVALID_PARAMS + finally: + await _stop_client(client, task) + + +@pytest.mark.asyncio +async def test_reflection_server_v2_run_action(fake_manager: FakeReflectionManager) -> None: + registry = Registry() + + async def inc(x: int) -> int: + return x + 1 + + registry.register_action_from_instance(Action(ActionKind.CUSTOM, 'test/inc', inc)) + + client, task = await _run_client_lifecycle(registry, fake_manager) + try: + await ack_register(fake_manager) + await fake_manager.write_rpc({ + 'jsonrpc': '2.0', + 'method': 'runAction', + 'params': {'key': '/custom/test/inc', 'input': 3}, + 'id': '3', + }) + resp: dict[str, Any] | None = None + while resp is None: + msg = await fake_manager.read_rpc() + if msg.get('method') == 'runActionState': + continue + resp = msg + assert resp.get('id') == '3' + assert resp.get('error') is None + result = resp.get('result') + assert isinstance(result, dict) + assert result.get('result') == 4 + telemetry = result.get('telemetry') + assert isinstance(telemetry, dict) + assert telemetry.get('traceId') + finally: + await _stop_client(client, task) + + +@pytest.mark.asyncio +async def test_reflection_server_v2_streaming_run_action(fake_manager: FakeReflectionManager) -> None: + registry = Registry() + + async def stream_inc(x: int, ctx: ActionRunContext) -> int: + for i in range(x): + ctx.send_chunk(i) + return x + + registry.register_action_from_instance(Action(ActionKind.CUSTOM, 'test/streaming', stream_inc)) + + client, task = await _run_client_lifecycle(registry, fake_manager) + try: + await ack_register(fake_manager) + await fake_manager.write_rpc({ + 'jsonrpc': '2.0', + 'method': 'runAction', + 'params': {'key': '/custom/test/streaming', 'input': 3, 'stream': True}, + 'id': '4', + }) + chunks: list[Any] = [] + final: dict[str, Any] | None = None + while final is None: + msg = await fake_manager.read_rpc() + if msg.get('method') == 'streamChunk': + params = msg.get('params') + assert isinstance(params, dict) + assert params.get('requestId') == '4' + chunks.append(params.get('chunk')) + continue + if msg.get('method') == 'runActionState': + continue + final = msg + assert len(chunks) == 3 + for i, c in enumerate(chunks): + assert c == i + assert final is not None + result = final.get('result') + assert isinstance(result, dict) + assert result.get('result') == 3 + finally: + await _stop_client(client, task) + + +@pytest.mark.asyncio +async def test_reflection_server_v2_run_action_not_found(fake_manager: FakeReflectionManager) -> None: + registry = Registry() + client, task = await _run_client_lifecycle(registry, fake_manager) + try: + await ack_register(fake_manager) + await fake_manager.write_rpc({ + 'jsonrpc': '2.0', + 'method': 'runAction', + 'params': {'key': '/custom/does-not-exist', 'input': None}, + 'id': '5', + }) + resp = await fake_manager.read_rpc() + err = resp.get('error') + assert isinstance(err, dict) + assert err.get('code') == JSON_RPC_INVALID_PARAMS + assert 'not found' in str(err.get('message', '')).lower() + finally: + await _stop_client(client, task) + + +@pytest.mark.asyncio +async def test_reflection_server_v2_cancel_action(fake_manager: FakeReflectionManager) -> None: + registry = Registry() + started = asyncio.Event() + + async def slow(_: Any = None) -> Any: + started.set() + await asyncio.sleep(10**6) + + registry.register_action_from_instance(Action(ActionKind.CUSTOM, 'test/slow', slow)) + + client, task = await _run_client_lifecycle(registry, fake_manager) + try: + await ack_register(fake_manager) + await fake_manager.write_rpc({ + 'jsonrpc': '2.0', + 'method': 'runAction', + 'params': {'key': '/custom/test/slow', 'input': None}, + 'id': '6', + }) + await asyncio.wait_for(started.wait(), timeout=2.0) + trace_id = '' + while not trace_id: + msg = await fake_manager.read_rpc() + if msg.get('method') == 'runActionState': + params = msg.get('params') + assert isinstance(params, dict) + state = params.get('state') + assert isinstance(state, dict) + tid = state.get('traceId') + if tid: + trace_id = str(tid) + + await fake_manager.write_rpc({ + 'jsonrpc': '2.0', + 'method': 'cancelAction', + 'params': {'traceId': trace_id}, + 'id': '7', + }) + + saw_cancel = False + saw_run_err = False + while not saw_cancel or not saw_run_err: + msg = await fake_manager.read_rpc() + mid = msg.get('id') + if mid == '7': + result = msg.get('result') + assert isinstance(result, dict) + assert result.get('message') == 'Action cancelled' + saw_cancel = True + elif mid == '6': + err = msg.get('error') + assert isinstance(err, dict) + assert 'cancel' in str(err.get('message', '')).lower() + saw_run_err = True + finally: + await _stop_client(client, task) + + +@pytest.mark.asyncio +@pytest.mark.parametrize('stream_method', ('sendInputStreamChunk', 'endInputStream')) +async def test_reflection_server_v2_input_stream_not_implemented_js_style( + fake_manager: FakeReflectionManager, + stream_method: str, +) -> None: + """Unimplemented input-stream methods return -32000 + data.stack when id is set (JS parity).""" + registry = Registry() + client, task = await _run_client_lifecycle(registry, fake_manager) + try: + await ack_register(fake_manager) + await fake_manager.write_rpc({ + 'jsonrpc': '2.0', + 'method': stream_method, + 'params': {}, + 'id': 'stream-1', + }) + resp = await fake_manager.read_rpc() + err = resp.get('error') + assert isinstance(err, dict) + assert err.get('code') == JSON_RPC_SERVER_ERROR + assert 'not implemented' in str(err.get('message', '')).lower() + data = err.get('data') + assert isinstance(data, dict) + assert 'stack' in data and str(data.get('stack', '')).strip() + finally: + await _stop_client(client, task) + + +@pytest.mark.asyncio +async def test_reflection_server_v2_method_not_found(fake_manager: FakeReflectionManager) -> None: + registry = Registry() + client, task = await _run_client_lifecycle(registry, fake_manager) + try: + await ack_register(fake_manager) + await fake_manager.write_rpc({ + 'jsonrpc': '2.0', + 'method': 'unknownMethod', + 'id': '8', + }) + resp = await fake_manager.read_rpc() + err = resp.get('error') + assert isinstance(err, dict) + assert err.get('code') == JSON_RPC_METHOD_NOT_FOUND + finally: + await _stop_client(client, task) + + +@pytest.mark.asyncio +async def test_reflection_server_v2_omits_data_for_simple_errors( + fake_manager: FakeReflectionManager, +) -> None: + """Plain validation errors omit ``error.data`` to match JS / Go reflection-v2. + + JS's ``JSON.stringify`` drops ``undefined`` props and Go's struct uses + ``json:",omitempty"`` on ``Data``, so ``sendError(id, code, message)`` with + no extra payload produces a frame without a ``data`` key at all. Only + handlers that assemble a Status-shaped payload (runAction errors) emit one. + """ + + registry = Registry() + client, task = await _run_client_lifecycle(registry, fake_manager) + try: + await ack_register(fake_manager) + + await fake_manager.write_rpc({ + 'jsonrpc': '2.0', + 'method': 'unknownMethod', + 'id': 'e1', + }) + resp = await fake_manager.read_rpc() + err = resp.get('error') or {} + assert err.get('code') == JSON_RPC_METHOD_NOT_FOUND + assert 'data' not in err, 'error.data must be omitted for plain JSON-RPC errors' + + await fake_manager.write_rpc({ + 'jsonrpc': '2.0', + 'method': 'runAction', + 'params': {'key': '/model/missing', 'input': {}}, + 'id': 'e2', + }) + resp = await fake_manager.read_rpc() + err = resp.get('error') or {} + assert err.get('code') == JSON_RPC_INVALID_PARAMS + assert 'not found' in str(err.get('message', '')).lower() + assert 'data' not in err, 'error.data must be omitted when no Status payload is built' + finally: + await _stop_client(client, task) + + +def test_reflection_run_action_params_accepts_dev_ui_telemetry_labels() -> None: + """Dev UI sends telemetryLabels as a string record (e.g. genkitx:ignore-trace).""" + + from genkit._core._typing import ReflectionRunActionParams + + p = ReflectionRunActionParams.model_validate({ + 'key': '/executable-prompt/story', + 'telemetryLabels': {'genkitx:ignore-trace': 'true'}, + }) + assert p.telemetry_labels == {'genkitx:ignore-trace': 'true'} diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/embedder.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/embedder.py index 498a22b63a..c7a6c34e7e 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/embedder.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/embedder.py @@ -27,7 +27,8 @@ from google import genai from google.genai import types as genai_types -from genkit import Embedding, EmbedRequest, EmbedResponse +from genkit import DocumentPart, Embedding, EmbedRequest, EmbedResponse +from genkit._core._typing import DocumentData from genkit.plugins.google_genai.models.utils import PartConverter @@ -101,6 +102,12 @@ async def generate(self, request: EmbedRequest) -> EmbedResponse: Returns: EmbedResponse """ + request = EmbedRequest.model_validate(request) + if not request.input: + raise ValueError( + 'Embed request input is empty: provide at least one document with content ' + '(for example input: [{"content": [{"text": "your text here"}]}]).' + ) contents = await self._build_contents(request) config = self._genkit_to_googleai_cfg(request) response = await self._client.aio.models.embed_content( @@ -123,9 +130,12 @@ async def _build_contents(self, request: EmbedRequest) -> list[genai.types.Conte """ request_contents: list[genai.types.Content] = [] for doc in request.input: + if not isinstance(doc, DocumentData): + doc = DocumentData.model_validate(doc) content_parts: list[genai.types.Part] = [] for p in doc.content: - converted = await PartConverter.to_gemini(p) + part = p if isinstance(p, DocumentPart) else DocumentPart.model_validate(p) + converted = await PartConverter.to_gemini(part) if isinstance(converted, list): content_parts.extend(converted) else: diff --git a/py/plugins/google-genai/test/models/googlegenai_embedder_test.py b/py/plugins/google-genai/test/models/googlegenai_embedder_test.py index 37f4d11f64..ed033f636a 100644 --- a/py/plugins/google-genai/test/models/googlegenai_embedder_test.py +++ b/py/plugins/google-genai/test/models/googlegenai_embedder_test.py @@ -20,11 +20,7 @@ from google import genai from pytest_mock import MockerFixture -from genkit import ( - Document, - EmbedRequest, - EmbedResponse, -) +from genkit import Document, EmbedRequest, EmbedResponse from genkit.plugins.google_genai.models.embedder import ( Embedder, GeminiEmbeddingModels, @@ -57,3 +53,13 @@ async def test_embedding(mocker: MockerFixture, version: GeminiEmbeddingModels) assert isinstance(response, EmbedResponse) assert len(response.embeddings) == 1 assert response.embeddings[0].embedding == embedding_values + + +@pytest.mark.asyncio +async def test_embedding_rejects_empty_input(mocker: MockerFixture) -> None: + """Empty input must not call the API (avoids opaque BatchEmbedContents errors).""" + googleai_client_mock = mocker.AsyncMock() + embedder = Embedder(GeminiEmbeddingModels.GEMINI_EMBEDDING_001, googleai_client_mock) + with pytest.raises(ValueError, match='Embed request input is empty'): + await embedder.generate(EmbedRequest(input=[])) + googleai_client_mock.aio.models.embed_content.assert_not_called() diff --git a/py/tools/schema_to_typing/schema_to_typing.py b/py/tools/schema_to_typing/schema_to_typing.py index 8a4e33eafd..554f6e6c30 100644 --- a/py/tools/schema_to_typing/schema_to_typing.py +++ b/py/tools/schema_to_typing/schema_to_typing.py @@ -130,6 +130,31 @@ def _models_allowing_extra(schema: dict) -> set[str]: return result +def _typed_map_aliases(defs: dict) -> dict[str, str]: + """Inline object schemas with typed scalar ``additionalProperties`` -> Python dict alias. + + e.g. ``ReflectionRunActionParams.telemetryLabels``: + ``{type: object, additionalProperties: {type: string}}`` -> ``dict[str, str]``. + + Emitting these as type aliases (mirroring ``Metadata`` / ``Custom``) keeps the + symbol exported and importable while letting callers pass plain Python dicts — + a class with no fields and ``extra='forbid'`` would reject every key on the + Dev UI's ``{'genkitx:ignore-trace': 'true'}`` payload. + """ + + result: dict[str, str] = {} + for name, defn in defs.items(): + if not isinstance(defn, dict) or defn.get('type') != 'object': + continue + ap = defn.get('additionalProperties') + if not isinstance(ap, dict): + continue + ap_type = ap.get('type') + if isinstance(ap_type, str) and ap_type in PRIM: + result[name] = f'dict[str, {PRIM[ap_type]}]' + return result + + def _extract_inline_classes(schema: dict) -> dict[str, dict]: """Extract inline object schemas to named classes (e.g. Score.details -> Details).""" result = {} @@ -288,6 +313,7 @@ def generate(schema_path: Path, _out: Path) -> str: defs = dict(schema.get('$defs', {})) defs.update({k: v for k, v in _extract_inline_classes(schema).items() if k not in defs}) allow_extra = _models_allowing_extra(schema) + typed_map_aliases = _typed_map_aliases(defs) out = [HEADER.format(year=datetime.now().year, schema_name=schema_path.name)] emitted = set() @@ -318,6 +344,13 @@ def generate(schema_path: Path, _out: Path) -> str: 'Custom = dict[str, Any] # type alias for flexible custom data', '', ]) + elif name in typed_map_aliases: + # Typed string-keyed maps (e.g. TelemetryLabels: dict[str, str]). Emitting as a + # type alias keeps the symbol exported and lets callers pass plain dicts. + out.extend([ + f'{class_name} = {typed_map_aliases[name]} # type alias for {name.lower()} (typed string map)', + '', + ]) elif name in TRANSFORMATIONS and (cfg := TRANSFORMATIONS[name]).get('omit'): omit_set = set(cfg.get('omit', [])) out.extend(_emit_model(class_name, defn, schema, defs, allow_extra, omit=omit_set))