Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 77 additions & 2 deletions go/ai/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,18 @@ func DefinePrompt(r api.Registry, name string, opts ...PromptOption) Prompt {
"output": map[string]any{"schema": p.OutputSchema},
"defaultInput": p.DefaultInput,
"tools": tools,
"toolChoice": string(pOpts.ToolChoice),
"maxTurns": p.MaxTurns,
}
var use []any
for _, m := range pOpts.commonGenOptions.Use {
if md := toMiddlewareMetadata(m); md != nil {
use = append(use, md)
}
}
if len(use) > 0 {
promptMetadata["use"] = use
}
if variant != "" {
promptMetadata["variant"] = variant
}
Expand All @@ -129,6 +139,44 @@ func DefinePrompt(r api.Registry, name string, opts ...PromptOption) Prompt {
return p
}

func toMiddlewareMetadata(m any) map[string]any {
if m == nil {
return nil
}
if mw, ok := m.(Middleware); ok {
name := mw.Name()
if name == "inline" {
return nil
}
res := map[string]any{
"name": name,
}
data, _ := json.Marshal(mw)
var config map[string]any
if err := json.Unmarshal(data, &config); err == nil && len(config) > 0 {
res["config"] = config
}
return res
}
switch v := m.(type) {
case string:
return map[string]any{"name": v}
case map[string]any:
name, _ := v["name"].(string)
if name == "" {
return nil
}
res := map[string]any{"name": name}
if config, ok := v["config"]; ok && config != nil {
if m, ok := config.(map[string]any); !ok || len(m) > 0 {
res["config"] = config
}
}
return res
}
return nil
}

// LookupPrompt looks up a [Prompt] registered by [DefinePrompt].
// It returns nil if the prompt was not defined.
func LookupPrompt(r api.Registry, name string) Prompt {
Expand Down Expand Up @@ -751,6 +799,31 @@ func LoadPromptFromSource(r api.Registry, source, name, namespace string) (Promp
if variant != "" {
promptMetadata["variant"] = variant
}
promptMetadata["name"] = name
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These mostly get overwritten with maps.Copy() in DefinePrompt, no? Only some survive because they're not set there.

promptMetadata["model"] = metadata.Model
promptMetadata["config"] = metadata.Config
promptMetadata["tools"] = metadata.Tools
toolChoice := metadata.Raw["toolChoice"]
var tc string
if v, ok := toolChoice.(string); ok {
tc = v
} else if v, ok := toolChoice.(ToolChoice); ok {
tc = string(v)
}
promptMetadata["toolChoice"] = tc

if ur, ok := metadata.Raw["use"].([]any); ok {
var use []any
for _, m := range ur {
if md := toMiddlewareMetadata(m); md != nil {
use = append(use, md)
}
}
if len(use) > 0 {
promptMetadata["use"] = use
}
}

promptOptMetadata["prompt"] = promptMetadata
promptOptMetadata["type"] = api.ActionTypeExecutablePrompt

Expand All @@ -769,8 +842,10 @@ func LoadPromptFromSource(r api.Registry, source, name, namespace string) (Promp
Description: metadata.Description,
}

if toolChoice, ok := metadata.Raw["toolChoice"].(ToolChoice); ok {
opts.ToolChoice = toolChoice
if v, ok := toolChoice.(string); ok {
opts.ToolChoice = ToolChoice(v)
} else if v, ok := toolChoice.(ToolChoice); ok {
opts.ToolChoice = v
}

if maxTurns, ok := metadata.Raw["maxTurns"].(uint64); ok {
Expand Down
128 changes: 128 additions & 0 deletions go/ai/prompt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3292,3 +3292,131 @@ Hello {{name}}, please help me with {{task}}.
}
})
}

type testMetadataMiddleware struct {
Foo string `json:"foo"`
}

func (m *testMetadataMiddleware) Name() string { return "test/middleware" }
func (m *testMetadataMiddleware) New(ctx context.Context) (*Hooks, error) { return nil, nil }

func TestPromptMetadata(t *testing.T) {
reg := registry.New()
toolA := testTool(reg, "toolA")

DefinePrompt(reg, "testPrompt",
WithTools(toolA),
WithToolChoice(ToolChoiceAuto),
WithUse(&testMetadataMiddleware{Foo: "bar"}),
WithPrompt("hello"),
)

action := reg.LookupAction("/executable-prompt/testPrompt")
if action == nil {
t.Fatal("Action not found")
}

metadata, ok := action.Desc().Metadata["prompt"].(map[string]any)
if !ok {
t.Fatal("Metadata not found")
}

if diff := cmp.Diff([]string{"toolA"}, metadata["tools"]); diff != "" {
t.Errorf("tools mismatch (-want +got):\n%s", diff)
}

if diff := cmp.Diff("auto", metadata["toolChoice"]); diff != "" {
t.Errorf("toolChoice mismatch (-want +got):\n%s", diff)
}

use, ok := metadata["use"].([]any)
if !ok {
t.Fatal("use metadata not found")
}

if len(use) != 1 {
t.Fatalf("use length = %d, want 1", len(use))
}

m, ok := use[0].(map[string]any)
if !ok {
t.Fatal("use[0] is not a map")
}

if m["name"] != "test/middleware" {
t.Errorf("middleware name = %v, want test/middleware", m["name"])
}

config, ok := m["config"].(map[string]any)
if !ok {
t.Fatal("middleware config not found")
}

if config["foo"] != "bar" {
t.Errorf("foo = %v, want bar", config["foo"])
}
}

func TestLoadPromptMetadata(t *testing.T) {
reg := registry.New()
source := `---
model: echoModel
tools: [toolA]
toolChoice: auto
use:
- name: myMiddleware
config:
foo: bar
---
hello`

_, err := LoadPromptFromSource(reg, source, "metadataTest", "test-ns")
if err != nil {
t.Fatalf("LoadPromptFromSource failed: %v", err)
}

action := reg.LookupAction("/executable-prompt/test-ns/metadataTest")
if action == nil {
t.Fatal("Action not found")
}

metadata, ok := action.Desc().Metadata["prompt"].(map[string]any)
if !ok {
t.Fatal("Metadata not found")
}

if diff := cmp.Diff([]string{"toolA"}, metadata["tools"]); diff != "" {
t.Errorf("tools mismatch (-want +got):\n%s", diff)
}

if diff := cmp.Diff("auto", metadata["toolChoice"]); diff != "" {
t.Errorf("toolChoice mismatch (-want +got):\n%s", diff)
}

use, ok := metadata["use"].([]any)
if !ok {
t.Fatal("use metadata not found")
}

if len(use) != 1 {
t.Fatalf("use length = %d, want 1", len(use))
}

m, ok := use[0].(map[string]any)
if !ok {
t.Fatal("use[0] is not a map")
}

if m["name"] != "myMiddleware" {
t.Errorf("middleware name = %v, want myMiddleware", m["name"])
}

config, ok := m["config"].(map[string]any)
if !ok {
t.Fatal("middleware config not found")
}

if config["foo"] != "bar" {
t.Errorf("foo = %v, want bar", config["foo"])
}
}
89 changes: 88 additions & 1 deletion go/samples/basic-prompts/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@
// curl -N -X POST http://localhost:8080/recipePromptFlow \
// -H "Content-Type: application/json" \
// -d '{"data": {"dish": "tacos", "cuisine": "Mexican", "servingSize": 4}}'
//
// Test an agent flow (with all middleware: retry, fallback, filesystem, skills):
//
// curl -N -X POST http://localhost:8080/agentPromptFlow \
// -H "Content-Type: application/json" \
// -d '{"data": "what files are in my current directory?"}'
package main

import (
Expand All @@ -49,6 +55,7 @@ import (
"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/genkit"
"github.com/firebase/genkit/go/plugins/googlegenai"
"github.com/firebase/genkit/go/plugins/middleware"
"github.com/firebase/genkit/go/plugins/server"
"google.golang.org/genai"
)
Expand Down Expand Up @@ -87,14 +94,18 @@ type Recipe struct {
Difficulty string `json:"difficulty" jsonschema:"enum=easy,enum=medium,enum=hard"`
}

type AgentRequest struct {
Query string `json:"query" jsonschema:"description=The user's query or request"`
}

func main() {
ctx := context.Background()

// Initialize Genkit with the Google AI plugin. When you pass nil for the
// Config parameter, the Google AI plugin will get the API key from the
// GEMINI_API_KEY or GOOGLE_API_KEY environment variable, which is the recommended
// practice.
g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{}))
g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{}, &middleware.Middleware{}))

// Define schemas for the expected input and output types so that the Dotprompt files can reference them.
// Alternatively, you can specify the JSON schema by hand in the Dotprompt metadata.
Expand All @@ -103,6 +114,7 @@ func main() {
genkit.DefineSchemaFor[Joke](g)
genkit.DefineSchemaFor[RecipeRequest](g)
genkit.DefineSchemaFor[Recipe](g)
genkit.DefineSchemaFor[AgentRequest](g)

// TODO: Include partials and helpers.

Expand All @@ -113,6 +125,8 @@ func main() {
DefineStructuredJokeWithDotprompt(g)
DefineRecipeWithInlinePrompt(g)
DefineRecipeWithDotprompt(g)
DefineAgentWithInlinePrompt(g)
DefineAgentWithDotprompt(g)

// Optionally, start a web server to make the flows callable via HTTP.
mux := http.NewServeMux()
Expand Down Expand Up @@ -310,3 +324,76 @@ func newIngredientFilter() func([]*Ingredient) []*Ingredient {
return
}
}

// DefineAgentWithInlinePrompt demonstrates attaching multiple middlewares
// (Retry, Fallback, Filesystem, and Skills) directly to a prompt definition.
// This creates a highly capable and resilient prompt that is also fully
// transparent in the Dev UI metadata.
func DefineAgentWithInlinePrompt(g *genkit.Genkit) {
agentPrompt := genkit.DefinePrompt(
g, "agent.code",
ai.WithModelName("googleai/gemini-2.5-flash"),
ai.WithPrompt("{{query}}"),
ai.WithInputSchema(map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{
"type": "string",
"description": "The user's query or request",
},
},
"required": []string{"query"},
}),
ai.WithUse(
&middleware.Retry{MaxRetries: 2},
&middleware.Fallback{
Models: []ai.ModelRef{
ai.NewModelRef("googleai/gemini-2.5-flash-lite", nil),
ai.NewModelRef("googleai/gemini-2.5-pro", &genai.GenerateContentConfig{
Temperature: genai.Ptr[float32](2.0),
}),
},
},
&middleware.Filesystem{RootDir: "."},
&middleware.Skills{SkillPaths: []string{"./skills"}},
),
)

genkit.DefineStreamingFlow(g, "agentPromptFlow",
func(ctx context.Context, query string, sendChunk core.StreamCallback[string]) (string, error) {
stream := agentPrompt.ExecuteStream(ctx, ai.WithInput(&AgentRequest{Query: query}))
for result, err := range stream {
if err != nil {
return "", fmt.Errorf("agent error: %w", err)
}
if result.Done {
return result.Response.Text(), nil
}
sendChunk(ctx, result.Chunk.Text())
}
return "", nil
},
)
}

// DefineAgentWithDotprompt demonstrates loading a prompt from a .prompt file
// that includes a full suite of middleware configuration in its YAML frontmatter.
func DefineAgentWithDotprompt(g *genkit.Genkit) {
genkit.DefineStreamingFlow(g, "agentDotpromptFlow",
func(ctx context.Context, query string, sendChunk core.StreamCallback[string]) (string, error) {
// The "agent" prompt file includes all middleware in its frontmatter.
agentPrompt := genkit.LookupPrompt(g, "agent")
stream := agentPrompt.ExecuteStream(ctx, ai.WithInput(&AgentRequest{Query: query}))
for result, err := range stream {
if err != nil {
return "", fmt.Errorf("agent error: %w", err)
}
if result.Done {
return result.Response.Text(), nil
}
sendChunk(ctx, result.Chunk.Text())
}
return "", nil
},
)
}
Loading
Loading