diff --git a/core/src/main/java/com/google/adk/agents/ParallelAgent.java b/core/src/main/java/com/google/adk/agents/ParallelAgent.java index 4bfd0b25..52f6f384 100644 --- a/core/src/main/java/com/google/adk/agents/ParallelAgent.java +++ b/core/src/main/java/com/google/adk/agents/ParallelAgent.java @@ -20,6 +20,8 @@ import com.google.adk.agents.ConfigAgentUtils.ConfigurationException; import com.google.adk.events.Event; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Scheduler; +import io.reactivex.rxjava3.schedulers.Schedulers; import java.util.ArrayList; import java.util.List; import org.slf4j.Logger; @@ -35,6 +37,7 @@ public class ParallelAgent extends BaseAgent { private static final Logger logger = LoggerFactory.getLogger(ParallelAgent.class); + private final Scheduler scheduler; /** * Constructor for ParallelAgent. @@ -44,24 +47,34 @@ public class ParallelAgent extends BaseAgent { * @param subAgents The list of sub-agents to run in parallel. * @param beforeAgentCallback Optional callback before the agent runs. * @param afterAgentCallback Optional callback after the agent runs. + * @param scheduler The scheduler to use for parallel execution. */ private ParallelAgent( String name, String description, List subAgents, List beforeAgentCallback, - List afterAgentCallback) { + List afterAgentCallback, + Scheduler scheduler) { super(name, description, subAgents, beforeAgentCallback, afterAgentCallback); + this.scheduler = scheduler; } /** Builder for {@link ParallelAgent}. */ public static class Builder extends BaseAgent.Builder { + private Scheduler scheduler = Schedulers.io(); + + public Builder scheduler(Scheduler scheduler) { + this.scheduler = scheduler; + return this; + } + @Override public ParallelAgent build() { return new ParallelAgent( - name, description, subAgents, beforeAgentCallback, afterAgentCallback); + name, description, subAgents, beforeAgentCallback, afterAgentCallback, scheduler); } } @@ -131,7 +144,7 @@ protected Flowable runAsyncImpl(InvocationContext invocationContext) { List> agentFlowables = new ArrayList<>(); for (BaseAgent subAgent : currentSubAgents) { - agentFlowables.add(subAgent.runAsync(invocationContext)); + agentFlowables.add(subAgent.runAsync(invocationContext).subscribeOn(scheduler)); } return Flowable.merge(agentFlowables); } diff --git a/core/src/test/java/com/google/adk/agents/ParallelAgentTest.java b/core/src/test/java/com/google/adk/agents/ParallelAgentTest.java index a6afb579..e51240c4 100644 --- a/core/src/test/java/com/google/adk/agents/ParallelAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/ParallelAgentTest.java @@ -25,7 +25,10 @@ import com.google.genai.types.Content; import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Scheduler; import io.reactivex.rxjava3.schedulers.Schedulers; +import io.reactivex.rxjava3.schedulers.TestScheduler; +import io.reactivex.rxjava3.subscribers.TestSubscriber; import java.util.List; import org.junit.Test; import org.junit.runner.RunWith; @@ -36,10 +39,16 @@ public final class ParallelAgentTest { static class TestingAgent extends BaseAgent { private final long delayMillis; + private final Scheduler scheduler; private TestingAgent(String name, String description, long delayMillis) { + this(name, description, delayMillis, Schedulers.computation()); + } + + private TestingAgent(String name, String description, long delayMillis, Scheduler scheduler) { super(name, description, ImmutableList.of(), null, null); this.delayMillis = delayMillis; + this.scheduler = scheduler; } @Override @@ -55,7 +64,7 @@ protected Flowable runAsyncImpl(InvocationContext invocationContext) { .build()); if (delayMillis > 0) { - return event.delay(delayMillis, MILLISECONDS, Schedulers.computation()); + return event.delay(delayMillis, MILLISECONDS, scheduler); } return event; } @@ -110,4 +119,79 @@ public void runAsync_noSubAgents_returnsEmptyFlowable() { assertThat(events).isEmpty(); } + + static class BlockingAgent extends BaseAgent { + private final long sleepMillis; + + private BlockingAgent(String name, long sleepMillis) { + super(name, "Blocking Agent", ImmutableList.of(), null, null); + this.sleepMillis = sleepMillis; + } + + @Override + protected Flowable runAsyncImpl(InvocationContext invocationContext) { + return Flowable.fromCallable( + () -> { + Thread.sleep(sleepMillis); + return Event.builder() + .author(name()) + .branch(invocationContext.branch().orElse(null)) + .invocationId(invocationContext.invocationId()) + .content(Content.fromParts(Part.fromText("Done"))) + .build(); + }); + } + + @Override + protected Flowable runLiveImpl(InvocationContext invocationContext) { + throw new UnsupportedOperationException("Not implemented"); + } + } + + @Test + public void runAsync_blockingSubAgents_shouldExecuteInParallel() { + long sleepTime = 1000; + BlockingAgent agent1 = new BlockingAgent("agent1", sleepTime); + BlockingAgent agent2 = new BlockingAgent("agent2", sleepTime); + + ParallelAgent parallelAgent = + ParallelAgent.builder().name("parallel_agent").subAgents(agent1, agent2).build(); + + InvocationContext invocationContext = createInvocationContext(parallelAgent); + + long startTime = System.currentTimeMillis(); + List events = parallelAgent.runAsync(invocationContext).toList().blockingGet(); + long duration = System.currentTimeMillis() - startTime; + + assertThat(events).hasSize(2); + // If parallel, duration should be less than 1.5 * sleepTime (1500ms). + assertThat(duration).isAtLeast(sleepTime); + assertThat(duration).isLessThan((long) (1.5 * sleepTime)); + } + + @Test + public void runAsync_withTestScheduler_usesVirtualTime() { + TestScheduler testScheduler = new TestScheduler(); + long delayMillis = 1000; + TestingAgent agent = + new TestingAgent("delayed_agent", "Delayed Agent", delayMillis, testScheduler); + + ParallelAgent parallelAgent = + ParallelAgent.builder() + .name("parallel_agent") + .subAgents(agent) + .scheduler(testScheduler) + .build(); + + InvocationContext invocationContext = createInvocationContext(parallelAgent); + + TestSubscriber testSubscriber = parallelAgent.runAsync(invocationContext).test(); + + testScheduler.advanceTimeBy(delayMillis - 100, MILLISECONDS); + testSubscriber.assertNoValues(); + testSubscriber.assertNotComplete(); + testScheduler.advanceTimeBy(200, MILLISECONDS); + testSubscriber.assertValueCount(1); + testSubscriber.assertComplete(); + } }