Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,7 @@ qodana.yaml
__pycache__/
*.pyc
tmp


plans

Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Source;
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.Type;
import org.springframework.ai.anthropic.api.AnthropicApi.Role;
import org.springframework.ai.anthropic.api.AnthropicCacheStrategy;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
Expand Down Expand Up @@ -460,6 +462,12 @@ Prompt buildRequestPrompt(Prompt prompt) {
this.defaultOptions.getToolCallbacks()));
requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(),
this.defaultOptions.getToolContext()));

// Merge cache strategy and TTL (also @JsonIgnore fields)
requestOptions.setCacheStrategy(runtimeOptions.getCacheStrategy() != null
? runtimeOptions.getCacheStrategy() : this.defaultOptions.getCacheStrategy());
requestOptions.setCacheTtl(runtimeOptions.getCacheTtl() != null ? runtimeOptions.getCacheTtl()
: this.defaultOptions.getCacheTtl());
}
else {
requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders());
Expand All @@ -483,69 +491,75 @@ private Map<String, String> mergeHttpHeaders(Map<String, String> runtimeHttpHead

ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {

List<AnthropicMessage> userMessages = prompt.getInstructions()
.stream()
.filter(message -> message.getMessageType() != MessageType.SYSTEM)
.map(message -> {
if (message.getMessageType() == MessageType.USER) {
List<ContentBlock> contents = new ArrayList<>(List.of(new ContentBlock(message.getText())));
if (message instanceof UserMessage userMessage) {
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
List<ContentBlock> mediaContent = userMessage.getMedia().stream().map(media -> {
Type contentBlockType = getContentBlockTypeByMedia(media);
var source = getSourceByMedia(media);
return new ContentBlock(contentBlockType, source);
}).toList();
contents.addAll(mediaContent);
}
}
return new AnthropicMessage(contents, Role.valueOf(message.getMessageType().name()));
}
else if (message.getMessageType() == MessageType.ASSISTANT) {
AssistantMessage assistantMessage = (AssistantMessage) message;
List<ContentBlock> contentBlocks = new ArrayList<>();
if (StringUtils.hasText(message.getText())) {
contentBlocks.add(new ContentBlock(message.getText()));
}
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
contentBlocks.add(new ContentBlock(Type.TOOL_USE, toolCall.id(), toolCall.name(),
ModelOptionsUtils.jsonToMap(toolCall.arguments())));
}
}
return new AnthropicMessage(contentBlocks, Role.ASSISTANT);
}
else if (message.getMessageType() == MessageType.TOOL) {
List<ContentBlock> toolResponses = ((ToolResponseMessage) message).getResponses()
.stream()
.map(toolResponse -> new ContentBlock(Type.TOOL_RESULT, toolResponse.id(),
toolResponse.responseData()))
.toList();
return new AnthropicMessage(toolResponses, Role.USER);
}
else {
throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
}
})
.toList();
// Get caching strategy and options from the request
logger.info("DEBUGINFO: prompt.getOptions() type: {}, value: {}",
prompt.getOptions() != null ? prompt.getOptions().getClass().getName() : "null", prompt.getOptions());

String systemPrompt = prompt.getInstructions()
.stream()
.filter(m -> m.getMessageType() == MessageType.SYSTEM)
.map(m -> m.getText())
.collect(Collectors.joining(System.lineSeparator()));
AnthropicChatOptions requestOptions = null;
if (prompt.getOptions() instanceof AnthropicChatOptions) {
requestOptions = (AnthropicChatOptions) prompt.getOptions();
logger.info("DEBUGINFO: Found AnthropicChatOptions - cacheStrategy: {}, cacheTtl: {}",
requestOptions.getCacheStrategy(), requestOptions.getCacheTtl());
}
else {
logger.info("DEBUGINFO: Options is NOT AnthropicChatOptions, it's: {}",
prompt.getOptions() != null ? prompt.getOptions().getClass().getName() : "null");
}

AnthropicCacheStrategy strategy = requestOptions != null ? requestOptions.getCacheStrategy()
: AnthropicCacheStrategy.NONE;
String cacheTtl = requestOptions != null ? requestOptions.getCacheTtl() : "5m";

logger.info("Cache strategy: {}, TTL: {}", strategy, cacheTtl);

// Track how many breakpoints we've used (max 4)
CacheBreakpointTracker breakpointsUsed = new CacheBreakpointTracker();
ChatCompletionRequest.CacheControl cacheControl = null;

if (strategy != AnthropicCacheStrategy.NONE) {
// Create cache control with TTL if specified, otherwise use default 5m
if (cacheTtl != null && !cacheTtl.equals("5m")) {
cacheControl = new ChatCompletionRequest.CacheControl("ephemeral", cacheTtl);
logger.info("Created cache control with TTL: type={}, ttl={}", "ephemeral", cacheTtl);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these need to be changed to 'debug'

}
else {
cacheControl = new ChatCompletionRequest.CacheControl("ephemeral");
logger.info("Created cache control with default TTL: type={}, ttl={}", "ephemeral", "5m");
}
}

// Build messages WITHOUT blanket cache control - strategic placement only
List<AnthropicMessage> userMessages = buildMessages(prompt, strategy, cacheControl, breakpointsUsed);

// Process system - as array if caching, string otherwise
Object systemContent = buildSystemContent(prompt, strategy, cacheControl, breakpointsUsed);

// Build base request
ChatCompletionRequest request = new ChatCompletionRequest(this.defaultOptions.getModel(), userMessages,
systemPrompt, this.defaultOptions.getMaxTokens(), this.defaultOptions.getTemperature(), stream);
systemContent, this.defaultOptions.getMaxTokens(), this.defaultOptions.getTemperature(), stream);

AnthropicChatOptions requestOptions = (AnthropicChatOptions) prompt.getOptions();
request = ModelOptionsUtils.merge(requestOptions, request, ChatCompletionRequest.class);

// Add the tool definitions to the request's tools parameter.
// Add the tool definitions with potential caching
List<ToolDefinition> toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions);
if (!CollectionUtils.isEmpty(toolDefinitions)) {
request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class);
request = ChatCompletionRequest.from(request).tools(getFunctionTools(toolDefinitions)).build();
List<AnthropicApi.Tool> tools = getFunctionTools(toolDefinitions);

// Apply caching to tools if strategy includes them
if ((strategy == AnthropicCacheStrategy.SYSTEM_AND_TOOLS
|| strategy == AnthropicCacheStrategy.CONVERSATION_HISTORY) && breakpointsUsed.canUse()) {
tools = addCacheToLastTool(tools, cacheControl, breakpointsUsed);
}

request = ChatCompletionRequest.from(request).tools(tools).build();
}

// Add beta header for 1-hour TTL if needed
if ("1h".equals(cacheTtl) && requestOptions != null) {
Map<String, String> headers = new HashMap<>(requestOptions.getHttpHeaders());
headers.put("anthropic-beta", AnthropicApi.BETA_EXTENDED_CACHE_TTL);
requestOptions.setHttpHeaders(headers);
}

return request;
Expand All @@ -561,6 +575,154 @@ private List<AnthropicApi.Tool> getFunctionTools(List<ToolDefinition> toolDefini
}).toList();
}

/**
* Build messages strategically, applying cache control only where specified by the
* strategy.
*/
private List<AnthropicMessage> buildMessages(Prompt prompt, AnthropicCacheStrategy strategy,
ChatCompletionRequest.CacheControl cacheControl, CacheBreakpointTracker breakpointsUsed) {

List<Message> allMessages = prompt.getInstructions()
.stream()
.filter(message -> message.getMessageType() != MessageType.SYSTEM)
.toList();

// Find the last user message (current question) for CONVERSATION_HISTORY strategy
int lastUserIndex = -1;
if (strategy == AnthropicCacheStrategy.CONVERSATION_HISTORY) {
for (int i = allMessages.size() - 1; i >= 0; i--) {
if (allMessages.get(i).getMessageType() == MessageType.USER) {
lastUserIndex = i;
break;
}
}
}

List<AnthropicMessage> result = new ArrayList<>();
for (int i = 0; i < allMessages.size(); i++) {
Message message = allMessages.get(i);
boolean shouldApplyCache = false;

// Apply cache to history tail (message before current question) for
// CONVERSATION_HISTORY
if (strategy == AnthropicCacheStrategy.CONVERSATION_HISTORY && breakpointsUsed.canUse()) {
if (lastUserIndex > 0) {
// Cache the message immediately before the last user message
// (multi-turn conversation)
shouldApplyCache = (i == lastUserIndex - 1);
}
if (shouldApplyCache) {
breakpointsUsed.use();
}
}

if (message.getMessageType() == MessageType.USER) {
List<ContentBlock> contents = new ArrayList<>();

// Apply cache control strategically, not to all user messages
if (shouldApplyCache && cacheControl != null) {
contents.add(new ContentBlock(message.getText(), cacheControl));
}
else {
contents.add(new ContentBlock(message.getText()));
}

if (message instanceof UserMessage userMessage) {
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
List<ContentBlock> mediaContent = userMessage.getMedia().stream().map(media -> {
Type contentBlockType = getContentBlockTypeByMedia(media);
var source = getSourceByMedia(media);
return new ContentBlock(contentBlockType, source);
}).toList();
contents.addAll(mediaContent);
}
}
result.add(new AnthropicMessage(contents, Role.valueOf(message.getMessageType().name())));
}
else if (message.getMessageType() == MessageType.ASSISTANT) {
AssistantMessage assistantMessage = (AssistantMessage) message;
List<ContentBlock> contentBlocks = new ArrayList<>();
if (StringUtils.hasText(message.getText())) {
contentBlocks.add(new ContentBlock(message.getText()));
}
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
contentBlocks.add(new ContentBlock(Type.TOOL_USE, toolCall.id(), toolCall.name(),
ModelOptionsUtils.jsonToMap(toolCall.arguments())));
}
}
result.add(new AnthropicMessage(contentBlocks, Role.ASSISTANT));
}
else if (message.getMessageType() == MessageType.TOOL) {
List<ContentBlock> toolResponses = ((ToolResponseMessage) message).getResponses()
.stream()
.map(toolResponse -> new ContentBlock(Type.TOOL_RESULT, toolResponse.id(),
toolResponse.responseData()))
.toList();
result.add(new AnthropicMessage(toolResponses, Role.USER));
}
else {
throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
}
}
return result;
}

/**
* Build system content - as array if caching, string otherwise.
*/
private Object buildSystemContent(Prompt prompt, AnthropicCacheStrategy strategy,
ChatCompletionRequest.CacheControl cacheControl, CacheBreakpointTracker breakpointsUsed) {

String systemText = prompt.getInstructions()
.stream()
.filter(m -> m.getMessageType() == MessageType.SYSTEM)
.map(Message::getText)
.collect(Collectors.joining(System.lineSeparator()));

if (!StringUtils.hasText(systemText)) {
return null;
}

// Use array format when caching system
if ((strategy == AnthropicCacheStrategy.SYSTEM_ONLY || strategy == AnthropicCacheStrategy.SYSTEM_AND_TOOLS
|| strategy == AnthropicCacheStrategy.CONVERSATION_HISTORY) && breakpointsUsed.canUse()
&& cacheControl != null) {

logger.info("Applying cache control to system message - strategy: {}, cacheControl: {}", strategy,
cacheControl);
List<ContentBlock> systemBlocks = List.of(new ContentBlock(systemText, cacheControl));
breakpointsUsed.use();
return systemBlocks;
}

// Use string format when not caching (backward compatible)
return systemText;
}

/**
* Add cache control to the last tool for deterministic caching.
*/
private List<AnthropicApi.Tool> addCacheToLastTool(List<AnthropicApi.Tool> tools,
ChatCompletionRequest.CacheControl cacheControl, CacheBreakpointTracker breakpointsUsed) {

if (tools == null || tools.isEmpty() || !breakpointsUsed.canUse() || cacheControl == null) {
return tools;
}

List<AnthropicApi.Tool> modifiedTools = new ArrayList<>();
for (int i = 0; i < tools.size(); i++) {
AnthropicApi.Tool tool = tools.get(i);
if (i == tools.size() - 1) {
// Add cache control to last tool
tool = new AnthropicApi.Tool(tool.name(), tool.description(), tool.inputSchema(), cacheControl);
breakpointsUsed.use();
}
modifiedTools.add(tool);
}
return modifiedTools;
}

@Override
public ChatOptions getDefaultOptions() {
return AnthropicChatOptions.fromOptions(this.defaultOptions);
Expand Down Expand Up @@ -642,4 +804,36 @@ public AnthropicChatModel build() {

}

/**
* Tracks cache breakpoints used (max 4 allowed by Anthropic). Non-static to ensure
* each request has its own instance.
*/
private class CacheBreakpointTracker {

private int count = 0;

private boolean hasWarned = false;

public boolean canUse() {
return this.count < 4;
}

public void use() {
if (this.count < 4) {
this.count++;
}
else if (!this.hasWarned) {
logger.warn(
"Anthropic cache breakpoint limit (4) reached. Additional cache_control directives will be ignored. "
+ "Consider using fewer cache strategies or simpler content structure.");
this.hasWarned = true;
}
}

public int getCount() {
return this.count;
}

}

}
Loading