Skip to content

Commit 34868fb

Browse files
committed
Add RAG test using QueryRouter and ContentInjector
1 parent 134f27f commit 34868fb

File tree

2 files changed

+129
-0
lines changed

2 files changed

+129
-0
lines changed
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package org.acme.example;
2+
3+
import java.util.Collection;
4+
import java.util.Collections;
5+
import java.util.List;
6+
import java.util.function.Supplier;
7+
8+
import jakarta.inject.Singleton;
9+
10+
import dev.langchain4j.data.message.ChatMessage;
11+
import dev.langchain4j.data.message.UserMessage;
12+
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
13+
import dev.langchain4j.rag.RetrievalAugmentor;
14+
import dev.langchain4j.rag.content.Content;
15+
import dev.langchain4j.rag.content.injector.ContentInjector;
16+
import dev.langchain4j.rag.content.retriever.ContentRetriever;
17+
import dev.langchain4j.rag.query.Query;
18+
import dev.langchain4j.rag.query.router.QueryRouter;
19+
import io.quarkiverse.langchain4j.RegisterAiService;
20+
21+
@RegisterAiService(retrievalAugmentor = AiServiceWithQueryRouterAndContentInjector.QueryRouterAugmentor.class)
22+
public interface AiServiceWithQueryRouterAndContentInjector {
23+
24+
String chat(String message);
25+
26+
/**
27+
* Contains a query transformer that transforms the query text to
28+
* lowercase. If the transformed worked properly, the content retriever
29+
* will append content saying "The transformer works!".
30+
*/
31+
@Singleton
32+
class QueryRouterAugmentor implements Supplier<RetrievalAugmentor> {
33+
34+
@Override
35+
public RetrievalAugmentor get() {
36+
return DefaultRetrievalAugmentor.builder()
37+
.queryRouter(new QueryRouter() {
38+
@Override
39+
public Collection<ContentRetriever> route(Query query) {
40+
if (query.text().contains("dog")) {
41+
return Collections.singletonList(dogsRetriever());
42+
} else if (query.text().contains("cat")) {
43+
return Collections.singletonList(catsRetriever());
44+
} else {
45+
return Collections.emptyList();
46+
}
47+
}
48+
})
49+
.contentInjector(new ContentInjector() {
50+
@Override
51+
public ChatMessage inject(List<Content> contents, ChatMessage chatMessage) {
52+
String rewrittenMessage = ((UserMessage) chatMessage).singleText() + " - "
53+
+ contents.get(0).textSegment().text();
54+
return UserMessage.userMessage(rewrittenMessage);
55+
}
56+
})
57+
.build();
58+
}
59+
60+
static ContentRetriever dogsRetriever() {
61+
return new ContentRetriever() {
62+
@Override
63+
public List<Content> retrieve(Query query) {
64+
return Collections.singletonList(Content.from("Dogs bark"));
65+
}
66+
};
67+
}
68+
69+
static ContentRetriever catsRetriever() {
70+
return new ContentRetriever() {
71+
@Override
72+
public List<Content> retrieve(Query query) {
73+
return Collections.singletonList(Content.from("Cats meow"));
74+
}
75+
};
76+
}
77+
}
78+
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package org.acme.example.openai;
2+
3+
import static io.quarkiverse.langchain4j.runtime.LangChain4jUtil.chatMessageToText;
4+
5+
import java.util.List;
6+
import java.util.concurrent.atomic.AtomicReference;
7+
8+
import jakarta.inject.Inject;
9+
10+
import org.acme.example.AiServiceWithQueryRouterAndContentInjector;
11+
import org.junit.jupiter.api.Assertions;
12+
import org.junit.jupiter.api.BeforeAll;
13+
import org.junit.jupiter.api.Test;
14+
import org.mockito.Mockito;
15+
import org.mockito.stubbing.Answer;
16+
17+
import dev.langchain4j.data.message.AiMessage;
18+
import dev.langchain4j.data.message.ChatMessage;
19+
import dev.langchain4j.model.chat.ChatModel;
20+
import dev.langchain4j.model.chat.request.ChatRequest;
21+
import dev.langchain4j.model.chat.response.ChatResponse;
22+
import io.quarkus.test.junit.QuarkusMock;
23+
import io.quarkus.test.junit.QuarkusTest;
24+
25+
@QuarkusTest
26+
public class RAGWithQueryRouterAndContentInjectorTest {
27+
28+
@Inject
29+
AiServiceWithQueryRouterAndContentInjector service;
30+
31+
private static AtomicReference<List<ChatMessage>> lastQuery = new AtomicReference<>();
32+
33+
@BeforeAll
34+
public static void initializeModel() {
35+
ChatModel mock = Mockito.mock(ChatModel.class);
36+
Answer<ChatResponse> answer = invocation -> {
37+
lastQuery.set(((ChatRequest) invocation.getArgument(0)).messages());
38+
return ChatResponse.builder().aiMessage(new AiMessage("Mock response")).build();
39+
};
40+
Mockito.when(mock.chat(Mockito.any(ChatRequest.class))).thenAnswer(answer);
41+
QuarkusMock.installMockForType(mock, ChatModel.class);
42+
}
43+
44+
@Test
45+
public void test() {
46+
service.chat("What dogs do?");
47+
String query = chatMessageToText(lastQuery.get().get(0));
48+
Assertions.assertTrue(query.equals("What dogs do? - Dogs bark"), query);
49+
}
50+
51+
}

0 commit comments

Comments
 (0)