Skip to content

Commit 54b6e92

Browse files
author
lucas
committed
Enhance AgentTool to support custom services and plugins
1 parent e6755a2 commit 54b6e92

File tree

2 files changed

+137
-6
lines changed

2 files changed

+137
-6
lines changed

core/src/main/java/com/google/adk/tools/AgentTool.java

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,23 @@
2121
import com.google.adk.SchemaUtils;
2222
import com.google.adk.agents.BaseAgent;
2323
import com.google.adk.agents.LlmAgent;
24+
import com.google.adk.artifacts.BaseArtifactService;
25+
import com.google.adk.artifacts.InMemoryArtifactService;
2426
import com.google.adk.events.Event;
25-
import com.google.adk.runner.InMemoryRunner;
27+
import com.google.adk.memory.BaseMemoryService;
28+
import com.google.adk.memory.InMemoryMemoryService;
29+
import com.google.adk.plugins.BasePlugin;
2630
import com.google.adk.runner.Runner;
31+
import com.google.adk.sessions.BaseSessionService;
32+
import com.google.adk.sessions.InMemorySessionService;
2733
import com.google.common.collect.ImmutableList;
2834
import com.google.common.collect.ImmutableMap;
2935
import com.google.genai.types.Content;
3036
import com.google.genai.types.FunctionDeclaration;
3137
import com.google.genai.types.Part;
3238
import com.google.genai.types.Schema;
3339
import io.reactivex.rxjava3.core.Single;
40+
import java.util.List;
3441
import java.util.Map;
3542
import java.util.Optional;
3643

@@ -39,19 +46,42 @@ public class AgentTool extends BaseTool {
3946

4047
private final BaseAgent agent;
4148
private final boolean skipSummarization;
49+
private final List<BasePlugin> plugins;
50+
private final BaseSessionService sessionService;
51+
private final BaseArtifactService artifactService;
52+
private final BaseMemoryService memoryService;
53+
54+
public static AgentTool create(
55+
BaseAgent agent,
56+
BaseSessionService sessionService,
57+
BaseArtifactService artifactService,
58+
BaseMemoryService memoryService,
59+
List<BasePlugin> plugins) {
60+
return new AgentTool(agent, false, sessionService, artifactService, memoryService, plugins);
61+
}
4262

4363
public static AgentTool create(BaseAgent agent, boolean skipSummarization) {
44-
return new AgentTool(agent, skipSummarization);
64+
return new AgentTool(agent, skipSummarization, null, null, null, ImmutableList.of());
4565
}
4666

4767
public static AgentTool create(BaseAgent agent) {
48-
return new AgentTool(agent, false);
68+
return new AgentTool(agent, false, null, null, null, ImmutableList.of());
4969
}
5070

51-
protected AgentTool(BaseAgent agent, boolean skipSummarization) {
71+
protected AgentTool(
72+
BaseAgent agent,
73+
boolean skipSummarization,
74+
BaseSessionService sessionService,
75+
BaseArtifactService artifactService,
76+
BaseMemoryService memoryService,
77+
List<BasePlugin> plugins) {
5278
super(agent.name(), agent.description());
5379
this.agent = agent;
5480
this.skipSummarization = skipSummarization;
81+
this.sessionService = sessionService;
82+
this.artifactService = artifactService;
83+
this.memoryService = memoryService;
84+
this.plugins = plugins != null ? plugins : ImmutableList.of();
5585
}
5686

5787
@Override
@@ -104,12 +134,34 @@ public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContex
104134
content = Content.fromParts(Part.fromText(input.toString()));
105135
}
106136

107-
Runner runner = new InMemoryRunner(this.agent, toolContext.agentName());
137+
// Determine effective services: use injected singletons if present, otherwise create fresh
138+
// instances per run (default behavior)
139+
BaseSessionService effectiveSessionService =
140+
this.sessionService != null ? this.sessionService : new InMemorySessionService();
141+
BaseArtifactService effectiveArtifactService =
142+
this.artifactService != null ? this.artifactService : new InMemoryArtifactService();
143+
BaseMemoryService effectiveMemoryService =
144+
this.memoryService != null ? this.memoryService : new InMemoryMemoryService();
145+
146+
Runner runner =
147+
new Runner(
148+
this.agent,
149+
toolContext.agentName(),
150+
effectiveArtifactService,
151+
effectiveSessionService,
152+
effectiveMemoryService,
153+
this.plugins);
154+
155+
String userId = "tmp-user";
156+
if (toolContext.userId() != null) {
157+
userId = toolContext.userId();
158+
}
159+
108160
// Session state is final, can't update to toolContext state
109161
// session.toBuilder().setState(toolContext.getState());
110162
return runner
111163
.sessionService()
112-
.createSession(toolContext.agentName(), "tmp-user", toolContext.state(), null)
164+
.createSession(toolContext.agentName(), userId, toolContext.state(), null)
113165
.flatMapPublisher(session -> runner.runAsync(session.userId(), session.id(), content))
114166
.lastElement()
115167
.map(Optional::of)

core/src/test/java/com/google/adk/tools/AgentToolTest.java

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@
2323

2424
import com.google.adk.agents.InvocationContext;
2525
import com.google.adk.agents.LlmAgent;
26+
import com.google.adk.artifacts.InMemoryArtifactService;
27+
import com.google.adk.memory.InMemoryMemoryService;
2628
import com.google.adk.models.LlmResponse;
29+
import com.google.adk.plugins.BasePlugin;
30+
import com.google.adk.sessions.InMemorySessionService;
31+
import com.google.adk.sessions.ListSessionsResponse;
2732
import com.google.adk.sessions.Session;
2833
import com.google.adk.testing.TestLlm;
2934
import com.google.common.collect.ImmutableList;
@@ -33,8 +38,10 @@
3338
import com.google.genai.types.Part;
3439
import com.google.genai.types.Schema;
3540
import io.reactivex.rxjava3.core.Flowable;
41+
import io.reactivex.rxjava3.core.Maybe;
3642
import java.util.Map;
3743
import java.util.Optional;
44+
import java.util.concurrent.atomic.AtomicBoolean;
3845
import org.junit.Test;
3946
import org.junit.runner.RunWith;
4047
import org.junit.runners.JUnit4;
@@ -344,6 +351,78 @@ public void call_withoutInputSchema_requestIsSentToAgent() throws Exception {
344351
.containsExactly(Content.fromParts(Part.fromText("magic")));
345352
}
346353

354+
@Test
355+
public void create_withServicesAndPlugins_initializesCorrectly() {
356+
LlmAgent testAgent =
357+
createTestAgentBuilder(createTestLlm(LlmResponse.builder().build()))
358+
.name("agent name")
359+
.description("agent description")
360+
.build();
361+
362+
AgentTool agentTool =
363+
AgentTool.create(
364+
testAgent,
365+
new InMemorySessionService(),
366+
new InMemoryArtifactService(),
367+
new InMemoryMemoryService(),
368+
ImmutableList.of());
369+
370+
assertThat(agentTool).isNotNull();
371+
assertThat(agentTool.declaration()).isPresent();
372+
}
373+
374+
@Test
375+
public void runAsync_withServicesAndPlugins_usesThem() {
376+
LlmAgent testAgent =
377+
createTestAgentBuilder(
378+
createTestLlm(
379+
LlmResponse.builder()
380+
.content(Content.fromParts(Part.fromText("Sub-agent executed")))
381+
.build()))
382+
.name("sub-agent")
383+
.description("sub-agent description")
384+
.build();
385+
386+
InMemorySessionService sessionService = new InMemorySessionService();
387+
TestPlugin testPlugin = new TestPlugin();
388+
389+
AgentTool agentTool =
390+
AgentTool.create(
391+
testAgent,
392+
sessionService,
393+
new InMemoryArtifactService(),
394+
new InMemoryMemoryService(),
395+
ImmutableList.of(testPlugin));
396+
397+
ToolContext toolContext = createToolContext(testAgent);
398+
399+
Map<String, Object> result =
400+
agentTool.runAsync(ImmutableMap.of("request", "start"), toolContext).blockingGet();
401+
402+
assertThat(result).containsEntry("result", "Sub-agent executed");
403+
404+
assertThat(testPlugin.wasCalled.get()).isTrue();
405+
406+
ListSessionsResponse sessionsResponse =
407+
sessionService.listSessions("sub-agent", "tmp-user").blockingGet();
408+
assertThat(sessionsResponse.sessions()).isNotEmpty();
409+
}
410+
411+
private static class TestPlugin extends BasePlugin {
412+
final AtomicBoolean wasCalled = new AtomicBoolean(false);
413+
414+
TestPlugin() {
415+
super("test-plugin");
416+
}
417+
418+
@Override
419+
public Maybe<Content> onUserMessageCallback(
420+
InvocationContext invocationContext, Content userMessage) {
421+
wasCalled.set(true);
422+
return Maybe.empty();
423+
}
424+
}
425+
347426
private static ToolContext createToolContext(LlmAgent agent) {
348427
return ToolContext.builder(
349428
new InvocationContext(

0 commit comments

Comments
 (0)