|
26 | 26 | import com.google.adk.artifacts.InMemoryArtifactService; |
27 | 27 | import com.google.adk.events.Event; |
28 | 28 | import com.google.adk.models.LlmRequest; |
| 29 | +import com.google.adk.models.Model; |
29 | 30 | import com.google.adk.sessions.InMemorySessionService; |
30 | 31 | import com.google.adk.sessions.Session; |
31 | 32 | import com.google.common.collect.ImmutableList; |
|
42 | 43 | import org.junit.Test; |
43 | 44 | import org.junit.runner.RunWith; |
44 | 45 | import org.junit.runners.JUnit4; |
| 46 | +import org.mockito.Mockito; |
45 | 47 |
|
46 | 48 | /** Unit tests for {@link Contents}. */ |
47 | 49 | @RunWith(JUnit4.class) |
@@ -464,6 +466,30 @@ public void processRequest_sequentialFCFR_returnsOriginalList() { |
464 | 466 | assertThat(result).isEqualTo(eventsToContents(inputEvents)); |
465 | 467 | } |
466 | 468 |
|
| 469 | + @Test |
| 470 | + public void rearrangeHistory_gemini3interleavedFCFR_groupsFcAndFr() { |
| 471 | + Event u1 = createUserEvent("u1", "Query"); |
| 472 | + Event fc1 = createFunctionCallEvent("fc1", "tool1", "call1"); |
| 473 | + Event fr1 = createFunctionResponseEvent("fr1", "tool1", "call1"); |
| 474 | + Event fc2 = createFunctionCallEvent("fc2", "tool2", "call2"); |
| 475 | + Event fr2 = createFunctionResponseEvent("fr2", "tool2", "call2"); |
| 476 | + |
| 477 | + ImmutableList<Event> inputEvents = ImmutableList.of(u1, fc1, fr1, fc2, fr2); |
| 478 | + |
| 479 | + List<Content> result = runContentsProcessorWithModelName(inputEvents, "gemini-3-flash-exp"); |
| 480 | + |
| 481 | + assertThat(result).hasSize(4); |
| 482 | + assertThat(result.get(0)).isEqualTo(u1.content().get()); |
| 483 | + assertThat(result.get(1)).isEqualTo(fc1.content().get()); |
| 484 | + assertThat(result.get(2)).isEqualTo(fc2.content().get()); |
| 485 | + Content mergedContent = result.get(3); |
| 486 | + assertThat(mergedContent.parts().get()).hasSize(2); |
| 487 | + assertThat(mergedContent.parts().get().get(0).functionResponse().get().name()) |
| 488 | + .hasValue("tool1"); |
| 489 | + assertThat(mergedContent.parts().get().get(1).functionResponse().get().name()) |
| 490 | + .hasValue("tool2"); |
| 491 | + } |
| 492 | + |
467 | 493 | private static Event createUserEvent(String id, String text) { |
468 | 494 | return Event.builder() |
469 | 495 | .id(id) |
@@ -628,6 +654,38 @@ private List<Content> runContentsProcessorWithIncludeContents( |
628 | 654 | return result.updatedRequest().contents(); |
629 | 655 | } |
630 | 656 |
|
| 657 | + private List<Content> runContentsProcessorWithModelName(List<Event> events, String modelName) { |
| 658 | + LlmAgent agent = |
| 659 | + Mockito.spy( |
| 660 | + LlmAgent.builder() |
| 661 | + .name(AGENT) |
| 662 | + .includeContents(LlmAgent.IncludeContents.DEFAULT) |
| 663 | + .build()); |
| 664 | + Model model = Model.builder().modelName(modelName).build(); |
| 665 | + Mockito.doReturn(model).when(agent).resolvedModel(); |
| 666 | + |
| 667 | + Session session = |
| 668 | + Session.builder("test-session") |
| 669 | + .appName("test-app") |
| 670 | + .userId("test-user") |
| 671 | + .events(new ArrayList<>(events)) |
| 672 | + .build(); |
| 673 | + InvocationContext context = |
| 674 | + InvocationContext.create( |
| 675 | + new InMemorySessionService(), |
| 676 | + new InMemoryArtifactService(), |
| 677 | + "test-invocation", |
| 678 | + agent, |
| 679 | + session, |
| 680 | + /* userContent= */ null, |
| 681 | + RunConfig.builder().build()); |
| 682 | + |
| 683 | + LlmRequest initialRequest = LlmRequest.builder().build(); |
| 684 | + RequestProcessor.RequestProcessingResult result = |
| 685 | + contentsProcessor.processRequest(context, initialRequest).blockingGet(); |
| 686 | + return result.updatedRequest().contents(); |
| 687 | + } |
| 688 | + |
631 | 689 | private static ImmutableList<Content> eventsToContents(List<Event> events) { |
632 | 690 | return events.stream() |
633 | 691 | .map(Event::content) |
|
0 commit comments