Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions py/packages/genkit/src/genkit/_ai/_aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import asyncio
import inspect
import json
import os
import signal
import socket
import threading
Expand All @@ -31,8 +32,6 @@

import anyio
import uvicorn
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import TracerProvider
from pydantic import BaseModel

from genkit._ai._embedding import EmbedderFn, EmbedderOptions, EmbedderRef, define_embedder
Expand Down Expand Up @@ -94,6 +93,7 @@
from genkit._core._model import Document
from genkit._core._plugin import Plugin
from genkit._core._reflection import ReflectionServer, ServerSpec, create_reflection_asgi_app
from genkit._core._reflection_v2 import ReflectionServerV2
from genkit._core._registry import Registry
from genkit._core._tracing import run_in_new_span
from genkit._core._typing import (
Expand Down Expand Up @@ -628,9 +628,22 @@ def define_resource(
# -------------------------------------------------------------------------

def _start_reflection_background(self) -> None:
"""Start the Dev UI reflection server in a background daemon thread."""
"""Start the Dev UI reflection server in a background daemon thread.

If GENKIT_REFLECTION_V2_SERVER is set (the CLI launches the runtime in
v2 mode and provides a WebSocket URL), run the v2 JSON-RPC client.
Otherwise start the v1 HTTP server.
"""

async def _run_server() -> None:
v2_url = os.environ.get('GENKIT_REFLECTION_V2_SERVER')
if v2_url:
await logger.ainfo(f'Genkit Dev UI reflection v2 client connecting to {v2_url}')
server_v2 = ReflectionServerV2(self.registry, v2_url)
self._reflection_ready.set()
await server_v2.run_forever()
return

sockets: list[socket.socket] | None = None
spec = self._reflection_server_spec
if spec is None:
Expand Down Expand Up @@ -668,7 +681,7 @@ def _initialize_registry(self, model: str | None, plugins: list[Plugin] | None)
"""Initialize the registry with default model and plugins."""
self.registry.default_model = model
if model:
self.registry.register_value('defaultModel', model, model)
self.registry.register_value('defaultModel', 'defaultModel', model)
for fmt in built_in_formats:
self.define_format(fmt)

Expand Down Expand Up @@ -1058,12 +1071,6 @@ def current_context() -> dict[str, Any] | None:
"""Get the current execution context, or None if not in an action."""
return ActionRunContext._current_context() # pyright: ignore[reportPrivateUsage]

async def flush_tracing(self) -> None:
"""Flush all pending trace spans to exporters."""
provider = trace_api.get_tracer_provider()
if isinstance(provider, TracerProvider):
await asyncio.to_thread(provider.force_flush)

async def run(
self,
*,
Expand Down
8 changes: 4 additions & 4 deletions py/packages/genkit/src/genkit/_core/_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ async def run(
input: InputT | None = None,
on_chunk: Callable[[ChunkT], None] | None = None,
context: dict[str, object] | None = None,
on_trace_start: Callable[[str, str], None] | None = None,
on_trace_start: Callable[[str, str], Awaitable[None]] | None = None,
telemetry_labels: dict[str, object] | None = None,
) -> ActionResponse[OutputT]:
"""Execute the action with optional input validation.
Expand Down Expand Up @@ -543,7 +543,7 @@ def _make_tracing_wrapper(
object | None,
ActionRunContext,
StreamingCallback | None,
Callable[[str, str], None] | None,
Callable[[str, str], Any] | None,
dict[str, object] | None,
],
Awaitable[ActionResponse[Any]],
Expand All @@ -565,7 +565,7 @@ async def tracing_wrapper(
input: object | None,
ctx: ActionRunContext,
on_chunk: StreamingCallback | None,
on_trace_start: Callable[[str, str], None] | None,
on_trace_start: Callable[[str, str], Awaitable[None]] | None,
telemetry_labels: dict[str, object] | None,
) -> ActionResponse[Any]:
start_time = time.perf_counter()
Expand All @@ -579,7 +579,7 @@ async def tracing_wrapper(
trace_id = format(span.get_span_context().trace_id, '032x')
span_id = format(span.get_span_context().span_id, '016x')
if on_trace_start:
on_trace_start(trace_id, span_id)
await on_trace_start(trace_id, span_id)

# Set telemetry labels as direct span attributes (matches JS/Go behavior)
if telemetry_labels:
Expand Down
2 changes: 1 addition & 1 deletion py/packages/genkit/src/genkit/_core/_reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class ActionRunner:
trace_id: str | None = None
span_id: str | None = None

def on_trace_start(self, tid: str, sid: str) -> None:
async def on_trace_start(self, tid: str, sid: str) -> None:
self.trace_id, self.span_id = tid, sid
if task := asyncio.current_task():
self.active_actions[tid] = task
Expand Down
Loading
Loading