diff --git a/core/src/main/java/com/google/adk/agents/InvocationContext.java b/core/src/main/java/com/google/adk/agents/InvocationContext.java index c01a9bc4..9491353f 100644 --- a/core/src/main/java/com/google/adk/agents/InvocationContext.java +++ b/core/src/main/java/com/google/adk/agents/InvocationContext.java @@ -17,13 +17,18 @@ package com.google.adk.agents; import com.google.adk.artifacts.BaseArtifactService; +import com.google.adk.events.Event; +import com.google.adk.flows.llmflows.ResumabilityConfig; import com.google.adk.memory.BaseMemoryService; import com.google.adk.models.LlmCallsLimitExceededException; import com.google.adk.plugins.PluginManager; import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; +import com.google.common.collect.ImmutableSet; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.InlineMe; import com.google.genai.types.Content; +import com.google.genai.types.FunctionCall; import java.util.Map; import java.util.Objects; import java.util.Optional; @@ -40,17 +45,37 @@ public class InvocationContext { private final PluginManager pluginManager; private final Optional liveRequestQueue; private final Map activeStreamingTools = new ConcurrentHashMap<>(); - - private Optional branch; private final String invocationId; - private BaseAgent agent; private final Session session; - private final Optional userContent; private final RunConfig runConfig; - private boolean endInvocation; + private final ResumabilityConfig resumabilityConfig; private final InvocationCostManager invocationCostManager = new InvocationCostManager(); + private Optional branch; + private BaseAgent agent; + private boolean endInvocation; + + private InvocationContext(Builder builder) { + this.sessionService = builder.sessionService; + this.artifactService = builder.artifactService; + this.memoryService = builder.memoryService; + this.pluginManager = builder.pluginManager; + this.liveRequestQueue = builder.liveRequestQueue; + this.branch = builder.branch; + this.invocationId = builder.invocationId; + this.agent = builder.agent; + this.session = builder.session; + this.userContent = builder.userContent; + this.runConfig = builder.runConfig; + this.endInvocation = builder.endInvocation; + this.resumabilityConfig = builder.resumabilityConfig; + } + + /** + * @deprecated Use {@link #builder()} instead. + */ + @Deprecated(forRemoval = true) public InvocationContext( BaseSessionService sessionService, BaseArtifactService artifactService, @@ -64,30 +89,26 @@ public InvocationContext( Optional userContent, RunConfig runConfig, boolean endInvocation) { - this.sessionService = sessionService; - this.artifactService = artifactService; - this.memoryService = memoryService; - this.pluginManager = pluginManager; - this.liveRequestQueue = liveRequestQueue; - this.branch = branch; - this.invocationId = invocationId; - this.agent = agent; - this.session = session; - this.userContent = userContent; - this.runConfig = runConfig; - this.endInvocation = endInvocation; + this( + builder() + .sessionService(sessionService) + .artifactService(artifactService) + .memoryService(memoryService) + .pluginManager(pluginManager) + .liveRequestQueue(liveRequestQueue) + .branch(branch) + .invocationId(invocationId) + .agent(agent) + .session(session) + .userContent(userContent) + .runConfig(runConfig) + .endInvocation(endInvocation)); } /** - * @deprecated Use the {@link #InvocationContext} constructor with PluginManager directly instead + * @deprecated Use {@link #builder()} instead. */ - @InlineMe( - replacement = - "this(sessionService, artifactService, memoryService, new" - + " PluginManager(), liveRequestQueue, branch, invocationId, agent," - + " session, userContent, runConfig, endInvocation)", - imports = "com.google.adk.plugins.PluginManager") - @Deprecated + @Deprecated(forRemoval = true) public InvocationContext( BaseSessionService sessionService, BaseArtifactService artifactService, @@ -101,34 +122,36 @@ public InvocationContext( RunConfig runConfig, boolean endInvocation) { this( - sessionService, - artifactService, - memoryService, - new PluginManager(), - liveRequestQueue, - branch, - invocationId, - agent, - session, - userContent, - runConfig, - endInvocation); + builder() + .sessionService(sessionService) + .artifactService(artifactService) + .memoryService(memoryService) + .liveRequestQueue(liveRequestQueue) + .branch(branch) + .invocationId(invocationId) + .agent(agent) + .session(session) + .userContent(userContent) + .runConfig(runConfig) + .endInvocation(endInvocation)); } /** - * @deprecated Use the {@link #InvocationContext} constructor directly instead + * @deprecated Use {@link #builder()} instead. */ @InlineMe( replacement = - "new InvocationContext(sessionService, artifactService, null, new PluginManager()," - + " Optional.empty(), Optional.empty(), invocationId, agent, session," - + " Optional.ofNullable(userContent), runConfig, false)", - imports = { - "com.google.adk.agents.InvocationContext", - "com.google.adk.plugins.PluginManager", - "java.util.Optional" - }) - @Deprecated + "InvocationContext.builder()" + + ".sessionService(sessionService)" + + ".artifactService(artifactService)" + + ".invocationId(invocationId)" + + ".agent(agent)" + + ".session(session)" + + ".userContent(Optional.ofNullable(userContent))" + + ".runConfig(runConfig)" + + ".build()", + imports = {"com.google.adk.agents.InvocationContext", "java.util.Optional"}) + @Deprecated(forRemoval = true) public static InvocationContext create( BaseSessionService sessionService, BaseArtifactService artifactService, @@ -137,36 +160,21 @@ public static InvocationContext create( Session session, Content userContent, RunConfig runConfig) { - return new InvocationContext( - sessionService, - artifactService, - /* memoryService= */ null, - new PluginManager(), - /* liveRequestQueue= */ Optional.empty(), - /* branch= */ Optional.empty(), - invocationId, - agent, - session, - Optional.ofNullable(userContent), - runConfig, - false); + return builder() + .sessionService(sessionService) + .artifactService(artifactService) + .invocationId(invocationId) + .agent(agent) + .session(session) + .userContent(Optional.ofNullable(userContent)) + .runConfig(runConfig) + .build(); } /** - * @deprecated Use the {@link #InvocationContext} constructor directly instead + * @deprecated Use {@link #builder()} instead. */ - @InlineMe( - replacement = - "new InvocationContext(sessionService, artifactService, null, new PluginManager()," - + " Optional.ofNullable(liveRequestQueue), Optional.empty()," - + " InvocationContext.newInvocationContextId(), agent, session, Optional.empty()," - + " runConfig, false)", - imports = { - "com.google.adk.agents.InvocationContext", - "com.google.adk.plugins.PluginManager", - "java.util.Optional" - }) - @Deprecated + @Deprecated(forRemoval = true) public static InvocationContext create( BaseSessionService sessionService, BaseArtifactService artifactService, @@ -174,124 +182,183 @@ public static InvocationContext create( Session session, LiveRequestQueue liveRequestQueue, RunConfig runConfig) { - return new InvocationContext( - sessionService, - artifactService, - /* memoryService= */ null, - new PluginManager(), - Optional.ofNullable(liveRequestQueue), - /* branch= */ Optional.empty(), - InvocationContext.newInvocationContextId(), - agent, - session, - Optional.empty(), - runConfig, - false); + return builder() + .sessionService(sessionService) + .artifactService(artifactService) + .agent(agent) + .session(session) + .liveRequestQueue(Optional.ofNullable(liveRequestQueue)) + .runConfig(runConfig) + .build(); + } + + /** Returns a new {@link Builder} for creating {@link InvocationContext} instances. */ + public static Builder builder() { + return new Builder(); } + /** Creates a shallow copy of the given {@link InvocationContext}. */ public static InvocationContext copyOf(InvocationContext other) { InvocationContext newContext = - new InvocationContext( - other.sessionService, - other.artifactService, - other.memoryService, - other.pluginManager, - other.liveRequestQueue, - other.branch, - other.invocationId, - other.agent, - other.session, - other.userContent, - other.runConfig, - other.endInvocation); + builder() + .sessionService(other.sessionService) + .artifactService(other.artifactService) + .memoryService(other.memoryService) + .pluginManager(other.pluginManager) + .liveRequestQueue(other.liveRequestQueue) + .branch(other.branch) + .invocationId(other.invocationId) + .agent(other.agent) + .session(other.session) + .userContent(other.userContent) + .runConfig(other.runConfig) + .endInvocation(other.endInvocation) + .resumabilityConfig(other.resumabilityConfig) + .build(); newContext.activeStreamingTools.putAll(other.activeStreamingTools); return newContext; } + /** Returns the session service for managing session state. */ public BaseSessionService sessionService() { return sessionService; } + /** Returns the artifact service for persisting artifacts. */ public BaseArtifactService artifactService() { return artifactService; } + /** Returns the memory service for accessing agent memory. */ public BaseMemoryService memoryService() { return memoryService; } + /** Returns the plugin manager for accessing tools and plugins. */ public PluginManager pluginManager() { return pluginManager; } + /** Returns a map of tool call IDs to active streaming tools for the current invocation. */ public Map activeStreamingTools() { return activeStreamingTools; } + /** Returns the queue for managing live requests, if available for this invocation. */ public Optional liveRequestQueue() { return liveRequestQueue; } + /** Returns the unique ID for this invocation. */ public String invocationId() { return invocationId; } + /** + * Sets the [branch] ID for the current invocation. A branch represents a fork in the conversation + * history. + */ public void branch(@Nullable String branch) { this.branch = Optional.ofNullable(branch); } + /** + * Returns the branch ID for the current invocation, if one is set. A branch represents a fork in + * the conversation history. + */ public Optional branch() { return branch; } + /** Returns the agent being invoked. */ public BaseAgent agent() { return agent; } + /** Sets the [agent] being invoked. This is useful when delegating to a sub-agent. */ public void agent(BaseAgent agent) { this.agent = agent; } + /** Returns the session associated with this invocation. */ public Session session() { return session; } + /** Returns the user content that triggered this invocation, if any. */ public Optional userContent() { return userContent; } + /** Returns the configuration for the current agent run. */ public RunConfig runConfig() { return runConfig; } + /** + * Returns whether this invocation should be ended, e.g., due to reaching a terminal state or + * error. + */ public boolean endInvocation() { return endInvocation; } + /** Sets whether this invocation should be ended. */ public void setEndInvocation(boolean endInvocation) { this.endInvocation = endInvocation; } + /** Returns the application name associated with the session. */ public String appName() { return session.appName(); } + /** Returns the user ID associated with the session. */ public String userId() { return session.userId(); } + /** Generates a new unique ID for an invocation context. */ public static String newInvocationContextId() { return "e-" + UUID.randomUUID(); } + /** + * Increments the count of LLM calls made during this invocation and throws an exception if the + * limit defined in {@link RunConfig} is exceeded. + * + * @throws LlmCallsLimitExceededException if the call limit is exceeded + */ public void incrementLlmCallsCount() throws LlmCallsLimitExceededException { this.invocationCostManager.incrementAndEnforceLlmCallsLimit(this.runConfig); } + /** Returns whether the current invocation is resumable. */ + public boolean isResumable() { + return resumabilityConfig.isResumable(); + } + + /** Returns whether to pause the invocation right after this [event]. */ + public boolean shouldPauseInvocation(Event event) { + if (!isResumable()) { + return false; + } + + var longRunningToolIds = event.longRunningToolIds().orElse(ImmutableSet.of()); + if (longRunningToolIds.isEmpty()) { + return false; + } + + return event.functionCalls().stream() + .map(FunctionCall::id) + .flatMap(Optional::stream) + .anyMatch(functionCallId -> longRunningToolIds.contains(functionCallId)); + } + private static class InvocationCostManager { private int numberOfLlmCalls = 0; - public void incrementAndEnforceLlmCallsLimit(RunConfig runConfig) + void incrementAndEnforceLlmCallsLimit(RunConfig runConfig) throws LlmCallsLimitExceededException { this.numberOfLlmCalls++; @@ -304,6 +371,231 @@ public void incrementAndEnforceLlmCallsLimit(RunConfig runConfig) } } + /** Builder for {@link InvocationContext}. */ + public static class Builder { + private BaseSessionService sessionService; + private BaseArtifactService artifactService; + private BaseMemoryService memoryService; + private PluginManager pluginManager = new PluginManager(); + private Optional liveRequestQueue = Optional.empty(); + private Optional branch = Optional.empty(); + private String invocationId = newInvocationContextId(); + private BaseAgent agent; + private Session session; + private Optional userContent = Optional.empty(); + private RunConfig runConfig = RunConfig.builder().build(); + private boolean endInvocation = false; + private ResumabilityConfig resumabilityConfig = new ResumabilityConfig(); + + /** + * Sets the session service for managing session state. + * + * @param sessionService the session service to use; required. + * @return this builder instance for chaining. + */ + @CanIgnoreReturnValue + public Builder sessionService(BaseSessionService sessionService) { + this.sessionService = sessionService; + return this; + } + + /** + * Sets the artifact service for persisting artifacts. + * + * @param artifactService the artifact service to use; required. + * @return this builder instance for chaining. + */ + @CanIgnoreReturnValue + public Builder artifactService(BaseArtifactService artifactService) { + this.artifactService = artifactService; + return this; + } + + /** + * Sets the memory service for accessing agent memory. + * + * @param memoryService the memory service to use. + * @return this builder instance for chaining. + */ + @CanIgnoreReturnValue + public Builder memoryService(BaseMemoryService memoryService) { + this.memoryService = memoryService; + return this; + } + + /** + * Sets the plugin manager for accessing tools and plugins. + * + * @param pluginManager the plugin manager to use. + * @return this builder instance for chaining. + */ + @CanIgnoreReturnValue + public Builder pluginManager(PluginManager pluginManager) { + this.pluginManager = pluginManager; + return this; + } + + /** + * Sets the queue for managing live requests. + * + * @param liveRequestQueue the queue for managing live requests. + * @return this builder instance for chaining. + * @deprecated Use {@link #liveRequestQueue(LiveRequestQueue)} instead. + */ + // TODO: b/462140921 - Builders should not accept Optional parameters. + @Deprecated(forRemoval = true) + @CanIgnoreReturnValue + public Builder liveRequestQueue(Optional liveRequestQueue) { + this.liveRequestQueue = liveRequestQueue; + return this; + } + + /** + * Sets the queue for managing live requests. + * + * @param liveRequestQueue the queue for managing live requests. + * @return this builder instance for chaining. + */ + @CanIgnoreReturnValue + public Builder liveRequestQueue(LiveRequestQueue liveRequestQueue) { + this.liveRequestQueue = Optional.of(liveRequestQueue); + return this; + } + + /** + * Sets the branch ID for the invocation. + * + * @param branch the branch ID for the invocation. + * @return this builder instance for chaining. + * @deprecated Use {@link #branch(String)} instead. + */ + // TODO: b/462140921 - Builders should not accept Optional parameters. + @Deprecated(forRemoval = true) + @CanIgnoreReturnValue + public Builder branch(Optional branch) { + this.branch = branch; + return this; + } + + /** + * Sets the branch ID for the invocation. + * + * @param branch the branch ID for the invocation. + * @return this builder instance for chaining. + */ + @CanIgnoreReturnValue + public Builder branch(String branch) { + this.branch = Optional.of(branch); + return this; + } + + /** + * Sets the unique ID for the invocation. + * + * @param invocationId the unique ID for the invocation. + * @return this builder instance for chaining. + */ + @CanIgnoreReturnValue + public Builder invocationId(String invocationId) { + this.invocationId = invocationId; + return this; + } + + /** + * Sets the agent being invoked. + * + * @param agent the agent being invoked; required. + * @return this builder instance for chaining. + */ + @CanIgnoreReturnValue + public Builder agent(BaseAgent agent) { + this.agent = agent; + return this; + } + + /** + * Sets the session associated with this invocation. + * + * @param session the session associated with this invocation; required. + * @return this builder instance for chaining. + */ + @CanIgnoreReturnValue + public Builder session(Session session) { + this.session = session; + return this; + } + + /** + * Sets the user content that triggered this invocation. + * + * @param userContent the user content that triggered this invocation. + * @return this builder instance for chaining. + */ + @CanIgnoreReturnValue + public Builder userContent(Optional userContent) { + this.userContent = userContent; + return this; + } + + /** + * Sets the user content that triggered this invocation. + * + * @param userContent the user content that triggered this invocation. + * @return this builder instance for chaining. + */ + @CanIgnoreReturnValue + public Builder userContent(Content userContent) { + this.userContent = Optional.of(userContent); + return this; + } + + /** + * Sets the configuration for the current agent run. + * + * @param runConfig the configuration for the current agent run. + * @return this builder instance for chaining. + */ + @CanIgnoreReturnValue + public Builder runConfig(RunConfig runConfig) { + this.runConfig = runConfig; + return this; + } + + /** + * Sets whether this invocation should be ended. + * + * @param endInvocation whether this invocation should be ended. + * @return this builder instance for chaining. + */ + @CanIgnoreReturnValue + public Builder endInvocation(boolean endInvocation) { + this.endInvocation = endInvocation; + return this; + } + + /** + * Sets the resumability configuration for the current agent run. + * + * @param resumabilityConfig the resumability configuration. + * @return this builder instance for chaining. + */ + @CanIgnoreReturnValue + public Builder resumabilityConfig(ResumabilityConfig resumabilityConfig) { + this.resumabilityConfig = resumabilityConfig; + return this; + } + + /** + * Builds the {@link InvocationContext} instance. + * + * @throws IllegalStateException if any required parameters are missing. + */ + // TODO: b/462183912 - Add validation for required parameters. + public InvocationContext build() { + return new InvocationContext(this); + } + } + @Override public boolean equals(Object o) { if (this == o) { @@ -324,7 +616,8 @@ public boolean equals(Object o) { && Objects.equals(agent, that.agent) && Objects.equals(session, that.session) && Objects.equals(userContent, that.userContent) - && Objects.equals(runConfig, that.runConfig); + && Objects.equals(runConfig, that.runConfig) + && Objects.equals(resumabilityConfig, that.resumabilityConfig); } @Override @@ -342,6 +635,7 @@ public int hashCode() { session, userContent, runConfig, - endInvocation); + endInvocation, + resumabilityConfig); } } diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index b0840ecb..12597adf 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -604,7 +604,16 @@ private void maybeSaveOutputToState(Event event) { @Override protected Flowable runAsyncImpl(InvocationContext invocationContext) { - return llmFlow.run(invocationContext).doOnNext(this::maybeSaveOutputToState); + return llmFlow + .run(invocationContext) + .concatMap( + event -> { + this.maybeSaveOutputToState(event); + if (invocationContext.shouldPauseInvocation(event)) { + return Flowable.just(event).concatWith(Flowable.empty()); + } + return Flowable.just(event); + }); } @Override diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 5c7dad32..9c921e7f 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -156,36 +156,9 @@ protected Flowable postprocess( } return currentLlmResponse.flatMapPublisher( - updatedResponse -> { - Flowable processorEvents = Flowable.fromIterable(Iterables.concat(eventIterables)); - - if (updatedResponse.content().isEmpty() - && updatedResponse.errorCode().isEmpty() - && !updatedResponse.interrupted().orElse(false) - && !updatedResponse.turnComplete().orElse(false)) { - return processorEvents; - } - - Event modelResponseEvent = - buildModelResponseEvent(baseEventForLlmResponse, llmRequest, updatedResponse); - - Flowable modelEventStream = Flowable.just(modelResponseEvent); - - if (modelResponseEvent.functionCalls().isEmpty()) { - return processorEvents.concatWith(modelEventStream); - } - - Maybe maybeFunctionCallEvent; - if (context.runConfig().streamingMode() == StreamingMode.BIDI) { - maybeFunctionCallEvent = - Functions.handleFunctionCallsLive(context, modelResponseEvent, llmRequest.tools()); - } else { - maybeFunctionCallEvent = - Functions.handleFunctionCalls(context, modelResponseEvent, llmRequest.tools()); - } - - return processorEvents.concatWith(modelEventStream).concatWith(maybeFunctionCallEvent); - }); + updatedResponse -> + buildPostprocessingEvents( + updatedResponse, eventIterables, context, baseEventForLlmResponse, llmRequest)); } /** @@ -623,6 +596,45 @@ public void onError(Throwable e) { * * @return A fully constructed {@link Event} representing the LLM response. */ + private Flowable buildPostprocessingEvents( + LlmResponse updatedResponse, + List> eventIterables, + InvocationContext context, + Event baseEventForLlmResponse, + LlmRequest llmRequest) { + Flowable processorEvents = Flowable.fromIterable(Iterables.concat(eventIterables)); + if (updatedResponse.content().isEmpty() + && updatedResponse.errorCode().isEmpty() + && !updatedResponse.interrupted().orElse(false) + && !updatedResponse.turnComplete().orElse(false)) { + return processorEvents; + } + + Event modelResponseEvent = + buildModelResponseEvent(baseEventForLlmResponse, llmRequest, updatedResponse); + if (modelResponseEvent.functionCalls().isEmpty()) { + return processorEvents.concatWith(Flowable.just(modelResponseEvent)); + } + + Maybe maybeFunctionResponseEvent = + context.runConfig().streamingMode() == StreamingMode.BIDI + ? Functions.handleFunctionCallsLive(context, modelResponseEvent, llmRequest.tools()) + : Functions.handleFunctionCalls(context, modelResponseEvent, llmRequest.tools()); + + Flowable functionEvents = + maybeFunctionResponseEvent.flatMapPublisher( + functionResponseEvent -> { + Optional toolConfirmationEvent = + Functions.generateRequestConfirmationEvent( + context, modelResponseEvent, functionResponseEvent); + return toolConfirmationEvent.isPresent() + ? Flowable.just(toolConfirmationEvent.get(), functionResponseEvent) + : Flowable.just(functionResponseEvent); + }); + + return processorEvents.concatWith(Flowable.just(modelResponseEvent)).concatWith(functionEvents); + } + private Event buildModelResponseEvent( Event baseEventForLlmResponse, LlmRequest llmRequest, LlmResponse llmResponse) { Event.Builder eventBuilder = @@ -640,10 +652,13 @@ private Event buildModelResponseEvent( Event event = eventBuilder.build(); + logger.info("event: {} functionCalls: {}", event, event.functionCalls()); + if (!event.functionCalls().isEmpty()) { Functions.populateClientFunctionCallId(event); Set longRunningToolIds = Functions.getLongRunningFunctionCalls(event.functionCalls(), llmRequest.tools()); + logger.info("longRunningToolIds: {}", longRunningToolIds); if (!longRunningToolIds.isEmpty()) { event.setLongRunningToolIds(Optional.of(longRunningToolIds)); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/ResumabilityConfig.java b/core/src/main/java/com/google/adk/flows/llmflows/ResumabilityConfig.java new file mode 100644 index 00000000..c5d839ca --- /dev/null +++ b/core/src/main/java/com/google/adk/flows/llmflows/ResumabilityConfig.java @@ -0,0 +1,29 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.flows.llmflows; + +/** + * An app contains Resumability configuration for the agents. + * + * @param isResumable Whether the app is resumable. + */ +public record ResumabilityConfig(boolean isResumable) { + + /** Creates a new {@code ResumabilityConfig} with resumability disabled. */ + public ResumabilityConfig() { + this(false); + } +} diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 08a2bb50..168d1d3f 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -24,18 +24,21 @@ import com.google.adk.agents.LlmAgent; import com.google.adk.agents.RunConfig; import com.google.adk.artifacts.BaseArtifactService; +import com.google.adk.artifacts.InMemoryArtifactService; import com.google.adk.events.Event; import com.google.adk.events.EventActions; +import com.google.adk.flows.llmflows.ResumabilityConfig; import com.google.adk.memory.BaseMemoryService; import com.google.adk.plugins.BasePlugin; import com.google.adk.plugins.PluginManager; import com.google.adk.sessions.BaseSessionService; +import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.adk.tools.BaseTool; import com.google.adk.tools.FunctionTool; import com.google.adk.utils.CollectionUtils; import com.google.common.collect.ImmutableList; -import com.google.errorprone.annotations.InlineMe; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.genai.types.AudioTranscriptionConfig; import com.google.genai.types.Content; import com.google.genai.types.Modality; @@ -62,20 +65,118 @@ public class Runner { private final String appName; private final BaseArtifactService artifactService; private final BaseSessionService sessionService; - private final @Nullable BaseMemoryService memoryService; + @Nullable private final BaseMemoryService memoryService; private final PluginManager pluginManager; + private final ResumabilityConfig resumabilityConfig; + + /** Builder for {@link Runner}. */ + public static class Builder { + private BaseAgent agent; + private String appName; + private BaseArtifactService artifactService = new InMemoryArtifactService(); + private BaseSessionService sessionService = new InMemorySessionService(); + @Nullable private BaseMemoryService memoryService = null; + private List plugins = ImmutableList.of(); + private ResumabilityConfig resumabilityConfig = new ResumabilityConfig(); + + @CanIgnoreReturnValue + public Builder agent(BaseAgent agent) { + this.agent = agent; + return this; + } + + @CanIgnoreReturnValue + public Builder appName(String appName) { + this.appName = appName; + return this; + } + + @CanIgnoreReturnValue + public Builder artifactService(BaseArtifactService artifactService) { + this.artifactService = artifactService; + return this; + } + + @CanIgnoreReturnValue + public Builder sessionService(BaseSessionService sessionService) { + this.sessionService = sessionService; + return this; + } + + @CanIgnoreReturnValue + public Builder memoryService(BaseMemoryService memoryService) { + this.memoryService = memoryService; + return this; + } + + @CanIgnoreReturnValue + public Builder plugins(List plugins) { + this.plugins = plugins; + return this; + } + + @CanIgnoreReturnValue + public Builder resumabilityConfig(ResumabilityConfig resumabilityConfig) { + this.resumabilityConfig = resumabilityConfig; + return this; + } + + public Runner build() { + if (agent == null) { + throw new IllegalStateException("Agent must be provided."); + } + if (appName == null) { + throw new IllegalStateException("App name must be provided."); + } + if (artifactService == null) { + throw new IllegalStateException("Artifact service must be provided."); + } + if (sessionService == null) { + throw new IllegalStateException("Session service must be provided."); + } + return new Runner( + agent, + appName, + artifactService, + sessionService, + memoryService, + plugins, + resumabilityConfig); + } + } + + public static Builder builder() { + return new Builder(); + } - /** Creates a new {@code Runner}. */ + /** + * Creates a new {@code Runner}. + * + * @deprecated Use {@link Runner.Builder} instead. + */ + @Deprecated public Runner( BaseAgent agent, String appName, BaseArtifactService artifactService, BaseSessionService sessionService, @Nullable BaseMemoryService memoryService) { - this(agent, appName, artifactService, sessionService, memoryService, ImmutableList.of()); + this( + agent, + appName, + artifactService, + sessionService, + memoryService, + ImmutableList.of(), + new ResumabilityConfig()); } - /** Creates a new {@code Runner} with a list of plugins. */ + /** + * Creates a new {@code Runner} with a list of plugins. + * + * @deprecated Use {@link Runner.Builder} instead. + */ + @Deprecated public Runner( BaseAgent agent, String appName, @@ -83,21 +184,44 @@ public Runner( BaseSessionService sessionService, @Nullable BaseMemoryService memoryService, List plugins) { + this( + agent, + appName, + artifactService, + sessionService, + memoryService, + plugins, + new ResumabilityConfig()); + } + + /** + * Creates a new {@code Runner} with a list of plugins and resumability config. + * + * @deprecated Use {@link Runner.Builder} instead. + */ + @Deprecated + public Runner( + BaseAgent agent, + String appName, + BaseArtifactService artifactService, + BaseSessionService sessionService, + @Nullable BaseMemoryService memoryService, + List plugins, + ResumabilityConfig resumabilityConfig) { this.agent = agent; this.appName = appName; this.artifactService = artifactService; this.sessionService = sessionService; this.memoryService = memoryService; this.pluginManager = new PluginManager(plugins); + this.resumabilityConfig = resumabilityConfig; } /** * Creates a new {@code Runner}. * - * @deprecated Use the constructor with {@code BaseMemoryService} instead even if with a null if - * you don't need the memory service. + * @deprecated Use {@link Runner.Builder} instead. */ - @InlineMe(replacement = "this(agent, appName, artifactService, sessionService, null)") @Deprecated public Runner( BaseAgent agent, @@ -123,7 +247,8 @@ public BaseSessionService sessionService() { return this.sessionService; } - public @Nullable BaseMemoryService memoryService() { + @Nullable + public BaseMemoryService memoryService() { return this.memoryService; } @@ -305,7 +430,7 @@ public Flowable runAsync( newInvocationContextWithId( updatedSession, event.content(), - Optional.empty(), + /* liveRequestQueue= */ Optional.empty(), runConfig, invocationId); contextWithUpdatedSession.agent( @@ -430,20 +555,19 @@ private InvocationContext newInvocationContext( Optional liveRequestQueue, RunConfig runConfig) { BaseAgent rootAgent = this.agent; - InvocationContext invocationContext = - new InvocationContext( - this.sessionService, - this.artifactService, - this.memoryService, - this.pluginManager, - liveRequestQueue, - /* branch= */ Optional.empty(), - InvocationContext.newInvocationContextId(), - rootAgent, - session, - newMessage, - runConfig, - /* endInvocation= */ false); + var invocationContextBuilder = + InvocationContext.builder() + .sessionService(this.sessionService) + .artifactService(this.artifactService) + .memoryService(this.memoryService) + .pluginManager(this.pluginManager) + .agent(rootAgent) + .session(session) + .userContent(newMessage) + .runConfig(runConfig) + .resumabilityConfig(this.resumabilityConfig); + liveRequestQueue.ifPresent(invocationContextBuilder::liveRequestQueue); + var invocationContext = invocationContextBuilder.build(); invocationContext.agent(this.findAgentToRun(session, rootAgent)); return invocationContext; } @@ -460,20 +584,20 @@ private InvocationContext newInvocationContextWithId( RunConfig runConfig, String invocationId) { BaseAgent rootAgent = this.agent; - InvocationContext invocationContext = - new InvocationContext( - this.sessionService, - this.artifactService, - this.memoryService, - this.pluginManager, - liveRequestQueue, - /* branch= */ Optional.empty(), - invocationId, - rootAgent, - session, - newMessage, - runConfig, - /* endInvocation= */ false); + var invocationContextBuilder = + InvocationContext.builder() + .sessionService(this.sessionService) + .artifactService(this.artifactService) + .memoryService(this.memoryService) + .pluginManager(this.pluginManager) + .invocationId(invocationId) + .agent(rootAgent) + .session(session) + .userContent(newMessage) + .runConfig(runConfig) + .resumabilityConfig(this.resumabilityConfig); + liveRequestQueue.ifPresent(invocationContextBuilder::liveRequestQueue); + var invocationContext = invocationContextBuilder.build(); invocationContext.agent(this.findAgentToRun(session, rootAgent)); return invocationContext; } diff --git a/core/src/main/java/com/google/adk/tools/LongRunningFunctionTool.java b/core/src/main/java/com/google/adk/tools/LongRunningFunctionTool.java index 328be196..0cd02202 100644 --- a/core/src/main/java/com/google/adk/tools/LongRunningFunctionTool.java +++ b/core/src/main/java/com/google/adk/tools/LongRunningFunctionTool.java @@ -17,18 +17,28 @@ package com.google.adk.tools; import java.lang.reflect.Method; +import javax.annotation.Nullable; /** A function tool that returns the result asynchronously. */ public class LongRunningFunctionTool extends FunctionTool { public static LongRunningFunctionTool create(Method func) { - return new LongRunningFunctionTool(func); + return create(func, /* requireConfirmation= */ false); + } + + public static LongRunningFunctionTool create(Method func, boolean requireConfirmation) { + return new LongRunningFunctionTool(func, requireConfirmation); } public static LongRunningFunctionTool create(Class cls, String methodName) { + return create(cls, methodName, /* requireConfirmation= */ false); + } + + public static LongRunningFunctionTool create( + Class cls, String methodName, boolean requireConfirmation) { for (Method method : cls.getMethods()) { if (method.getName().equals(methodName)) { - return create(method); + return create(method, requireConfirmation); } } throw new IllegalArgumentException( @@ -36,21 +46,36 @@ public static LongRunningFunctionTool create(Class cls, String methodName) { } public static LongRunningFunctionTool create(Object instance, String methodName) { + return create(instance, methodName, /* requireConfirmation= */ false); + } + + public static LongRunningFunctionTool create( + Object instance, String methodName, boolean requireConfirmation) { Class cls = instance.getClass(); for (Method method : cls.getMethods()) { if (method.getName().equals(methodName)) { - return new LongRunningFunctionTool(instance, method); + return create(instance, method, requireConfirmation); } } throw new IllegalArgumentException( String.format("Method %s not found in class %s.", methodName, cls.getName())); } - private LongRunningFunctionTool(Method func) { - super(null, func, /* isLongRunning= */ true, /* requireConfirmation= */ false); + public static LongRunningFunctionTool create(@Nullable Object instance, Method method) { + return create(instance, method, false); + } + + public static LongRunningFunctionTool create( + @Nullable Object instance, Method method, boolean requireConfirmation) { + return new LongRunningFunctionTool(instance, method, requireConfirmation); + } + + private LongRunningFunctionTool(Method func, boolean requireConfirmation) { + super(null, func, /* isLongRunning= */ true, requireConfirmation); } - private LongRunningFunctionTool(Object instance, Method func) { - super(instance, func, /* isLongRunning= */ true, /* requireConfirmation= */ false); + private LongRunningFunctionTool( + @Nullable Object instance, Method func, boolean requireConfirmation) { + super(instance, func, /* isLongRunning= */ true, requireConfirmation); } } diff --git a/core/src/test/java/com/google/adk/agents/InvocationContextTest.java b/core/src/test/java/com/google/adk/agents/InvocationContextTest.java index cca9dbc9..d85f2eb7 100644 --- a/core/src/test/java/com/google/adk/agents/InvocationContextTest.java +++ b/core/src/test/java/com/google/adk/agents/InvocationContextTest.java @@ -20,11 +20,17 @@ import static org.mockito.Mockito.mock; import com.google.adk.artifacts.BaseArtifactService; +import com.google.adk.events.Event; +import com.google.adk.flows.llmflows.ResumabilityConfig; import com.google.adk.memory.BaseMemoryService; import com.google.adk.plugins.PluginManager; import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import com.google.genai.types.Content; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.Part; import java.util.HashMap; import java.util.Map; import java.util.Optional; @@ -65,19 +71,18 @@ public void setUp() { @Test public void testCreateWithUserContent() { InvocationContext context = - new InvocationContext( - mockSessionService, - mockArtifactService, - mockMemoryService, - pluginManager, - /* liveRequestQueue= */ Optional.empty(), - /* branch= */ Optional.empty(), - testInvocationId, - mockAgent, - session, - Optional.of(userContent), - runConfig, - /* endInvocation= */ false); + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .pluginManager(pluginManager) + .invocationId(testInvocationId) + .agent(mockAgent) + .session(session) + .userContent(Optional.of(userContent)) + .runConfig(runConfig) + .endInvocation(false) + .build(); assertThat(context).isNotNull(); assertThat(context.sessionService()).isEqualTo(mockSessionService); @@ -95,19 +100,18 @@ public void testCreateWithUserContent() { @Test public void testCreateWithNullUserContent() { InvocationContext context = - new InvocationContext( - mockSessionService, - mockArtifactService, - mockMemoryService, - pluginManager, - /* liveRequestQueue= */ Optional.empty(), - /* branch= */ Optional.empty(), - testInvocationId, - mockAgent, - session, - /* userContent= */ Optional.empty(), - runConfig, - /* endInvocation= */ false); + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .pluginManager(pluginManager) + .invocationId(testInvocationId) + .agent(mockAgent) + .session(session) + .userContent(Optional.empty()) + .runConfig(runConfig) + .endInvocation(false) + .build(); assertThat(context).isNotNull(); assertThat(context.userContent()).isEmpty(); @@ -116,19 +120,18 @@ public void testCreateWithNullUserContent() { @Test public void testCreateWithLiveRequestQueue() { InvocationContext context = - new InvocationContext( - mockSessionService, - mockArtifactService, - mockMemoryService, - pluginManager, - Optional.of(liveRequestQueue), - /* branch= */ Optional.empty(), - InvocationContext.newInvocationContextId(), - mockAgent, - session, - /* userContent= */ Optional.empty(), - runConfig, - /* endInvocation= */ false); + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .pluginManager(pluginManager) + .liveRequestQueue(liveRequestQueue) + .agent(mockAgent) + .session(session) + .userContent(Optional.empty()) + .runConfig(runConfig) + .endInvocation(false) + .build(); assertThat(context).isNotNull(); assertThat(context.sessionService()).isEqualTo(mockSessionService); @@ -146,19 +149,18 @@ public void testCreateWithLiveRequestQueue() { @Test public void testCopyOf() { InvocationContext originalContext = - new InvocationContext( - mockSessionService, - mockArtifactService, - mockMemoryService, - pluginManager, - /* liveRequestQueue= */ Optional.empty(), - /* branch= */ Optional.empty(), - testInvocationId, - mockAgent, - session, - Optional.of(userContent), - runConfig, - /* endInvocation= */ false); + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .pluginManager(pluginManager) + .invocationId(testInvocationId) + .agent(mockAgent) + .session(session) + .userContent(Optional.of(userContent)) + .runConfig(runConfig) + .endInvocation(false) + .build(); originalContext.activeStreamingTools().putAll(activeStreamingTools); InvocationContext copiedContext = InvocationContext.copyOf(originalContext); @@ -183,19 +185,18 @@ public void testCopyOf() { @Test public void testGetters() { InvocationContext context = - new InvocationContext( - mockSessionService, - mockArtifactService, - mockMemoryService, - pluginManager, - /* liveRequestQueue= */ Optional.empty(), - /* branch= */ Optional.empty(), - testInvocationId, - mockAgent, - session, - Optional.of(userContent), - runConfig, - /* endInvocation= */ false); + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .pluginManager(pluginManager) + .invocationId(testInvocationId) + .agent(mockAgent) + .session(session) + .userContent(Optional.of(userContent)) + .runConfig(runConfig) + .endInvocation(false) + .build(); assertThat(context.sessionService()).isEqualTo(mockSessionService); assertThat(context.artifactService()).isEqualTo(mockArtifactService); @@ -212,19 +213,18 @@ public void testGetters() { @Test public void testSetAgent() { InvocationContext context = - new InvocationContext( - mockSessionService, - mockArtifactService, - mockMemoryService, - pluginManager, - /* liveRequestQueue= */ Optional.empty(), - /* branch= */ Optional.empty(), - testInvocationId, - mockAgent, - session, - Optional.of(userContent), - runConfig, - /* endInvocation= */ false); + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .pluginManager(pluginManager) + .invocationId(testInvocationId) + .agent(mockAgent) + .session(session) + .userContent(Optional.of(userContent)) + .runConfig(runConfig) + .endInvocation(false) + .build(); BaseAgent newMockAgent = mock(BaseAgent.class); context.agent(newMockAgent); @@ -247,19 +247,18 @@ public void testNewInvocationContextId() { @Test public void testEquals_sameObject() { InvocationContext context = - new InvocationContext( - mockSessionService, - mockArtifactService, - mockMemoryService, - pluginManager, - /* liveRequestQueue= */ Optional.empty(), - /* branch= */ Optional.empty(), - testInvocationId, - mockAgent, - session, - Optional.of(userContent), - runConfig, - /* endInvocation= */ false); + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .pluginManager(pluginManager) + .invocationId(testInvocationId) + .agent(mockAgent) + .session(session) + .userContent(Optional.of(userContent)) + .runConfig(runConfig) + .endInvocation(false) + .build(); assertThat(context.equals(context)).isTrue(); } @@ -267,19 +266,18 @@ public void testEquals_sameObject() { @Test public void testEquals_null() { InvocationContext context = - new InvocationContext( - mockSessionService, - mockArtifactService, - mockMemoryService, - pluginManager, - /* liveRequestQueue= */ Optional.empty(), - /* branch= */ Optional.empty(), - testInvocationId, - mockAgent, - session, - Optional.of(userContent), - runConfig, - /* endInvocation= */ false); + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .pluginManager(pluginManager) + .invocationId(testInvocationId) + .agent(mockAgent) + .session(session) + .userContent(Optional.of(userContent)) + .runConfig(runConfig) + .endInvocation(false) + .build(); assertThat(context.equals(null)).isFalse(); } @@ -287,35 +285,33 @@ public void testEquals_null() { @Test public void testEquals_sameValues() { InvocationContext context1 = - new InvocationContext( - mockSessionService, - mockArtifactService, - mockMemoryService, - pluginManager, - /* liveRequestQueue= */ Optional.empty(), - /* branch= */ Optional.empty(), - testInvocationId, - mockAgent, - session, - Optional.of(userContent), - runConfig, - /* endInvocation= */ false); + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .pluginManager(pluginManager) + .invocationId(testInvocationId) + .agent(mockAgent) + .session(session) + .userContent(Optional.of(userContent)) + .runConfig(runConfig) + .endInvocation(false) + .build(); // Create another context with the same parameters InvocationContext context2 = - new InvocationContext( - mockSessionService, - mockArtifactService, - mockMemoryService, - pluginManager, - /* liveRequestQueue= */ Optional.empty(), - /* branch= */ Optional.empty(), - testInvocationId, - mockAgent, - session, - Optional.of(userContent), - runConfig, - /* endInvocation= */ false); + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .pluginManager(pluginManager) + .invocationId(testInvocationId) + .agent(mockAgent) + .session(session) + .userContent(Optional.of(userContent)) + .runConfig(runConfig) + .endInvocation(false) + .build(); assertThat(context1.equals(context2)).isTrue(); assertThat(context2.equals(context1)).isTrue(); // Check symmetry @@ -324,95 +320,89 @@ public void testEquals_sameValues() { @Test public void testEquals_differentValues() { InvocationContext context = - new InvocationContext( - mockSessionService, - mockArtifactService, - mockMemoryService, - pluginManager, - /* liveRequestQueue= */ Optional.empty(), - /* branch= */ Optional.empty(), - testInvocationId, - mockAgent, - session, - Optional.of(userContent), - runConfig, - /* endInvocation= */ false); + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .pluginManager(pluginManager) + .invocationId(testInvocationId) + .agent(mockAgent) + .session(session) + .userContent(Optional.of(userContent)) + .runConfig(runConfig) + .endInvocation(false) + .build(); // Create contexts with one field different InvocationContext contextWithDiffSessionService = - new InvocationContext( - mock(BaseSessionService.class), // Different mock - mockArtifactService, - mockMemoryService, - pluginManager, - /* liveRequestQueue= */ Optional.empty(), - /* branch= */ Optional.empty(), - testInvocationId, - mockAgent, - session, - Optional.of(userContent), - runConfig, - /* endInvocation= */ false); + InvocationContext.builder() + .sessionService(mock(BaseSessionService.class)) // Different mock + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .pluginManager(pluginManager) + .invocationId(testInvocationId) + .agent(mockAgent) + .session(session) + .userContent(Optional.of(userContent)) + .runConfig(runConfig) + .endInvocation(false) + .build(); InvocationContext contextWithDiffInvocationId = - new InvocationContext( - mockSessionService, - mockArtifactService, - mockMemoryService, - pluginManager, - /* liveRequestQueue= */ Optional.empty(), - /* branch= */ Optional.empty(), - "another-id", // Different ID - mockAgent, - session, - Optional.of(userContent), - runConfig, - /* endInvocation= */ false); + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .pluginManager(pluginManager) + .invocationId("another-id") // Different ID + .agent(mockAgent) + .session(session) + .userContent(Optional.of(userContent)) + .runConfig(runConfig) + .endInvocation(false) + .build(); InvocationContext contextWithDiffAgent = - new InvocationContext( - mockSessionService, - mockArtifactService, - mockMemoryService, - pluginManager, - /* liveRequestQueue= */ Optional.empty(), - /* branch= */ Optional.empty(), - testInvocationId, - mock(BaseAgent.class), // Different mock - session, - Optional.of(userContent), - runConfig, - /* endInvocation= */ false); + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .pluginManager(pluginManager) + .invocationId(testInvocationId) + .agent(mock(BaseAgent.class)) // Different mock + .session(session) + .userContent(Optional.of(userContent)) + .runConfig(runConfig) + .endInvocation(false) + .build(); InvocationContext contextWithUserContentEmpty = - new InvocationContext( - mockSessionService, - mockArtifactService, - mockMemoryService, - pluginManager, - /* liveRequestQueue= */ Optional.empty(), - /* branch= */ Optional.empty(), - testInvocationId, - mockAgent, - session, - /* userContent= */ Optional.empty(), - runConfig, - /* endInvocation= */ false); + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .pluginManager(pluginManager) + .invocationId(testInvocationId) + .agent(mockAgent) + .session(session) + .userContent(Optional.empty()) + .runConfig(runConfig) + .endInvocation(false) + .build(); InvocationContext contextWithLiveQueuePresent = - new InvocationContext( - mockSessionService, - mockArtifactService, - mockMemoryService, - pluginManager, - Optional.of(liveRequestQueue), - /* branch= */ Optional.empty(), - InvocationContext.newInvocationContextId(), - mockAgent, - session, - /* userContent= */ Optional.empty(), - runConfig, - /* endInvocation= */ false); + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .pluginManager(pluginManager) + .liveRequestQueue(liveRequestQueue) + .agent(mockAgent) + .session(session) + .userContent(Optional.empty()) + .runConfig(runConfig) + .endInvocation(false) + .build(); assertThat(context.equals(contextWithDiffSessionService)).isFalse(); assertThat(context.equals(contextWithDiffInvocationId)).isFalse(); @@ -424,52 +414,199 @@ public void testEquals_differentValues() { @Test public void testHashCode_differentValues() { InvocationContext context = - new InvocationContext( - mockSessionService, - mockArtifactService, - mockMemoryService, - pluginManager, - /* liveRequestQueue= */ Optional.empty(), - /* branch= */ Optional.empty(), - testInvocationId, - mockAgent, - session, - Optional.of(userContent), - runConfig, - /* endInvocation= */ false); + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .pluginManager(pluginManager) + .invocationId(testInvocationId) + .agent(mockAgent) + .session(session) + .userContent(Optional.of(userContent)) + .runConfig(runConfig) + .endInvocation(false) + .build(); // Create contexts with one field different InvocationContext contextWithDiffSessionService = - new InvocationContext( - mock(BaseSessionService.class), // Different mock - mockArtifactService, - mockMemoryService, - pluginManager, - /* liveRequestQueue= */ Optional.empty(), - /* branch= */ Optional.empty(), - testInvocationId, - mockAgent, - session, - Optional.of(userContent), - runConfig, - /* endInvocation= */ false); + InvocationContext.builder() + .sessionService(mock(BaseSessionService.class)) // Different mock + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .pluginManager(pluginManager) + .invocationId(testInvocationId) + .agent(mockAgent) + .session(session) + .userContent(Optional.of(userContent)) + .runConfig(runConfig) + .endInvocation(false) + .build(); InvocationContext contextWithDiffInvocationId = - new InvocationContext( - mockSessionService, - mockArtifactService, - mockMemoryService, - pluginManager, - /* liveRequestQueue= */ Optional.empty(), - /* branch= */ Optional.empty(), - "another-id", // Different ID - mockAgent, - session, - Optional.of(userContent), - runConfig, - /* endInvocation= */ false); + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .pluginManager(pluginManager) + .invocationId("another-id") // Different ID + .agent(mockAgent) + .session(session) + .userContent(Optional.of(userContent)) + .runConfig(runConfig) + .endInvocation(false) + .build(); assertThat(context).isNotEqualTo(contextWithDiffSessionService); assertThat(context).isNotEqualTo(contextWithDiffInvocationId); } + + @Test + public void isResumable_whenResumabilityConfigIsNotResumable_isFalse() { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .agent(mockAgent) + .session(session) + .resumabilityConfig(new ResumabilityConfig(false)) + .build(); + assertThat(context.isResumable()).isFalse(); + } + + @Test + public void isResumable_whenResumabilityConfigIsResumable_isTrue() { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .agent(mockAgent) + .session(session) + .resumabilityConfig(new ResumabilityConfig(true)) + .build(); + assertThat(context.isResumable()).isTrue(); + } + + @Test + public void shouldPauseInvocation_whenNotResumable_isFalse() { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .agent(mockAgent) + .session(session) + .resumabilityConfig(new ResumabilityConfig(false)) + .build(); + Event event = + Event.builder() + .longRunningToolIds(Optional.of(ImmutableSet.of("fc1"))) + .content( + Content.builder() + .parts( + ImmutableList.of( + Part.builder() + .functionCall( + FunctionCall.builder().name("tool1").id("fc1").build()) + .build())) + .build()) + .build(); + assertThat(context.shouldPauseInvocation(event)).isFalse(); + } + + @Test + public void shouldPauseInvocation_whenResumableAndNoLongRunningToolIds_isFalse() { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .agent(mockAgent) + .session(session) + .resumabilityConfig(new ResumabilityConfig(true)) + .build(); + Event event = + Event.builder() + .content( + Content.builder() + .parts( + ImmutableList.of( + Part.builder() + .functionCall( + FunctionCall.builder().name("tool1").id("fc1").build()) + .build())) + .build()) + .build(); + assertThat(context.shouldPauseInvocation(event)).isFalse(); + } + + @Test + public void shouldPauseInvocation_whenResumableAndNoFunctionCalls_isFalse() { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .agent(mockAgent) + .session(session) + .resumabilityConfig(new ResumabilityConfig(true)) + .build(); + Event event = Event.builder().longRunningToolIds(Optional.of(ImmutableSet.of("fc1"))).build(); + assertThat(context.shouldPauseInvocation(event)).isFalse(); + } + + @Test + public void shouldPauseInvocation_whenResumableAndNoMatchingFunctionCallId_isFalse() { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .agent(mockAgent) + .session(session) + .resumabilityConfig(new ResumabilityConfig(true)) + .build(); + Event event = + Event.builder() + .longRunningToolIds(Optional.of(ImmutableSet.of("fc2"))) + .content( + Content.builder() + .parts( + ImmutableList.of( + Part.builder() + .functionCall( + FunctionCall.builder().name("tool1").id("fc1").build()) + .build())) + .build()) + .build(); + assertThat(context.shouldPauseInvocation(event)).isFalse(); + } + + @Test + public void shouldPauseInvocation_whenResumableAndMatchingFunctionCallId_isTrue() { + InvocationContext context = + InvocationContext.builder() + .sessionService(mockSessionService) + .artifactService(mockArtifactService) + .memoryService(mockMemoryService) + .agent(mockAgent) + .session(session) + .resumabilityConfig(new ResumabilityConfig(true)) + .build(); + Event event = + Event.builder() + .longRunningToolIds(Optional.of(ImmutableSet.of("fc1"))) + .content( + Content.builder() + .parts( + ImmutableList.of( + Part.builder() + .functionCall( + FunctionCall.builder().name("tool1").id("fc1").build()) + .build())) + .build()) + .build(); + assertThat(context.shouldPauseInvocation(event)).isTrue(); + } } diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index fc64d736..52218a0e 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -34,6 +34,7 @@ import com.google.adk.agents.LlmAgent; import com.google.adk.agents.RunConfig; import com.google.adk.events.Event; +import com.google.adk.flows.llmflows.ResumabilityConfig; import com.google.adk.models.LlmResponse; import com.google.adk.plugins.BasePlugin; import com.google.adk.sessions.Session; @@ -73,7 +74,8 @@ public final class RunnerTest { private final Content pluginContent = createContent("from plugin"); private final TestLlm testLlm = createTestLlm(createLlmResponse(createContent("from llm"))); private final LlmAgent agent = createTestAgentBuilder(testLlm).build(); - private final Runner runner = new InMemoryRunner(agent, "test", ImmutableList.of(plugin)); + private final Runner runner = + Runner.builder().agent(agent).appName("test").plugins(ImmutableList.of(plugin)).build(); private final Session session = runner.sessionService().createSession("test", "user").blockingGet(); private Tracer originalTracer; @@ -156,7 +158,12 @@ public void beforeRunCallback_multiplePluginsFirstOnly() { BasePlugin plugin2 = mockPlugin("test2"); when(plugin2.beforeRunCallback(any())).thenReturn(Maybe.empty()); - Runner runner = new InMemoryRunner(agent, "test", ImmutableList.of(plugin1, plugin2)); + Runner runner = + Runner.builder() + .agent(agent) + .appName("test") + .plugins(ImmutableList.of(plugin1, plugin2)) + .build(); Session session = runner.sessionService().createSession("test", "user").blockingGet(); var events = runner @@ -268,7 +275,8 @@ public void onModelErrorCallback_success() { TestLlm failingTestLlm = createTestLlm(Flowable.error(exception)); LlmAgent agent = createTestAgentBuilder(failingTestLlm).build(); - Runner runner = new InMemoryRunner(agent, "test", ImmutableList.of(plugin)); + Runner runner = + Runner.builder().agent(agent).appName("test").plugins(ImmutableList.of(plugin)).build(); Session session = runner.sessionService().createSession("test", "user").blockingGet(); var events = runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet(); @@ -286,7 +294,8 @@ public void onModelErrorCallback_error() { TestLlm failingTestLlm = createTestLlm(Flowable.error(exception)); LlmAgent agent = createTestAgentBuilder(failingTestLlm).build(); - Runner runner = new InMemoryRunner(agent, "test", ImmutableList.of(plugin)); + Runner runner = + Runner.builder().agent(agent).appName("test").plugins(ImmutableList.of(plugin)).build(); Session session = runner.sessionService().createSession("test", "user").blockingGet(); runner.runAsync("user", session.id(), createContent("from user")).test().assertError(exception); @@ -304,7 +313,8 @@ public void beforeToolCallback_success() { .tools(ImmutableList.of(failingEchoTool)) .build(); - Runner runner = new InMemoryRunner(agent, "test", ImmutableList.of(plugin)); + Runner runner = + Runner.builder().agent(agent).appName("test").plugins(ImmutableList.of(plugin)).build(); Session session = runner.sessionService().createSession("test", "user").blockingGet(); var events = runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet(); @@ -327,7 +337,8 @@ public void afterToolCallback_success() { LlmAgent agent = createTestAgentBuilder(testLlmWithFunctionCall).tools(ImmutableList.of(echoTool)).build(); - Runner runner = new InMemoryRunner(agent, "test", ImmutableList.of(plugin)); + Runner runner = + Runner.builder().agent(agent).appName("test").plugins(ImmutableList.of(plugin)).build(); Session session = runner.sessionService().createSession("test", "user").blockingGet(); var events = runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet(); @@ -352,7 +363,8 @@ public void onToolErrorCallback_success() { .tools(ImmutableList.of(failingEchoTool)) .build(); - Runner runner = new InMemoryRunner(agent, "test", ImmutableList.of(plugin)); + Runner runner = + Runner.builder().agent(agent).appName("test").plugins(ImmutableList.of(plugin)).build(); Session session = runner.sessionService().createSession("test", "user").blockingGet(); var events = runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet(); @@ -374,7 +386,8 @@ public void onToolErrorCallback_error() { .tools(ImmutableList.of(failingEchoTool)) .build(); - Runner runner = new InMemoryRunner(agent, "test", ImmutableList.of(plugin)); + Runner runner = + Runner.builder().agent(agent).appName("test").plugins(ImmutableList.of(plugin)).build(); Session session = runner.sessionService().createSession("test", "user").blockingGet(); runner .runAsync("user", session.id(), createContent("from user")) @@ -663,7 +676,7 @@ public void runLive_success() throws Exception { public void runLive_withToolExecution() throws Exception { LlmAgent agentWithTool = createTestAgentBuilder(testLlmWithFunctionCall).tools(ImmutableList.of(echoTool)).build(); - Runner runnerWithTool = new InMemoryRunner(agentWithTool, "test", ImmutableList.of()); + Runner runnerWithTool = Runner.builder().agent(agentWithTool).appName("test").build(); Session sessionWithTool = runnerWithTool.sessionService().createSession("test", "user").blockingGet(); LiveRequestQueue liveRequestQueue = new LiveRequestQueue(); @@ -690,7 +703,7 @@ public void runLive_llmError() throws Exception { Exception exception = new Exception("LLM test error"); TestLlm failingTestLlm = createTestLlm(Flowable.error(exception)); LlmAgent agent = createTestAgentBuilder(failingTestLlm).build(); - Runner runner = new InMemoryRunner(agent, "test", ImmutableList.of()); + Runner runner = Runner.builder().agent(agent).appName("test").build(); Session session = runner.sessionService().createSession("test", "user").blockingGet(); LiveRequestQueue liveRequestQueue = new LiveRequestQueue(); TestSubscriber testSubscriber = @@ -710,7 +723,7 @@ public void runLive_toolError() throws Exception { .tools(ImmutableList.of(failingEchoTool)) .build(); Runner runnerWithFailingTool = - new InMemoryRunner(agentWithFailingTool, "test", ImmutableList.of()); + Runner.builder().agent(agentWithFailingTool).appName("test").build(); Session sessionWithFailingTool = runnerWithFailingTool.sessionService().createSession("test", "user").blockingGet(); LiveRequestQueue liveRequestQueue = new LiveRequestQueue(); @@ -742,4 +755,40 @@ public void runLive_createsInvocationSpan() { assertThat(invocationSpan).isPresent(); assertThat(invocationSpan.get().hasEnded()).isTrue(); } + + @Test + public void resumabilityConfig_isResumable_isTrueInInvocationContext() { + ArgumentCaptor contextCaptor = + ArgumentCaptor.forClass(InvocationContext.class); + when(plugin.beforeRunCallback(contextCaptor.capture())).thenReturn(Maybe.empty()); + Runner runner = + Runner.builder() + .agent(agent) + .appName("test") + .plugins(ImmutableList.of(plugin)) + .resumabilityConfig(new ResumabilityConfig(true)) + .build(); + Session session = runner.sessionService().createSession("test", "user").blockingGet(); + var unused = + runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet(); + assertThat(contextCaptor.getValue().isResumable()).isTrue(); + } + + @Test + public void resumabilityConfig_isNotResumable_isFalseInInvocationContext() { + ArgumentCaptor contextCaptor = + ArgumentCaptor.forClass(InvocationContext.class); + when(plugin.beforeRunCallback(contextCaptor.capture())).thenReturn(Maybe.empty()); + Runner runner = + Runner.builder() + .agent(agent) + .appName("test") + .plugins(ImmutableList.of(plugin)) + .resumabilityConfig(new ResumabilityConfig(false)) + .build(); + Session session = runner.sessionService().createSession("test", "user").blockingGet(); + var unused = + runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet(); + assertThat(contextCaptor.getValue().isResumable()).isFalse(); + } } diff --git a/core/src/test/java/com/google/adk/tools/LongRunningFunctionToolTest.java b/core/src/test/java/com/google/adk/tools/LongRunningFunctionToolTest.java index 4c1823e9..1b19eb7c 100644 --- a/core/src/test/java/com/google/adk/tools/LongRunningFunctionToolTest.java +++ b/core/src/test/java/com/google/adk/tools/LongRunningFunctionToolTest.java @@ -3,14 +3,19 @@ import static com.google.common.truth.Truth.assertThat; import com.google.adk.agents.LlmAgent; +import com.google.adk.artifacts.InMemoryArtifactService; import com.google.adk.events.Event; +import com.google.adk.flows.llmflows.ResumabilityConfig; +import com.google.adk.memory.InMemoryMemoryService; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; -import com.google.adk.runner.InMemoryRunner; +import com.google.adk.runner.Runner; +import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.adk.testing.TestLlm; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.Keep; import com.google.genai.types.Content; import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; @@ -27,26 +32,31 @@ import org.junit.runners.JUnit4; @RunWith(JUnit4.class) +// TODO: Add test for raw string return style. FunctionTool currently only supports a map. We need +// to add functionality of returning raw string. public final class LongRunningFunctionToolTest { - // TODO: Add test for raw string return style. - // FunctionTool currently only supports a map. We need to add functionality of returning raw - // string. private TestLlm testLlm; private LlmAgent agent; - private InMemoryRunner runner; + private Runner runner; private Session session; + private InMemorySessionService sessionService; + private InMemoryArtifactService artifactService; + private InMemoryMemoryService memoryService; @Before public void setUp() { TestFunctions.reset(); + sessionService = new InMemorySessionService(); + artifactService = new InMemoryArtifactService(); + memoryService = new InMemoryMemoryService(); } @Test public void asyncFunction_handlesPendingAndResults() throws Exception { FunctionTool longRunningTool = LongRunningFunctionTool.create( - TestFunctions.class.getMethod("increaseByOne", int.class, ToolContext.class)); + null, TestFunctions.class.getMethod("increaseByOne", int.class, ToolContext.class)); FunctionCall modelRequestForFunctionCall = FunctionCall.builder().name("increase_by_one").args(ImmutableMap.of("x", 1)).build(); @@ -66,11 +76,11 @@ public void asyncFunction_handlesPendingAndResults() throws Exception { Content firstUserContent = Content.fromParts(Part.fromText("test1")); - ImmutableMap expectedInitialToolResponseMap; - expectedInitialToolResponseMap = ImmutableMap.of("status", "pending"); - assertInitialInteractionAndEvents( - firstUserContent, modelRequestForFunctionCall, expectedInitialToolResponseMap, "response1"); + firstUserContent, + modelRequestForFunctionCall, + ImmutableMap.of("status", "pending"), + "response1"); assertSubsequentInteraction( "increase_by_one", ImmutableMap.of("status", "still waiting"), "response2", 3); @@ -83,12 +93,14 @@ public void asyncFunction_handlesPendingAndResults() throws Exception { } private static class TestFunctions { - static AtomicInteger functionCalledCount = new AtomicInteger(0); + static final AtomicInteger functionCalledCount = new AtomicInteger(0); static void reset() { functionCalledCount.set(0); } + @Keep // Keep this function to avoid unused function warning. + @SuppressWarnings("unused") // Suppress unused warning for test parameters. @Annotations.Schema(name = "increase_by_one", description = "Test func: increases by one") public static Maybe> increaseByOne(int x, ToolContext toolContext) { functionCalledCount.incrementAndGet(); @@ -126,7 +138,15 @@ private void setUpAgentAndRunner( .tools(ImmutableList.of(tool)) .description(description) .build(); - runner = new InMemoryRunner(agent, "test-user"); + runner = + new Runner( + agent, + "test_app", + artifactService, + sessionService, + memoryService, + null, + new ResumabilityConfig(false)); session = runner .sessionService()