Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion internal/llminternal/outputschema_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (

"google.golang.org/adk/agent"
"google.golang.org/adk/internal/llminternal/googlellm"
"google.golang.org/adk/internal/schemautil"
"google.golang.org/adk/internal/toolinternal/toolutils"
"google.golang.org/adk/internal/utils"
"google.golang.org/adk/model"
Expand Down Expand Up @@ -134,7 +135,11 @@ func (t *setModelResponseTool) Run(ctx tool.Context, args any) (map[string]any,
if !ok {
return nil, fmt.Errorf("unexpected args type for set_model_response: %T", args)
}
if err := utils.ValidateMapOnSchema(m, t.schema, false); err != nil {
resolved, err := schemautil.GenaiToResolvedJSONSchema(t.schema)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Resolving the output schema on every Run call is inefficient. Since setModelResponseTool is created per request in the processor, and the schema is derived from the agent's state, you should consider resolving the schema once and reusing it across invocations.

if err != nil {
return nil, fmt.Errorf("failed to resolve output schema: %w", err)
}
if err := resolved.Validate(m); err != nil {
return nil, fmt.Errorf("invalid output schema: %w", err)
}
return m, nil
Expand Down
100 changes: 100 additions & 0 deletions internal/schemautil/convert.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package schemautil provides utilities for schema conversion and validation.
package schemautil

import (
"encoding/json"
"strings"

"github.com/google/jsonschema-go/jsonschema"
"google.golang.org/genai"
)

// GenaiToJSONSchema converts a genai.Schema to a jsonschema.Schema.
func GenaiToJSONSchema(gs *genai.Schema) (*jsonschema.Schema, error) {
if gs == nil {
return nil, nil
}

// Marshal to intermediate map
data, err := json.Marshal(gs)
if err != nil {
return nil, err
}

var m map[string]any
if err := json.Unmarshal(data, &m); err != nil {
return nil, err
}

// Normalize type to lowercase (genai uses "STRING", jsonschema expects "string")
normalizeTypes(m)

// Marshal back and unmarshal to jsonschema.Schema
data, err = json.Marshal(m)
if err != nil {
return nil, err
}

var js jsonschema.Schema
if err := json.Unmarshal(data, &js); err != nil {
return nil, err
}

return &js, nil
}

// normalizeTypes recursively lowercases type fields in the schema map.
func normalizeTypes(m map[string]any) {
if t, ok := m["type"].(string); ok {
m["type"] = strings.ToLower(t)
}

// Recurse into properties
if props, ok := m["properties"].(map[string]any); ok {
for _, v := range props {
if prop, ok := v.(map[string]any); ok {
normalizeTypes(prop)
}
}
}

// Recurse into items
if items, ok := m["items"].(map[string]any); ok {
normalizeTypes(items)
}

// Recurse into anyOf
if anyOf, ok := m["anyOf"].([]any); ok {
for _, v := range anyOf {
if s, ok := v.(map[string]any); ok {
normalizeTypes(s)
}
}
}
}

// GenaiToResolvedJSONSchema converts a genai.Schema to a resolved jsonschema.
func GenaiToResolvedJSONSchema(gs *genai.Schema) (*jsonschema.Resolved, error) {
if gs == nil {
return nil, nil
}
js, err := GenaiToJSONSchema(gs)
if err != nil {
return nil, err
}
return js.Resolve(nil)
}
221 changes: 221 additions & 0 deletions internal/schemautil/convert_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package schemautil

import (
"testing"

"google.golang.org/genai"
)

func TestGenaiToJSONSchema_Nil(t *testing.T) {
js, err := GenaiToJSONSchema(nil)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if js != nil {
t.Errorf("expected nil, got %v", js)
}
}

func TestGenaiToJSONSchema_BasicTypes(t *testing.T) {
tests := []struct {
name string
genaiType genai.Type
wantType string
}{
{"string", genai.TypeString, "string"},
{"integer", genai.TypeInteger, "integer"},
{"number", genai.TypeNumber, "number"},
{"boolean", genai.TypeBoolean, "boolean"},
{"array", genai.TypeArray, "array"},
{"object", genai.TypeObject, "object"},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gs := &genai.Schema{Type: tt.genaiType}
js, err := GenaiToJSONSchema(gs)
if err != nil {
t.Fatalf("GenaiToJSONSchema error: %v", err)
}
if js.Type != tt.wantType {
t.Errorf("Type = %q, want %q", js.Type, tt.wantType)
}
})
}
}

func TestGenaiToJSONSchema_Enum(t *testing.T) {
gs := &genai.Schema{
Type: genai.TypeString,
Enum: []string{"red", "green", "blue"},
}
js, err := GenaiToJSONSchema(gs)
if err != nil {
t.Fatalf("GenaiToJSONSchema error: %v", err)
}

if len(js.Enum) != 3 {
t.Fatalf("Enum length = %d, want 3", len(js.Enum))
}
for i, want := range []string{"red", "green", "blue"} {
if js.Enum[i] != want {
t.Errorf("Enum[%d] = %v, want %q", i, js.Enum[i], want)
}
}
}

func TestGenaiToJSONSchema_EnumValidation(t *testing.T) {
gs := &genai.Schema{
Type: genai.TypeString,
Enum: []string{"red", "green", "blue"},
}
resolved, err := GenaiToResolvedJSONSchema(gs)
if err != nil {
t.Fatalf("GenaiToResolvedJSONSchema error: %v", err)
}

if err := resolved.Validate("red"); err != nil {
t.Errorf("Validate('red') error = %v, want nil", err)
}

if err := resolved.Validate("purple"); err == nil {
t.Error("Validate('purple') error = nil, want error")
}
}

func TestGenaiToJSONSchema_Properties(t *testing.T) {
gs := &genai.Schema{
Type: genai.TypeObject,
Properties: map[string]*genai.Schema{
"name": {Type: genai.TypeString, Description: "The name"},
"age": {Type: genai.TypeInteger},
},
Required: []string{"name"},
}
js, err := GenaiToJSONSchema(gs)
if err != nil {
t.Fatalf("GenaiToJSONSchema error: %v", err)
}

if len(js.Properties) != 2 {
t.Fatalf("Properties length = %d, want 2", len(js.Properties))
}
if js.Properties["name"].Type != "string" {
t.Errorf("Properties[name].Type = %q, want string", js.Properties["name"].Type)
}
if js.Properties["name"].Description != "The name" {
t.Errorf("Properties[name].Description = %q, want 'The name'", js.Properties["name"].Description)
}
if js.Properties["age"].Type != "integer" {
t.Errorf("Properties[age].Type = %q, want integer", js.Properties["age"].Type)
}
if len(js.Required) != 1 || js.Required[0] != "name" {
t.Errorf("Required = %v, want [name]", js.Required)
}
}

func TestGenaiToJSONSchema_ObjectValidation(t *testing.T) {
gs := &genai.Schema{
Type: genai.TypeObject,
Properties: map[string]*genai.Schema{
"color": {
Type: genai.TypeString,
Enum: []string{"red", "green", "blue"},
},
},
Required: []string{"color"},
}
resolved, err := GenaiToResolvedJSONSchema(gs)
if err != nil {
t.Fatalf("GenaiToResolvedJSONSchema error: %v", err)
}

valid := map[string]any{"color": "red"}
if err := resolved.Validate(valid); err != nil {
t.Errorf("Validate(valid) error = %v, want nil", err)
}

invalid := map[string]any{"color": "purple"}
if err := resolved.Validate(invalid); err == nil {
t.Error("Validate(invalid) error = nil, want error for invalid enum")
}

missing := map[string]any{}
if err := resolved.Validate(missing); err == nil {
t.Error("Validate(missing) error = nil, want error for missing required")
}
}

func TestGenaiToJSONSchema_Array(t *testing.T) {
gs := &genai.Schema{
Type: genai.TypeArray,
Items: &genai.Schema{
Type: genai.TypeString,
Enum: []string{"a", "b", "c"},
},
}
resolved, err := GenaiToResolvedJSONSchema(gs)
if err != nil {
t.Fatalf("GenaiToResolvedJSONSchema error: %v", err)
}

valid := []any{"a", "b"}
if err := resolved.Validate(valid); err != nil {
t.Errorf("Validate(valid) error = %v, want nil", err)
}

invalid := []any{"a", "d"}
if err := resolved.Validate(invalid); err == nil {
t.Error("Validate(invalid) error = nil, want error for invalid enum in array item")
}
}

func TestGenaiToJSONSchema_NumericConstraints(t *testing.T) {
min := 0.0
max := 100.0
gs := &genai.Schema{
Type: genai.TypeNumber,
Minimum: &min,
Maximum: &max,
}
resolved, err := GenaiToResolvedJSONSchema(gs)
if err != nil {
t.Fatalf("GenaiToResolvedJSONSchema error: %v", err)
}

if err := resolved.Validate(50.0); err != nil {
t.Errorf("Validate(50) error = %v, want nil", err)
}

if err := resolved.Validate(-1.0); err == nil {
t.Error("Validate(-1) error = nil, want error")
}

if err := resolved.Validate(101.0); err == nil {
t.Error("Validate(101) error = nil, want error")
}
}

func TestGenaiToResolvedJSONSchema_Nil(t *testing.T) {
resolved, err := GenaiToResolvedJSONSchema(nil)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if resolved != nil {
t.Error("expected nil resolved schema for nil input")
}
}
Loading