Skip to content

Commit c76557a

Browse files
committed
feat: add missing tool resolution strategy
1 parent 493c064 commit c76557a

File tree

5 files changed

+128
-10
lines changed

5 files changed

+128
-10
lines changed

core/src/main/java/com/google/adk/agents/InvocationContext.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ public InvocationContext(
8686
"this(sessionService, artifactService, memoryService, new"
8787
+ " PluginManager(), liveRequestQueue, branch, invocationId, agent,"
8888
+ " session, userContent, runConfig, endInvocation)",
89-
imports = "com.google.adk.plugins.PluginManager")
89+
imports = {
90+
"com.google.adk.plugins.PluginManager",
91+
})
9092
@Deprecated
9193
public InvocationContext(
9294
BaseSessionService sessionService,

core/src/main/java/com/google/adk/agents/RunConfig.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package com.google.adk.agents;
1818

19+
import com.google.adk.tools.MissingToolResolutionStrategy;
1920
import com.google.auto.value.AutoValue;
2021
import com.google.common.collect.ImmutableList;
2122
import com.google.errorprone.annotations.CanIgnoreReturnValue;
@@ -70,6 +71,8 @@ public enum ToolExecutionMode {
7071

7172
public abstract int maxLlmCalls();
7273

74+
public abstract MissingToolResolutionStrategy missingToolResolutionStrategy();
75+
7376
public abstract Builder toBuilder();
7477

7578
public static Builder builder() {
@@ -78,6 +81,7 @@ public static Builder builder() {
7881
.setResponseModalities(ImmutableList.of())
7982
.setStreamingMode(StreamingMode.NONE)
8083
.setToolExecutionMode(ToolExecutionMode.NONE)
84+
.setMissingToolResolutionStrategy(MissingToolResolutionStrategy.THROW_EXCEPTION)
8185
.setMaxLlmCalls(500);
8286
}
8387

@@ -90,7 +94,8 @@ public static Builder builder(RunConfig runConfig) {
9094
.setResponseModalities(runConfig.responseModalities())
9195
.setSpeechConfig(runConfig.speechConfig())
9296
.setOutputAudioTranscription(runConfig.outputAudioTranscription())
93-
.setInputAudioTranscription(runConfig.inputAudioTranscription());
97+
.setInputAudioTranscription(runConfig.inputAudioTranscription())
98+
.setMissingToolResolutionStrategy(runConfig.missingToolResolutionStrategy());
9499
}
95100

96101
/** Builder for {@link RunConfig}. */
@@ -123,6 +128,10 @@ public abstract Builder setInputAudioTranscription(
123128
@CanIgnoreReturnValue
124129
public abstract Builder setMaxLlmCalls(int maxLlmCalls);
125130

131+
@CanIgnoreReturnValue
132+
public abstract Builder setMissingToolResolutionStrategy(
133+
MissingToolResolutionStrategy missingToolResolutionStrategy);
134+
126135
abstract RunConfig autoBuild();
127136

128137
public RunConfig build() {

core/src/main/java/com/google/adk/flows/llmflows/Functions.java

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@
3030
import com.google.adk.events.EventActions;
3131
import com.google.adk.tools.BaseTool;
3232
import com.google.adk.tools.FunctionTool;
33+
import com.google.adk.tools.MissingToolResolutionStrategy;
3334
import com.google.adk.tools.ToolConfirmation;
3435
import com.google.adk.tools.ToolContext;
35-
import com.google.common.base.VerifyException;
3636
import com.google.common.collect.ImmutableList;
3737
import com.google.common.collect.ImmutableMap;
3838
import com.google.genai.types.Content;
@@ -127,7 +127,7 @@ public static void populateClientFunctionCallId(Event modelResponseEvent) {
127127
/** Handles standard, non-streaming function calls. */
128128
public static Maybe<Event> handleFunctionCalls(
129129
InvocationContext invocationContext, Event functionCallEvent, Map<String, BaseTool> tools) {
130-
return handleFunctionCalls(invocationContext, functionCallEvent, tools, ImmutableMap.of());
130+
return handleFunctionCalls(invocationContext, functionCallEvent, tools);
131131
}
132132

133133
/** Handles standard, non-streaming function calls with tool confirmations. */
@@ -137,10 +137,13 @@ public static Maybe<Event> handleFunctionCalls(
137137
Map<String, BaseTool> tools,
138138
Map<String, ToolConfirmation> toolConfirmations) {
139139
ImmutableList<FunctionCall> functionCalls = functionCallEvent.functionCalls();
140-
140+
MissingToolResolutionStrategy missingToolResolutionStrategy =
141+
invocationContext.runConfig().missingToolResolutionStrategy();
142+
ImmutableList.Builder<Maybe<Event>> missingTools = ImmutableList.builder();
141143
for (FunctionCall functionCall : functionCalls) {
142144
if (!tools.containsKey(functionCall.name().get())) {
143-
throw new VerifyException("Tool not found: " + functionCall.name().get());
145+
missingTools.add(
146+
missingToolResolutionStrategy.onMissingTool(invocationContext, functionCall));
144147
}
145148
}
146149

@@ -207,7 +210,11 @@ public static Maybe<Event> handleFunctionCalls(
207210
functionResponseEventsFlowable =
208211
Flowable.fromIterable(functionCalls).flatMapMaybe(functionCallMapper);
209212
}
210-
return functionResponseEventsFlowable
213+
Flowable<Event> missingToolsFlowable =
214+
Flowable.fromIterable(missingTools.build()).concatMapMaybe(maybe -> maybe);
215+
Flowable<Event> allEventsFlowable =
216+
Flowable.concat(missingToolsFlowable, functionResponseEventsFlowable);
217+
return allEventsFlowable
211218
.toList()
212219
.flatMapMaybe(
213220
events -> {
@@ -242,10 +249,14 @@ public static Maybe<Event> handleFunctionCalls(
242249
public static Maybe<Event> handleFunctionCallsLive(
243250
InvocationContext invocationContext, Event functionCallEvent, Map<String, BaseTool> tools) {
244251
ImmutableList<FunctionCall> functionCalls = functionCallEvent.functionCalls();
252+
MissingToolResolutionStrategy missingToolResolutionStrategy =
253+
invocationContext.runConfig().missingToolResolutionStrategy();
245254

255+
ImmutableList.Builder<Maybe<Event>> missingTools = ImmutableList.builder();
246256
for (FunctionCall functionCall : functionCalls) {
247257
if (!tools.containsKey(functionCall.name().get())) {
248-
throw new VerifyException("Tool not found: " + functionCall.name().get());
258+
missingTools.add(
259+
missingToolResolutionStrategy.onMissingTool(invocationContext, functionCall));
249260
}
250261
}
251262

@@ -311,7 +322,6 @@ public static Maybe<Event> handleFunctionCallsLive(
311322
};
312323

313324
Flowable<Event> responseEventsFlowable;
314-
315325
if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) {
316326
responseEventsFlowable =
317327
Flowable.fromIterable(functionCalls).concatMapMaybe(functionCallMapper);
@@ -320,8 +330,12 @@ public static Maybe<Event> handleFunctionCallsLive(
320330
responseEventsFlowable =
321331
Flowable.fromIterable(functionCalls).flatMapMaybe(functionCallMapper);
322332
}
333+
Flowable<Event> missingToolsFlowable =
334+
Flowable.fromIterable(missingTools.build()).concatMapMaybe(maybe -> maybe);
335+
Flowable<Event> allEventsFlowable =
336+
Flowable.concat(missingToolsFlowable, responseEventsFlowable);
323337

324-
return responseEventsFlowable
338+
return allEventsFlowable
325339
.toList()
326340
.flatMapMaybe(
327341
events -> {
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package com.google.adk.tools;
2+
3+
import com.google.adk.agents.InvocationContext;
4+
import com.google.adk.events.Event;
5+
import com.google.common.base.VerifyException;
6+
import com.google.genai.types.FunctionCall;
7+
import io.reactivex.rxjava3.core.Maybe;
8+
import java.util.function.BiFunction;
9+
10+
public interface MissingToolResolutionStrategy {
11+
public static final MissingToolResolutionStrategy THROW_EXCEPTION =
12+
new MissingToolResolutionStrategy() {
13+
@Override
14+
public Maybe<Event> onMissingTool(
15+
InvocationContext invocationContext, FunctionCall functionCall) {
16+
throw new VerifyException(
17+
"Tool not found: " + functionCall.name().orElse(functionCall.toJson()));
18+
}
19+
};
20+
21+
public static final MissingToolResolutionStrategy RETURN_ERROR =
22+
new MissingToolResolutionStrategy() {
23+
@Override
24+
public Maybe<Event> onMissingTool(
25+
InvocationContext invocationContext, FunctionCall functionCall) {
26+
return Maybe.error(
27+
new VerifyException(
28+
"Tool not found: " + functionCall.name().orElse(functionCall.toJson())));
29+
}
30+
};
31+
32+
public static final MissingToolResolutionStrategy IGNORE =
33+
new MissingToolResolutionStrategy() {
34+
@Override
35+
public Maybe<Event> onMissingTool(
36+
InvocationContext invocationContext, FunctionCall functionCall) {
37+
return Maybe.empty();
38+
}
39+
};
40+
41+
public static MissingToolResolutionStrategy respondWithEvent(
42+
BiFunction<InvocationContext, FunctionCall, Maybe<Event>> eventFactory) {
43+
return new MissingToolResolutionStrategy() {
44+
@Override
45+
public Maybe<Event> onMissingTool(
46+
InvocationContext invocationContext, FunctionCall functionCall) {
47+
return eventFactory.apply(invocationContext, functionCall);
48+
}
49+
};
50+
}
51+
52+
public static MissingToolResolutionStrategy respondWithEventSync(
53+
BiFunction<InvocationContext, FunctionCall, Event> eventFactory) {
54+
return respondWithEvent(
55+
(invocationContext, functionCall) ->
56+
Maybe.just(eventFactory.apply(invocationContext, functionCall)));
57+
}
58+
59+
Maybe<Event> onMissingTool(InvocationContext invocationContext, FunctionCall functionCall);
60+
}

core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@
2323
import static org.junit.Assert.assertThrows;
2424

2525
import com.google.adk.agents.InvocationContext;
26+
import com.google.adk.agents.RunConfig;
2627
import com.google.adk.events.Event;
2728
import com.google.adk.testing.TestUtils;
29+
import com.google.adk.tools.MissingToolResolutionStrategy;
2830
import com.google.common.collect.ImmutableList;
2931
import com.google.common.collect.ImmutableMap;
3032
import com.google.genai.types.Content;
@@ -67,6 +69,37 @@ public void handleFunctionCalls_missingTool() {
6769
invocationContext, event, /* tools= */ ImmutableMap.of()));
6870
}
6971

72+
@Test
73+
public void handleFunctionCalls_missingTool_recoveryStrategy() {
74+
InvocationContext invocationContext =
75+
createInvocationContext(
76+
createRootAgent(),
77+
RunConfig.builder()
78+
.setMissingToolResolutionStrategy(
79+
MissingToolResolutionStrategy.respondWithEventSync(
80+
(ctx, call) ->
81+
Event.builder()
82+
.content(
83+
Content.fromParts(
84+
Part.fromText("tool missing: " + call.name().get())))
85+
.build()))
86+
.build());
87+
Event event =
88+
createEvent("event").toBuilder()
89+
.content(
90+
Content.fromParts(
91+
Part.fromText("..."), Part.fromFunctionCall("missing_tool", ImmutableMap.of())))
92+
.build();
93+
94+
Event functionResponseEvent =
95+
Functions.handleFunctionCalls(invocationContext, event, /* tools= */ ImmutableMap.of())
96+
.blockingGet();
97+
98+
assertThat(functionResponseEvent).isNotNull();
99+
assertThat(functionResponseEvent.content().get().parts().get())
100+
.containsExactly(Part.fromText("tool missing: missing_tool"));
101+
}
102+
70103
@Test
71104
public void handleFunctionCalls_singleFunctionCall() {
72105
InvocationContext invocationContext = createInvocationContext(createRootAgent());

0 commit comments

Comments
 (0)