Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 99 additions & 10 deletions py/packages/genkit/src/genkit/_ai/_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@
)
from genkit._ai._resource import ResourceArgument, ResourceInput, find_matching_resource, resolve_resources
from genkit._ai._tools import Tool, ToolInterruptError
from genkit._core._action import Action, ActionKind, ActionRunContext
from genkit._core._action import (
GENKIT_DYNAMIC_ACTION_PROVIDER_ATTR,
Action,
ActionKind,
ActionRunContext,
)
from genkit._core._error import GenkitError
from genkit._core._logger import get_logger
from genkit._core._model import GenerateActionOptions
Expand All @@ -61,6 +66,51 @@
logger = get_logger(__name__)


async def expand_wildcard_tools(registry: Registry, tool_names: list[str]) -> list[str]:
Comment thread
huangjeff5 marked this conversation as resolved.
"""Expand DAP wildcard tool names into individual registry keys.

A wildcard has the form ``<provider>:tool/*`` (or ``<provider>:tool/<prefix>*``).
Each match becomes a full DAP key
``/dynamic-action-provider/<provider>:<actionType>/<toolName>`` so later resolution
stays bound to that provider (no ambiguous bare-name lookup across DAPs).

Non-wildcard names are passed through unchanged.
"""
expanded: list[str] = []
for name in tool_names:
if not name.endswith('*') or ':' not in name:
expanded.append(name)
continue

colon = name.index(':')
provider_name = name[:colon]
rest = name[colon + 1 :] # e.g. "tool/*" or "tool/prefix*"

provider_action = await registry.resolve_action(ActionKind.DYNAMIC_ACTION_PROVIDER, provider_name)
if provider_action is None:
expanded.append(name)
continue

dap = getattr(provider_action, GENKIT_DYNAMIC_ACTION_PROVIDER_ATTR, None)
if dap is None:
expanded.append(name)
continue

if '/' not in rest:
expanded.append(name)
continue

action_type, action_pattern = rest.split('/', 1)
metas = await dap.list_action_metadata(action_type, action_pattern)
for meta in metas:
tool_name = meta.get('name')
if tool_name:
tn = str(tool_name)
expanded.append(f'/dynamic-action-provider/{provider_name}:{action_type}/{tn}')

return expanded


def tools_to_action_names(
tools: Sequence[str | Tool] | None,
) -> list[str] | None:
Expand Down Expand Up @@ -162,14 +212,19 @@ async def _generate_action(
context: dict[str, Any] | None = None,
) -> ModelResponse:
"""Execute a generation request with tool calling and middleware support."""
tools_in = raw_request.tools
if tools_in:
raw_request = raw_request.model_copy()
raw_request.tools = await expand_wildcard_tools(registry, tools_in)

model, tools, format_def = await resolve_parameters(registry, raw_request)

raw_request, formatter = apply_format(raw_request, format_def)

if raw_request.resources:
raw_request = await apply_resources(registry, raw_request)

assert_valid_tool_names(raw_request)
assert_valid_tool_names(tools)

(
revised_request,
Expand Down Expand Up @@ -586,10 +641,30 @@ async def apply_resources(registry: Registry, raw_request: GenerateActionOptions
return new_request


def assert_valid_tool_names(_raw_request: GenerateActionOptions) -> None:
"""Validate tool names in the request. (TODO: not yet implemented)."""
# TODO(#4338): implement me
pass
def _tool_short_name_for_model(name: str) -> str:
"""Return the last path segment of a tool name."""
if '/' not in name:
return name
return name[name.rfind('/') + 1 :]


def assert_valid_tool_names(tools: list[Action[Any, Any, Any]]) -> None:
"""Reject overlapping model-facing tool names before the model is called.

Two resolved tools that share the same short name (segment after the last ``/``)
cannot both appear in one generate request.
"""
if not tools:
return
seen: dict[str, str] = {}
for tool in tools:
short = _tool_short_name_for_model(tool.name)
if short in seen:
raise GenkitError(
status='INVALID_ARGUMENT',
message=(f"Cannot provide two tools with the same name: '{tool.name}' and '{seen[short]}'"),
)
seen[short] = tool.name


async def resolve_parameters(
Expand All @@ -611,9 +686,10 @@ async def resolve_parameters(
tools: list[Action[Any, Any, Any]] = []
if request.tools:
for tool_name in request.tools:
tool_action = await registry.resolve_action(ActionKind.TOOL, tool_name)
if tool_action is None:
raise Exception(f'Unable to resolve tool {tool_name}')
try:
tool_action = await resolve_tool(registry, tool_name)
except GenkitError as e:
raise Exception(f'Unable to resolve tool {tool_name}') from e
tools.append(tool_action)

format_def: FormatDef | None = None
Expand Down Expand Up @@ -671,7 +747,12 @@ async def resolve_tool_requests(
tool_dict: dict[str, Action] = {}
if request.tools:
for tool_name in request.tools:
tool_dict[tool_name] = await resolve_tool(registry, tool_name)
tool_action = await resolve_tool(registry, tool_name)
tool_dict[tool_name] = tool_action
# Model tool calls use ToolDefinition.name (short); wildcard expansion uses full DAP keys.
short = tool_action.name
if short not in tool_dict:
tool_dict[short] = tool_action

revised_model_message = message.model_copy(deep=True)

Expand Down Expand Up @@ -768,11 +849,19 @@ async def _resolve_tool_request(tool: Action, tool_request_part: ToolRequestPart
async def resolve_tool(registry: Registry, tool_ref: str | Tool) -> Action:
"""Resolve a tool from a registry name or a Tool instance.

Accepts full action keys (``/dynamic-action-provider/...``), DAP-qualified
names (``provider:tool/name``), or plain registered tool names.

Used when building ModelRequest (for example from to_generate_request).
"""
if isinstance(tool_ref, Tool):
return tool_ref.action

if tool_ref.startswith('/'):
tool = await registry.resolve_action_by_key(tool_ref)
if tool is not None:
return tool

tool = await registry.resolve_action(kind=ActionKind.TOOL, name=tool_ref)
if tool is None:
raise GenkitError(status='NOT_FOUND', message=f'Unable to resolve tool {tool_ref}')
Expand Down
19 changes: 13 additions & 6 deletions py/packages/genkit/src/genkit/_core/_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from collections.abc import AsyncIterator, Awaitable, Callable, Generator, Mapping
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, ClassVar, Generic, cast, get_type_hints
from typing import Any, ClassVar, Generic, NamedTuple, cast, get_type_hints

from opentelemetry import trace as trace_api
from opentelemetry.trace import Span
Expand Down Expand Up @@ -237,7 +237,15 @@ def extract_action_args_and_types(
GENKIT_DYNAMIC_ACTION_PROVIDER_ATTR = '_genkit_dynamic_action_provider'


def parse_dap_qualified_name(name: str) -> tuple[str, str, str] | None:
class DapQualifiedName(NamedTuple):
"""Segments of a DAP-qualified name ``provider:innerKind/innerName``."""

provider: str
inner_kind: str
inner_name: str


def parse_dap_qualified_name(name: str) -> DapQualifiedName | None:
"""Parse DAP-qualified segment ``provider:innerKind/innerName``.

Used when the action key kind is ``dynamic-action-provider`` and the name
Expand All @@ -247,9 +255,8 @@ def parse_dap_qualified_name(name: str) -> tuple[str, str, str] | None:
provider segment (``plugin/foo`` is not a valid provider host).

Returns:
``(provider_name, inner_kind, inner_name)`` if the string matches the
pattern; otherwise ``None`` so callers can treat the name as a plain
dynamic-action-provider id.
A :class:`DapQualifiedName` if the string matches; otherwise ``None`` so
callers can treat the name as a plain dynamic-action-provider id.
"""
# Pattern: [provider]:[inner_kind]/[inner_name]; no '/' or ':' in provider.
match = re.match(r'^([^/:]+):([^/:]+)/(.+)$', name)
Expand All @@ -258,7 +265,7 @@ def parse_dap_qualified_name(name: str) -> tuple[str, str, str] | None:
provider, inner_kind, inner_name = match.groups()
if not provider or not inner_kind or not inner_name:
return None
return (provider, inner_kind, inner_name)
return DapQualifiedName(provider, inner_kind, inner_name)


def parse_action_key(key: str) -> tuple[ActionKind, str]:
Expand Down
133 changes: 62 additions & 71 deletions py/packages/genkit/src/genkit/_core/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,30 +516,77 @@ async def _trigger_lazy_loading(self, action: Action | None) -> Action | None:
self._loading_actions.discard(action_id)
return action

async def _resolve_dap_qualified_action(self, kind: ActionKind, name: str) -> Action | None:
"""Resolve through the one registered DAP for ``provider:innerKind/innerName`` names.

Caller must ensure :func:`parse_dap_qualified_name` accepts ``name``. Does not consult
plugins. Returns ``None`` if the provider is not registered here (caller may delegate
to a parent registry).
"""
qualified = parse_dap_qualified_name(name)
if qualified is None:
return None
dap_host = qualified.provider
with self._lock:
provider = self._entries.get(ActionKind.DYNAMIC_ACTION_PROVIDER, {}).get(dap_host)
if provider is None:
return None
dap_action = await self._trigger_lazy_loading(provider)
if dap_action is None:
raise RuntimeError(
f'Dynamic action provider {dap_host!r} is not registered. '
'DAPs must be registered using define_dynamic_action_provider '
'before referencing qualified action names.'
)
dap = getattr(dap_action, GENKIT_DYNAMIC_ACTION_PROVIDER_ATTR, None)
if dap is not None:
try:
resolved = await dap.get_action(qualified.inner_kind, qualified.inner_name)
except Exception as e:
raise ValueError(f'Dynamic action provider {dap_host!r} get_action failed for {kind} {name!r}') from e
if resolved is not None and resolved.kind == kind:
return resolved
if resolved is None:
raise ValueError(
f'Dynamic action provider {dap_host!r} has no action '
f'{qualified.inner_kind!r}/{qualified.inner_name!r} for {name!r}'
)
raise ValueError(
f'Dynamic action provider {dap_host!r} returned {resolved.kind!r} for {name!r}, expected {kind!r}'
)
raise RuntimeError(
f'Dynamic action provider {dap_host!r} is missing the Genkit DAP helper. '
'Register it using define_dynamic_action_provider before referencing qualified action names.'
)

async def resolve_action(self, kind: ActionKind, name: str) -> Action | None:
"""Resolve an action by kind and name, supporting both prefixed and unprefixed names.
"""Resolve an action by kind and name.

This method supports:
1. Cache hit: Returns immediately if action is already registered
2. Namespaced request (e.g., "plugin/model"): Resolves via specific plugin
3. Unprefixed request (e.g., "model"): Tries all plugins, errors on ambiguity
4. Dynamic action providers: Last-resort fallback for dynamic action creation
Tries an exact (kind, name) cache hit first. DAP-qualified names
(provider:innerKind/innerName) go through that provider only. If the name contains a
slash, the first segment is treated as a plugin id: that plugin is initialized and
plugin.resolve is used. Falls back to parent registry if nothing found.

Args:
kind: The type of action to resolve.
name: The name of the action (may be prefixed with "plugin/" or unprefixed).
name: Action name, optionally plugin/... for a specific plugin.

Returns:
The Action instance if found, None otherwise.

Raises:
ValueError: If an unprefixed name matches multiple plugins (ambiguous).
"""
# Cache hit
with self._lock:
if kind in self._entries and name in self._entries[kind]:
return await self._trigger_lazy_loading(self._entries[kind][name])

# DAP-qualified names: resolve via that provider only (not plugin slash splitting).
if kind != ActionKind.DYNAMIC_ACTION_PROVIDER and parse_dap_qualified_name(name) is not None:
action = await self._resolve_dap_qualified_action(kind, name)
if action is not None:
return action
if self._parent is not None:
return await self._parent.resolve_action(kind, name)
return None

action: Action | None = None

# Namespaced request
Expand All @@ -563,61 +610,6 @@ async def resolve_action(self, kind: ActionKind, name: str) -> Action | None:
self.register_action_instance(action, namespace=plugin_name)
with self._lock:
return await self._trigger_lazy_loading(self._entries.get(kind, {}).get(target))
else:
# Unprefixed request: try all plugins
successes: list[tuple[str, Action]] = []
with self._lock:
plugins = list(self._plugins.items())
for plugin_name, plugin in plugins:
await self._ensure_plugin_initialized(plugin_name)
target = f'{plugin_name}/{name}'

# Check cache first - init() might have registered this action
with self._lock:
cached_action = self._entries.get(kind, {}).get(target)
if cached_action is not None:
successes.append((plugin_name, cached_action))
continue

action = await plugin.resolve(kind, target)
if action is not None:
successes.append((plugin_name, action))

if len(successes) > 1:
plugin_names = [p for p, _ in successes]
raise ValueError(
f"Ambiguous {kind} action name '{name}'. "
+ f"Matches plugins: {plugin_names}. Use 'plugin/{name}'."
)

if len(successes) == 1:
plugin_name, action = successes[0]
self.register_action_instance(action, namespace=plugin_name)
with self._lock:
return await self._trigger_lazy_loading(self._entries.get(kind, {}).get(f'{plugin_name}/{name}'))

# Fallback: try dynamic action providers (for MCP, dynamic resources, etc.)
# Skip if we're looking up a dynamic action provider itself to avoid recursion
if kind != ActionKind.DYNAMIC_ACTION_PROVIDER:
with self._lock:
if ActionKind.DYNAMIC_ACTION_PROVIDER in self._entries:
providers_dict = self._entries[ActionKind.DYNAMIC_ACTION_PROVIDER]
else:
providers_dict = {}
providers = list(providers_dict.values())
for provider_action in providers:
dap = getattr(provider_action, GENKIT_DYNAMIC_ACTION_PROVIDER_ATTR, None)
if dap is None:
continue
try:
resolved = await dap.get_action(str(kind), name)
if resolved is not None:
return resolved
except Exception as e:
logger.debug(
f'Dynamic action provider {provider_action.name} failed for {kind}/{name}',
exc_info=e,
)

# Final fallback: delegate to parent registry.
if self._parent is not None:
Expand Down Expand Up @@ -650,22 +642,21 @@ async def resolve_action_by_key(self, key: str) -> Action | None:
if kind == ActionKind.DYNAMIC_ACTION_PROVIDER:
dap_parts = parse_dap_qualified_name(name)
if dap_parts is not None:
provider_name, inner_kind_str, inner_name = dap_parts
provider_action = await self.resolve_action(
ActionKind.DYNAMIC_ACTION_PROVIDER,
provider_name,
dap_parts.provider,
)
if provider_action is None:
return None
dap = getattr(provider_action, GENKIT_DYNAMIC_ACTION_PROVIDER_ATTR, None)
if dap is None:
return None
try:
resolved = await dap.get_action(inner_kind_str, inner_name)
resolved = await dap.get_action(dap_parts.inner_kind, dap_parts.inner_name)
except Exception as e:
logger.debug(
f'Dynamic action provider {provider_name} failed for '
f'qualified key {inner_kind_str}/{inner_name}',
f'Dynamic action provider {dap_parts.provider} failed for '
f'qualified key {dap_parts.inner_kind}/{dap_parts.inner_name}',
exc_info=e,
)
return None
Expand Down
Loading
Loading