Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion go/core/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,19 @@ func (a *ActionDef[In, Out, Stream]) RunJSONWithTelemetry(ctx context.Context, i
}

// Desc returns a descriptor of the action with resolved schema references.
// Schema references that cannot be resolved (e.g., the action is not yet registered,
// or the referenced schema has not been defined) are returned as-is.
func (a *ActionDef[In, Out, Stream]) Desc() api.ActionDesc {
return *a.desc
desc := *a.desc
if a.registry != nil {
if resolved, err := ResolveSchema(a.registry, desc.InputSchema); err == nil {
desc.InputSchema = resolved
}
if resolved, err := ResolveSchema(a.registry, desc.OutputSchema); err == nil {
desc.OutputSchema = resolved
}
}
return desc
}

// Register registers the action with the given registry.
Expand Down
58 changes: 58 additions & 0 deletions go/core/action_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,64 @@ func TestActionDesc(t *testing.T) {
t.Error("OutputSchema is nil")
}
})

t.Run("resolves genkit schema references", func(t *testing.T) {
r := registry.New()

inputSchema := map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{"type": "string"},
},
}
outputSchema := map[string]any{
"type": "object",
"properties": map[string]any{
"answer": map[string]any{"type": "string"},
},
}
DefineSchema(r, "AgentRequest", inputSchema)
DefineSchema(r, "AgentResponse", outputSchema)

fn := func(ctx context.Context, input any) (any, error) { return nil, nil }
a := NewAction("test/refs", api.ActionTypeCustom, nil, SchemaRef("AgentRequest"), fn)
// Override the inferred output schema with a $ref so we can test resolution
// of OutputSchema as well.
a.desc.OutputSchema = SchemaRef("AgentResponse")
a.Register(r)

desc := a.Desc()

if _, ok := desc.InputSchema["$ref"]; ok {
t.Errorf("InputSchema still contains $ref: %v", desc.InputSchema)
}
if got, want := desc.InputSchema["type"], "object"; got != want {
t.Errorf("InputSchema type = %v, want %v", got, want)
}
if _, ok := desc.OutputSchema["$ref"]; ok {
t.Errorf("OutputSchema still contains $ref: %v", desc.OutputSchema)
}
if got, want := desc.OutputSchema["type"], "object"; got != want {
t.Errorf("OutputSchema type = %v, want %v", got, want)
}
})

t.Run("preserves $ref when schema is unregistered", func(t *testing.T) {
r := registry.New()

fn := func(ctx context.Context, input any) (any, error) { return nil, nil }
a := DefineAction(r, "test/unresolved", api.ActionTypeCustom, nil, SchemaRef("Missing"), fn)

desc := a.Desc()

ref, ok := desc.InputSchema["$ref"].(string)
if !ok {
t.Fatalf("expected InputSchema to retain $ref, got: %v", desc.InputSchema)
}
if want := "genkit:Missing"; ref != want {
t.Errorf("InputSchema $ref = %q, want %q", ref, want)
}
})
}

func TestActionRegister(t *testing.T) {
Expand Down
20 changes: 19 additions & 1 deletion go/genkit/reflection.go
Original file line number Diff line number Diff line change
Expand Up @@ -642,9 +642,14 @@ func listActions(g *Genkit) []api.ActionDesc {
}

// listResolvableActions lists all the registered and resolvable actions.
// Schema references in the descriptors are resolved to their concrete schemas
// so that consumers (e.g., the Dev UI) don't have to perform secondary lookups.
func listResolvableActions(ctx context.Context, g *Genkit) []api.ActionDesc {
ads := listActions(g)
keys := make(map[string]struct{})
keys := make(map[string]struct{}, len(ads))
for _, d := range ads {
keys[d.Name] = struct{}{}
}

plugins := g.reg.ListPlugins()
for _, p := range plugins {
Expand All @@ -656,6 +661,7 @@ func listResolvableActions(ctx context.Context, g *Genkit) []api.ActionDesc {

for _, desc := range dp.ListActions(ctx) {
if _, exists := keys[desc.Name]; !exists {
resolveDescSchemas(g.reg, &desc)
Comment thread
apascal07 marked this conversation as resolved.
ads = append(ads, desc)
keys[desc.Name] = struct{}{}
}
Expand All @@ -669,6 +675,18 @@ func listResolvableActions(ctx context.Context, g *Genkit) []api.ActionDesc {
return ads
}

// resolveDescSchemas best-effort resolves any "genkit:" schema references in
// the descriptor's InputSchema and OutputSchema. Unresolvable references are
// left as-is.
func resolveDescSchemas(r api.Registry, desc *api.ActionDesc) {
if resolved, err := core.ResolveSchema(r, desc.InputSchema); err == nil {
desc.InputSchema = resolved
}
if resolved, err := core.ResolveSchema(r, desc.OutputSchema); err == nil {
desc.OutputSchema = resolved
}
}

// TODO: Pull these from common types in genkit-tools.

type runActionResponse struct {
Expand Down
34 changes: 34 additions & 0 deletions go/genkit/reflection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,40 @@ func TestServeMux(t *testing.T) {
}
})

t.Run("list actions resolves schema references", func(t *testing.T) {
core.DefineSchema(g.reg, "AgentRequest", map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{"type": "string"},
},
})
core.DefineAction(g.reg, "test/withRef", api.ActionTypeCustom, nil,
core.SchemaRef("AgentRequest"),
func(ctx context.Context, in any) (any, error) { return nil, nil })

res, err := http.Get(ts.URL + "/api/actions")
if err != nil {
t.Fatal(err)
}
defer res.Body.Close()

var actions map[string]api.ActionDesc
if err := json.NewDecoder(res.Body).Decode(&actions); err != nil {
t.Fatal(err)
}

desc, ok := actions["/custom/test/withRef"]
if !ok {
t.Fatal("action /custom/test/withRef not found in response")
}
if _, hasRef := desc.InputSchema["$ref"]; hasRef {
t.Errorf("InputSchema still contains $ref, expected resolved schema: %v", desc.InputSchema)
}
if got, want := desc.InputSchema["type"], "object"; got != want {
t.Errorf("InputSchema type = %v, want %v", got, want)
}
})

t.Run("run action", func(t *testing.T) {
tests := []struct {
name string
Expand Down
Loading