Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions go/ai/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,27 @@ func (MiddlewareFunc) Name() string { return "inline" }

func (f MiddlewareFunc) New(ctx context.Context) (*Hooks, error) { return f(ctx) }

// middlewareRefArg is a lazy [Middleware] that carries only a registered
// name and an opaque config payload. It exists so data-driven sources (most
// notably dotprompt `use:` entries) can populate the same `[]Middleware`
// channel that [WithUse] uses, while [configsToRefs] strips the adapter out
// and routes resolution through the registry rather than the local fast
// path. The type is unexported on purpose: users who want this shape should
// build a [GenerateActionOptions] with [MiddlewareRef]s directly.
type middlewareRefArg struct {
name string
config any
}

func (r middlewareRefArg) Name() string { return r.name }

// New is unreachable in normal use because [configsToRefs] swaps the adapter
// for a name-only [MiddlewareRef] before [resolveRefs] sees it; the error
// here surfaces a routing bug instead of returning nil hooks.
func (middlewareRefArg) New(context.Context) (*Hooks, error) {
return nil, core.NewError(core.INTERNAL, "ai: middlewareRefArg must be resolved via the registry")
}

// LookupMiddleware returns the registered middleware descriptor with the
// given name, or nil if no such descriptor exists in the registry or any
// ancestor. Primarily useful for inspection and for the reflection API;
Expand All @@ -199,9 +220,12 @@ type MiddlewarePlugin interface {
}

// configsToRefs converts a user-supplied slice of [Middleware] values into
// the [MiddlewareRef] entries carried on [GenerateActionOptions.Use]. The Go
// value is stored on each ref so [resolveRefs] can build the hooks bundle
// directly without a registry round trip for local calls.
// the [MiddlewareRef] entries carried on [GenerateActionOptions.Use]. For
// regular middleware the Go value is stored on each ref so [resolveRefs] can
// build the hooks bundle directly without a registry round trip. For
// [middlewareRefArg] only the name and raw config travel; this routes
// [resolveRefs] through the registry-lookup path so dotprompt-supplied
// middleware resolves like any other JSON-dispatched call.
func configsToRefs(configs []Middleware) ([]*MiddlewareRef, error) {
if len(configs) == 0 {
return nil, nil
Expand All @@ -211,6 +235,10 @@ func configsToRefs(configs []Middleware) ([]*MiddlewareRef, error) {
if c == nil {
return nil, core.NewError(core.INVALID_ARGUMENT, "ai: nil middleware")
}
if lazy, ok := c.(middlewareRefArg); ok {
refs = append(refs, &MiddlewareRef{Name: lazy.name, Config: lazy.config})
continue
}
refs = append(refs, &MiddlewareRef{Name: c.Name(), Config: c})
}
return refs, nil
Expand Down
36 changes: 36 additions & 0 deletions go/ai/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -702,3 +702,39 @@ func TestMiddlewareHookOrderOnToolRestart(t *testing.T) {
}

var testCtx = context.Background()

// --- middlewareRefArg: lazy reference used by the dotprompt loader ---

func TestConfigsToRefs_StripsLazyAdapter(t *testing.T) {
refs, err := configsToRefs([]Middleware{
middlewareRefArg{name: "test/foo"},
middlewareRefArg{name: "test/bar", config: map[string]any{"k": "v"}},
counterConfig{},
})
assertNoError(t, err)
if len(refs) != 3 {
t.Fatalf("got %d refs, want 3", len(refs))
}

if refs[0].Name != "test/foo" || refs[0].Config != nil {
t.Errorf("name-only adapter: got {Name:%q Config:%v}, want {test/foo nil}", refs[0].Name, refs[0].Config)
}

cfg, ok := refs[1].Config.(map[string]any)
if refs[1].Name != "test/bar" || !ok || cfg["k"] != "v" {
t.Errorf("adapter with config: got {Name:%q Config:%v}, want {test/bar map[k:v]}", refs[1].Name, refs[1].Config)
}

if _, ok := refs[2].Config.(Middleware); !ok {
t.Errorf("regular middleware ref: Config should retain Middleware value for fast path, got %T", refs[2].Config)
}
}

func TestMiddlewareRefArg_NewErrors(t *testing.T) {
// Defensive: configsToRefs strips the adapter, so resolveRefs should
// never call New on it. If routing ever regresses, fail loudly instead
// of silently producing nil hooks.
if _, err := (middlewareRefArg{name: "x"}).New(testCtx); err == nil {
t.Fatal("expected middlewareRefArg.New to return an error")
}
}
46 changes: 46 additions & 0 deletions go/ai/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,11 @@ func (p *prompt) buildRequest(ctx context.Context, input any) (*GenerateActionOp
return nil, core.NewError(core.INVALID_ARGUMENT, "invalid output schema for prompt %q: %v", p.Name(), err)
}

useRefs, err := configsToRefs(p.Use)
if err != nil {
return nil, fmt.Errorf("prompt %q: %w", p.Name(), err)
}
Comment thread
apascal07 marked this conversation as resolved.

return &GenerateActionOptions{
Model: modelName,
Config: config,
Expand All @@ -455,6 +460,7 @@ func (p *prompt) buildRequest(ctx context.Context, input any) (*GenerateActionOp
ReturnToolRequests: p.ReturnToolRequests != nil && *p.ReturnToolRequests,
Messages: messages,
Tools: tools,
Use: useRefs,
Output: &GenerateActionOutputConfig{
Format: p.OutputFormat,
JsonSchema: outputSchema,
Expand Down Expand Up @@ -781,6 +787,12 @@ func LoadPromptFromSource(r api.Registry, source, name, namespace string) (Promp
opts.ReturnToolRequests = &returnToolRequests
}

if uses, err := parseDotpromptUse(metadata.Raw["use"]); err != nil {
return nil, fmt.Errorf("prompt %q: %w", name, err)
} else if len(uses) > 0 {
opts.Use = uses
}

if inputSchema, ok := metadata.Input.Schema.(*jsonschema.Schema); ok {
if inputSchema.Ref != "" {
opts.InputSchema = core.SchemaRef(inputSchema.Ref)
Expand Down Expand Up @@ -830,6 +842,40 @@ func LoadPromptFromSource(r api.Registry, source, name, namespace string) (Promp
return prompt, nil
}

// parseDotpromptUse converts the value of the dotprompt `use:` frontmatter
// field into a slice of lazy [Middleware] references. Each entry may be a
// bare string (interpreted as a registered middleware name) or a map with
// `name` and optional `config`, mirroring the TypeScript MiddlewareRef shape.
// Returns nil if the input is nil or an empty slice.
func parseDotpromptUse(raw any) ([]Middleware, error) {
if raw == nil {
return nil, nil
}
entries, ok := raw.([]any)
if !ok {
return nil, fmt.Errorf("`use` must be a list, got %T", raw)
}
uses := make([]Middleware, 0, len(entries))
for i, entry := range entries {
switch v := entry.(type) {
case string:
if v == "" {
return nil, fmt.Errorf("`use[%d]` is an empty string", i)
}
uses = append(uses, middlewareRefArg{name: v})
case map[string]any:
name, _ := v["name"].(string)
if name == "" {
return nil, fmt.Errorf("`use[%d]` is missing required `name` field", i)
}
uses = append(uses, middlewareRefArg{name: name, config: v["config"]})
default:
return nil, fmt.Errorf("`use[%d]` must be a string or map, got %T", i, entry)
}
}
return uses, nil
}

// LoadPromptDir loads prompts and partials from a directory on the local filesystem.
func LoadPromptDir(r api.Registry, dir string, namespace string) {
LoadPromptDirFromFS(r, os.DirFS(dir), ".", namespace)
Expand Down
120 changes: 120 additions & 0 deletions go/ai/prompt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3292,3 +3292,123 @@ Hello {{name}}, please help me with {{task}}.
}
})
}

func TestParseDotpromptUse(t *testing.T) {
tests := []struct {
name string
raw any
want []Middleware
wantErr bool
}{
{name: "nil", raw: nil, want: nil},
{name: "empty list", raw: []any{}, want: []Middleware{}},
{
name: "bare names",
raw: []any{"a", "b"},
want: []Middleware{middlewareRefArg{name: "a"}, middlewareRefArg{name: "b"}},
},
{
name: "mixed names and refs with config",
raw: []any{
"a",
map[string]any{"name": "b", "config": map[string]any{"k": 1}},
},
want: []Middleware{
middlewareRefArg{name: "a"},
middlewareRefArg{name: "b", config: map[string]any{"k": 1}},
},
},
{
name: "ref without config",
raw: []any{map[string]any{"name": "x"}},
want: []Middleware{middlewareRefArg{name: "x"}},
},
{name: "not a list", raw: "single", wantErr: true},
{name: "empty string entry", raw: []any{""}, wantErr: true},
{name: "map missing name", raw: []any{map[string]any{"config": "x"}}, wantErr: true},
{name: "unsupported entry type", raw: []any{42}, wantErr: true},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := parseDotpromptUse(tc.raw)
if tc.wantErr {
if err == nil {
t.Fatalf("expected error, got %v", got)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if diff := cmp.Diff(tc.want, got, cmp.AllowUnexported(middlewareRefArg{})); diff != "" {
t.Errorf("mismatch (-want +got):\n%s", diff)
}
})
}
}

func TestLoadPrompt_WithUseMiddleware(t *testing.T) {
const promptSource = `---
model: test/fakeModel
use:
- test/counter
- name: test/counter
config:
label: from-yaml
---
hello
`
reg := newTestRegistry(t)
defineFakeModel(t, reg, fakeModelConfig{})

var modelCalls int32
NewMiddleware("counter", counterConfig{sharedModelCalls: &modelCalls}).Register(reg)

p, err := LoadPromptFromSource(reg, promptSource, "withuse", "")
if err != nil {
t.Fatalf("LoadPromptFromSource: %v", err)
}

if _, err := p.Execute(context.Background()); err != nil {
t.Fatalf("Execute: %v", err)
}

// Both `use:` entries fire on the single model call: 2 hooks layered.
if got := modelCalls; got != 2 {
t.Errorf("model hook calls = %d, want 2 (one per use entry)", got)
}
}

func TestLoadPrompt_WithUseMiddleware_NotRegistered(t *testing.T) {
const promptSource = `---
model: test/fakeModel
use:
- test/missing
---
hi
`
reg := newTestRegistry(t)
defineFakeModel(t, reg, fakeModelConfig{})

p, err := LoadPromptFromSource(reg, promptSource, "missingmw", "")
if err != nil {
t.Fatalf("LoadPromptFromSource: %v", err)
}

if _, err := p.Execute(context.Background()); err == nil {
t.Fatal("expected NOT_FOUND error for unregistered middleware referenced from `use:`")
}
}

func TestLoadPrompt_WithUseMiddleware_InvalidShape(t *testing.T) {
const promptSource = `---
model: test/fakeModel
use: "not-a-list"
---
hi
`
reg := newTestRegistry(t)
if _, err := LoadPromptFromSource(reg, promptSource, "badshape", ""); err == nil {
t.Fatal("expected error for non-list `use:`")
}
}
Loading