Skip to content

Commit 493c064

Browse files
google-genai-botcopybara-github
authored andcommitted
test: add test for rearranging history for gemini3 interleaved
PiperOrigin-RevId: 837534831
1 parent 8628ebb commit 493c064

File tree

2 files changed

+59
-1
lines changed

2 files changed

+59
-1
lines changed

core/src/main/java/com/google/adk/flows/llmflows/Contents.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ private static List<Event> rearrangeEventsForAsyncFunctionResponsesInHistory(
376376

377377
// Gemini 3 requires function calls to be grouped first and only then function responses:
378378
// FC1 FC2 FR1 FR2
379-
boolean shouldBufferResponseEvents = modelName.startsWith("gemini-3");
379+
boolean shouldBufferResponseEvents = modelName.startsWith("gemini-3-");
380380

381381
for (int i = 0; i < events.size(); i++) {
382382
Event event = events.get(i);

core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import com.google.adk.artifacts.InMemoryArtifactService;
2727
import com.google.adk.events.Event;
2828
import com.google.adk.models.LlmRequest;
29+
import com.google.adk.models.Model;
2930
import com.google.adk.sessions.InMemorySessionService;
3031
import com.google.adk.sessions.Session;
3132
import com.google.common.collect.ImmutableList;
@@ -42,6 +43,7 @@
4243
import org.junit.Test;
4344
import org.junit.runner.RunWith;
4445
import org.junit.runners.JUnit4;
46+
import org.mockito.Mockito;
4547

4648
/** Unit tests for {@link Contents}. */
4749
@RunWith(JUnit4.class)
@@ -464,6 +466,30 @@ public void processRequest_sequentialFCFR_returnsOriginalList() {
464466
assertThat(result).isEqualTo(eventsToContents(inputEvents));
465467
}
466468

469+
@Test
470+
public void rearrangeHistory_gemini3interleavedFCFR_groupsFcAndFr() {
471+
Event u1 = createUserEvent("u1", "Query");
472+
Event fc1 = createFunctionCallEvent("fc1", "tool1", "call1");
473+
Event fr1 = createFunctionResponseEvent("fr1", "tool1", "call1");
474+
Event fc2 = createFunctionCallEvent("fc2", "tool2", "call2");
475+
Event fr2 = createFunctionResponseEvent("fr2", "tool2", "call2");
476+
477+
ImmutableList<Event> inputEvents = ImmutableList.of(u1, fc1, fr1, fc2, fr2);
478+
479+
List<Content> result = runContentsProcessorWithModelName(inputEvents, "gemini-3-flash-exp");
480+
481+
assertThat(result).hasSize(4);
482+
assertThat(result.get(0)).isEqualTo(u1.content().get());
483+
assertThat(result.get(1)).isEqualTo(fc1.content().get());
484+
assertThat(result.get(2)).isEqualTo(fc2.content().get());
485+
Content mergedContent = result.get(3);
486+
assertThat(mergedContent.parts().get()).hasSize(2);
487+
assertThat(mergedContent.parts().get().get(0).functionResponse().get().name())
488+
.hasValue("tool1");
489+
assertThat(mergedContent.parts().get().get(1).functionResponse().get().name())
490+
.hasValue("tool2");
491+
}
492+
467493
private static Event createUserEvent(String id, String text) {
468494
return Event.builder()
469495
.id(id)
@@ -628,6 +654,38 @@ private List<Content> runContentsProcessorWithIncludeContents(
628654
return result.updatedRequest().contents();
629655
}
630656

657+
private List<Content> runContentsProcessorWithModelName(List<Event> events, String modelName) {
658+
LlmAgent agent =
659+
Mockito.spy(
660+
LlmAgent.builder()
661+
.name(AGENT)
662+
.includeContents(LlmAgent.IncludeContents.DEFAULT)
663+
.build());
664+
Model model = Model.builder().modelName(modelName).build();
665+
Mockito.doReturn(model).when(agent).resolvedModel();
666+
667+
Session session =
668+
Session.builder("test-session")
669+
.appName("test-app")
670+
.userId("test-user")
671+
.events(new ArrayList<>(events))
672+
.build();
673+
InvocationContext context =
674+
InvocationContext.create(
675+
new InMemorySessionService(),
676+
new InMemoryArtifactService(),
677+
"test-invocation",
678+
agent,
679+
session,
680+
/* userContent= */ null,
681+
RunConfig.builder().build());
682+
683+
LlmRequest initialRequest = LlmRequest.builder().build();
684+
RequestProcessor.RequestProcessingResult result =
685+
contentsProcessor.processRequest(context, initialRequest).blockingGet();
686+
return result.updatedRequest().contents();
687+
}
688+
631689
private static ImmutableList<Content> eventsToContents(List<Event> events) {
632690
return events.stream()
633691
.map(Event::content)

0 commit comments

Comments
 (0)