diff --git a/agent/workflowagents/parallelagent/agent.go b/agent/workflowagents/parallelagent/agent.go index d8afeffd5..fff0afdb5 100644 --- a/agent/workflowagents/parallelagent/agent.go +++ b/agent/workflowagents/parallelagent/agent.go @@ -107,7 +107,14 @@ func run(ctx agent.InvocationContext) iter.Seq2[*session.Event, error] { defer close(doneChan) for res := range resultsChan { - if !yield(res.event, res.err) { + shouldContinue := yield(res.event, res.err) + + // Signal sub-agent that event processing (including session append) is complete + if res.ackChan != nil { + close(res.ackChan) + } + + if !shouldContinue { break } } @@ -116,30 +123,37 @@ func run(ctx agent.InvocationContext) iter.Seq2[*session.Event, error] { func runSubAgent(ctx agent.InvocationContext, agent agent.Agent, results chan<- result, done <-chan bool) error { for event, err := range agent.Run(ctx) { + ackChan := make(chan struct{}) + select { case <-done: return nil case <-ctx.Done(): - select { - case <-done: - case results <- result{ - err: ctx.Err(), - }: - } return ctx.Err() case results <- result{ - event: event, - err: err, + event: event, + err: err, + ackChan: ackChan, }: if err != nil { return err } + + // Wait for runner to finish processing before continuing to next iteration + select { + case <-ackChan: + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() + } } } return nil } type result struct { - event *session.Event - err error + event *session.Event + err error + ackChan chan struct{} } diff --git a/agent/workflowagents/parallelagent/agent_test.go b/agent/workflowagents/parallelagent/agent_test.go index e41ead433..dc8ecd210 100644 --- a/agent/workflowagents/parallelagent/agent_test.go +++ b/agent/workflowagents/parallelagent/agent_test.go @@ -19,7 +19,10 @@ import ( "fmt" "iter" rand "math/rand/v2" + "net/http" + "path/filepath" "slices" + "strings" "testing" "time" @@ -27,13 +30,21 @@ import ( "google.golang.org/genai" "google.golang.org/adk/agent" + "google.golang.org/adk/agent/llmagent" "google.golang.org/adk/agent/workflowagents/loopagent" "google.golang.org/adk/agent/workflowagents/parallelagent" + "google.golang.org/adk/internal/httprr" + "google.golang.org/adk/internal/testutil" "google.golang.org/adk/model" + "google.golang.org/adk/model/gemini" "google.golang.org/adk/runner" "google.golang.org/adk/session" + "google.golang.org/adk/tool" + "google.golang.org/adk/tool/functiontool" ) +const modelName = "gemini-2.0-flash-exp" + func TestNewParallelAgent(t *testing.T) { tests := []struct { name string @@ -224,3 +235,134 @@ func customRun(id int, agentErr error) func(agent.InvocationContext) iter.Seq2[* } } } + +func TestParallelAgentWithTools(t *testing.T) { + agent1 := createAgentWithGemini(t, "agent1") + agent2 := createAgentWithGemini(t, "agent2") + + parallelAgent, err := parallelagent.New(parallelagent.Config{ + AgentConfig: agent.Config{ + Name: "parallel_test", + SubAgents: []agent.Agent{agent1, agent2}, + }, + }) + if err != nil { + t.Fatalf("Failed to create parallel agent: %v", err) + } + + runner := testutil.NewTestAgentRunner(t, parallelAgent) + stream := runner.Run(t, "test_session", "Search for AI news") + + events, err := testutil.CollectEvents(stream) + if err != nil { + t.Fatalf("Agent run failed: %v", err) + } + + if len(events) < 2 { + t.Errorf("Expected at least 2 events from parallel agents, got %d", len(events)) + } + + // Count FunctionCall and FunctionResponse events per branch + branchCalls := make(map[string]int) + branchResponses := make(map[string]int) + + for _, ev := range events { + branch := ev.Branch + if ev.LLMResponse.Content != nil { + for _, part := range ev.LLMResponse.Content.Parts { + if part.FunctionCall != nil { + branchCalls[branch]++ + } + if part.FunctionResponse != nil { + branchResponses[branch]++ + } + } + } + } + + for branch, calls := range branchCalls { + responses := branchResponses[branch] + if calls > responses { + t.Errorf("Branch %s: session has %d FunctionCalls but only %d FunctionResponses. "+ + "This indicates race condition: agent read session before FunctionResponse was appended.", + branch, calls, responses) + } + } +} + +func createAgentWithGemini(t *testing.T, name string) agent.Agent { + t.Helper() + + searchTool, err := functiontool.New( + functiontool.Config{ + Name: fmt.Sprintf("search_tool_%s", name), + Description: "Search for information on the web", + }, + func(ctx tool.Context, args struct{ Query string }) (string, error) { + return fmt.Sprintf("search result for '%s' from %s", args.Query, name), nil + }, + ) + if err != nil { + t.Fatalf("Failed to create search tool: %v", err) + } + + analyzeTool, err := functiontool.New( + functiontool.Config{ + Name: fmt.Sprintf("analyze_tool_%s", name), + Description: "Analyze data and return insights", + }, + func(ctx tool.Context, args struct{ Data string }) (string, error) { + return fmt.Sprintf("analysis result for '%s' from %s", args.Data, name), nil + }, + ) + if err != nil { + t.Fatalf("Failed to create analyze tool: %v", err) + } + + model := newGeminiModelForTest(t, modelName, name) + + a, err := llmagent.New(llmagent.Config{ + Name: name, + Description: fmt.Sprintf("Test agent %s that searches for information", name), + Model: model, + Tools: []tool.Tool{searchTool, analyzeTool}, + Instruction: "Use the search tool to find information, then provide a brief response.", + }) + if err != nil { + t.Fatalf("Failed to create agent %s: %v", name, err) + } + + return a +} + +func newGeminiModelForTest(t *testing.T, modelName string, agentName string) model.LLM { + t.Helper() + + trace := filepath.Join("testdata", fmt.Sprintf("%s_%s.httprr", + strings.ReplaceAll(t.Name(), "/", "_"), agentName)) + + apiKey := "fakeKey" + transport, recording := newGeminiTestTransport(t, trace) + if recording { + apiKey = "" + } + + model, err := gemini.NewModel(t.Context(), modelName, &genai.ClientConfig{ + HTTPClient: &http.Client{Transport: transport}, + APIKey: apiKey, + }) + if err != nil { + t.Fatalf("Failed to create Gemini model: %v", err) + } + return model +} + +func newGeminiTestTransport(t *testing.T, rrfile string) (http.RoundTripper, bool) { + t.Helper() + rr, err := testutil.NewGeminiTransport(rrfile) + if err != nil { + t.Fatal(err) + } + recording, _ := httprr.Recording(rrfile) + return rr, recording +} diff --git a/agent/workflowagents/parallelagent/testdata/TestParallelAgentWithTools_agent1.httprr b/agent/workflowagents/parallelagent/testdata/TestParallelAgentWithTools_agent1.httprr new file mode 100644 index 000000000..adb1d6a02 --- /dev/null +++ b/agent/workflowagents/parallelagent/testdata/TestParallelAgentWithTools_agent1.httprr @@ -0,0 +1,116 @@ +httprr trace v1 +994 1148 +POST https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent HTTP/1.1 +Host: generativelanguage.googleapis.com +User-Agent: Go-http-client/1.1 +Content-Length: 758 +Content-Type: application/json + +{"contents":[{"parts":[{"text":"Search for AI news"}],"role":"user"}],"generationConfig":{},"systemInstruction":{"parts":[{"text":"Use the search tool to find information, then provide a brief response."}],"role":"user"},"tools":[{"functionDeclarations":[{"description":"Search for information on the web","name":"search_tool_agent1","parametersJsonSchema":{"additionalProperties":false,"properties":{"Query":{"type":"string"}},"required":["Query"],"type":"object"},"responseJsonSchema":{"type":"string"}},{"description":"Analyze data and return insights","name":"analyze_tool_agent1","parametersJsonSchema":{"additionalProperties":false,"properties":{"Data":{"type":"string"}},"required":["Data"],"type":"object"},"responseJsonSchema":{"type":"string"}}]}]}HTTP/2.0 200 OK +Alt-Svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 +Content-Type: application/json; charset=UTF-8 +Date: Mon, 12 Jan 2026 03:32:51 GMT +Server: scaffolding on HTTPServer2 +Server-Timing: gfet4t7; dur=661 +Vary: Origin +Vary: X-Origin +Vary: Referer +X-Content-Type-Options: nosniff +X-Frame-Options: SAMEORIGIN +X-Xss-Protection: 0 + +{ + "candidates": [ + { + "content": { + "parts": [ + { + "functionCall": { + "name": "search_tool_agent1", + "args": { + "Query": "AI news" + } + } + } + ], + "role": "model" + }, + "finishReason": "STOP", + "avgLogprobs": -1.7506775394495991e-05 + } + ], + "usageMetadata": { + "promptTokenCount": 41, + "candidatesTokenCount": 9, + "totalTokenCount": 50, + "promptTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 41 + } + ], + "candidatesTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 9 + } + ] + }, + "modelVersion": "gemini-2.0-flash-exp", + "responseId": "Ymtkad2GMvz_2roPv6WLmQs" +} +1237 1033 +POST https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent HTTP/1.1 +Host: generativelanguage.googleapis.com +User-Agent: Go-http-client/1.1 +Content-Length: 1000 +Content-Type: application/json + +{"contents":[{"parts":[{"text":"Search for AI news"}],"role":"user"},{"parts":[{"functionCall":{"args":{"Query":"AI news"},"name":"search_tool_agent1"}}],"role":"model"},{"parts":[{"functionResponse":{"name":"search_tool_agent1","response":{"result":"search result for 'AI news' from agent1"}}}],"role":"user"}],"generationConfig":{},"systemInstruction":{"parts":[{"text":"Use the search tool to find information, then provide a brief response."}],"role":"user"},"tools":[{"functionDeclarations":[{"description":"Search for information on the web","name":"search_tool_agent1","parametersJsonSchema":{"additionalProperties":false,"properties":{"Query":{"type":"string"}},"required":["Query"],"type":"object"},"responseJsonSchema":{"type":"string"}},{"description":"Analyze data and return insights","name":"analyze_tool_agent1","parametersJsonSchema":{"additionalProperties":false,"properties":{"Data":{"type":"string"}},"required":["Data"],"type":"object"},"responseJsonSchema":{"type":"string"}}]}]}HTTP/2.0 200 OK +Alt-Svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 +Content-Type: application/json; charset=UTF-8 +Date: Mon, 12 Jan 2026 03:32:52 GMT +Server: scaffolding on HTTPServer2 +Server-Timing: gfet4t7; dur=656 +Vary: Origin +Vary: X-Origin +Vary: Referer +X-Content-Type-Options: nosniff +X-Frame-Options: SAMEORIGIN +X-Xss-Protection: 0 + +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "I have searched for AI news." + } + ], + "role": "model" + }, + "finishReason": "STOP", + "avgLogprobs": -0.03738286665507725 + } + ], + "usageMetadata": { + "promptTokenCount": 67, + "candidatesTokenCount": 7, + "totalTokenCount": 74, + "promptTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 67 + } + ], + "candidatesTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 7 + } + ] + }, + "modelVersion": "gemini-2.0-flash-exp", + "responseId": "Y2tkaaH-NJKd0-kP9puPwAc" +} diff --git a/agent/workflowagents/parallelagent/testdata/TestParallelAgentWithTools_agent2.httprr b/agent/workflowagents/parallelagent/testdata/TestParallelAgentWithTools_agent2.httprr new file mode 100644 index 000000000..3ee793c9b --- /dev/null +++ b/agent/workflowagents/parallelagent/testdata/TestParallelAgentWithTools_agent2.httprr @@ -0,0 +1,116 @@ +httprr trace v1 +994 1148 +POST https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent HTTP/1.1 +Host: generativelanguage.googleapis.com +User-Agent: Go-http-client/1.1 +Content-Length: 758 +Content-Type: application/json + +{"contents":[{"parts":[{"text":"Search for AI news"}],"role":"user"}],"generationConfig":{},"systemInstruction":{"parts":[{"text":"Use the search tool to find information, then provide a brief response."}],"role":"user"},"tools":[{"functionDeclarations":[{"description":"Search for information on the web","name":"search_tool_agent2","parametersJsonSchema":{"additionalProperties":false,"properties":{"Query":{"type":"string"}},"required":["Query"],"type":"object"},"responseJsonSchema":{"type":"string"}},{"description":"Analyze data and return insights","name":"analyze_tool_agent2","parametersJsonSchema":{"additionalProperties":false,"properties":{"Data":{"type":"string"}},"required":["Data"],"type":"object"},"responseJsonSchema":{"type":"string"}}]}]}HTTP/2.0 200 OK +Alt-Svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 +Content-Type: application/json; charset=UTF-8 +Date: Mon, 12 Jan 2026 03:32:51 GMT +Server: scaffolding on HTTPServer2 +Server-Timing: gfet4t7; dur=922 +Vary: Origin +Vary: X-Origin +Vary: Referer +X-Content-Type-Options: nosniff +X-Frame-Options: SAMEORIGIN +X-Xss-Protection: 0 + +{ + "candidates": [ + { + "content": { + "parts": [ + { + "functionCall": { + "name": "search_tool_agent2", + "args": { + "Query": "AI news" + } + } + } + ], + "role": "model" + }, + "finishReason": "STOP", + "avgLogprobs": -1.1056118334333101e-05 + } + ], + "usageMetadata": { + "promptTokenCount": 41, + "candidatesTokenCount": 9, + "totalTokenCount": 50, + "promptTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 41 + } + ], + "candidatesTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 9 + } + ] + }, + "modelVersion": "gemini-2.0-flash-exp", + "responseId": "YmtkafzJPLKx2roPsNnkiA4" +} +1237 1034 +POST https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent HTTP/1.1 +Host: generativelanguage.googleapis.com +User-Agent: Go-http-client/1.1 +Content-Length: 1000 +Content-Type: application/json + +{"contents":[{"parts":[{"text":"Search for AI news"}],"role":"user"},{"parts":[{"functionCall":{"args":{"Query":"AI news"},"name":"search_tool_agent2"}}],"role":"model"},{"parts":[{"functionResponse":{"name":"search_tool_agent2","response":{"result":"search result for 'AI news' from agent2"}}}],"role":"user"}],"generationConfig":{},"systemInstruction":{"parts":[{"text":"Use the search tool to find information, then provide a brief response."}],"role":"user"},"tools":[{"functionDeclarations":[{"description":"Search for information on the web","name":"search_tool_agent2","parametersJsonSchema":{"additionalProperties":false,"properties":{"Query":{"type":"string"}},"required":["Query"],"type":"object"},"responseJsonSchema":{"type":"string"}},{"description":"Analyze data and return insights","name":"analyze_tool_agent2","parametersJsonSchema":{"additionalProperties":false,"properties":{"Data":{"type":"string"}},"required":["Data"],"type":"object"},"responseJsonSchema":{"type":"string"}}]}]}HTTP/2.0 200 OK +Alt-Svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 +Content-Type: application/json; charset=UTF-8 +Date: Mon, 12 Jan 2026 03:32:52 GMT +Server: scaffolding on HTTPServer2 +Server-Timing: gfet4t7; dur=740 +Vary: Origin +Vary: X-Origin +Vary: Referer +X-Content-Type-Options: nosniff +X-Frame-Options: SAMEORIGIN +X-Xss-Protection: 0 + +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "I have searched for AI news." + } + ], + "role": "model" + }, + "finishReason": "STOP", + "avgLogprobs": -0.010005326143332891 + } + ], + "usageMetadata": { + "promptTokenCount": 67, + "candidatesTokenCount": 7, + "totalTokenCount": 74, + "promptTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 67 + } + ], + "candidatesTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 7 + } + ] + }, + "modelVersion": "gemini-2.0-flash-exp", + "responseId": "ZGtkadaWAtnG2roP5uuq4Qc" +} diff --git a/internal/sessioninternal/mutablesession.go b/internal/sessioninternal/mutablesession.go index 1c68f8ef7..c9d276771 100644 --- a/internal/sessioninternal/mutablesession.go +++ b/internal/sessioninternal/mutablesession.go @@ -15,8 +15,11 @@ package sessioninternal import ( + "context" "fmt" "iter" + "log" + "sync" "time" "google.golang.org/adk/session" @@ -26,6 +29,7 @@ import ( type MutableSession struct { service session.Service storedSession session.Session + mu sync.RWMutex } // NewMutableSession creates and returns session.Session implementation. @@ -41,26 +45,52 @@ func (s *MutableSession) State() session.State { } func (s *MutableSession) AppName() string { + s.mu.RLock() + defer s.mu.RUnlock() return s.storedSession.AppName() } func (s *MutableSession) UserID() string { + s.mu.RLock() + defer s.mu.RUnlock() return s.storedSession.UserID() } func (s *MutableSession) ID() string { + s.mu.RLock() + defer s.mu.RUnlock() return s.storedSession.ID() } func (s *MutableSession) Events() session.Events { + s.mu.Lock() + defer s.mu.Unlock() + + ctx := context.Background() + resp, err := s.service.Get(ctx, &session.GetRequest{ + AppName: s.storedSession.AppName(), + UserID: s.storedSession.UserID(), + SessionID: s.storedSession.ID(), + }) + if err != nil { + log.Printf("MutableSession: failed to fetch latest session (app=%s, user=%s, session=%s), using cached version: %v", + s.storedSession.AppName(), s.storedSession.UserID(), s.storedSession.ID(), err) + return s.storedSession.Events() + } + + s.storedSession = resp.Session return s.storedSession.Events() } func (s *MutableSession) LastUpdateTime() time.Time { + s.mu.RLock() + defer s.mu.RUnlock() return s.storedSession.LastUpdateTime() } func (s *MutableSession) Get(key string) (any, error) { + s.mu.RLock() + defer s.mu.RUnlock() value, err := s.storedSession.State().Get(key) if err != nil { return nil, fmt.Errorf("failed to get key %q from state: %w", key, err) @@ -69,10 +99,14 @@ func (s *MutableSession) Get(key string) (any, error) { } func (s *MutableSession) All() iter.Seq2[string, any] { + s.mu.RLock() + defer s.mu.RUnlock() return s.storedSession.State().All() } func (s *MutableSession) Set(key string, value any) error { + s.mu.Lock() + defer s.mu.Unlock() mutableState, ok := s.storedSession.State().(MutableState) if !ok { return fmt.Errorf("this session state is not mutable") diff --git a/internal/telemetry/telemetry.go b/internal/telemetry/telemetry.go index 640ab775c..ab75c04fc 100644 --- a/internal/telemetry/telemetry.go +++ b/internal/telemetry/telemetry.go @@ -47,6 +47,7 @@ type tracerProviderConfig struct { var ( once sync.Once localTracer tracerProviderHolder + localTracerMu sync.RWMutex localTracerConfig = tracerProviderConfig{ spanProcessors: []sdktrace.SpanProcessor{}, mu: &sync.RWMutex{}, @@ -101,7 +102,10 @@ func RegisterTelemetry() { for _, processor := range spanProcessors { traceProvider.RegisterSpanProcessor(processor) } + + localTracerMu.Lock() localTracer = tracerProviderHolder{tp: traceProvider} + localTracerMu.Unlock() }) } @@ -109,11 +113,14 @@ func RegisterTelemetry() { // That means that the spans are NOT recording/exporting // If the local tracer is not set, we'll set up tracer with all registered span processors. func getTracers() []trace.Tracer { - if localTracer.tp == nil { - RegisterTelemetry() - } + RegisterTelemetry() + + localTracerMu.RLock() + tp := localTracer.tp + localTracerMu.RUnlock() + return []trace.Tracer{ - localTracer.tp.Tracer(systemName), + tp.Tracer(systemName), otel.GetTracerProvider().Tracer(systemName), } }