Skip to content

Commit 36ff663

Browse files
committed
Add RAG test using QueryRouter and ContentInjector
1 parent 134f27f commit 36ff663

File tree

2 files changed

+125
-0
lines changed

2 files changed

+125
-0
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
package org.acme.example;
2+
3+
import dev.langchain4j.data.message.ChatMessage;
4+
import dev.langchain4j.data.message.UserMessage;
5+
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
6+
import dev.langchain4j.rag.RetrievalAugmentor;
7+
import dev.langchain4j.rag.content.Content;
8+
import dev.langchain4j.rag.content.injector.ContentInjector;
9+
import dev.langchain4j.rag.content.retriever.ContentRetriever;
10+
import dev.langchain4j.rag.query.Query;
11+
import dev.langchain4j.rag.query.router.QueryRouter;
12+
import io.quarkiverse.langchain4j.RegisterAiService;
13+
import jakarta.inject.Singleton;
14+
15+
import java.util.Collection;
16+
import java.util.Collections;
17+
import java.util.List;
18+
import java.util.function.Supplier;
19+
20+
@RegisterAiService(retrievalAugmentor = AiServiceWithQueryRouterAndContentInjector.QueryRouterAugmentor.class)
21+
public interface AiServiceWithQueryRouterAndContentInjector {
22+
23+
String chat(String message);
24+
25+
/**
26+
* Contains a query transformer that transforms the query text to
27+
* lowercase. If the transformed worked properly, the content retriever
28+
* will append content saying "The transformer works!".
29+
*/
30+
@Singleton
31+
class QueryRouterAugmentor implements Supplier<RetrievalAugmentor> {
32+
33+
@Override
34+
public RetrievalAugmentor get() {
35+
return DefaultRetrievalAugmentor.builder()
36+
.queryRouter(new QueryRouter() {
37+
@Override
38+
public Collection<ContentRetriever> route(Query query) {
39+
if (query.text().contains("dog")) {
40+
return Collections.singletonList(dogsRetriever());
41+
} else if (query.text().contains("cat")) {
42+
return Collections.singletonList(catsRetriever());
43+
} else {
44+
return Collections.emptyList();
45+
}
46+
}
47+
})
48+
.contentInjector(new ContentInjector() {
49+
@Override
50+
public ChatMessage inject(List<Content> contents, ChatMessage chatMessage) {
51+
String rewrittenMessage = ((UserMessage) chatMessage).singleText() + " - " + contents.get(0).textSegment().text();
52+
return UserMessage.userMessage(rewrittenMessage);
53+
}
54+
})
55+
.build();
56+
}
57+
58+
static ContentRetriever dogsRetriever() {
59+
return new ContentRetriever() {
60+
@Override
61+
public List<Content> retrieve(Query query) {
62+
return Collections.singletonList(Content.from("Dogs bark"));
63+
}
64+
};
65+
}
66+
67+
static ContentRetriever catsRetriever() {
68+
return new ContentRetriever() {
69+
@Override
70+
public List<Content> retrieve(Query query) {
71+
return Collections.singletonList(Content.from("Cats meow"));
72+
}
73+
};
74+
}
75+
}
76+
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package org.acme.example.openai;
2+
3+
import dev.langchain4j.data.message.AiMessage;
4+
import dev.langchain4j.data.message.ChatMessage;
5+
import dev.langchain4j.model.chat.ChatModel;
6+
import dev.langchain4j.model.chat.request.ChatRequest;
7+
import dev.langchain4j.model.chat.response.ChatResponse;
8+
import io.quarkus.test.junit.QuarkusMock;
9+
import io.quarkus.test.junit.QuarkusTest;
10+
import jakarta.inject.Inject;
11+
import org.acme.example.AiServiceWithQueryRouterAndContentInjector;
12+
import org.junit.jupiter.api.Assertions;
13+
import org.junit.jupiter.api.BeforeAll;
14+
import org.junit.jupiter.api.Test;
15+
import org.mockito.Mockito;
16+
import org.mockito.stubbing.Answer;
17+
18+
import java.util.List;
19+
import java.util.concurrent.atomic.AtomicReference;
20+
21+
import static io.quarkiverse.langchain4j.runtime.LangChain4jUtil.chatMessageToText;
22+
23+
@QuarkusTest
24+
public class RAGWithQueryRouterAndContentInjectorTest {
25+
26+
@Inject
27+
AiServiceWithQueryRouterAndContentInjector service;
28+
29+
private static AtomicReference<List<ChatMessage>> lastQuery = new AtomicReference<>();
30+
31+
@BeforeAll
32+
public static void initializeModel() {
33+
ChatModel mock = Mockito.mock(ChatModel.class);
34+
Answer<ChatResponse> answer = invocation -> {
35+
lastQuery.set(((ChatRequest) invocation.getArgument(0)).messages());
36+
return ChatResponse.builder().aiMessage(new AiMessage("Mock response")).build();
37+
};
38+
Mockito.when(mock.chat(Mockito.any(ChatRequest.class))).thenAnswer(answer);
39+
QuarkusMock.installMockForType(mock, ChatModel.class);
40+
}
41+
42+
@Test
43+
public void test() {
44+
service.chat("What dogs do?");
45+
String query = chatMessageToText(lastQuery.get().get(0));
46+
Assertions.assertTrue(query.equals("What dogs do? - Dogs bark"), query);
47+
}
48+
49+
}

0 commit comments

Comments
 (0)