From a2cad3e4eaaf0208fb8385342d3e7ea1335af6e7 Mon Sep 17 00:00:00 2001 From: Sun Yuhan Date: Wed, 9 Jul 2025 10:59:58 +0800 Subject: [PATCH 1/2] refactor: GH-3620 Add `BedrockChatOptions` to Bedrock Signed-off-by: Sun Yuhan --- .../BedrockConverseProxyChatProperties.java | 12 +- .../tool/FunctionCallWithFunctionBeanIT.java | 8 +- .../FunctionCallWithPromptFunctionIT.java | 6 +- .../bedrock/converse/BedrockChatOptions.java | 346 ++++++++++++++++++ .../converse/BedrockProxyChatModel.java | 37 +- .../converse/BedrockChatOptionsTests.java | 109 ++++++ .../converse/BedrockConverseChatClientIT.java | 7 +- .../BedrockConverseTestConfiguration.java | 3 +- .../BedrockConverseUsageAggregationTests.java | 5 +- .../converse/BedrockProxyChatModelIT.java | 15 +- .../BedrockProxyChatModelObservationIT.java | 9 +- .../client/BedrockNovaChatClientIT.java | 3 +- 12 files changed, 506 insertions(+), 54 deletions(-) create mode 100644 models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockChatOptions.java create mode 100644 models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockChatOptionsTests.java diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-bedrock-ai/src/main/java/org/springframework/ai/model/bedrock/converse/autoconfigure/BedrockConverseProxyChatProperties.java b/auto-configurations/models/spring-ai-autoconfigure-model-bedrock-ai/src/main/java/org/springframework/ai/model/bedrock/converse/autoconfigure/BedrockConverseProxyChatProperties.java index 370e826bd17..a7e430a1588 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-bedrock-ai/src/main/java/org/springframework/ai/model/bedrock/converse/autoconfigure/BedrockConverseProxyChatProperties.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-bedrock-ai/src/main/java/org/springframework/ai/model/bedrock/converse/autoconfigure/BedrockConverseProxyChatProperties.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2024-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ package org.springframework.ai.model.bedrock.converse.autoconfigure; -import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.bedrock.converse.BedrockChatOptions; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; import org.springframework.util.Assert; @@ -33,14 +33,14 @@ public class BedrockConverseProxyChatProperties { public static final String CONFIG_PREFIX = "spring.ai.bedrock.converse.chat"; @NestedConfigurationProperty - private ToolCallingChatOptions options = ToolCallingChatOptions.builder().temperature(0.7).maxTokens(300).build(); + private BedrockChatOptions options = BedrockChatOptions.builder().temperature(0.7).maxTokens(300).build(); - public ToolCallingChatOptions getOptions() { + public BedrockChatOptions getOptions() { return this.options; } - public void setOptions(ToolCallingChatOptions options) { - Assert.notNull(options, "ToolCallingChatOptions must not be null"); + public void setOptions(BedrockChatOptions options) { + Assert.notNull(options, "BedrockChatOptions must not be null"); this.options = options; } diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-bedrock-ai/src/test/java/org/springframework/ai/model/bedrock/converse/autoconfigure/tool/FunctionCallWithFunctionBeanIT.java b/auto-configurations/models/spring-ai-autoconfigure-model-bedrock-ai/src/test/java/org/springframework/ai/model/bedrock/converse/autoconfigure/tool/FunctionCallWithFunctionBeanIT.java index e78cc1e30f2..223fdab4eed 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-bedrock-ai/src/test/java/org/springframework/ai/model/bedrock/converse/autoconfigure/tool/FunctionCallWithFunctionBeanIT.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-bedrock-ai/src/test/java/org/springframework/ai/model/bedrock/converse/autoconfigure/tool/FunctionCallWithFunctionBeanIT.java @@ -23,6 +23,7 @@ import org.junit.jupiter.api.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.bedrock.converse.BedrockChatOptions; import reactor.core.publisher.Flux; import org.springframework.ai.bedrock.converse.BedrockProxyChatModel; @@ -32,7 +33,6 @@ import org.springframework.ai.model.bedrock.autoconfigure.BedrockTestUtils; import org.springframework.ai.model.bedrock.autoconfigure.RequiresAwsCredentials; import org.springframework.ai.model.bedrock.converse.autoconfigure.BedrockConverseProxyChatAutoConfiguration; -import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; @@ -64,14 +64,14 @@ void functionCallTest() { "What's the weather like in San Francisco, in Paris, France and in Tokyo, Japan? Return the temperature in Celsius."); ChatResponse response = chatModel.call(new Prompt(List.of(userMessage), - ToolCallingChatOptions.builder().toolNames("weatherFunction").build())); + BedrockChatOptions.builder().toolNames("weatherFunction").build())); logger.info("Response: {}", response); assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15"); response = chatModel.call(new Prompt(List.of(userMessage), - ToolCallingChatOptions.builder().toolNames("weatherFunction3").build())); + BedrockChatOptions.builder().toolNames("weatherFunction3").build())); logger.info("Response: {}", response); @@ -93,7 +93,7 @@ void functionStreamTest() { "What's the weather like in San Francisco, in Paris, France and in Tokyo, Japan? Return the temperature in Celsius."); Flux responses = chatModel.stream(new Prompt(List.of(userMessage), - ToolCallingChatOptions.builder().toolNames("weatherFunction").build())); + BedrockChatOptions.builder().toolNames("weatherFunction").build())); String content = responses.collectList() .block() diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-bedrock-ai/src/test/java/org/springframework/ai/model/bedrock/converse/autoconfigure/tool/FunctionCallWithPromptFunctionIT.java b/auto-configurations/models/spring-ai-autoconfigure-model-bedrock-ai/src/test/java/org/springframework/ai/model/bedrock/converse/autoconfigure/tool/FunctionCallWithPromptFunctionIT.java index ecc6033f6d7..4974513311f 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-bedrock-ai/src/test/java/org/springframework/ai/model/bedrock/converse/autoconfigure/tool/FunctionCallWithPromptFunctionIT.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-bedrock-ai/src/test/java/org/springframework/ai/model/bedrock/converse/autoconfigure/tool/FunctionCallWithPromptFunctionIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.bedrock.converse.BedrockChatOptions; import org.springframework.ai.bedrock.converse.BedrockProxyChatModel; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; @@ -29,7 +30,6 @@ import org.springframework.ai.model.bedrock.autoconfigure.BedrockTestUtils; import org.springframework.ai.model.bedrock.autoconfigure.RequiresAwsCredentials; import org.springframework.ai.model.bedrock.converse.autoconfigure.BedrockConverseProxyChatAutoConfiguration; -import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -56,7 +56,7 @@ void functionCallTest() { UserMessage userMessage = new UserMessage( "What's the weather like in San Francisco, in Paris and in Tokyo? Return the temperature in Celsius."); - var promptOptions = ToolCallingChatOptions.builder() + var promptOptions = BedrockChatOptions.builder() .toolCallbacks( List.of(FunctionToolCallback.builder("CurrentWeatherService", new MockWeatherService()) .description("Get the weather in location. Return temperature in 36°F or 36°C format.") diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockChatOptions.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockChatOptions.java new file mode 100644 index 00000000000..d6fe4bde1fd --- /dev/null +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockChatOptions.java @@ -0,0 +1,346 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.bedrock.converse; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * The options to be used when sending a chat request to the Bedrock API. + * + * @author Sun Yuhan + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public class BedrockChatOptions implements ToolCallingChatOptions { + + @JsonProperty("model") + private String model; + + @JsonProperty("frequency_penalty") + private Double frequencyPenalty; + + @JsonProperty("max_tokens") + private Integer maxTokens; + + @JsonProperty("presence_penalty") + private Double presencePenalty; + + @JsonProperty("stop_sequences") + private List stopSequences; + + @JsonProperty("temperature") + private Double temperature; + + @JsonProperty("top_k") + private Integer topK; + + @JsonProperty("top_p") + private Double topP; + + @JsonIgnore + private List toolCallbacks = new ArrayList<>(); + + @JsonIgnore + private Set toolNames = new HashSet<>(); + + @JsonIgnore + private Map toolContext = new HashMap<>(); + + @JsonIgnore + private Boolean internalToolExecutionEnabled; + + public static Builder builder() { + return new Builder(); + } + + public static BedrockChatOptions fromOptions(BedrockChatOptions fromOptions) { + fromOptions.getToolNames(); + return builder().model(fromOptions.getModel()) + .frequencyPenalty(fromOptions.getFrequencyPenalty()) + .maxTokens(fromOptions.getMaxTokens()) + .presencePenalty(fromOptions.getPresencePenalty()) + .stopSequences( + fromOptions.getStopSequences() != null ? new ArrayList<>(fromOptions.getStopSequences()) : null) + .temperature(fromOptions.getTemperature()) + .topK(fromOptions.getTopK()) + .topP(fromOptions.getTopP()) + .toolCallbacks(new ArrayList<>(fromOptions.getToolCallbacks())) + .toolNames(new HashSet<>(fromOptions.getToolNames())) + .toolContext(new HashMap<>(fromOptions.getToolContext())) + .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) + .build(); + } + + @Override + public String getModel() { + return this.model; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public Double getFrequencyPenalty() { + return this.frequencyPenalty; + } + + public void setFrequencyPenalty(Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + @Override + public Integer getMaxTokens() { + return this.maxTokens; + } + + public void setMaxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + } + + @Override + public Double getPresencePenalty() { + return this.presencePenalty; + } + + public void setPresencePenalty(Double presencePenalty) { + this.presencePenalty = presencePenalty; + } + + @Override + public List getStopSequences() { + return this.stopSequences; + } + + public void setStopSequences(List stopSequences) { + this.stopSequences = stopSequences; + } + + @Override + public Double getTemperature() { + return this.temperature; + } + + public void setTemperature(Double temperature) { + this.temperature = temperature; + } + + @Override + public Integer getTopK() { + return this.topK; + } + + public void setTopK(Integer topK) { + this.topK = topK; + } + + @Override + public Double getTopP() { + return this.topP; + } + + public void setTopP(Double topP) { + this.topP = topP; + } + + @Override + @JsonIgnore + public List getToolCallbacks() { + return this.toolCallbacks; + } + + @Override + @JsonIgnore + public void setToolCallbacks(List toolCallbacks) { + Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); + this.toolCallbacks = toolCallbacks; + } + + @Override + @JsonIgnore + public Set getToolNames() { + return Set.copyOf(this.toolNames); + } + + @Override + @JsonIgnore + public void setToolNames(Set toolNames) { + Assert.notNull(toolNames, "toolNames cannot be null"); + Assert.noNullElements(toolNames, "toolNames cannot contain null elements"); + toolNames.forEach(toolName -> Assert.hasText(toolName, "toolNames cannot contain empty elements")); + this.toolNames = toolNames; + } + + @Override + @JsonIgnore + public Map getToolContext() { + return this.toolContext; + } + + @Override + @JsonIgnore + public void setToolContext(Map toolContext) { + this.toolContext = toolContext; + } + + @Override + @Nullable + public Boolean getInternalToolExecutionEnabled() { + return this.internalToolExecutionEnabled; + } + + @Override + @JsonIgnore + public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) { + this.internalToolExecutionEnabled = internalToolExecutionEnabled; + } + + @Override + @SuppressWarnings("unchecked") + public BedrockChatOptions copy() { + return fromOptions(this); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof BedrockChatOptions that)) { + return false; + } + return Objects.equals(this.model, that.model) && Objects.equals(this.frequencyPenalty, that.frequencyPenalty) + && Objects.equals(this.maxTokens, that.maxTokens) + && Objects.equals(this.presencePenalty, that.presencePenalty) + && Objects.equals(this.stopSequences, that.stopSequences) + && Objects.equals(this.temperature, that.temperature) && Objects.equals(this.topK, that.topK) + && Objects.equals(this.topP, that.topP) && Objects.equals(this.toolCallbacks, that.toolCallbacks) + && Objects.equals(this.toolNames, that.toolNames) && Objects.equals(this.toolContext, that.toolContext) + && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled); + } + + @Override + public int hashCode() { + return Objects.hash(this.model, this.frequencyPenalty, this.maxTokens, this.presencePenalty, this.stopSequences, + this.temperature, this.topK, this.topP, this.toolCallbacks, this.toolNames, this.toolContext, + this.internalToolExecutionEnabled); + } + + public static class Builder { + + private final BedrockChatOptions options = new BedrockChatOptions(); + + public Builder model(String model) { + this.options.model = model; + return this; + } + + public Builder frequencyPenalty(Double frequencyPenalty) { + this.options.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder maxTokens(Integer maxTokens) { + this.options.maxTokens = maxTokens; + return this; + } + + public Builder presencePenalty(Double presencePenalty) { + this.options.presencePenalty = presencePenalty; + return this; + } + + public Builder stopSequences(List stopSequences) { + this.options.stopSequences = stopSequences; + return this; + } + + public Builder temperature(Double temperature) { + this.options.temperature = temperature; + return this; + } + + public Builder topK(Integer topK) { + this.options.topK = topK; + return this; + } + + public Builder topP(Double topP) { + this.options.topP = topP; + return this; + } + + public Builder toolCallbacks(List toolCallbacks) { + this.options.setToolCallbacks(toolCallbacks); + return this; + } + + public Builder toolCallbacks(ToolCallback... toolCallbacks) { + Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); + this.options.toolCallbacks.addAll(Arrays.asList(toolCallbacks)); + return this; + } + + public Builder toolNames(Set toolNames) { + Assert.notNull(toolNames, "toolNames cannot be null"); + this.options.setToolNames(toolNames); + return this; + } + + public Builder toolNames(String... toolNames) { + Assert.notNull(toolNames, "toolNames cannot be null"); + this.options.toolNames.addAll(Set.of(toolNames)); + return this; + } + + public Builder toolContext(Map toolContext) { + if (this.options.toolContext == null) { + this.options.toolContext = toolContext; + } + else { + this.options.toolContext.putAll(toolContext); + } + return this; + } + + public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) { + this.options.setInternalToolExecutionEnabled(internalToolExecutionEnabled); + return this; + } + + public BedrockChatOptions build() { + return this.options; + } + + } + +} diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java index 484e979385e..071e77a78cb 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java @@ -33,6 +33,11 @@ import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; +import org.springframework.ai.model.tool.ToolExecutionResult; import reactor.core.publisher.Flux; import reactor.core.publisher.Sinks; import reactor.core.publisher.Sinks.EmitFailureHandler; @@ -96,11 +101,6 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.content.Media; import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate; -import org.springframework.ai.model.tool.ToolCallingChatOptions; -import org.springframework.ai.model.tool.ToolCallingManager; -import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate; -import org.springframework.ai.model.tool.ToolExecutionResult; import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.tool.definition.ToolDefinition; @@ -134,6 +134,7 @@ * @author Alexandros Pappas * @author Jihoon Kim * @author Soby Chacko + * @author Sun Yuhan * @since 1.0.0 */ public class BedrockProxyChatModel implements ChatModel { @@ -148,7 +149,7 @@ public class BedrockProxyChatModel implements ChatModel { private final BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient; - private ToolCallingChatOptions defaultOptions; + private final BedrockChatOptions defaultOptions; /** * Observation registry used for instrumentation. @@ -169,14 +170,14 @@ public class BedrockProxyChatModel implements ChatModel { private ChatModelObservationConvention observationConvention; public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient, - BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, ToolCallingChatOptions defaultOptions, + BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, BedrockChatOptions defaultOptions, ObservationRegistry observationRegistry, ToolCallingManager toolCallingManager) { this(bedrockRuntimeClient, bedrockRuntimeAsyncClient, defaultOptions, observationRegistry, toolCallingManager, new DefaultToolExecutionEligibilityPredicate()); } public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient, - BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, ToolCallingChatOptions defaultOptions, + BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, BedrockChatOptions defaultOptions, ObservationRegistry observationRegistry, ToolCallingManager toolCallingManager, ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) { @@ -193,8 +194,8 @@ public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient, this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate; } - private static ToolCallingChatOptions from(ChatOptions options) { - return ToolCallingChatOptions.builder() + private static BedrockChatOptions from(ChatOptions options) { + return BedrockChatOptions.builder() .model(options.getModel()) .maxTokens(options.getMaxTokens()) .stopSequences(options.getStopSequences()) @@ -267,10 +268,10 @@ public ChatOptions getDefaultOptions() { } Prompt buildRequestPrompt(Prompt prompt) { - ToolCallingChatOptions runtimeOptions = null; + BedrockChatOptions runtimeOptions = null; if (prompt.getOptions() != null) { - if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { - runtimeOptions = toolCallingChatOptions.copy(); + if (prompt.getOptions() instanceof BedrockChatOptions bedrockChatOptions) { + runtimeOptions = bedrockChatOptions.copy(); } else { runtimeOptions = from(prompt.getOptions()); @@ -278,7 +279,7 @@ Prompt buildRequestPrompt(Prompt prompt) { } // Merge runtime options with the default options - ToolCallingChatOptions updatedRuntimeOptions = null; + BedrockChatOptions updatedRuntimeOptions = null; if (runtimeOptions == null) { updatedRuntimeOptions = this.defaultOptions.copy(); } @@ -292,7 +293,7 @@ Prompt buildRequestPrompt(Prompt prompt) { if (runtimeOptions.getTopK() != null) { logger.warn("The topK option is not supported by BedrockProxyChatModel. Ignoring."); } - updatedRuntimeOptions = ToolCallingChatOptions.builder() + updatedRuntimeOptions = BedrockChatOptions.builder() .model(runtimeOptions.getModel() != null ? runtimeOptions.getModel() : this.defaultOptions.getModel()) .maxTokens(runtimeOptions.getMaxTokens() != null ? runtimeOptions.getMaxTokens() : this.defaultOptions.getMaxTokens()) @@ -388,7 +389,7 @@ else if (message.getMessageType() == MessageType.TOOL) { .map(sysMessage -> SystemContentBlock.builder().text(sysMessage.getText()).build()) .toList(); - ToolCallingChatOptions updatedRuntimeOptions = prompt.getOptions().copy(); + BedrockChatOptions updatedRuntimeOptions = prompt.getOptions().copy(); ToolConfiguration toolConfiguration = null; @@ -787,7 +788,7 @@ public static final class Builder { private ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate = new DefaultToolExecutionEligibilityPredicate(); - private ToolCallingChatOptions defaultOptions = ToolCallingChatOptions.builder().build(); + private BedrockChatOptions defaultOptions = BedrockChatOptions.builder().build(); private ObservationRegistry observationRegistry = ObservationRegistry.NOOP; @@ -835,7 +836,7 @@ public Builder timeout(Duration timeout) { return this; } - public Builder defaultOptions(ToolCallingChatOptions defaultOptions) { + public Builder defaultOptions(BedrockChatOptions defaultOptions) { Assert.notNull(defaultOptions, "'defaultOptions' must not be null."); this.defaultOptions = defaultOptions; return this; diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockChatOptionsTests.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockChatOptionsTests.java new file mode 100644 index 00000000000..aed48c1a3b5 --- /dev/null +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockChatOptionsTests.java @@ -0,0 +1,109 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.bedrock.converse; + +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link BedrockChatOptions}. + * + * @author Sun Yuhan + */ +class BedrockChatOptionsTests { + + @Test + void testBuilderWithAllFields() { + BedrockChatOptions options = BedrockChatOptions.builder() + .model("test-model") + .frequencyPenalty(0.0) + .maxTokens(100) + .presencePenalty(0.0) + .stopSequences(List.of("stop1", "stop2")) + .temperature(0.7) + .topP(0.8) + .topK(50) + .build(); + + assertThat(options) + .extracting("model", "frequencyPenalty", "maxTokens", "presencePenalty", "stopSequences", "temperature", + "topP", "topK") + .containsExactly("test-model", 0.0, 100, 0.0, List.of("stop1", "stop2"), 0.7, 0.8, 50); + } + + @Test + void testCopy() { + BedrockChatOptions original = BedrockChatOptions.builder() + .model("test-model") + .frequencyPenalty(0.0) + .maxTokens(100) + .presencePenalty(0.0) + .stopSequences(List.of("stop1", "stop2")) + .temperature(0.7) + .topP(0.8) + .topK(50) + .toolContext(Map.of("key1", "value1")) + .build(); + + BedrockChatOptions copied = original.copy(); + + assertThat(copied).isNotSameAs(original).isEqualTo(original); + // Ensure deep copy + assertThat(copied.getStopSequences()).isNotSameAs(original.getStopSequences()); + assertThat(copied.getToolContext()).isNotSameAs(original.getToolContext()); + } + + @Test + void testSetters() { + BedrockChatOptions options = new BedrockChatOptions(); + options.setModel("test-model"); + options.setFrequencyPenalty(0.0); + options.setMaxTokens(100); + options.setPresencePenalty(0.0); + options.setTemperature(0.7); + options.setTopK(50); + options.setTopP(0.8); + options.setStopSequences(List.of("stop1", "stop2")); + + assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getFrequencyPenalty()).isEqualTo(0.0); + assertThat(options.getMaxTokens()).isEqualTo(100); + assertThat(options.getPresencePenalty()).isEqualTo(0.0); + assertThat(options.getTemperature()).isEqualTo(0.7); + assertThat(options.getTopK()).isEqualTo(50); + assertThat(options.getTopP()).isEqualTo(0.8); + assertThat(options.getStopSequences()).isEqualTo(List.of("stop1", "stop2")); + } + + @Test + void testDefaultValues() { + BedrockChatOptions options = new BedrockChatOptions(); + assertThat(options.getModel()).isNull(); + assertThat(options.getFrequencyPenalty()).isNull(); + assertThat(options.getMaxTokens()).isNull(); + assertThat(options.getPresencePenalty()).isNull(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getTopK()).isNull(); + assertThat(options.getTopP()).isNull(); + assertThat(options.getStopSequences()).isNull(); + } + +} diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java index f6c7733a431..fa5a292a535 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java @@ -36,7 +36,6 @@ import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; -import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.test.CurlyBracketEscaper; import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.beans.factory.annotation.Autowired; @@ -367,7 +366,7 @@ void multiModalityEmbeddedImage(String modelName) throws IOException { // @formatter:off String response = ChatClient.create(this.chatModel).prompt() - .options(ToolCallingChatOptions.builder().model(modelName).build()) + .options(BedrockChatOptions.builder().model(modelName).build()) .user(u -> u.text("Explain what do you see on this picture?") .media(MimeTypeUtils.IMAGE_PNG, new ClassPathResource("/test.png"))) .call() @@ -388,7 +387,7 @@ void multiModalityImageUrl2(String modelName) throws IOException { // @formatter:off String response = ChatClient.create(this.chatModel).prompt() // TODO consider adding model(...) method to ChatClient as a shortcut to - .options(ToolCallingChatOptions.builder().model(modelName).build()) + .options(BedrockChatOptions.builder().model(modelName).build()) .user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, url)) .call() .content(); @@ -408,7 +407,7 @@ void multiModalityImageUrl(String modelName) throws IOException { // @formatter:off String response = ChatClient.create(this.chatModel).prompt() // TODO consider adding model(...) method to ChatClient as a shortcut to - .options(ToolCallingChatOptions.builder().model(modelName).build()) + .options(BedrockChatOptions.builder().model(modelName).build()) .user(u -> u.text("Explain what do you see on this picture?").media(MimeTypeUtils.IMAGE_PNG, url)) .call() .content(); diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseTestConfiguration.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseTestConfiguration.java index a42a4beecba..05992349a01 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseTestConfiguration.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseTestConfiguration.java @@ -21,7 +21,6 @@ import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; -import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.boot.SpringBootConfiguration; import org.springframework.context.annotation.Bean; @@ -42,7 +41,7 @@ public BedrockProxyChatModel bedrockConverseChatModel() { .region(Region.US_EAST_1) // .region(Region.US_EAST_1) .timeout(Duration.ofSeconds(120)) - .defaultOptions(ToolCallingChatOptions.builder().model(modelId).build()) + .defaultOptions(BedrockChatOptions.builder().model(modelId).build()) .build(); } diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseUsageAggregationTests.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseUsageAggregationTests.java index 09d7c83d675..9cc13f17572 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseUsageAggregationTests.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseUsageAggregationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2024-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,7 +36,6 @@ import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlock; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.function.FunctionToolCallback; @@ -143,7 +142,7 @@ public void callWithToolUse() { .build(); var result = this.chatModel.call(new Prompt("What is the weather in Paris?", - ToolCallingChatOptions.builder().toolCallbacks(toolCallback).build())); + BedrockChatOptions.builder().toolCallbacks(toolCallback).build())); assertThat(result).isNotNull(); assertThat(result.getResult().getOutput().getText()) diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java index 65d85c1a2a4..bd3e07c1d77 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,7 +46,6 @@ import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; -import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; @@ -90,7 +89,7 @@ void roleTest(String modelName) { SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage), - ToolCallingChatOptions.builder().model(modelName).build()); + BedrockChatOptions.builder().model(modelName).build()); ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResults()).hasSize(1); assertThat(response.getMetadata().getUsage().getCompletionTokens()).isGreaterThan(0); @@ -126,7 +125,7 @@ void testMessageHistory() { @Test void streamingWithTokenUsage() { - var promptOptions = ToolCallingChatOptions.builder().temperature(0.0).build(); + var promptOptions = BedrockChatOptions.builder().temperature(0.0).build(); var prompt = new Prompt("List two colors of the Polish flag. Be brief.", promptOptions); var streamingTokenUsage = this.chatModel.stream(prompt).blockLast().getMetadata().getUsage(); @@ -265,7 +264,7 @@ void functionCallTest() { List messages = new ArrayList<>(List.of(userMessage)); - var promptOptions = ToolCallingChatOptions.builder() + var promptOptions = BedrockChatOptions.builder() .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description("Get the weather in location. Return in 36°C format") .inputType(MockWeatherService.Request.class) @@ -290,7 +289,7 @@ void streamFunctionCallTest() { List messages = new ArrayList<>(List.of(userMessage)); - var promptOptions = ToolCallingChatOptions.builder() + var promptOptions = BedrockChatOptions.builder() .model("anthropic.claude-3-5-sonnet-20240620-v1:0") .toolCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) .description( @@ -317,7 +316,7 @@ void validateCallResponseMetadata() { String model = "anthropic.claude-3-5-sonnet-20240620-v1:0"; // @formatter:off ChatResponse response = ChatClient.create(this.chatModel).prompt() - .options(ToolCallingChatOptions.builder().model(model).build()) + .options(BedrockChatOptions.builder().model(model).build()) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") .call() .chatResponse(); @@ -332,7 +331,7 @@ void validateStreamCallResponseMetadata() { String model = "anthropic.claude-3-5-sonnet-20240620-v1:0"; // @formatter:off ChatResponse response = ChatClient.create(this.chatModel).prompt() - .options(ToolCallingChatOptions.builder().model(model).build()) + .options(BedrockChatOptions.builder().model(model).build()) .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") .stream() .chatResponse() diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelObservationIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelObservationIT.java index fb1b9c3077e..1824a1b84d2 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelObservationIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModelObservationIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,7 +34,6 @@ import org.springframework.ai.chat.observation.ChatModelObservationDocumentation.LowCardinalityKeyNames; import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention; import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.observation.conventions.AiOperationType; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.beans.factory.annotation.Autowired; @@ -67,7 +66,7 @@ void beforeEach() { @Test void observationForChatOperation() { - var options = ToolCallingChatOptions.builder() + var options = BedrockChatOptions.builder() .model("anthropic.claude-3-5-sonnet-20240620-v1:0") .maxTokens(2048) .stopSequences(List.of("this-is-the-end")) @@ -89,7 +88,7 @@ void observationForChatOperation() { @Test void observationForStreamingChatOperation() { - var options = ToolCallingChatOptions.builder() + var options = BedrockChatOptions.builder() .model("anthropic.claude-3-5-sonnet-20240620-v1:0") .maxTokens(2048) .stopSequences(List.of("this-is-the-end")) @@ -173,7 +172,7 @@ public BedrockProxyChatModel bedrockConverseChatModel(ObservationRegistry observ .credentialsProvider(EnvironmentVariableCredentialsProvider.create()) .region(Region.US_EAST_1) .observationRegistry(observationRegistry) - .defaultOptions(ToolCallingChatOptions.builder().model(modelId).build()) + .defaultOptions(BedrockChatOptions.builder().model(modelId).build()) .build(); } diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java index 61d3b1da882..fa99516ff16 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/client/BedrockNovaChatClientIT.java @@ -27,6 +27,7 @@ import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.bedrock.converse.BedrockChatOptions; import reactor.core.publisher.Flux; import software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider; import software.amazon.awssdk.regions.Region; @@ -268,7 +269,7 @@ public BedrockProxyChatModel bedrockConverseChatModel() { .credentialsProvider(EnvironmentVariableCredentialsProvider.create()) .region(Region.US_EAST_1) .timeout(Duration.ofSeconds(120)) - .defaultOptions(ToolCallingChatOptions.builder().model(modelId).build()) + .defaultOptions(BedrockChatOptions.builder().model(modelId).build()) .build(); } From 71acd70b232f11fda3da0239b69be7f398183f19 Mon Sep 17 00:00:00 2001 From: Sun Yuhan Date: Thu, 10 Jul 2025 20:48:58 +0800 Subject: [PATCH 2/2] fix: Adjust the parameter format to camel case style. Signed-off-by: Sun Yuhan --- .../ai/bedrock/converse/BedrockChatOptions.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockChatOptions.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockChatOptions.java index d6fe4bde1fd..776cba66d58 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockChatOptions.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockChatOptions.java @@ -44,25 +44,25 @@ public class BedrockChatOptions implements ToolCallingChatOptions { @JsonProperty("model") private String model; - @JsonProperty("frequency_penalty") + @JsonProperty("frequencyPenalty") private Double frequencyPenalty; - @JsonProperty("max_tokens") + @JsonProperty("maxTokens") private Integer maxTokens; - @JsonProperty("presence_penalty") + @JsonProperty("presencePenalty") private Double presencePenalty; - @JsonProperty("stop_sequences") + @JsonProperty("stopSequences") private List stopSequences; @JsonProperty("temperature") private Double temperature; - @JsonProperty("top_k") + @JsonProperty("topK") private Integer topK; - @JsonProperty("top_p") + @JsonProperty("topP") private Double topP; @JsonIgnore