Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions internal/llminternal/googlellm/variant.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,19 @@ func IsGeminiAPIVariant(llm model.LLM) bool {
return GetGoogleLLMVariant(llm) == genai.BackendGeminiAPI
}

// NeedsOutputSchemaProcessor returns true if the Gemini model doesn't support output schema with tools natively and requires a processor to handle it.
// Only Gemini 2.5 models and lower and only in Gemini API don't support natively, so we enable the processor for them.
// NeedsOutputSchemaProcessor returns true if the model doesn't support output schema with tools natively and requires a processor to handle it.
// Non-Gemini models (e.g., Bedrock/Claude) need the processor since they don't support native output schema with tools.
// For Gemini models, only Gemini 2.5 and lower on Gemini API need the processor.
func NeedsOutputSchemaProcessor(llm model.LLM) bool {
if llm == nil {
return false
}
return IsGeminiModel(llm.Name()) && IsGeminiAPIVariant(llm) && IsGemini25OrLower(llm.Name())
// Non-Gemini models need the processor since we can't assume they support native output schema with tools
if !IsGeminiModel(llm.Name()) {
return true
}
// Gemini models: only Gemini API (not Vertex AI) with version <= 2.5 need the processor
return IsGeminiAPIVariant(llm) && IsGemini25OrLower(llm.Name())
}

func extractModelName(model string) string {
Expand Down
55 changes: 39 additions & 16 deletions internal/llminternal/googlellm/variant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
package googlellm

import (
"context"
"iter"
"testing"

"google.golang.org/genai"
Expand Down Expand Up @@ -68,28 +70,33 @@ func TestIsGeminiModel(t *testing.T) {

func TestNeedsOutputSchemaProcessor(t *testing.T) {
testCases := []struct {
name string
model string
variant genai.Backend
want bool
name string
llm model.LLM
want bool
}{
{"Gemini2.0_Vertex", "gemini-2.0-flash", genai.BackendVertexAI, false},
{"Gemini2.0_GeminiAPI", "gemini-2.0-flash", genai.BackendGeminiAPI, true},
{"NonGemini_Vertex", "not-a-gemini", genai.BackendVertexAI, false},
{"Gemini3.0_GeminiAPI", "gemini-3.0", genai.BackendGeminiAPI, false},
{"Gemini3.0_Vertex", "gemini-3.0", genai.BackendVertexAI, false},
{"CustomGemini2", "gemini-2.0-hack", genai.BackendUnspecified, false},
{"CustomGemini3", "gemini-3.0-hack", genai.BackendUnspecified, false},
// Gemini models on Vertex AI have native support
{"Gemini2.0_Vertex", &mockGoogleLLM{nameVal: "gemini-2.0-flash", variant: genai.BackendVertexAI}, false},
{"Gemini3.0_Vertex", &mockGoogleLLM{nameVal: "gemini-3.0", variant: genai.BackendVertexAI}, false},
// Gemini <= 2.5 on Gemini API need the processor
{"Gemini2.0_GeminiAPI", &mockGoogleLLM{nameVal: "gemini-2.0-flash", variant: genai.BackendGeminiAPI}, true},
// Gemini >= 3.0 on Gemini API have native support
{"Gemini3.0_GeminiAPI", &mockGoogleLLM{nameVal: "gemini-3.0", variant: genai.BackendGeminiAPI}, false},
// Unspecified backend defaults to no processor (conservative)
{"CustomGemini2", &mockGoogleLLM{nameVal: "gemini-2.0-hack", variant: genai.BackendUnspecified}, false},
{"CustomGemini3", &mockGoogleLLM{nameVal: "gemini-3.0-hack", variant: genai.BackendUnspecified}, false},
// Non-Gemini models (Bedrock/Claude/GPT) always need the processor
{"Claude_Sonnet", &simpleLLM{name: "claude-3-5-sonnet"}, true},
{"Claude_Opus", &simpleLLM{name: "anthropic.claude-v2"}, true},
{"Bedrock_Claude", &simpleLLM{name: "bedrock/anthropic.claude-3-sonnet"}, true},
{"GPT4", &simpleLLM{name: "gpt-4"}, true},
{"CustomModel", &simpleLLM{name: "my-custom-model"}, true},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got := NeedsOutputSchemaProcessor(&mockGoogleLLM{
variant: tc.variant,
nameVal: tc.model,
})
got := NeedsOutputSchemaProcessor(tc.llm)
if got != tc.want {
t.Errorf("NeedsOutputSchemaProcessor(%q) = %v, want %v", tc.model, got, tc.want)
t.Errorf("NeedsOutputSchemaProcessor(%q) = %v, want %v", tc.llm.Name(), got, tc.want)
}
})
}
Expand All @@ -110,3 +117,19 @@ func (m *mockGoogleLLM) Name() string {
}

var _ GoogleLLM = (*mockGoogleLLM)(nil)

// simpleLLM is a minimal LLM mock for non-Google LLMs (e.g., Bedrock, OpenAI)
// that don't implement the GoogleLLM interface.
type simpleLLM struct {
name string
}

func (m *simpleLLM) Name() string {
return m.name
}

func (m *simpleLLM) GenerateContent(ctx context.Context, req *model.LLMRequest, stream bool) iter.Seq2[*model.LLMResponse, error] {
return nil
}

var _ model.LLM = (*simpleLLM)(nil)
59 changes: 59 additions & 0 deletions internal/llminternal/outputschema_processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package llminternal
import (
"context"
"encoding/json"
"iter"
"strings"
"testing"

Expand Down Expand Up @@ -59,6 +60,20 @@ func (m *mockLLM) GetGoogleLLMVariant() genai.Backend {
return genai.BackendGeminiAPI
}

// mockNonGeminiLLM represents a non-Gemini model (e.g., Bedrock/Claude)
// that doesn't implement the GoogleLLM interface.
type mockNonGeminiLLM struct {
name string
}

func (m *mockNonGeminiLLM) Name() string { return m.name }

func (m *mockNonGeminiLLM) GenerateContent(ctx context.Context, req *model.LLMRequest, stream bool) iter.Seq2[*model.LLMResponse, error] {
return nil
}

var _ model.LLM = (*mockNonGeminiLLM)(nil)

// mockLLMAgent satisfies both agent.Agent (via embedding) and llminternal.Agent (via internal() implementation)
type mockLLMAgent struct {
agent.Agent
Expand Down Expand Up @@ -203,6 +218,50 @@ func TestOutputSchemaRequestProcessor(t *testing.T) {
t.Error("set_model_response tool should NOT be added when native support is available")
}
})

t.Run("InjectsToolForNonGeminiModels", func(t *testing.T) {
// Non-Gemini models (like Bedrock/Claude) should use the processor
// since they don't support native output schema with tools
llm := &mockNonGeminiLLM{name: "claude-3-5-sonnet"}

baseAgent := utils.Must(agent.New(agent.Config{Name: "ClaudeAgent"}))
mockAgent := &mockLLMAgent{
Agent: baseAgent,
s: &State{
Model: llm,
OutputSchema: schema,
Tools: []tool.Tool{&mockTool{name: "other_tool"}},
},
}

req := &model.LLMRequest{}
ctx := icontext.NewInvocationContext(context.Background(), icontext.InvocationContextParams{
Agent: mockAgent,
})

events := outputSchemaRequestProcessor(ctx, req, f)
for _, err := range events {
t.Fatalf("outputSchemaRequestProcessor() error = %v", err)
}

// Verify set_model_response tool is present for non-Gemini models
if _, ok := req.Tools["set_model_response"]; !ok {
t.Error("req.Tools['set_model_response'] should be added for non-Gemini models with tools")
}

// Verify instructions
instructions := utils.TextParts(req.Config.SystemInstruction)
found := false
for _, s := range instructions {
if strings.Contains(s, "set_model_response") && strings.Contains(s, "required structured format") {
found = true
break
}
}
if !found {
t.Errorf("Instruction about set_model_response not found for non-Gemini model. Instructions: %v", instructions)
}
})
}

func TestCreateFinalModelResponseEvent(t *testing.T) {
Expand Down