diff --git a/go/ai/prompt.go b/go/ai/prompt.go index de62751ea9..ed16fae7bd 100644 --- a/go/ai/prompt.go +++ b/go/ai/prompt.go @@ -113,8 +113,18 @@ func DefinePrompt(r api.Registry, name string, opts ...PromptOption) Prompt { "output": map[string]any{"schema": p.OutputSchema}, "defaultInput": p.DefaultInput, "tools": tools, + "toolChoice": string(pOpts.ToolChoice), "maxTurns": p.MaxTurns, } + var use []any + for _, m := range pOpts.commonGenOptions.Use { + if md := toMiddlewareMetadata(m); md != nil { + use = append(use, md) + } + } + if len(use) > 0 { + promptMetadata["use"] = use + } if variant != "" { promptMetadata["variant"] = variant } @@ -129,6 +139,44 @@ func DefinePrompt(r api.Registry, name string, opts ...PromptOption) Prompt { return p } +func toMiddlewareMetadata(m any) map[string]any { + if m == nil { + return nil + } + if mw, ok := m.(Middleware); ok { + name := mw.Name() + if name == "inline" { + return nil + } + res := map[string]any{ + "name": name, + } + data, _ := json.Marshal(mw) + var config map[string]any + if err := json.Unmarshal(data, &config); err == nil && len(config) > 0 { + res["config"] = config + } + return res + } + switch v := m.(type) { + case string: + return map[string]any{"name": v} + case map[string]any: + name, _ := v["name"].(string) + if name == "" { + return nil + } + res := map[string]any{"name": name} + if config, ok := v["config"]; ok && config != nil { + if m, ok := config.(map[string]any); !ok || len(m) > 0 { + res["config"] = config + } + } + return res + } + return nil +} + // LookupPrompt looks up a [Prompt] registered by [DefinePrompt]. // It returns nil if the prompt was not defined. func LookupPrompt(r api.Registry, name string) Prompt { @@ -751,6 +799,31 @@ func LoadPromptFromSource(r api.Registry, source, name, namespace string) (Promp if variant != "" { promptMetadata["variant"] = variant } + promptMetadata["name"] = name + promptMetadata["model"] = metadata.Model + promptMetadata["config"] = metadata.Config + promptMetadata["tools"] = metadata.Tools + toolChoice := metadata.Raw["toolChoice"] + var tc string + if v, ok := toolChoice.(string); ok { + tc = v + } else if v, ok := toolChoice.(ToolChoice); ok { + tc = string(v) + } + promptMetadata["toolChoice"] = tc + + if ur, ok := metadata.Raw["use"].([]any); ok { + var use []any + for _, m := range ur { + if md := toMiddlewareMetadata(m); md != nil { + use = append(use, md) + } + } + if len(use) > 0 { + promptMetadata["use"] = use + } + } + promptOptMetadata["prompt"] = promptMetadata promptOptMetadata["type"] = api.ActionTypeExecutablePrompt @@ -769,8 +842,10 @@ func LoadPromptFromSource(r api.Registry, source, name, namespace string) (Promp Description: metadata.Description, } - if toolChoice, ok := metadata.Raw["toolChoice"].(ToolChoice); ok { - opts.ToolChoice = toolChoice + if v, ok := toolChoice.(string); ok { + opts.ToolChoice = ToolChoice(v) + } else if v, ok := toolChoice.(ToolChoice); ok { + opts.ToolChoice = v } if maxTurns, ok := metadata.Raw["maxTurns"].(uint64); ok { diff --git a/go/ai/prompt_test.go b/go/ai/prompt_test.go index 08aa75b44d..b249c76cea 100644 --- a/go/ai/prompt_test.go +++ b/go/ai/prompt_test.go @@ -3292,3 +3292,131 @@ Hello {{name}}, please help me with {{task}}. } }) } + +type testMetadataMiddleware struct { + Foo string `json:"foo"` +} + +func (m *testMetadataMiddleware) Name() string { return "test/middleware" } +func (m *testMetadataMiddleware) New(ctx context.Context) (*Hooks, error) { return nil, nil } + +func TestPromptMetadata(t *testing.T) { + reg := registry.New() + toolA := testTool(reg, "toolA") + + DefinePrompt(reg, "testPrompt", + WithTools(toolA), + WithToolChoice(ToolChoiceAuto), + WithUse(&testMetadataMiddleware{Foo: "bar"}), + WithPrompt("hello"), + ) + + action := reg.LookupAction("/executable-prompt/testPrompt") + if action == nil { + t.Fatal("Action not found") + } + + metadata, ok := action.Desc().Metadata["prompt"].(map[string]any) + if !ok { + t.Fatal("Metadata not found") + } + + if diff := cmp.Diff([]string{"toolA"}, metadata["tools"]); diff != "" { + t.Errorf("tools mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff("auto", metadata["toolChoice"]); diff != "" { + t.Errorf("toolChoice mismatch (-want +got):\n%s", diff) + } + + use, ok := metadata["use"].([]any) + if !ok { + t.Fatal("use metadata not found") + } + + if len(use) != 1 { + t.Fatalf("use length = %d, want 1", len(use)) + } + + m, ok := use[0].(map[string]any) + if !ok { + t.Fatal("use[0] is not a map") + } + + if m["name"] != "test/middleware" { + t.Errorf("middleware name = %v, want test/middleware", m["name"]) + } + + config, ok := m["config"].(map[string]any) + if !ok { + t.Fatal("middleware config not found") + } + + if config["foo"] != "bar" { + t.Errorf("foo = %v, want bar", config["foo"]) + } +} + +func TestLoadPromptMetadata(t *testing.T) { + reg := registry.New() + source := `--- +model: echoModel +tools: [toolA] +toolChoice: auto +use: + - name: myMiddleware + config: + foo: bar +--- +hello` + + _, err := LoadPromptFromSource(reg, source, "metadataTest", "test-ns") + if err != nil { + t.Fatalf("LoadPromptFromSource failed: %v", err) + } + + action := reg.LookupAction("/executable-prompt/test-ns/metadataTest") + if action == nil { + t.Fatal("Action not found") + } + + metadata, ok := action.Desc().Metadata["prompt"].(map[string]any) + if !ok { + t.Fatal("Metadata not found") + } + + if diff := cmp.Diff([]string{"toolA"}, metadata["tools"]); diff != "" { + t.Errorf("tools mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff("auto", metadata["toolChoice"]); diff != "" { + t.Errorf("toolChoice mismatch (-want +got):\n%s", diff) + } + + use, ok := metadata["use"].([]any) + if !ok { + t.Fatal("use metadata not found") + } + + if len(use) != 1 { + t.Fatalf("use length = %d, want 1", len(use)) + } + + m, ok := use[0].(map[string]any) + if !ok { + t.Fatal("use[0] is not a map") + } + + if m["name"] != "myMiddleware" { + t.Errorf("middleware name = %v, want myMiddleware", m["name"]) + } + + config, ok := m["config"].(map[string]any) + if !ok { + t.Fatal("middleware config not found") + } + + if config["foo"] != "bar" { + t.Errorf("foo = %v, want bar", config["foo"]) + } +} diff --git a/go/samples/basic-prompts/main.go b/go/samples/basic-prompts/main.go index 2f36017932..c0ca2d0083 100644 --- a/go/samples/basic-prompts/main.go +++ b/go/samples/basic-prompts/main.go @@ -37,6 +37,12 @@ // curl -N -X POST http://localhost:8080/recipePromptFlow \ // -H "Content-Type: application/json" \ // -d '{"data": {"dish": "tacos", "cuisine": "Mexican", "servingSize": 4}}' +// +// Test an agent flow (with all middleware: retry, fallback, filesystem, skills): +// +// curl -N -X POST http://localhost:8080/agentPromptFlow \ +// -H "Content-Type: application/json" \ +// -d '{"data": "what files are in my current directory?"}' package main import ( @@ -49,6 +55,7 @@ import ( "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/genkit" "github.com/firebase/genkit/go/plugins/googlegenai" + "github.com/firebase/genkit/go/plugins/middleware" "github.com/firebase/genkit/go/plugins/server" "google.golang.org/genai" ) @@ -87,6 +94,10 @@ type Recipe struct { Difficulty string `json:"difficulty" jsonschema:"enum=easy,enum=medium,enum=hard"` } +type AgentRequest struct { + Query string `json:"query" jsonschema:"description=The user's query or request"` +} + func main() { ctx := context.Background() @@ -94,7 +105,7 @@ func main() { // Config parameter, the Google AI plugin will get the API key from the // GEMINI_API_KEY or GOOGLE_API_KEY environment variable, which is the recommended // practice. - g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) + g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{}, &middleware.Middleware{})) // Define schemas for the expected input and output types so that the Dotprompt files can reference them. // Alternatively, you can specify the JSON schema by hand in the Dotprompt metadata. @@ -103,6 +114,7 @@ func main() { genkit.DefineSchemaFor[Joke](g) genkit.DefineSchemaFor[RecipeRequest](g) genkit.DefineSchemaFor[Recipe](g) + genkit.DefineSchemaFor[AgentRequest](g) // TODO: Include partials and helpers. @@ -113,6 +125,8 @@ func main() { DefineStructuredJokeWithDotprompt(g) DefineRecipeWithInlinePrompt(g) DefineRecipeWithDotprompt(g) + DefineAgentWithInlinePrompt(g) + DefineAgentWithDotprompt(g) // Optionally, start a web server to make the flows callable via HTTP. mux := http.NewServeMux() @@ -310,3 +324,76 @@ func newIngredientFilter() func([]*Ingredient) []*Ingredient { return } } + +// DefineAgentWithInlinePrompt demonstrates attaching multiple middlewares +// (Retry, Fallback, Filesystem, and Skills) directly to a prompt definition. +// This creates a highly capable and resilient prompt that is also fully +// transparent in the Dev UI metadata. +func DefineAgentWithInlinePrompt(g *genkit.Genkit) { + agentPrompt := genkit.DefinePrompt( + g, "agent.code", + ai.WithModelName("googleai/gemini-2.5-flash"), + ai.WithPrompt("{{query}}"), + ai.WithInputSchema(map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{ + "type": "string", + "description": "The user's query or request", + }, + }, + "required": []string{"query"}, + }), + ai.WithUse( + &middleware.Retry{MaxRetries: 2}, + &middleware.Fallback{ + Models: []ai.ModelRef{ + ai.NewModelRef("googleai/gemini-2.5-flash-lite", nil), + ai.NewModelRef("googleai/gemini-2.5-pro", &genai.GenerateContentConfig{ + Temperature: genai.Ptr[float32](2.0), + }), + }, + }, + &middleware.Filesystem{RootDir: "."}, + &middleware.Skills{SkillPaths: []string{"./skills"}}, + ), + ) + + genkit.DefineStreamingFlow(g, "agentPromptFlow", + func(ctx context.Context, query string, sendChunk core.StreamCallback[string]) (string, error) { + stream := agentPrompt.ExecuteStream(ctx, ai.WithInput(&AgentRequest{Query: query})) + for result, err := range stream { + if err != nil { + return "", fmt.Errorf("agent error: %w", err) + } + if result.Done { + return result.Response.Text(), nil + } + sendChunk(ctx, result.Chunk.Text()) + } + return "", nil + }, + ) +} + +// DefineAgentWithDotprompt demonstrates loading a prompt from a .prompt file +// that includes a full suite of middleware configuration in its YAML frontmatter. +func DefineAgentWithDotprompt(g *genkit.Genkit) { + genkit.DefineStreamingFlow(g, "agentDotpromptFlow", + func(ctx context.Context, query string, sendChunk core.StreamCallback[string]) (string, error) { + // The "agent" prompt file includes all middleware in its frontmatter. + agentPrompt := genkit.LookupPrompt(g, "agent") + stream := agentPrompt.ExecuteStream(ctx, ai.WithInput(&AgentRequest{Query: query})) + for result, err := range stream { + if err != nil { + return "", fmt.Errorf("agent error: %w", err) + } + if result.Done { + return result.Response.Text(), nil + } + sendChunk(ctx, result.Chunk.Text()) + } + return "", nil + }, + ) +} diff --git a/go/samples/basic-prompts/prompts/agent.prompt b/go/samples/basic-prompts/prompts/agent.prompt new file mode 100644 index 0000000000..4f50597d46 --- /dev/null +++ b/go/samples/basic-prompts/prompts/agent.prompt @@ -0,0 +1,33 @@ +--- +model: googleai/gemini-2.5-flash +config: + temperature: 1.0 +input: + schema: + properties: + query: + type: string + description: The user's query or request + required: + - query +use: + - name: genkit-middleware/retry + config: + maxRetries: 2 + - name: genkit-middleware/fallback + config: + models: + - googleai/gemini-2.5-flash-lite + - name: googleai/gemini-2.5-pro + config: + temperature: 2.0 + - name: genkit-middleware/filesystem + config: + rootDirectory: . + - name: genkit-middleware/skills + config: + skillPaths: + - ./skills +--- + +{{query}}