From 8b199c9793134d7ad550702f9720f0418301e8ca Mon Sep 17 00:00:00 2001 From: Sergey Beryozkin Date: Fri, 2 May 2025 17:38:01 +0100 Subject: [PATCH] Make Gemini API version configurable --- .../gemini/aiservices/GeminiResource.java | 2 +- ...hatLanguageModelAuthProviderSmokeTest.java | 2 +- .../AiGeminiChatLanguageModelSmokeTest.java | 2 +- ...eminiChatLanguageModelV1BetaSmokeTest.java | 112 +++++++++++++++++ ...niEmbeddingModelAuthProviderSmokeTest.java | 4 +- .../AiGeminiEmbeddingModelSmokeTest.java | 4 +- ...AiGeminiEmbeddingModelV1BetaSmokeTest.java | 115 ++++++++++++++++++ .../gemini/AiGeminiChatLanguageModel.java | 10 ++ .../gemini/AiGeminiEmbeddingModel.java | 10 ++ .../ai/runtime/gemini/AiGeminiRecorder.java | 2 + .../ai/runtime/gemini/AiGeminiRestApi.java | 2 +- .../config/LangChain4jAiGeminiConfig.java | 6 + .../VertexAiGeminiChatLanguageModel.java | 10 ++ .../gemini/VertexAiGeminiEmbeddingModel.java | 10 ++ .../gemini/VertexAiGeminiRecorder.java | 2 + .../runtime/gemini/VertxAiGeminiRestApi.java | 2 +- .../LangChain4jVertexAiGeminiConfig.java | 6 + 17 files changed, 292 insertions(+), 9 deletions(-) create mode 100644 model-providers/google/gemini/ai-gemini/deployment/src/test/java/io/quarkiverse/langchain4j/ai/gemini/deployment/AiGeminiChatLanguageModelV1BetaSmokeTest.java create mode 100644 model-providers/google/gemini/ai-gemini/deployment/src/test/java/io/quarkiverse/langchain4j/ai/gemini/deployment/AiGeminiEmbeddingModelV1BetaSmokeTest.java diff --git a/integration-tests/ai-gemini/src/main/java/org/acme/example/gemini/aiservices/GeminiResource.java b/integration-tests/ai-gemini/src/main/java/org/acme/example/gemini/aiservices/GeminiResource.java index 8f0b621a8..46274dc80 100644 --- a/integration-tests/ai-gemini/src/main/java/org/acme/example/gemini/aiservices/GeminiResource.java +++ b/integration-tests/ai-gemini/src/main/java/org/acme/example/gemini/aiservices/GeminiResource.java @@ -16,7 +16,7 @@ public class GeminiResource { @POST - @Path("v1beta/models/gemini-1.5-flash:generateContent") + @Path("v1/models/gemini-1.5-flash:generateContent") @Produces("application/json") @Consumes("application/json") public String generateResponse(String generateRequest, @RestQuery String key) { diff --git a/model-providers/google/gemini/ai-gemini/deployment/src/test/java/io/quarkiverse/langchain4j/ai/gemini/deployment/AiGeminiChatLanguageModelAuthProviderSmokeTest.java b/model-providers/google/gemini/ai-gemini/deployment/src/test/java/io/quarkiverse/langchain4j/ai/gemini/deployment/AiGeminiChatLanguageModelAuthProviderSmokeTest.java index 130dfe2cc..ab6079c5c 100644 --- a/model-providers/google/gemini/ai-gemini/deployment/src/test/java/io/quarkiverse/langchain4j/ai/gemini/deployment/AiGeminiChatLanguageModelAuthProviderSmokeTest.java +++ b/model-providers/google/gemini/ai-gemini/deployment/src/test/java/io/quarkiverse/langchain4j/ai/gemini/deployment/AiGeminiChatLanguageModelAuthProviderSmokeTest.java @@ -43,7 +43,7 @@ void test() { wiremock().register( post(urlEqualTo( - String.format("/v1beta/models/%s:generateContent", CHAT_MODEL_ID))) + String.format("/v1/models/%s:generateContent", CHAT_MODEL_ID))) .withHeader("Authorization", equalTo("Bearer " + API_KEY)) .willReturn(aResponse() .withHeader("Content-Type", "application/json") diff --git a/model-providers/google/gemini/ai-gemini/deployment/src/test/java/io/quarkiverse/langchain4j/ai/gemini/deployment/AiGeminiChatLanguageModelSmokeTest.java b/model-providers/google/gemini/ai-gemini/deployment/src/test/java/io/quarkiverse/langchain4j/ai/gemini/deployment/AiGeminiChatLanguageModelSmokeTest.java index 6d9f7f222..288e9bfed 100644 --- a/model-providers/google/gemini/ai-gemini/deployment/src/test/java/io/quarkiverse/langchain4j/ai/gemini/deployment/AiGeminiChatLanguageModelSmokeTest.java +++ b/model-providers/google/gemini/ai-gemini/deployment/src/test/java/io/quarkiverse/langchain4j/ai/gemini/deployment/AiGeminiChatLanguageModelSmokeTest.java @@ -41,7 +41,7 @@ void test() { wiremock().register( post(urlEqualTo( - String.format("/v1beta/models/%s:generateContent?key=%s", + String.format("/v1/models/%s:generateContent?key=%s", CHAT_MODEL_ID, API_KEY))) .willReturn(aResponse() .withHeader("Content-Type", "application/json") diff --git a/model-providers/google/gemini/ai-gemini/deployment/src/test/java/io/quarkiverse/langchain4j/ai/gemini/deployment/AiGeminiChatLanguageModelV1BetaSmokeTest.java b/model-providers/google/gemini/ai-gemini/deployment/src/test/java/io/quarkiverse/langchain4j/ai/gemini/deployment/AiGeminiChatLanguageModelV1BetaSmokeTest.java new file mode 100644 index 000000000..4936137cd --- /dev/null +++ b/model-providers/google/gemini/ai-gemini/deployment/src/test/java/io/quarkiverse/langchain4j/ai/gemini/deployment/AiGeminiChatLanguageModelV1BetaSmokeTest.java @@ -0,0 +1,112 @@ +package io.quarkiverse.langchain4j.ai.gemini.deployment; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static org.assertj.core.api.Assertions.assertThat; + +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import com.github.tomakehurst.wiremock.verification.LoggedRequest; + +import dev.langchain4j.model.chat.ChatLanguageModel; +import io.quarkiverse.langchain4j.ai.runtime.gemini.AiGeminiChatLanguageModel; +import io.quarkiverse.langchain4j.testing.internal.WiremockAware; +import io.quarkus.arc.ClientProxy; +import io.quarkus.test.QuarkusUnitTest; + +public class AiGeminiChatLanguageModelV1BetaSmokeTest extends WiremockAware { + + private static final String API_VERSION = "v1Beta"; + private static final String API_KEY = "dummy"; + private static final String CHAT_MODEL_ID = "gemini-1.5-flash"; + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)) + .overrideRuntimeConfigKey("quarkus.langchain4j.ai.gemini.base-url", WiremockAware.wiremockUrlForConfig()) + .overrideRuntimeConfigKey("quarkus.langchain4j.ai.gemini.api-version", API_VERSION) + .overrideRuntimeConfigKey("quarkus.langchain4j.ai.gemini.api-key", API_KEY) + .overrideRuntimeConfigKey("quarkus.langchain4j.ai.gemini.log-requests", "true"); + + @Inject + ChatLanguageModel chatLanguageModel; + + @Test + void test() { + assertThat(ClientProxy.unwrap(chatLanguageModel)).isInstanceOf(AiGeminiChatLanguageModel.class); + + wiremock().register( + post(urlEqualTo( + String.format("/%s/models/%s:generateContent?key=%s", + API_VERSION, CHAT_MODEL_ID, API_KEY))) + .willReturn(aResponse() + .withHeader("Content-Type", "application/json") + .withBody(""" + { + "candidates": [ + { + "content": { + "role": "model", + "parts": [ + { + "text": "Nice to meet you" + } + ] + }, + "finishReason": "STOP", + "safetyRatings": [ + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.044847902, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.05592617 + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.18877223, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.027324531 + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.15278918, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.045437217 + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.15869519, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.036838707 + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 11, + "candidatesTokenCount": 37, + "totalTokenCount": 48 + } + } + """))); + + String response = chatLanguageModel.chat("hello"); + assertThat(response).isEqualTo("Nice to meet you"); + + LoggedRequest loggedRequest = singleLoggedRequest(); + assertThat(loggedRequest.getHeader("User-Agent")).isEqualTo("Quarkus REST Client"); + String requestBody = new String(loggedRequest.getBody()); + assertThat(requestBody).contains("hello"); + } + +} diff --git a/model-providers/google/gemini/ai-gemini/deployment/src/test/java/io/quarkiverse/langchain4j/ai/gemini/deployment/AiGeminiEmbeddingModelAuthProviderSmokeTest.java b/model-providers/google/gemini/ai-gemini/deployment/src/test/java/io/quarkiverse/langchain4j/ai/gemini/deployment/AiGeminiEmbeddingModelAuthProviderSmokeTest.java index b0c335934..e290b9cb6 100644 --- a/model-providers/google/gemini/ai-gemini/deployment/src/test/java/io/quarkiverse/langchain4j/ai/gemini/deployment/AiGeminiEmbeddingModelAuthProviderSmokeTest.java +++ b/model-providers/google/gemini/ai-gemini/deployment/src/test/java/io/quarkiverse/langchain4j/ai/gemini/deployment/AiGeminiEmbeddingModelAuthProviderSmokeTest.java @@ -44,7 +44,7 @@ public class AiGeminiEmbeddingModelAuthProviderSmokeTest extends WiremockAware { void testBatch() { wiremock().register( post(urlEqualTo( - String.format("/v1beta/models/%s:batchEmbedContents", EMBED_MODEL_ID))) + String.format("/v1/models/%s:batchEmbedContents", EMBED_MODEL_ID))) .withHeader("Authorization", equalTo("Bearer " + API_KEY)) .willReturn(aResponse() .withHeader("Content-Type", "application/json") @@ -89,7 +89,7 @@ void test() { wiremock().register( post(urlEqualTo( - String.format("/v1beta/models/%s:embedContent", EMBED_MODEL_ID))) + String.format("/v1/models/%s:embedContent", EMBED_MODEL_ID))) .withHeader("Authorization", equalTo("Bearer " + API_KEY)) .willReturn(aResponse() .withHeader("Content-Type", "application/json") diff --git a/model-providers/google/gemini/ai-gemini/deployment/src/test/java/io/quarkiverse/langchain4j/ai/gemini/deployment/AiGeminiEmbeddingModelSmokeTest.java b/model-providers/google/gemini/ai-gemini/deployment/src/test/java/io/quarkiverse/langchain4j/ai/gemini/deployment/AiGeminiEmbeddingModelSmokeTest.java index 8998f0fca..6482b9b9e 100644 --- a/model-providers/google/gemini/ai-gemini/deployment/src/test/java/io/quarkiverse/langchain4j/ai/gemini/deployment/AiGeminiEmbeddingModelSmokeTest.java +++ b/model-providers/google/gemini/ai-gemini/deployment/src/test/java/io/quarkiverse/langchain4j/ai/gemini/deployment/AiGeminiEmbeddingModelSmokeTest.java @@ -42,7 +42,7 @@ public class AiGeminiEmbeddingModelSmokeTest extends WiremockAware { void testBatch() { wiremock().register( post(urlEqualTo( - String.format("/v1beta/models/%s:batchEmbedContents?key=%s", + String.format("/v1/models/%s:batchEmbedContents?key=%s", EMBED_MODEL_ID, API_KEY))) .willReturn(aResponse() .withHeader("Content-Type", "application/json") @@ -87,7 +87,7 @@ void test() { wiremock().register( post(urlEqualTo( - String.format("/v1beta/models/%s:embedContent?key=%s", + String.format("/v1/models/%s:embedContent?key=%s", EMBED_MODEL_ID, API_KEY))) .willReturn(aResponse() .withHeader("Content-Type", "application/json") diff --git a/model-providers/google/gemini/ai-gemini/deployment/src/test/java/io/quarkiverse/langchain4j/ai/gemini/deployment/AiGeminiEmbeddingModelV1BetaSmokeTest.java b/model-providers/google/gemini/ai-gemini/deployment/src/test/java/io/quarkiverse/langchain4j/ai/gemini/deployment/AiGeminiEmbeddingModelV1BetaSmokeTest.java new file mode 100644 index 000000000..8b45ec673 --- /dev/null +++ b/model-providers/google/gemini/ai-gemini/deployment/src/test/java/io/quarkiverse/langchain4j/ai/gemini/deployment/AiGeminiEmbeddingModelV1BetaSmokeTest.java @@ -0,0 +1,115 @@ +package io.quarkiverse.langchain4j.ai.gemini.deployment; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; + +import jakarta.inject.Inject; + +import org.jboss.shrinkwrap.api.ShrinkWrap; +import org.jboss.shrinkwrap.api.spec.JavaArchive; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.output.Response; +import io.quarkiverse.langchain4j.ai.runtime.gemini.AiGeminiEmbeddingModel; +import io.quarkiverse.langchain4j.testing.internal.WiremockAware; +import io.quarkus.arc.ClientProxy; +import io.quarkus.test.QuarkusUnitTest; + +public class AiGeminiEmbeddingModelV1BetaSmokeTest extends WiremockAware { + + private static final String API_VERSION = "v1Beta"; + private static final String API_KEY = "dummy"; + private static final String EMBED_MODEL_ID = "text-embedding-004"; + + @RegisterExtension + static final QuarkusUnitTest unitTest = new QuarkusUnitTest() + .setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)) + .overrideRuntimeConfigKey("quarkus.langchain4j.ai.gemini.base-url", WiremockAware.wiremockUrlForConfig()) + .overrideRuntimeConfigKey("quarkus.langchain4j.ai.gemini.api-version", API_VERSION) + .overrideRuntimeConfigKey("quarkus.langchain4j.ai.gemini.api-key", API_KEY) + .overrideRuntimeConfigKey("quarkus.langchain4j.ai.gemini.log-requests", "true"); + + @Inject + EmbeddingModel embeddingModel; + + @Test + void testBatch() { + wiremock().register( + post(urlEqualTo( + String.format("/%s/models/%s:batchEmbedContents?key=%s", + API_VERSION, EMBED_MODEL_ID, API_KEY))) + .willReturn(aResponse() + .withHeader("Content-Type", "application/json") + .withBody(""" + { + "embeddings": [ + { + "values": [ + -0.010632273, + 0.019375853, + 0.020965198, + 0.0007706437, + -0.061464068, + -0.007153866, + -0.028534686 + ] + }, + { + "values": [ + 0.018468002, + 0.0054281265, + -0.017658807, + 0.013859263, + 0.05341865, + 0.026714388, + 0.0018762478 + ] + } + ] + } + """))); + + List textSegments = List.of(TextSegment.from("Hello"), TextSegment.from("Bye")); + Response> response = embeddingModel.embedAll(textSegments); + + assertThat(response.content()).hasSize(2); + } + + @Test + void test() { + assertThat(ClientProxy.unwrap(embeddingModel)).isInstanceOf(AiGeminiEmbeddingModel.class); + + wiremock().register( + post(urlEqualTo( + String.format("/%s/models/%s:embedContent?key=%s", + API_VERSION, EMBED_MODEL_ID, API_KEY))) + .willReturn(aResponse() + .withHeader("Content-Type", "application/json") + .withBody(""" + { + "embedding": { + "values": [ + 0.013168517, + -0.00871193, + -0.046782672, + 0.00069969177, + -0.009518872, + -0.008720178, + 0.06010358 + ] + } + } + """))); + + float[] response = embeddingModel.embed("Hello World").content().vector(); + assertThat(response).hasSize(7); + } +} diff --git a/model-providers/google/gemini/ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/ai/runtime/gemini/AiGeminiChatLanguageModel.java b/model-providers/google/gemini/ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/ai/runtime/gemini/AiGeminiChatLanguageModel.java index 3a81b8bd3..2b825c907 100644 --- a/model-providers/google/gemini/ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/ai/runtime/gemini/AiGeminiChatLanguageModel.java +++ b/model-providers/google/gemini/ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/ai/runtime/gemini/AiGeminiChatLanguageModel.java @@ -35,6 +35,10 @@ private AiGeminiChatLanguageModel(Builder builder) { try { String baseUrl = builder.baseUrl.orElse("https://generativelanguage.googleapis.com"); + if (!baseUrl.endsWith("/")) { + baseUrl += "/"; + } + baseUrl += builder.apiVersion; var restApiBuilder = QuarkusRestClientBuilder.newBuilder() .baseUri(new URI(baseUrl)) .connectTimeout(builder.timeout.toSeconds(), TimeUnit.SECONDS) @@ -70,6 +74,7 @@ public static final class Builder { private String configName; private Optional baseUrl = Optional.empty(); + private String apiVersion; private String modelId; private String key; private Double temperature; @@ -92,6 +97,11 @@ public Builder baseUrl(Optional baseUrl) { return this; } + public Builder apiVersion(String apiVersion) { + this.apiVersion = apiVersion; + return this; + } + public Builder key(String key) { this.key = key; return this; diff --git a/model-providers/google/gemini/ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/ai/runtime/gemini/AiGeminiEmbeddingModel.java b/model-providers/google/gemini/ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/ai/runtime/gemini/AiGeminiEmbeddingModel.java index a61484174..52123abab 100644 --- a/model-providers/google/gemini/ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/ai/runtime/gemini/AiGeminiEmbeddingModel.java +++ b/model-providers/google/gemini/ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/ai/runtime/gemini/AiGeminiEmbeddingModel.java @@ -32,6 +32,10 @@ public AiGeminiEmbeddingModel(Builder builder) { try { String baseUrl = builder.baseUrl.orElse("https://generativelanguage.googleapis.com"); + if (!baseUrl.endsWith("/")) { + baseUrl += "/"; + } + baseUrl += builder.apiVersion; var restApiBuilder = QuarkusRestClientBuilder.newBuilder() .baseUri(new URI(baseUrl)) .connectTimeout(builder.timeout.toSeconds(), TimeUnit.SECONDS) @@ -68,6 +72,7 @@ public static Builder builder() { public static final class Builder { private String configName; private Optional baseUrl = Optional.empty(); + private String apiVersion; private String modelId; private String key; private Integer dimension; @@ -81,6 +86,11 @@ public Builder configName(String configName) { return this; } + public Builder apiVersion(String apiVersion) { + this.apiVersion = apiVersion; + return this; + } + public Builder baseUrl(Optional baseUrl) { this.baseUrl = baseUrl; return this; diff --git a/model-providers/google/gemini/ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/ai/runtime/gemini/AiGeminiRecorder.java b/model-providers/google/gemini/ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/ai/runtime/gemini/AiGeminiRecorder.java index 161de8ac1..e802b5172 100644 --- a/model-providers/google/gemini/ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/ai/runtime/gemini/AiGeminiRecorder.java +++ b/model-providers/google/gemini/ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/ai/runtime/gemini/AiGeminiRecorder.java @@ -41,6 +41,7 @@ public Function, EmbeddingModel> embe var builder = AiGeminiEmbeddingModel.builder() .configName(configName) .baseUrl(baseUrl) + .apiVersion(aiConfig.apiVersion()) .key(apiKey) .modelId(embeddingModelConfig.modelId()) .logRequests(firstOrDefault(false, embeddingModelConfig.logRequests(), aiConfig.logRequests())) @@ -84,6 +85,7 @@ public Function, ChatLanguageModel String apiKey = aiConfig.apiKey().orElse(null); var builder = AiGeminiChatLanguageModel.builder() .baseUrl(baseUrl) + .apiVersion(aiConfig.apiVersion()) .key(apiKey) .modelId(chatModelConfig.modelId()) .maxOutputTokens(chatModelConfig.maxOutputTokens()) diff --git a/model-providers/google/gemini/ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/ai/runtime/gemini/AiGeminiRestApi.java b/model-providers/google/gemini/ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/ai/runtime/gemini/AiGeminiRestApi.java index 827be3fcf..d1f45c24d 100644 --- a/model-providers/google/gemini/ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/ai/runtime/gemini/AiGeminiRestApi.java +++ b/model-providers/google/gemini/ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/ai/runtime/gemini/AiGeminiRestApi.java @@ -27,7 +27,7 @@ import io.vertx.core.http.HttpClientRequest; import io.vertx.core.http.HttpClientResponse; -@Path("v1beta/models/") +@Path("models/") public interface AiGeminiRestApi { @Path("{modelId}:batchEmbedContents") diff --git a/model-providers/google/gemini/ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/ai/runtime/gemini/config/LangChain4jAiGeminiConfig.java b/model-providers/google/gemini/ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/ai/runtime/gemini/config/LangChain4jAiGeminiConfig.java index fcec7f563..38bd4a077 100644 --- a/model-providers/google/gemini/ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/ai/runtime/gemini/config/LangChain4jAiGeminiConfig.java +++ b/model-providers/google/gemini/ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/ai/runtime/gemini/config/LangChain4jAiGeminiConfig.java @@ -53,6 +53,12 @@ interface AiGeminiConfig { */ Optional baseUrl(); + /** + * The API version to use for this operation. + */ + @WithDefault("v1") + String apiVersion(); + /** * Whether to enable the integration. Defaults to {@code true}, which means requests are made to the Vertex AI Gemini * provider. diff --git a/model-providers/google/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiChatLanguageModel.java b/model-providers/google/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiChatLanguageModel.java index cd2f833d2..5fded8e40 100644 --- a/model-providers/google/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiChatLanguageModel.java +++ b/model-providers/google/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiChatLanguageModel.java @@ -37,6 +37,10 @@ private VertexAiGeminiChatLanguageModel(Builder builder) { try { String baseUrl = builder.baseUrl.orElse(String.format("https://%s-aiplatform.googleapis.com", builder.location)); + if (!baseUrl.endsWith("/")) { + baseUrl += "/"; + } + baseUrl += builder.apiVersion; var restApiBuilder = QuarkusRestClientBuilder.newBuilder() .baseUri(new URI(baseUrl)) .connectTimeout(builder.timeout.toSeconds(), TimeUnit.SECONDS) @@ -66,6 +70,7 @@ public static Builder builder() { public static final class Builder { private Optional baseUrl = Optional.empty(); + private String apiVersion; private String projectId; private String location; private String modelId; @@ -85,6 +90,11 @@ public Builder baseUrl(Optional baseUrl) { return this; } + public Builder apiVersion(String apiVersion) { + this.apiVersion = apiVersion; + return this; + } + public Builder projectId(String projectId) { this.projectId = projectId; return this; diff --git a/model-providers/google/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiEmbeddingModel.java b/model-providers/google/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiEmbeddingModel.java index 7ba89f2dd..d465429d9 100644 --- a/model-providers/google/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiEmbeddingModel.java +++ b/model-providers/google/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiEmbeddingModel.java @@ -34,6 +34,10 @@ public VertexAiGeminiEmbeddingModel(Builder builder) { try { String baseUrl = builder.baseUrl.orElse("https://generativelanguage.googleapis.com"); + if (!baseUrl.endsWith("/")) { + baseUrl += "/"; + } + baseUrl += builder.apiVersion; var restApiBuilder = QuarkusRestClientBuilder.newBuilder() .baseUri(new URI(baseUrl)) .connectTimeout(builder.timeout.toSeconds(), TimeUnit.SECONDS) @@ -68,6 +72,7 @@ public static Builder builder() { public static final class Builder { private Optional baseUrl = Optional.empty(); + private String apiVersion; private String projectId; private String location; private String modelId; @@ -83,6 +88,11 @@ public Builder baseUrl(Optional baseUrl) { return this; } + public Builder apiVersion(String apiVersion) { + this.apiVersion = apiVersion; + return this; + } + public Builder projectId(String projectId) { this.projectId = projectId; return this; diff --git a/model-providers/google/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiRecorder.java b/model-providers/google/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiRecorder.java index 19a96739a..f0db57731 100644 --- a/model-providers/google/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiRecorder.java +++ b/model-providers/google/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertexAiGeminiRecorder.java @@ -44,6 +44,7 @@ public Supplier embeddingModel(LangChain4jVertexAiGeminiConfig c } var builder = VertexAiGeminiEmbeddingModel.builder() .baseUrl(baseUrl) + .apiVersion(vertexAiConfig.apiVersion()) .location(location) .projectId(projectId) .publisher(vertexAiConfig.publisher()) @@ -83,6 +84,7 @@ public Function, ChatLanguageModel } var builder = VertexAiGeminiChatLanguageModel.builder() .baseUrl(baseUrl) + .apiVersion(vertexAiConfig.apiVersion()) .location(location) .projectId(projectId) .publisher(vertexAiConfig.publisher()) diff --git a/model-providers/google/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertxAiGeminiRestApi.java b/model-providers/google/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertxAiGeminiRestApi.java index a14708f15..055175eaa 100644 --- a/model-providers/google/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertxAiGeminiRestApi.java +++ b/model-providers/google/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/VertxAiGeminiRestApi.java @@ -30,7 +30,7 @@ import io.vertx.core.http.HttpClientRequest; import io.vertx.core.http.HttpClientResponse; -@Path("v1/projects/{projectId}/locations/{location}/publishers/{publisher}/models") +@Path("projects/{projectId}/locations/{location}/publishers/{publisher}/models") public interface VertxAiGeminiRestApi { @Path("{modelId}:generateContent") diff --git a/model-providers/google/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/config/LangChain4jVertexAiGeminiConfig.java b/model-providers/google/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/config/LangChain4jVertexAiGeminiConfig.java index 09b1e2ee9..4c95014b5 100644 --- a/model-providers/google/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/config/LangChain4jVertexAiGeminiConfig.java +++ b/model-providers/google/vertex-ai-gemini/runtime/src/main/java/io/quarkiverse/langchain4j/vertexai/runtime/gemini/config/LangChain4jVertexAiGeminiConfig.java @@ -60,6 +60,12 @@ interface VertexAiGeminiConfig { */ Optional baseUrl(); + /** + * The API version to use for this operation. + */ + @WithDefault("v1") + String apiVersion(); + /** * Whether to enable the integration. Defaults to {@code true}, which means requests are made to the Vertex AI Gemini * provider.