diff --git a/agent/workflowagents/parallelagent/agent.go b/agent/workflowagents/parallelagent/agent.go index efd25309e..78156ee67 100644 --- a/agent/workflowagents/parallelagent/agent.go +++ b/agent/workflowagents/parallelagent/agent.go @@ -100,7 +100,12 @@ func run(ctx agent.InvocationContext) iter.Seq2[*session.Event, error] { } go func() { - _ = errGroup.Wait() // this error is already sent to the user via iterator + if err := errGroup.Wait(); err != nil { + select { + case resultsChan <- result{err: err}: + case <-doneChan: + } + } close(resultsChan) }() @@ -108,7 +113,14 @@ func run(ctx agent.InvocationContext) iter.Seq2[*session.Event, error] { defer close(doneChan) for res := range resultsChan { - if !yield(res.event, res.err) { + shouldContinue := yield(res.event, res.err) + + // Signal sub-agent that event processing (including session append) is complete + if res.ackChan != nil { + close(res.ackChan) + } + + if !shouldContinue { break } } @@ -117,23 +129,28 @@ func run(ctx agent.InvocationContext) iter.Seq2[*session.Event, error] { func runSubAgent(ctx agent.InvocationContext, agent agent.Agent, results chan<- result, done <-chan bool) error { for event, err := range agent.Run(ctx) { + if err != nil { + return err + } + + ackChan := make(chan struct{}) + select { case <-done: return nil case <-ctx.Done(): - select { - case <-done: - case results <- result{ - err: ctx.Err(), - }: - } return ctx.Err() case results <- result{ - event: event, - err: err, + event: event, + ackChan: ackChan, }: - if err != nil { - return err + // Wait for runner to finish processing before continuing to next iteration + select { + case <-ackChan: + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() } } } @@ -141,6 +158,7 @@ func runSubAgent(ctx agent.InvocationContext, agent agent.Agent, results chan<- } type result struct { - event *session.Event - err error + event *session.Event + err error + ackChan chan struct{} } diff --git a/agent/workflowagents/parallelagent/agent_test.go b/agent/workflowagents/parallelagent/agent_test.go index e41ead433..90466c385 100644 --- a/agent/workflowagents/parallelagent/agent_test.go +++ b/agent/workflowagents/parallelagent/agent_test.go @@ -19,7 +19,10 @@ import ( "fmt" "iter" rand "math/rand/v2" + "net/http" + "path/filepath" "slices" + "strings" "testing" "time" @@ -27,13 +30,21 @@ import ( "google.golang.org/genai" "google.golang.org/adk/agent" + "google.golang.org/adk/agent/llmagent" "google.golang.org/adk/agent/workflowagents/loopagent" "google.golang.org/adk/agent/workflowagents/parallelagent" + "google.golang.org/adk/internal/httprr" + "google.golang.org/adk/internal/testutil" "google.golang.org/adk/model" + "google.golang.org/adk/model/gemini" "google.golang.org/adk/runner" "google.golang.org/adk/session" + "google.golang.org/adk/tool" + "google.golang.org/adk/tool/functiontool" ) +const modelName = "gemini-2.0-flash-exp" + func TestNewParallelAgent(t *testing.T) { tests := []struct { name string @@ -224,3 +235,303 @@ func customRun(id int, agentErr error) func(agent.InvocationContext) iter.Seq2[* } } } + +func TestParallelAgentWithTools(t *testing.T) { + agent1 := createAgentWithGemini(t, "agent1") + agent2 := createAgentWithGemini(t, "agent2") + + parallelAgent, err := parallelagent.New(parallelagent.Config{ + AgentConfig: agent.Config{ + Name: "parallel_test", + SubAgents: []agent.Agent{agent1, agent2}, + }, + }) + if err != nil { + t.Fatalf("Failed to create parallel agent: %v", err) + } + + runner := testutil.NewTestAgentRunner(t, parallelAgent) + stream := runner.Run(t, "test_session", "Search for AI news") + + events, err := testutil.CollectEvents(stream) + if err != nil { + t.Fatalf("Agent run failed: %v", err) + } + + if len(events) < 2 { + t.Errorf("Expected at least 2 events from parallel agents, got %d", len(events)) + } + + // Count FunctionCall and FunctionResponse events per branch + branchCalls := make(map[string]int) + branchResponses := make(map[string]int) + + for _, ev := range events { + branch := ev.Branch + if ev.LLMResponse.Content != nil { + for _, part := range ev.LLMResponse.Content.Parts { + if part.FunctionCall != nil { + branchCalls[branch]++ + } + if part.FunctionResponse != nil { + branchResponses[branch]++ + } + } + } + } + + for branch, calls := range branchCalls { + responses := branchResponses[branch] + if calls > responses { + t.Errorf("Branch %s: session has %d FunctionCalls but only %d FunctionResponses. "+ + "This indicates race condition: agent read session before FunctionResponse was appended.", + branch, calls, responses) + } + } +} + +func createAgentWithGemini(t *testing.T, name string) agent.Agent { + t.Helper() + + searchTool, err := functiontool.New( + functiontool.Config{ + Name: fmt.Sprintf("search_tool_%s", name), + Description: "Search for information on the web", + }, + func(ctx tool.Context, args struct{ Query string }) (string, error) { + return fmt.Sprintf("search result for '%s' from %s", args.Query, name), nil + }, + ) + if err != nil { + t.Fatalf("Failed to create search tool: %v", err) + } + + analyzeTool, err := functiontool.New( + functiontool.Config{ + Name: fmt.Sprintf("analyze_tool_%s", name), + Description: "Analyze data and return insights", + }, + func(ctx tool.Context, args struct{ Data string }) (string, error) { + return fmt.Sprintf("analysis result for '%s' from %s", args.Data, name), nil + }, + ) + if err != nil { + t.Fatalf("Failed to create analyze tool: %v", err) + } + + model := newGeminiModelForTest(t, modelName, name) + + a, err := llmagent.New(llmagent.Config{ + Name: name, + Description: fmt.Sprintf("Test agent %s that searches for information", name), + Model: model, + Tools: []tool.Tool{searchTool, analyzeTool}, + Instruction: "Use the search tool to find information, then provide a brief response.", + }) + if err != nil { + t.Fatalf("Failed to create agent %s: %v", name, err) + } + + return a +} + +func newGeminiModelForTest(t *testing.T, modelName, agentName string) model.LLM { + t.Helper() + + trace := filepath.Join("testdata", fmt.Sprintf("%s_%s.httprr", + strings.ReplaceAll(t.Name(), "/", "_"), agentName)) + + apiKey := "fakeKey" + transport, recording := newGeminiTestTransport(t, trace) + if recording { + apiKey = "" + } + + model, err := gemini.NewModel(t.Context(), modelName, &genai.ClientConfig{ + HTTPClient: &http.Client{Transport: transport}, + APIKey: apiKey, + }) + if err != nil { + t.Fatalf("Failed to create Gemini model: %v", err) + } + return model +} + +func newGeminiTestTransport(t *testing.T, rrfile string) (http.RoundTripper, bool) { + t.Helper() + rr, err := testutil.NewGeminiTransport(rrfile) + if err != nil { + t.Fatal(err) + } + recording, _ := httprr.Recording(rrfile) + return rr, recording +} + +// TestParallelAgent_PropagatesContextError verifies that if the context is canceled, +// the iterator yields the error from errgroup.Wait(). +func TestParallelAgent_PropagatesContextError(t *testing.T) { + t.Parallel() + + // Create a sub-agent that yields an event and then waits. + // We want to trigger runSubAgent returning ctx.Err(). + subAgent := must(agent.New(agent.Config{ + Name: "yielder", + Run: func(ctx agent.InvocationContext) iter.Seq2[*session.Event, error] { + return func(yield func(*session.Event, error) bool) { + // Yield one event so we engage runSubAgent logic + if !yield(&session.Event{ + LLMResponse: model.LLMResponse{ + Content: genai.NewContentFromText("hello", genai.RoleModel), + }, + }, nil) { + return + } + + // Wait for context cancellation + <-ctx.Done() + } + }, + })) + + parallelAgent, err := parallelagent.New(parallelagent.Config{ + AgentConfig: agent.Config{ + Name: "parallel_agent", + SubAgents: []agent.Agent{subAgent}, + }, + }) + if err != nil { + t.Fatal(err) + } + + spy := &spyAgent{Agent: parallelAgent} + + ctx, cancel := context.WithCancel(t.Context()) + + sessionService := session.InMemoryService() + _, _ = sessionService.Create(ctx, &session.CreateRequest{ + AppName: "test_app", + UserID: "user_id", + SessionID: "session_id", + }) + + r, err := runner.New(runner.Config{ + AppName: "test_app", + Agent: spy, + SessionService: sessionService, + }) + if err != nil { + t.Fatal(err) + } + + go func() { + // Wait a tiny bit to ensure we started + time.Sleep(10 * time.Millisecond) + cancel() + }() + + for range r.Run(ctx, "user_id", "session_id", genai.NewContentFromText("hi", genai.RoleUser), agent.RunConfig{}) { + // Simulate processing delay so that ackChan takes time, + // increasing chance runSubAgent is blocked on ackChan when cancel happens? + time.Sleep(100 * time.Millisecond) + } + + if spy.yieldedError == nil { + t.Fatal("Expected parallelAgent to yield an error (e.g. context canceled), but it yielded nil") + } + + t.Logf("Yielded error: %v", spy.yieldedError) +} + +type spyAgent struct { + agent.Agent + yieldedError error +} + +func (s *spyAgent) Run(ctx agent.InvocationContext) iter.Seq2[*session.Event, error] { + next := s.Agent.Run(ctx) + return func(yield func(*session.Event, error) bool) { + for event, err := range next { + if err != nil { + s.yieldedError = err + } + if !yield(event, err) { + return + } + } + } +} + +func TestParallelAgent_StateSync(t *testing.T) { + ctx := t.Context() + + var gotValue any + var gotErr error + + subAgent, err := agent.New(agent.Config{ + Name: "test_subagent", + Run: func(agent.InvocationContext) iter.Seq2[*session.Event, error] { + return func(yield func(*session.Event, error) bool) { + event := &session.Event{ + LLMResponse: model.LLMResponse{ + Content: genai.NewContentFromText("hello", genai.RoleModel), + }, + Actions: session.EventActions{ + StateDelta: map[string]any{"test_key": "test_value"}, + }, + } + yield(event, nil) + } + }, + AfterAgentCallbacks: []agent.AfterAgentCallback{ + func(c agent.CallbackContext) (*genai.Content, error) { + gotValue, gotErr = c.State().Get("test_key") + return nil, nil + }, + }, + }) + if err != nil { + t.Fatal(err) + } + + parallelAgent, err := parallelagent.New(parallelagent.Config{ + AgentConfig: agent.Config{ + Name: "test_parallel_agent", + SubAgents: []agent.Agent{subAgent}, + }, + }) + if err != nil { + t.Fatal(err) + } + + sessionService := session.InMemoryService() + agentRunner, err := runner.New(runner.Config{ + AppName: "test_app", + Agent: parallelAgent, + SessionService: sessionService, + }) + if err != nil { + t.Fatal(err) + } + + _, err = sessionService.Create(ctx, &session.CreateRequest{ + AppName: "test_app", + UserID: "user_id", + SessionID: "session_id", + }) + if err != nil { + t.Fatal(err) + } + + for _, err := range agentRunner.Run(ctx, "user_id", "session_id", genai.NewContentFromText("user input", genai.RoleUser), agent.RunConfig{}) { + if err != nil { + t.Fatal(err) + } + } + + if gotErr != nil { + t.Fatalf("expected to get value from state, got error: %v", gotErr) + } + if gotValue != "test_value" { + t.Fatalf("expected state value 'test_value', got %v", gotValue) + } +} diff --git a/agent/workflowagents/parallelagent/testdata/TestParallelAgentWithTools_agent1.httprr b/agent/workflowagents/parallelagent/testdata/TestParallelAgentWithTools_agent1.httprr new file mode 100644 index 000000000..adb1d6a02 --- /dev/null +++ b/agent/workflowagents/parallelagent/testdata/TestParallelAgentWithTools_agent1.httprr @@ -0,0 +1,116 @@ +httprr trace v1 +994 1148 +POST https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent HTTP/1.1 +Host: generativelanguage.googleapis.com +User-Agent: Go-http-client/1.1 +Content-Length: 758 +Content-Type: application/json + +{"contents":[{"parts":[{"text":"Search for AI news"}],"role":"user"}],"generationConfig":{},"systemInstruction":{"parts":[{"text":"Use the search tool to find information, then provide a brief response."}],"role":"user"},"tools":[{"functionDeclarations":[{"description":"Search for information on the web","name":"search_tool_agent1","parametersJsonSchema":{"additionalProperties":false,"properties":{"Query":{"type":"string"}},"required":["Query"],"type":"object"},"responseJsonSchema":{"type":"string"}},{"description":"Analyze data and return insights","name":"analyze_tool_agent1","parametersJsonSchema":{"additionalProperties":false,"properties":{"Data":{"type":"string"}},"required":["Data"],"type":"object"},"responseJsonSchema":{"type":"string"}}]}]}HTTP/2.0 200 OK +Alt-Svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 +Content-Type: application/json; charset=UTF-8 +Date: Mon, 12 Jan 2026 03:32:51 GMT +Server: scaffolding on HTTPServer2 +Server-Timing: gfet4t7; dur=661 +Vary: Origin +Vary: X-Origin +Vary: Referer +X-Content-Type-Options: nosniff +X-Frame-Options: SAMEORIGIN +X-Xss-Protection: 0 + +{ + "candidates": [ + { + "content": { + "parts": [ + { + "functionCall": { + "name": "search_tool_agent1", + "args": { + "Query": "AI news" + } + } + } + ], + "role": "model" + }, + "finishReason": "STOP", + "avgLogprobs": -1.7506775394495991e-05 + } + ], + "usageMetadata": { + "promptTokenCount": 41, + "candidatesTokenCount": 9, + "totalTokenCount": 50, + "promptTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 41 + } + ], + "candidatesTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 9 + } + ] + }, + "modelVersion": "gemini-2.0-flash-exp", + "responseId": "Ymtkad2GMvz_2roPv6WLmQs" +} +1237 1033 +POST https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent HTTP/1.1 +Host: generativelanguage.googleapis.com +User-Agent: Go-http-client/1.1 +Content-Length: 1000 +Content-Type: application/json + +{"contents":[{"parts":[{"text":"Search for AI news"}],"role":"user"},{"parts":[{"functionCall":{"args":{"Query":"AI news"},"name":"search_tool_agent1"}}],"role":"model"},{"parts":[{"functionResponse":{"name":"search_tool_agent1","response":{"result":"search result for 'AI news' from agent1"}}}],"role":"user"}],"generationConfig":{},"systemInstruction":{"parts":[{"text":"Use the search tool to find information, then provide a brief response."}],"role":"user"},"tools":[{"functionDeclarations":[{"description":"Search for information on the web","name":"search_tool_agent1","parametersJsonSchema":{"additionalProperties":false,"properties":{"Query":{"type":"string"}},"required":["Query"],"type":"object"},"responseJsonSchema":{"type":"string"}},{"description":"Analyze data and return insights","name":"analyze_tool_agent1","parametersJsonSchema":{"additionalProperties":false,"properties":{"Data":{"type":"string"}},"required":["Data"],"type":"object"},"responseJsonSchema":{"type":"string"}}]}]}HTTP/2.0 200 OK +Alt-Svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 +Content-Type: application/json; charset=UTF-8 +Date: Mon, 12 Jan 2026 03:32:52 GMT +Server: scaffolding on HTTPServer2 +Server-Timing: gfet4t7; dur=656 +Vary: Origin +Vary: X-Origin +Vary: Referer +X-Content-Type-Options: nosniff +X-Frame-Options: SAMEORIGIN +X-Xss-Protection: 0 + +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "I have searched for AI news." + } + ], + "role": "model" + }, + "finishReason": "STOP", + "avgLogprobs": -0.03738286665507725 + } + ], + "usageMetadata": { + "promptTokenCount": 67, + "candidatesTokenCount": 7, + "totalTokenCount": 74, + "promptTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 67 + } + ], + "candidatesTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 7 + } + ] + }, + "modelVersion": "gemini-2.0-flash-exp", + "responseId": "Y2tkaaH-NJKd0-kP9puPwAc" +} diff --git a/agent/workflowagents/parallelagent/testdata/TestParallelAgentWithTools_agent2.httprr b/agent/workflowagents/parallelagent/testdata/TestParallelAgentWithTools_agent2.httprr new file mode 100644 index 000000000..3ee793c9b --- /dev/null +++ b/agent/workflowagents/parallelagent/testdata/TestParallelAgentWithTools_agent2.httprr @@ -0,0 +1,116 @@ +httprr trace v1 +994 1148 +POST https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent HTTP/1.1 +Host: generativelanguage.googleapis.com +User-Agent: Go-http-client/1.1 +Content-Length: 758 +Content-Type: application/json + +{"contents":[{"parts":[{"text":"Search for AI news"}],"role":"user"}],"generationConfig":{},"systemInstruction":{"parts":[{"text":"Use the search tool to find information, then provide a brief response."}],"role":"user"},"tools":[{"functionDeclarations":[{"description":"Search for information on the web","name":"search_tool_agent2","parametersJsonSchema":{"additionalProperties":false,"properties":{"Query":{"type":"string"}},"required":["Query"],"type":"object"},"responseJsonSchema":{"type":"string"}},{"description":"Analyze data and return insights","name":"analyze_tool_agent2","parametersJsonSchema":{"additionalProperties":false,"properties":{"Data":{"type":"string"}},"required":["Data"],"type":"object"},"responseJsonSchema":{"type":"string"}}]}]}HTTP/2.0 200 OK +Alt-Svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 +Content-Type: application/json; charset=UTF-8 +Date: Mon, 12 Jan 2026 03:32:51 GMT +Server: scaffolding on HTTPServer2 +Server-Timing: gfet4t7; dur=922 +Vary: Origin +Vary: X-Origin +Vary: Referer +X-Content-Type-Options: nosniff +X-Frame-Options: SAMEORIGIN +X-Xss-Protection: 0 + +{ + "candidates": [ + { + "content": { + "parts": [ + { + "functionCall": { + "name": "search_tool_agent2", + "args": { + "Query": "AI news" + } + } + } + ], + "role": "model" + }, + "finishReason": "STOP", + "avgLogprobs": -1.1056118334333101e-05 + } + ], + "usageMetadata": { + "promptTokenCount": 41, + "candidatesTokenCount": 9, + "totalTokenCount": 50, + "promptTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 41 + } + ], + "candidatesTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 9 + } + ] + }, + "modelVersion": "gemini-2.0-flash-exp", + "responseId": "YmtkafzJPLKx2roPsNnkiA4" +} +1237 1034 +POST https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent HTTP/1.1 +Host: generativelanguage.googleapis.com +User-Agent: Go-http-client/1.1 +Content-Length: 1000 +Content-Type: application/json + +{"contents":[{"parts":[{"text":"Search for AI news"}],"role":"user"},{"parts":[{"functionCall":{"args":{"Query":"AI news"},"name":"search_tool_agent2"}}],"role":"model"},{"parts":[{"functionResponse":{"name":"search_tool_agent2","response":{"result":"search result for 'AI news' from agent2"}}}],"role":"user"}],"generationConfig":{},"systemInstruction":{"parts":[{"text":"Use the search tool to find information, then provide a brief response."}],"role":"user"},"tools":[{"functionDeclarations":[{"description":"Search for information on the web","name":"search_tool_agent2","parametersJsonSchema":{"additionalProperties":false,"properties":{"Query":{"type":"string"}},"required":["Query"],"type":"object"},"responseJsonSchema":{"type":"string"}},{"description":"Analyze data and return insights","name":"analyze_tool_agent2","parametersJsonSchema":{"additionalProperties":false,"properties":{"Data":{"type":"string"}},"required":["Data"],"type":"object"},"responseJsonSchema":{"type":"string"}}]}]}HTTP/2.0 200 OK +Alt-Svc: h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 +Content-Type: application/json; charset=UTF-8 +Date: Mon, 12 Jan 2026 03:32:52 GMT +Server: scaffolding on HTTPServer2 +Server-Timing: gfet4t7; dur=740 +Vary: Origin +Vary: X-Origin +Vary: Referer +X-Content-Type-Options: nosniff +X-Frame-Options: SAMEORIGIN +X-Xss-Protection: 0 + +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "I have searched for AI news." + } + ], + "role": "model" + }, + "finishReason": "STOP", + "avgLogprobs": -0.010005326143332891 + } + ], + "usageMetadata": { + "promptTokenCount": 67, + "candidatesTokenCount": 7, + "totalTokenCount": 74, + "promptTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 67 + } + ], + "candidatesTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 7 + } + ] + }, + "modelVersion": "gemini-2.0-flash-exp", + "responseId": "ZGtkadaWAtnG2roP5uuq4Qc" +} diff --git a/session/database/session.go b/session/database/session.go index d444ec4af..e925844bd 100644 --- a/session/database/session.go +++ b/session/database/session.go @@ -126,18 +126,17 @@ func (s *state) Get(key string) (any, error) { } func (s *state) All() iter.Seq2[string, any] { - return func(yield func(key string, val any) bool) { - s.mu.RLock() + s.mu.RLock() + // Create a copy of the state to iterate over it without holding the lock. + stateCopy := maps.Clone(s.state) + s.mu.RUnlock() - for k, v := range s.state { - s.mu.RUnlock() + return func(yield func(key string, val any) bool) { + for k, v := range stateCopy { if !yield(k, v) { return } - s.mu.RLock() } - - s.mu.RUnlock() } } diff --git a/session/inmemory.go b/session/inmemory.go index 8f48896fc..de4082f9f 100644 --- a/session/inmemory.go +++ b/session/inmemory.go @@ -223,8 +223,26 @@ func (s *inMemoryService) AppendEvent(ctx context.Context, curSession Session, e return fmt.Errorf("fail to set state on appendEvent: %w", err) } + eventCopy := &Event{ + ID: event.ID, + InvocationID: event.InvocationID, + Timestamp: event.Timestamp, + Author: event.Author, + Branch: event.Branch, + Actions: EventActions{ + StateDelta: maps.Clone(event.Actions.StateDelta), + ArtifactDelta: maps.Clone(event.Actions.ArtifactDelta), + RequestedToolConfirmations: maps.Clone(event.Actions.RequestedToolConfirmations), + TransferToAgent: event.Actions.TransferToAgent, + Escalate: event.Actions.Escalate, + SkipSummarization: event.Actions.SkipSummarization, + }, + LongRunningToolIDs: slices.Clone(event.LongRunningToolIDs), + LLMResponse: event.LLMResponse, + } + // update the in-memory session service - stored_session.events = append(stored_session.events, event) + stored_session.events = append(stored_session.events, eventCopy) stored_session.updatedAt = event.Timestamp if len(event.Actions.StateDelta) > 0 { appDelta, userDelta, sessionDelta := sessionutils.ExtractStateDeltas(event.Actions.StateDelta) @@ -314,6 +332,8 @@ func (s *session) State() State { } func (s *session) Events() Events { + s.mu.RLock() + defer s.mu.RUnlock() return events(s.events) } @@ -329,6 +349,9 @@ func (s *session) appendEvent(event *Event) error { return nil } + s.mu.Lock() + defer s.mu.Unlock() + if err := updateSessionState(s, event); err != nil { return fmt.Errorf("error on appendEvent: %w", err) } @@ -380,18 +403,17 @@ func (s *state) Get(key string) (any, error) { } func (s *state) All() iter.Seq2[string, any] { - return func(yield func(key string, val any) bool) { - s.mu.RLock() + s.mu.RLock() + // Create a copy of the state to iterate over it without holding the lock. + stateCopy := maps.Clone(s.state) + s.mu.RUnlock() - for k, v := range s.state { - s.mu.RUnlock() + return func(yield func(key string, val any) bool) { + for k, v := range stateCopy { if !yield(k, v) { return } - s.mu.RLock() } - - s.mu.RUnlock() } } @@ -434,13 +456,7 @@ func updateSessionState(session *session, event *Event) error { session.state = make(map[string]any) } - state := session.State() - for key, value := range event.Actions.StateDelta { - err := state.Set(key, value) - if err != nil { - return fmt.Errorf("error on updateSessionState state: %w", err) - } - } + maps.Copy(session.state, event.Actions.StateDelta) return nil } diff --git a/session/inmemory_test.go b/session/inmemory_test.go index 2afae9ff6..29733e94c 100644 --- a/session/inmemory_test.go +++ b/session/inmemory_test.go @@ -1058,3 +1058,39 @@ func Test_inMemoryService_CreateConcurrentAccess(t *testing.T) { t.Errorf("expected %d 'already exists' errors, but got %d", expectedErrors, errorCount.Load()) } } + +func TestInMemorySession_AppendEvent_Deadlock(t *testing.T) { + ctx := t.Context() + service := InMemoryService() + + // Create a session + createReq := &CreateRequest{ + AppName: "testapp", + UserID: "testuser", + } + createResp, err := service.Create(ctx, createReq) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + sess := createResp.Session + + // Event with StateDelta to trigger updateSessionState + event := &Event{ + ID: "event1", + Timestamp: time.Now(), + Actions: EventActions{ + StateDelta: map[string]any{ + "test_key": "test_value", + }, + }, + } + + // This call should hang if the deadlock is present + err = service.AppendEvent(ctx, sess, event) + if err != nil { + t.Fatalf("AppendEvent failed: %v", err) + } + + // If it doesn't hang, the test passes (meaning no deadlock) + t.Log("AppendEvent did not deadlock") +}