Skip to content
Merged
46 changes: 32 additions & 14 deletions agent/workflowagents/parallelagent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,27 @@ func run(ctx agent.InvocationContext) iter.Seq2[*session.Event, error] {
}

go func() {
_ = errGroup.Wait() // this error is already sent to the user via iterator
if err := errGroup.Wait(); err != nil {
select {
case resultsChan <- result{err: err}:
case <-doneChan:
}
}
close(resultsChan)
}()

return func(yield func(*session.Event, error) bool) {
defer close(doneChan)

for res := range resultsChan {
if !yield(res.event, res.err) {
shouldContinue := yield(res.event, res.err)

// Signal sub-agent that event processing (including session append) is complete
if res.ackChan != nil {
close(res.ackChan)
}

if !shouldContinue {
break
}
}
Expand All @@ -117,30 +129,36 @@ func run(ctx agent.InvocationContext) iter.Seq2[*session.Event, error] {

func runSubAgent(ctx agent.InvocationContext, agent agent.Agent, results chan<- result, done <-chan bool) error {
for event, err := range agent.Run(ctx) {
if err != nil {
return err
}

ackChan := make(chan struct{})

select {
case <-done:
return nil
case <-ctx.Done():
select {
case <-done:
case results <- result{
err: ctx.Err(),
}:
}
return ctx.Err()
case results <- result{
event: event,
err: err,
event: event,
ackChan: ackChan,
}:
if err != nil {
return err
// Wait for runner to finish processing before continuing to next iteration
select {
case <-ackChan:
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
}
return nil
}

type result struct {
event *session.Event
err error
event *session.Event
err error
ackChan chan struct{}
}
Loading
Loading