Skip to content

Commit 71fc7e5

Browse files
committed
enhanced-deepseek cot
1 parent af07517 commit 71fc7e5

File tree

8 files changed

+480
-11
lines changed

8 files changed

+480
-11
lines changed

auto-configurations/models/spring-ai-autoconfigure-model-deepseek/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@
7272
<optional>true</optional>
7373
</dependency>
7474

75+
<dependency>
76+
<groupId>org.springframework.ai</groupId>
77+
<artifactId>spring-ai-autoconfigure-model-chat-client</artifactId>
78+
<version>${project.parent.version}</version>
79+
</dependency>
80+
7581
<!-- Test dependencies -->
7682
<dependency>
7783
<groupId>org.springframework.ai</groupId>

models/spring-ai-deepseek/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@
5050
<artifactId>slf4j-api</artifactId>
5151
</dependency>
5252

53+
<dependency>
54+
<groupId>org.springframework.ai</groupId>
55+
<artifactId>spring-ai-client-chat</artifactId>
56+
<version>${project.parent.version}</version>
57+
</dependency>
58+
5359
<!-- test dependencies -->
5460
<dependency>
5561
<groupId>org.springframework.ai</groupId>

models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekAssistantMessage.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import java.util.List;
2020
import java.util.Map;
2121
import java.util.Objects;
22-
2322
import org.springframework.ai.chat.messages.AssistantMessage;
2423
import org.springframework.ai.content.Media;
2524

@@ -38,6 +37,11 @@ public DeepSeekAssistantMessage(String content, String reasoningContent) {
3837
this.reasoningContent = reasoningContent;
3938
}
4039

40+
public DeepSeekAssistantMessage(String content, String reasoningContent, Map<String, Object> properties) {
41+
super(content, properties);
42+
this.reasoningContent = reasoningContent;
43+
}
44+
4145
public DeepSeekAssistantMessage(String content, Map<String, Object> properties) {
4246
super(content, properties);
4347
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package org.springframework.ai.deepseek;
2+
3+
import java.util.HashMap;
4+
import java.util.Map;
5+
import java.util.concurrent.atomic.AtomicReference;
6+
import java.util.function.Consumer;
7+
import org.springframework.ai.chat.client.ChatClientResponse;
8+
import reactor.core.publisher.Flux;
9+
10+
public class DeepSeekChatClientMessageAggregator {
11+
12+
public Flux<ChatClientResponse> aggregateChatClientResponse(
13+
Flux<ChatClientResponse> chatClientResponses,
14+
Consumer<ChatClientResponse> aggregationHandler) {
15+
16+
AtomicReference<Map<String, Object>> context = new AtomicReference<>(new HashMap<>());
17+
18+
return new DeepSeekMessageAggregator().aggregate(chatClientResponses.mapNotNull(chatClientResponse -> {
19+
context.get().putAll(chatClientResponse.context());
20+
return chatClientResponse.chatResponse();
21+
}), aggregatedChatResponse -> {
22+
ChatClientResponse aggregatedChatClientResponse = ChatClientResponse.builder()
23+
.chatResponse(aggregatedChatResponse).context(context.get()).build();
24+
aggregationHandler.accept(aggregatedChatClientResponse);
25+
}).map(chatResponse -> ChatClientResponse.builder().chatResponse(chatResponse)
26+
.context(context.get()).build());
27+
}
28+
}

models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,14 @@
1616

1717
package org.springframework.ai.deepseek;
1818

19-
import java.util.List;
20-
import java.util.Map;
21-
import java.util.concurrent.ConcurrentHashMap;
22-
2319
import io.micrometer.observation.Observation;
2420
import io.micrometer.observation.ObservationRegistry;
2521
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
22+
import java.util.List;
23+
import java.util.Map;
24+
import java.util.concurrent.ConcurrentHashMap;
2625
import org.slf4j.Logger;
2726
import org.slf4j.LoggerFactory;
28-
import reactor.core.publisher.Flux;
29-
import reactor.core.publisher.Mono;
30-
import reactor.core.scheduler.Schedulers;
31-
3227
import org.springframework.ai.chat.messages.AssistantMessage;
3328
import org.springframework.ai.chat.messages.MessageType;
3429
import org.springframework.ai.chat.messages.ToolResponseMessage;
@@ -40,7 +35,6 @@
4035
import org.springframework.ai.chat.model.ChatModel;
4136
import org.springframework.ai.chat.model.ChatResponse;
4237
import org.springframework.ai.chat.model.Generation;
43-
import org.springframework.ai.chat.model.MessageAggregator;
4438
import org.springframework.ai.chat.model.StreamingChatModel;
4539
import org.springframework.ai.chat.observation.ChatModelObservationContext;
4640
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
@@ -69,6 +63,9 @@
6963
import org.springframework.retry.support.RetryTemplate;
7064
import org.springframework.util.Assert;
7165
import org.springframework.util.CollectionUtils;
66+
import reactor.core.publisher.Flux;
67+
import reactor.core.publisher.Mono;
68+
import reactor.core.scheduler.Schedulers;
7269

7370
/**
7471
* {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal DeepSeek}
@@ -312,7 +309,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
312309
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
313310
// @formatter:on
314311

315-
return new MessageAggregator().aggregate(flux, observationContext::setResponse);
312+
return new DeepSeekMessageAggregator().aggregate(flux, observationContext::setResponse);
316313

317314
});
318315
}
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
package org.springframework.ai.deepseek;
2+
3+
import java.util.HashMap;
4+
import java.util.List;
5+
import java.util.Map;
6+
import java.util.concurrent.atomic.AtomicReference;
7+
import java.util.function.Consumer;
8+
import org.slf4j.Logger;
9+
import org.slf4j.LoggerFactory;
10+
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
11+
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
12+
import org.springframework.ai.chat.metadata.EmptyRateLimit;
13+
import org.springframework.ai.chat.metadata.PromptMetadata;
14+
import org.springframework.ai.chat.metadata.RateLimit;
15+
import org.springframework.ai.chat.metadata.Usage;
16+
import org.springframework.ai.chat.model.ChatResponse;
17+
import org.springframework.ai.chat.model.Generation;
18+
import org.springframework.ai.chat.model.MessageAggregator;
19+
import org.springframework.util.StringUtils;
20+
import reactor.core.publisher.Flux;
21+
22+
/**
23+
* deepseek消息聚合器
24+
* lucas
25+
*/
26+
public class DeepSeekMessageAggregator extends MessageAggregator {
27+
private static final Logger logger = LoggerFactory.getLogger(DeepSeekMessageAggregator.class);
28+
@Override
29+
public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
30+
Consumer<ChatResponse> onAggregationComplete) {
31+
32+
// Assistant Message
33+
AtomicReference<StringBuilder> messageTextContentRef = new AtomicReference<>(
34+
new StringBuilder());
35+
// Reasoning Message
36+
AtomicReference<StringBuilder> reasoningContentRef = new AtomicReference<>(
37+
new StringBuilder());
38+
AtomicReference<Map<String, Object>> messageMetadataMapRef = new AtomicReference<>();
39+
40+
// ChatGeneration Metadata
41+
AtomicReference<ChatGenerationMetadata> generationMetadataRef = new AtomicReference<>(
42+
ChatGenerationMetadata.NULL);
43+
44+
// Usage
45+
AtomicReference<Integer> metadataUsagePromptTokensRef = new AtomicReference<Integer>(0);
46+
AtomicReference<Integer> metadataUsageGenerationTokensRef = new AtomicReference<Integer>(0);
47+
AtomicReference<Integer> metadataUsageTotalTokensRef = new AtomicReference<Integer>(0);
48+
49+
AtomicReference<PromptMetadata> metadataPromptMetadataRef = new AtomicReference<>(
50+
PromptMetadata.empty());
51+
AtomicReference<RateLimit> metadataRateLimitRef = new AtomicReference<>(new EmptyRateLimit());
52+
53+
AtomicReference<String> metadataIdRef = new AtomicReference<>("");
54+
AtomicReference<String> metadataModelRef = new AtomicReference<>("");
55+
56+
return fluxChatResponse.doOnSubscribe(subscription -> {
57+
messageTextContentRef.set(new StringBuilder());
58+
reasoningContentRef.set(new StringBuilder());
59+
messageMetadataMapRef.set(new HashMap<>());
60+
metadataIdRef.set("");
61+
metadataModelRef.set("");
62+
metadataUsagePromptTokensRef.set(0);
63+
metadataUsageGenerationTokensRef.set(0);
64+
metadataUsageTotalTokensRef.set(0);
65+
metadataPromptMetadataRef.set(PromptMetadata.empty());
66+
metadataRateLimitRef.set(new EmptyRateLimit());
67+
68+
}).doOnNext(chatResponse -> {
69+
70+
if (chatResponse.getResult() != null) {
71+
if (chatResponse.getResult().getMetadata() != null
72+
&& chatResponse.getResult().getMetadata() != ChatGenerationMetadata.NULL) {
73+
generationMetadataRef.set(chatResponse.getResult().getMetadata());
74+
}
75+
if (chatResponse.getResult().getOutput().getText() != null) {
76+
messageTextContentRef.get().append(chatResponse.getResult().getOutput().getText());
77+
}
78+
if (chatResponse.getResult()
79+
.getOutput() instanceof DeepSeekAssistantMessage deepSeekAssistantMessage) {
80+
reasoningContentRef.get().append(deepSeekAssistantMessage.getReasoningContent());
81+
}
82+
messageMetadataMapRef.get().putAll(chatResponse.getResult().getOutput().getMetadata());
83+
}
84+
if (chatResponse.getMetadata() != null) {
85+
if (chatResponse.getMetadata().getUsage() != null) {
86+
Usage usage = chatResponse.getMetadata().getUsage();
87+
metadataUsagePromptTokensRef.set(
88+
usage.getPromptTokens() > 0 ? usage.getPromptTokens()
89+
: metadataUsagePromptTokensRef.get());
90+
metadataUsageGenerationTokensRef.set(
91+
usage.getCompletionTokens() > 0 ? usage.getCompletionTokens()
92+
: metadataUsageGenerationTokensRef.get());
93+
metadataUsageTotalTokensRef
94+
.set(usage.getTotalTokens() > 0 ? usage.getTotalTokens()
95+
: metadataUsageTotalTokensRef.get());
96+
}
97+
if (chatResponse.getMetadata().getPromptMetadata() != null
98+
&& chatResponse.getMetadata().getPromptMetadata().iterator().hasNext()) {
99+
metadataPromptMetadataRef.set(chatResponse.getMetadata().getPromptMetadata());
100+
}
101+
if (chatResponse.getMetadata().getRateLimit() != null
102+
&& !(metadataRateLimitRef.get() instanceof EmptyRateLimit)) {
103+
metadataRateLimitRef.set(chatResponse.getMetadata().getRateLimit());
104+
}
105+
if (StringUtils.hasText(chatResponse.getMetadata().getId())) {
106+
metadataIdRef.set(chatResponse.getMetadata().getId());
107+
}
108+
if (StringUtils.hasText(chatResponse.getMetadata().getModel())) {
109+
metadataModelRef.set(chatResponse.getMetadata().getModel());
110+
}
111+
}
112+
}).doOnComplete(() -> {
113+
114+
var usage = new DefaultUsage(metadataUsagePromptTokensRef.get(),
115+
metadataUsageGenerationTokensRef.get(),
116+
metadataUsageTotalTokensRef.get());
117+
118+
var chatResponseMetadata = ChatResponseMetadata.builder()
119+
.id(metadataIdRef.get())
120+
.model(metadataModelRef.get())
121+
.rateLimit(metadataRateLimitRef.get())
122+
.usage(usage)
123+
.promptMetadata(metadataPromptMetadataRef.get())
124+
.build();
125+
onAggregationComplete.accept(new ChatResponse(List.of(new Generation(
126+
new DeepSeekAssistantMessage(messageTextContentRef.get().toString(),
127+
reasoningContentRef.get().toString(), messageMetadataMapRef.get()),
128+
generationMetadataRef.get())), chatResponseMetadata));
129+
130+
messageTextContentRef.set(new StringBuilder());
131+
reasoningContentRef.set(new StringBuilder());
132+
messageMetadataMapRef.set(new HashMap<>());
133+
metadataIdRef.set("");
134+
metadataModelRef.set("");
135+
metadataUsagePromptTokensRef.set(0);
136+
metadataUsageGenerationTokensRef.set(0);
137+
metadataUsageTotalTokensRef.set(0);
138+
metadataPromptMetadataRef.set(PromptMetadata.empty());
139+
metadataRateLimitRef.set(new EmptyRateLimit());
140+
141+
}).doOnError(e -> logger.error("Aggregation Error", e));
142+
}
143+
}

0 commit comments

Comments
 (0)