Skip to content

Commit f4d5372

Browse files
Merge pull request #17942 from sIvanovKonstantyn/master2
BAEL-8840 - A Guide to Spring AI Advisors
2 parents faa3ca3 + b1286ac commit f4d5372

File tree

6 files changed

+301
-5
lines changed

6 files changed

+301
-5
lines changed

spring-ai/pom.xml

+2-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@
114114

115115
<properties>
116116
<spring-boot.version>3.3.2</spring-boot.version>
117-
<spring-ai.version>1.0.0-M1</spring-ai.version>
117+
<spring-ai.version>1.0.0-M3</spring-ai.version>
118+
<junit-jupiter.version>5.9.0</junit-jupiter.version>
118119
</properties>
119120

120121
</project>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package com.baeldung.springai.advisors;
2+
3+
import org.slf4j.Logger;
4+
import org.slf4j.LoggerFactory;
5+
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
6+
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
7+
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
8+
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
9+
10+
public class CustomLoggingAdvisor implements CallAroundAdvisor {
11+
private final static Logger logger = LoggerFactory.getLogger(CustomLoggingAdvisor.class);
12+
13+
@Override
14+
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest, CallAroundAdvisorChain chain) {
15+
16+
advisedRequest = this.before(advisedRequest);
17+
18+
AdvisedResponse advisedResponse = chain.nextAroundCall(advisedRequest);
19+
20+
this.observeAfter(advisedResponse);
21+
22+
return advisedResponse;
23+
}
24+
25+
private void observeAfter(AdvisedResponse advisedResponse) {
26+
logger.info(advisedResponse.response()
27+
.getResult()
28+
.getOutput()
29+
.getContent());
30+
31+
}
32+
33+
private AdvisedRequest before(AdvisedRequest advisedRequest) {
34+
logger.info(advisedRequest.userText());
35+
return advisedRequest;
36+
}
37+
38+
@Override
39+
public String getName() {
40+
return "CustomLoggingAdvisor";
41+
}
42+
43+
@Override
44+
public int getOrder() {
45+
return Integer.MAX_VALUE;
46+
}
47+
}

spring-ai/src/main/java/com/baeldung/springai/web/ExceptionTranslator.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package com.baeldung.springai.web;
22

3-
import org.springframework.ai.openai.api.common.OpenAiApiException;
3+
import org.springframework.ai.openai.api.common.OpenAiApiClientErrorException;
44
import org.springframework.http.HttpStatus;
55
import org.springframework.http.ProblemDetail;
66
import org.springframework.web.bind.annotation.ExceptionHandler;
@@ -14,8 +14,8 @@ public class ExceptionTranslator extends ResponseEntityExceptionHandler {
1414

1515
public static final String OPEN_AI_CLIENT_RAISED_EXCEPTION = "Open AI client raised exception";
1616

17-
@ExceptionHandler(OpenAiApiException.class)
18-
ProblemDetail handleOpenAiHttpException(OpenAiApiException ex) {
17+
@ExceptionHandler(OpenAiApiClientErrorException.class)
18+
ProblemDetail handleOpenAiHttpException(OpenAiApiClientErrorException ex) {
1919
HttpStatus status = Optional
2020
.ofNullable(HttpStatus.resolve(400))
2121
.orElse(HttpStatus.BAD_REQUEST);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package com.baeldung.springai.advisors;
2+
3+
import org.springframework.ai.document.Document;
4+
import org.springframework.ai.embedding.EmbeddingModel;
5+
import org.springframework.ai.vectorstore.SearchRequest;
6+
import org.springframework.ai.vectorstore.SimpleVectorStore;
7+
import org.springframework.ai.vectorstore.VectorStore;
8+
import org.springframework.beans.factory.annotation.Qualifier;
9+
import org.springframework.context.annotation.Bean;
10+
import org.springframework.context.annotation.Configuration;
11+
import org.springframework.data.util.Pair;
12+
13+
import java.util.Comparator;
14+
import java.util.List;
15+
16+
@Configuration
17+
public class SimpleVectorStoreConfiguration {
18+
19+
@Bean
20+
public VectorStore vectorStore(@Qualifier("openAiEmbeddingModel")EmbeddingModel embeddingModel) {
21+
return new SimpleVectorStore(embeddingModel) {
22+
@Override
23+
public List<Document> doSimilaritySearch(SearchRequest request) {
24+
float[] userQueryEmbedding = embeddingModel.embed(request.query);
25+
return this.store.values()
26+
.stream()
27+
.map(entry -> Pair.of(entry.getId(),
28+
EmbeddingMath.cosineSimilarity(userQueryEmbedding, entry.getEmbedding())))
29+
.filter(s -> s.getSecond() >= request.getSimilarityThreshold())
30+
.sorted(Comparator.comparing(Pair::getSecond))
31+
.limit(request.getTopK())
32+
.map(s -> this.store.get(s.getFirst()))
33+
.toList();
34+
}
35+
};
36+
}
37+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
package com.baeldung.springai.advisors;
2+
3+
import org.junit.jupiter.api.BeforeEach;
4+
import org.junit.jupiter.api.Test;
5+
import org.junit.jupiter.api.extension.ExtendWith;
6+
import org.springframework.ai.autoconfigure.mistralai.MistralAiAutoConfiguration;
7+
import org.springframework.ai.autoconfigure.vectorstore.mongo.MongoDBAtlasVectorStoreAutoConfiguration;
8+
import org.springframework.ai.autoconfigure.vectorstore.redis.RedisVectorStoreAutoConfiguration;
9+
import org.springframework.ai.chat.client.ChatClient;
10+
import org.springframework.ai.chat.client.advisor.*;
11+
import org.springframework.ai.chat.memory.ChatMemory;
12+
import org.springframework.ai.chat.memory.InMemoryChatMemory;
13+
import org.springframework.ai.chat.model.ChatModel;
14+
import org.springframework.ai.document.Document;
15+
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
16+
import org.springframework.ai.vectorstore.VectorStore;
17+
import org.springframework.beans.factory.annotation.Autowired;
18+
import org.springframework.beans.factory.annotation.Qualifier;
19+
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
20+
import org.springframework.boot.autoconfigure.mongo.MongoAutoConfiguration;
21+
import org.springframework.boot.test.context.SpringBootTest;
22+
import org.springframework.test.context.junit.jupiter.SpringExtension;
23+
24+
import java.util.List;
25+
26+
import static org.assertj.core.api.Assertions.assertThat;
27+
28+
/*
29+
* To set the test environment:
30+
* populate OPENAI_API_KEY env. variable with active Open AI API key
31+
* */
32+
@SpringBootTest(classes = {ChatModel.class, SimpleVectorStoreConfiguration.class})
33+
@EnableAutoConfiguration(exclude = {RedisVectorStoreAutoConfiguration.class,
34+
MistralAiAutoConfiguration.class, MongoDBAtlasVectorStoreAutoConfiguration.class, MongoAutoConfiguration.class})
35+
@ExtendWith(SpringExtension.class)
36+
public class SpringAILiveTest {
37+
38+
@Autowired
39+
@Qualifier("openAiChatModel")
40+
ChatModel chatModel;
41+
@Autowired
42+
VectorStore vectorStore;
43+
ChatClient chatClient;
44+
45+
@BeforeEach
46+
void setup() {
47+
chatClient = ChatClient.builder(chatModel).build();
48+
}
49+
50+
@Test
51+
void givenMessageChatMemoryAdvisor_whenAskingChatToIncrementTheResponseWithNewName_thenNamesFromTheChatHistoryExistInResponse() {
52+
ChatMemory chatMemory = new InMemoryChatMemory();
53+
MessageChatMemoryAdvisor chatMemoryAdvisor = new MessageChatMemoryAdvisor(chatMemory);
54+
55+
String responseContent = chatClient.prompt()
56+
.user("Add this name to a list and return all the values: Bob")
57+
.advisors(chatMemoryAdvisor)
58+
.call()
59+
.content();
60+
61+
assertThat(responseContent)
62+
.contains("Bob");
63+
64+
responseContent = chatClient.prompt()
65+
.user("Add this name to a list and return all the values: John")
66+
.advisors(chatMemoryAdvisor)
67+
.call()
68+
.content();
69+
70+
assertThat(responseContent)
71+
.contains("Bob")
72+
.contains("John");
73+
74+
responseContent = chatClient.prompt()
75+
.user("Add this name to a list and return all the values: Anna")
76+
.advisors(chatMemoryAdvisor)
77+
.call()
78+
.content();
79+
80+
assertThat(responseContent)
81+
.contains("Bob")
82+
.contains("John")
83+
.contains("Anna");
84+
}
85+
86+
@Test
87+
void givenPromptChatMemoryAdvisor_whenAskingChatToIncrementTheResponseWithNewName_thenNamesFromTheChatHistoryExistInResponse() {
88+
ChatMemory chatMemory = new InMemoryChatMemory();
89+
PromptChatMemoryAdvisor chatMemoryAdvisor = new PromptChatMemoryAdvisor(chatMemory);
90+
91+
String responseContent = chatClient.prompt()
92+
.user("Add this name to a list and return all the values: Bob")
93+
.advisors(chatMemoryAdvisor)
94+
.call()
95+
.content();
96+
97+
assertThat(responseContent)
98+
.contains("Bob");
99+
100+
responseContent = chatClient.prompt()
101+
.user("Add this name to a list and return all the values: John")
102+
.advisors(chatMemoryAdvisor)
103+
.call()
104+
.content();
105+
106+
assertThat(responseContent)
107+
.contains("Bob")
108+
.contains("John");
109+
110+
responseContent = chatClient.prompt()
111+
.user("Add this name to a list and return all the values: Anna")
112+
.advisors(chatMemoryAdvisor)
113+
.call()
114+
.content();
115+
116+
assertThat(responseContent)
117+
.contains("Bob")
118+
.contains("John")
119+
.contains("Anna");
120+
}
121+
122+
@Test
123+
void givenVectorStoreChatMemoryAdvisor_whenAskingChatToIncrementTheResponseWithNewName_thenNamesFromTheChatHistoryExistInResponse() {
124+
VectorStoreChatMemoryAdvisor chatMemoryAdvisor = new VectorStoreChatMemoryAdvisor(vectorStore);
125+
126+
String responseContent = chatClient.prompt()
127+
.user("Find cats from our chat history, add Lion there and return a list")
128+
.advisors(chatMemoryAdvisor)
129+
.call()
130+
.content();
131+
132+
assertThat(responseContent)
133+
.contains("Lion");
134+
135+
responseContent = chatClient.prompt()
136+
.user("Find cats from our chat history, add Puma there and return a list")
137+
.advisors(chatMemoryAdvisor)
138+
.call()
139+
.content();
140+
141+
assertThat(responseContent)
142+
.contains("Lion")
143+
.contains("Puma");
144+
145+
responseContent = chatClient.prompt()
146+
.user("Find cats from our chat history, add Leopard there and return a list")
147+
.advisors(chatMemoryAdvisor)
148+
.call()
149+
.content();
150+
151+
assertThat(responseContent)
152+
.contains("Lion")
153+
.contains("Puma")
154+
.contains("Leopard");
155+
}
156+
157+
@Test
158+
void givenQuestionAnswerAdvisor_whenAskingQuestion_thenAnswerShouldBeProvidedBasedOnVectorStoreInformation() {
159+
160+
Document document = new Document("The sky is green");
161+
List<Document> documents = new TokenTextSplitter().apply(List.of(document));
162+
vectorStore.add(documents);
163+
QuestionAnswerAdvisor questionAnswerAdvisor = new QuestionAnswerAdvisor(vectorStore);
164+
165+
String responseContent = chatClient.prompt()
166+
.user("What is the sky color?")
167+
.advisors(questionAnswerAdvisor)
168+
.call()
169+
.content();
170+
171+
assertThat(responseContent)
172+
.containsIgnoringCase("green");
173+
}
174+
175+
@Test
176+
void givenSafeGuardAdvisor_whenSendPromptWithSensitiveWord_thenExpectedMessageShouldBeReturned() {
177+
178+
List<String> forbiddenWords = List.of("Word2");
179+
SafeGuardAdvisor safeGuardAdvisor = new SafeGuardAdvisor(forbiddenWords);
180+
181+
String responseContent = chatClient.prompt()
182+
.user("Please split the 'Word2' into characters")
183+
.advisors(safeGuardAdvisor)
184+
.call()
185+
.content();
186+
187+
assertThat(responseContent)
188+
.contains("I'm unable to respond to that due to sensitive content");
189+
}
190+
191+
@Test
192+
void givenCustomLoggingAdvisor_whenSendPrompt_thenPromptTextAndResponseShouldBeLogged() {
193+
194+
CustomLoggingAdvisor customLoggingAdvisor = new CustomLoggingAdvisor();
195+
196+
String responseContent = chatClient.prompt()
197+
.user("Count from 1 to 10")
198+
.advisors(customLoggingAdvisor)
199+
.call()
200+
.content();
201+
202+
assertThat(responseContent)
203+
.contains("1")
204+
.contains("10");
205+
}
206+
}

spring-ai/src/test/java/com/baeldung/springai/rag/mongodb/RAGMongoDBApplicationManualTest.java

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
package com.baeldung.springai.rag.mongodb;
22

3+
import com.baeldung.springai.rag.mongodb.config.VectorStoreConfig;
34
import org.junit.jupiter.api.Test;
45
import org.junit.jupiter.api.extension.ExtendWith;
56
import org.slf4j.Logger;
67
import org.slf4j.LoggerFactory;
8+
import org.springframework.ai.autoconfigure.mistralai.MistralAiAutoConfiguration;
9+
import org.springframework.ai.autoconfigure.vectorstore.redis.RedisVectorStoreAutoConfiguration;
710
import org.springframework.beans.factory.annotation.Autowired;
11+
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
812
import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc;
913
import org.springframework.boot.test.context.SpringBootTest;
1014
import org.springframework.test.context.junit.jupiter.SpringExtension;
@@ -21,7 +25,8 @@
2125
* */
2226
@AutoConfigureMockMvc
2327
@ExtendWith(SpringExtension.class)
24-
@SpringBootTest
28+
@EnableAutoConfiguration(exclude = {RedisVectorStoreAutoConfiguration.class, MistralAiAutoConfiguration.class})
29+
@SpringBootTest(classes = VectorStoreConfig.class)
2530
class RAGMongoDBApplicationManualTest {
2631
private static Logger logger = LoggerFactory.getLogger(RAGMongoDBApplicationManualTest.class);
2732

0 commit comments

Comments
 (0)