diff --git a/go/ai/middleware.go b/go/ai/middleware.go index 1b28012b65..96c0c8e38c 100644 --- a/go/ai/middleware.go +++ b/go/ai/middleware.go @@ -177,6 +177,27 @@ func (MiddlewareFunc) Name() string { return "inline" } func (f MiddlewareFunc) New(ctx context.Context) (*Hooks, error) { return f(ctx) } +// middlewareRefArg is a lazy [Middleware] that carries only a registered +// name and an opaque config payload. It exists so data-driven sources (most +// notably dotprompt `use:` entries) can populate the same `[]Middleware` +// channel that [WithUse] uses, while [configsToRefs] strips the adapter out +// and routes resolution through the registry rather than the local fast +// path. The type is unexported on purpose: users who want this shape should +// build a [GenerateActionOptions] with [MiddlewareRef]s directly. +type middlewareRefArg struct { + name string + config any +} + +func (r middlewareRefArg) Name() string { return r.name } + +// New is unreachable in normal use because [configsToRefs] swaps the adapter +// for a name-only [MiddlewareRef] before [resolveRefs] sees it; the error +// here surfaces a routing bug instead of returning nil hooks. +func (middlewareRefArg) New(context.Context) (*Hooks, error) { + return nil, core.NewError(core.INTERNAL, "ai: middlewareRefArg must be resolved via the registry") +} + // LookupMiddleware returns the registered middleware descriptor with the // given name, or nil if no such descriptor exists in the registry or any // ancestor. Primarily useful for inspection and for the reflection API; @@ -199,9 +220,12 @@ type MiddlewarePlugin interface { } // configsToRefs converts a user-supplied slice of [Middleware] values into -// the [MiddlewareRef] entries carried on [GenerateActionOptions.Use]. The Go -// value is stored on each ref so [resolveRefs] can build the hooks bundle -// directly without a registry round trip for local calls. +// the [MiddlewareRef] entries carried on [GenerateActionOptions.Use]. For +// regular middleware the Go value is stored on each ref so [resolveRefs] can +// build the hooks bundle directly without a registry round trip. For +// [middlewareRefArg] only the name and raw config travel; this routes +// [resolveRefs] through the registry-lookup path so dotprompt-supplied +// middleware resolves like any other JSON-dispatched call. func configsToRefs(configs []Middleware) ([]*MiddlewareRef, error) { if len(configs) == 0 { return nil, nil @@ -211,6 +235,10 @@ func configsToRefs(configs []Middleware) ([]*MiddlewareRef, error) { if c == nil { return nil, core.NewError(core.INVALID_ARGUMENT, "ai: nil middleware") } + if lazy, ok := c.(middlewareRefArg); ok { + refs = append(refs, &MiddlewareRef{Name: lazy.name, Config: lazy.config}) + continue + } refs = append(refs, &MiddlewareRef{Name: c.Name(), Config: c}) } return refs, nil diff --git a/go/ai/middleware_test.go b/go/ai/middleware_test.go index 498d718ab3..944e919ee2 100644 --- a/go/ai/middleware_test.go +++ b/go/ai/middleware_test.go @@ -702,3 +702,39 @@ func TestMiddlewareHookOrderOnToolRestart(t *testing.T) { } var testCtx = context.Background() + +// --- middlewareRefArg: lazy reference used by the dotprompt loader --- + +func TestConfigsToRefs_StripsLazyAdapter(t *testing.T) { + refs, err := configsToRefs([]Middleware{ + middlewareRefArg{name: "test/foo"}, + middlewareRefArg{name: "test/bar", config: map[string]any{"k": "v"}}, + counterConfig{}, + }) + assertNoError(t, err) + if len(refs) != 3 { + t.Fatalf("got %d refs, want 3", len(refs)) + } + + if refs[0].Name != "test/foo" || refs[0].Config != nil { + t.Errorf("name-only adapter: got {Name:%q Config:%v}, want {test/foo nil}", refs[0].Name, refs[0].Config) + } + + cfg, ok := refs[1].Config.(map[string]any) + if refs[1].Name != "test/bar" || !ok || cfg["k"] != "v" { + t.Errorf("adapter with config: got {Name:%q Config:%v}, want {test/bar map[k:v]}", refs[1].Name, refs[1].Config) + } + + if _, ok := refs[2].Config.(Middleware); !ok { + t.Errorf("regular middleware ref: Config should retain Middleware value for fast path, got %T", refs[2].Config) + } +} + +func TestMiddlewareRefArg_NewErrors(t *testing.T) { + // Defensive: configsToRefs strips the adapter, so resolveRefs should + // never call New on it. If routing ever regresses, fail loudly instead + // of silently producing nil hooks. + if _, err := (middlewareRefArg{name: "x"}).New(testCtx); err == nil { + t.Fatal("expected middlewareRefArg.New to return an error") + } +} diff --git a/go/ai/prompt.go b/go/ai/prompt.go index de62751ea9..27baebc3c7 100644 --- a/go/ai/prompt.go +++ b/go/ai/prompt.go @@ -447,6 +447,11 @@ func (p *prompt) buildRequest(ctx context.Context, input any) (*GenerateActionOp return nil, core.NewError(core.INVALID_ARGUMENT, "invalid output schema for prompt %q: %v", p.Name(), err) } + useRefs, err := configsToRefs(p.Use) + if err != nil { + return nil, fmt.Errorf("prompt %q: %w", p.Name(), err) + } + return &GenerateActionOptions{ Model: modelName, Config: config, @@ -455,6 +460,7 @@ func (p *prompt) buildRequest(ctx context.Context, input any) (*GenerateActionOp ReturnToolRequests: p.ReturnToolRequests != nil && *p.ReturnToolRequests, Messages: messages, Tools: tools, + Use: useRefs, Output: &GenerateActionOutputConfig{ Format: p.OutputFormat, JsonSchema: outputSchema, @@ -781,6 +787,12 @@ func LoadPromptFromSource(r api.Registry, source, name, namespace string) (Promp opts.ReturnToolRequests = &returnToolRequests } + if uses, err := parseDotpromptUse(metadata.Raw["use"]); err != nil { + return nil, fmt.Errorf("prompt %q: %w", name, err) + } else if len(uses) > 0 { + opts.Use = uses + } + if inputSchema, ok := metadata.Input.Schema.(*jsonschema.Schema); ok { if inputSchema.Ref != "" { opts.InputSchema = core.SchemaRef(inputSchema.Ref) @@ -830,6 +842,40 @@ func LoadPromptFromSource(r api.Registry, source, name, namespace string) (Promp return prompt, nil } +// parseDotpromptUse converts the value of the dotprompt `use:` frontmatter +// field into a slice of lazy [Middleware] references. Each entry may be a +// bare string (interpreted as a registered middleware name) or a map with +// `name` and optional `config`, mirroring the TypeScript MiddlewareRef shape. +// Returns nil if the input is nil or an empty slice. +func parseDotpromptUse(raw any) ([]Middleware, error) { + if raw == nil { + return nil, nil + } + entries, ok := raw.([]any) + if !ok { + return nil, fmt.Errorf("`use` must be a list, got %T", raw) + } + uses := make([]Middleware, 0, len(entries)) + for i, entry := range entries { + switch v := entry.(type) { + case string: + if v == "" { + return nil, fmt.Errorf("`use[%d]` is an empty string", i) + } + uses = append(uses, middlewareRefArg{name: v}) + case map[string]any: + name, _ := v["name"].(string) + if name == "" { + return nil, fmt.Errorf("`use[%d]` is missing required `name` field", i) + } + uses = append(uses, middlewareRefArg{name: name, config: v["config"]}) + default: + return nil, fmt.Errorf("`use[%d]` must be a string or map, got %T", i, entry) + } + } + return uses, nil +} + // LoadPromptDir loads prompts and partials from a directory on the local filesystem. func LoadPromptDir(r api.Registry, dir string, namespace string) { LoadPromptDirFromFS(r, os.DirFS(dir), ".", namespace) diff --git a/go/ai/prompt_test.go b/go/ai/prompt_test.go index 08aa75b44d..bb5eecc258 100644 --- a/go/ai/prompt_test.go +++ b/go/ai/prompt_test.go @@ -3292,3 +3292,123 @@ Hello {{name}}, please help me with {{task}}. } }) } + +func TestParseDotpromptUse(t *testing.T) { + tests := []struct { + name string + raw any + want []Middleware + wantErr bool + }{ + {name: "nil", raw: nil, want: nil}, + {name: "empty list", raw: []any{}, want: []Middleware{}}, + { + name: "bare names", + raw: []any{"a", "b"}, + want: []Middleware{middlewareRefArg{name: "a"}, middlewareRefArg{name: "b"}}, + }, + { + name: "mixed names and refs with config", + raw: []any{ + "a", + map[string]any{"name": "b", "config": map[string]any{"k": 1}}, + }, + want: []Middleware{ + middlewareRefArg{name: "a"}, + middlewareRefArg{name: "b", config: map[string]any{"k": 1}}, + }, + }, + { + name: "ref without config", + raw: []any{map[string]any{"name": "x"}}, + want: []Middleware{middlewareRefArg{name: "x"}}, + }, + {name: "not a list", raw: "single", wantErr: true}, + {name: "empty string entry", raw: []any{""}, wantErr: true}, + {name: "map missing name", raw: []any{map[string]any{"config": "x"}}, wantErr: true}, + {name: "unsupported entry type", raw: []any{42}, wantErr: true}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := parseDotpromptUse(tc.raw) + if tc.wantErr { + if err == nil { + t.Fatalf("expected error, got %v", got) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if diff := cmp.Diff(tc.want, got, cmp.AllowUnexported(middlewareRefArg{})); diff != "" { + t.Errorf("mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestLoadPrompt_WithUseMiddleware(t *testing.T) { + const promptSource = `--- +model: test/fakeModel +use: + - test/counter + - name: test/counter + config: + label: from-yaml +--- +hello +` + reg := newTestRegistry(t) + defineFakeModel(t, reg, fakeModelConfig{}) + + var modelCalls int32 + NewMiddleware("counter", counterConfig{sharedModelCalls: &modelCalls}).Register(reg) + + p, err := LoadPromptFromSource(reg, promptSource, "withuse", "") + if err != nil { + t.Fatalf("LoadPromptFromSource: %v", err) + } + + if _, err := p.Execute(context.Background()); err != nil { + t.Fatalf("Execute: %v", err) + } + + // Both `use:` entries fire on the single model call: 2 hooks layered. + if got := modelCalls; got != 2 { + t.Errorf("model hook calls = %d, want 2 (one per use entry)", got) + } +} + +func TestLoadPrompt_WithUseMiddleware_NotRegistered(t *testing.T) { + const promptSource = `--- +model: test/fakeModel +use: + - test/missing +--- +hi +` + reg := newTestRegistry(t) + defineFakeModel(t, reg, fakeModelConfig{}) + + p, err := LoadPromptFromSource(reg, promptSource, "missingmw", "") + if err != nil { + t.Fatalf("LoadPromptFromSource: %v", err) + } + + if _, err := p.Execute(context.Background()); err == nil { + t.Fatal("expected NOT_FOUND error for unregistered middleware referenced from `use:`") + } +} + +func TestLoadPrompt_WithUseMiddleware_InvalidShape(t *testing.T) { + const promptSource = `--- +model: test/fakeModel +use: "not-a-list" +--- +hi +` + reg := newTestRegistry(t) + if _, err := LoadPromptFromSource(reg, promptSource, "badshape", ""); err == nil { + t.Fatal("expected error for non-list `use:`") + } +}