Skip to content

Commit 998cdd5

Browse files
leo-hoetlhoet-googlejonathan-buttner
authored andcommitted
Implemented completion task for Google VertexAI (#128694)
* Google Vertex AI completion model, response entity and tests * Fixed GoogleVertexAiServiceTest for Service configuration * Changelog * Removed downcasting and using `moveToFirstToken` * Create GoogleVertexAiChatCompletionResponseHandler for streaming and non streaming responses * Added unit tests * PR feedback * Removed googlevertexaicompletion model. Using just GoogleVertexAiChatCompletionModel for completion and chat completion * Renamed uri -> nonStreamingUri. Added streamingUri and getters in GoogleVertexAiChatCompletionModel * Moved rateLimitGroupHashing to subclasses of GoogleVertexAiModel * Fixed rate limit has of GoogleVertexAiRerankModel and refactored uri for GoogleVertexAiUnifiedChatCompletionRequest --------- Co-authored-by: lhoet-google <[email protected]> Co-authored-by: Jonathan Buttner <[email protected]>
1 parent 1289922 commit 998cdd5

23 files changed

+520
-56
lines changed

docs/changelog/128694.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 128694
2+
summary: "Adding Google VertexAI completion integration"
3+
area: Inference
4+
type: enhancement
5+
issues: [ ]

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiModel.java

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ public abstract class GoogleVertexAiModel extends RateLimitGroupingModel {
2424

2525
private final GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings;
2626

27-
protected URI uri;
27+
protected URI nonStreamingUri;
2828

2929
public GoogleVertexAiModel(
3030
ModelConfigurations configurations,
@@ -39,14 +39,14 @@ public GoogleVertexAiModel(
3939
public GoogleVertexAiModel(GoogleVertexAiModel model, ServiceSettings serviceSettings) {
4040
super(model, serviceSettings);
4141

42-
uri = model.uri();
42+
nonStreamingUri = model.nonStreamingUri();
4343
rateLimitServiceSettings = model.rateLimitServiceSettings();
4444
}
4545

4646
public GoogleVertexAiModel(GoogleVertexAiModel model, TaskSettings taskSettings) {
4747
super(model, taskSettings);
4848

49-
uri = model.uri();
49+
nonStreamingUri = model.nonStreamingUri();
5050
rateLimitServiceSettings = model.rateLimitServiceSettings();
5151
}
5252

@@ -56,17 +56,8 @@ public GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings() {
5656
return rateLimitServiceSettings;
5757
}
5858

59-
public URI uri() {
60-
return uri;
61-
}
62-
63-
@Override
64-
public int rateLimitGroupingHash() {
65-
// In VertexAI rate limiting is scoped to the project, region and model. URI already has this information so we are using that.
66-
// API Key does not affect the quota
67-
// https://ai.google.dev/gemini-api/docs/rate-limits
68-
// https://cloud.google.com/vertex-ai/docs/quotas
69-
return Objects.hash(uri);
59+
public URI nonStreamingUri() {
60+
return nonStreamingUri;
7061
}
7162

7263
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiResponseHandler.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,19 @@
77

88
package org.elasticsearch.xpack.inference.services.googlevertexai;
99

10+
import org.elasticsearch.inference.InferenceServiceResults;
11+
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
1012
import org.elasticsearch.xpack.inference.external.http.HttpResult;
1113
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
1214
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
1315
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
1416
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
1517
import org.elasticsearch.xpack.inference.external.request.Request;
18+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
19+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;
1620
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiErrorResponseEntity;
1721

22+
import java.util.concurrent.Flow;
1823
import java.util.function.Function;
1924

2025
import static org.elasticsearch.core.Strings.format;
@@ -66,4 +71,14 @@ protected void checkForFailureStatusCode(Request request, HttpResult result) thr
6671
private static String resourceNotFoundError(Request request) {
6772
return format("Resource not found at [%s]", request.getURI());
6873
}
74+
75+
@Override
76+
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
77+
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
78+
var googleVertexAiProcessor = new GoogleVertexAiStreamingProcessor();
79+
80+
flow.subscribe(serverSentEventProcessor);
81+
serverSentEventProcessor.subscribe(googleVertexAiProcessor);
82+
return new StreamingChatCompletionResults(googleVertexAiProcessor);
83+
}
6984
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiSecretSettings.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,9 @@ public static Map<String, SettingsConfiguration> get() {
124124
var configurationMap = new HashMap<String, SettingsConfiguration>();
125125
configurationMap.put(
126126
SERVICE_ACCOUNT_JSON,
127-
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.CHAT_COMPLETION))
128-
.setDescription("API Key for the provider you're connecting to.")
127+
new SettingsConfiguration.Builder(
128+
EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.RERANK, TaskType.CHAT_COMPLETION, TaskType.COMPLETION)
129+
).setDescription("API Key for the provider you're connecting to.")
129130
.setLabel("Credentials JSON")
130131
.setRequired(true)
131132
.setSensitive(true)

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ public class GoogleVertexAiService extends SenderService {
7575
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
7676
TaskType.TEXT_EMBEDDING,
7777
TaskType.RERANK,
78-
TaskType.CHAT_COMPLETION
78+
TaskType.CHAT_COMPLETION,
79+
TaskType.COMPLETION
7980
);
8081

8182
public static final EnumSet<InputType> VALID_INPUT_TYPE_VALUES = EnumSet.of(
@@ -87,13 +88,13 @@ public class GoogleVertexAiService extends SenderService {
8788
InputType.INTERNAL_SEARCH
8889
);
8990

90-
private final ResponseHandler COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler(
91+
public static final ResponseHandler COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler(
9192
"Google VertexAI chat completion"
9293
);
9394

9495
@Override
9596
public Set<TaskType> supportedStreamingTasks() {
96-
return EnumSet.of(TaskType.CHAT_COMPLETION);
97+
return EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.COMPLETION);
9798
}
9899

99100
public GoogleVertexAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
@@ -358,7 +359,7 @@ private static GoogleVertexAiModel createModel(
358359
context
359360
);
360361

361-
case CHAT_COMPLETION -> new GoogleVertexAiChatCompletionModel(
362+
case CHAT_COMPLETION, COMPLETION -> new GoogleVertexAiChatCompletionModel(
362363
inferenceEntityId,
363364
taskType,
364365
NAME,
@@ -396,10 +397,11 @@ public static InferenceServiceConfiguration get() {
396397

397398
configurationMap.put(
398399
LOCATION,
399-
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION)).setDescription(
400-
"Please provide the GCP region where the Vertex AI API(s) is enabled. "
401-
+ "For more information, refer to the {geminiVertexAIDocs}."
402-
)
400+
new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.COMPLETION))
401+
.setDescription(
402+
"Please provide the GCP region where the Vertex AI API(s) is enabled. "
403+
+ "For more information, refer to the {geminiVertexAIDocs}."
404+
)
403405
.setLabel("GCP Region")
404406
.setRequired(true)
405407
.setSensitive(false)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.googlevertexai;
9+
10+
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.common.Strings;
12+
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
13+
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.rest.RestStatus;
15+
import org.elasticsearch.xcontent.XContentFactory;
16+
import org.elasticsearch.xcontent.XContentParser;
17+
import org.elasticsearch.xcontent.XContentParserConfiguration;
18+
import org.elasticsearch.xcontent.XContentType;
19+
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
20+
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
21+
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
22+
23+
import java.io.IOException;
24+
import java.util.Deque;
25+
import java.util.Objects;
26+
import java.util.stream.Stream;
27+
28+
public class GoogleVertexAiStreamingProcessor extends DelegatingProcessor<Deque<ServerSentEvent>, InferenceServiceResults.Result> {
29+
30+
@Override
31+
protected void next(Deque<ServerSentEvent> item) throws Exception {
32+
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
33+
var results = parseEvent(item, GoogleVertexAiStreamingProcessor::parse, parserConfig);
34+
35+
if (results.isEmpty()) {
36+
upstream().request(1);
37+
} else {
38+
downstream().onNext(new StreamingChatCompletionResults.Results(results));
39+
}
40+
}
41+
42+
public static Stream<StreamingChatCompletionResults.Result> parse(XContentParserConfiguration parserConfig, ServerSentEvent event) {
43+
String data = event.data();
44+
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, data)) {
45+
var chunk = GoogleVertexAiUnifiedStreamingProcessor.GoogleVertexAiChatCompletionChunkParser.parse(jsonParser);
46+
47+
return chunk.choices()
48+
.stream()
49+
.map(choice -> choice.delta())
50+
.filter(Objects::nonNull)
51+
.map(delta -> delta.content())
52+
.filter(content -> Strings.isNullOrEmpty(content) == false)
53+
.map(StreamingChatCompletionResults.Result::new);
54+
55+
} catch (IOException e) {
56+
throw new ElasticsearchStatusException(
57+
"Failed to parse event from inference provider: {}",
58+
RestStatus.INTERNAL_SERVER_ERROR,
59+
e,
60+
event
61+
);
62+
}
63+
}
64+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
2424
import org.elasticsearch.xpack.inference.external.http.HttpResult;
2525
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
26-
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
2726
import org.elasticsearch.xpack.inference.external.request.Request;
2827
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
2928
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;
29+
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity;
3030

3131
import java.nio.charset.StandardCharsets;
3232
import java.util.Locale;
@@ -43,10 +43,8 @@ public class GoogleVertexAiUnifiedChatCompletionResponseHandler extends GoogleVe
4343
private static final String ERROR_MESSAGE_FIELD = "message";
4444
private static final String ERROR_STATUS_FIELD = "status";
4545

46-
private static final ResponseParser noopParseFunction = (a, b) -> null;
47-
4846
public GoogleVertexAiUnifiedChatCompletionResponseHandler(String requestType) {
49-
super(requestType, noopParseFunction, GoogleVertexAiErrorResponse::fromResponse, true);
47+
super(requestType, GoogleVertexAiCompletionResponseEntity::fromResponse, GoogleVertexAiErrorResponse::fromResponse, true);
5048
}
5149

5250
@Override
@@ -64,6 +62,7 @@ public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpR
6462
@Override
6563
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
6664
assert request.isStreaming() : "Only streaming requests support this format";
65+
6766
var responseStatusCode = result.response().getStatusLine().getStatusCode();
6867
var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode);
6968
var restStatus = toRestStatus(responseStatusCode);
@@ -111,7 +110,7 @@ private static Exception buildMidStreamError(Request request, String message, Ex
111110
}
112111
}
113112

114-
private static class GoogleVertexAiErrorResponse extends ErrorResponse {
113+
public static class GoogleVertexAiErrorResponse extends ErrorResponse {
115114
private static final Logger logger = LogManager.getLogger(GoogleVertexAiErrorResponse.class);
116115
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
117116
"google_vertex_ai_error_wrapper",
@@ -138,7 +137,7 @@ private static class GoogleVertexAiErrorResponse extends ErrorResponse {
138137
);
139138
}
140139

141-
static ErrorResponse fromResponse(HttpResult response) {
140+
public static ErrorResponse fromResponse(HttpResult response) {
142141
try (
143142
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
144143
.createParser(XContentParserConfiguration.EMPTY, response.body())

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionCreator.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818
import org.elasticsearch.xpack.inference.services.ServiceComponents;
1919
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiEmbeddingsRequestManager;
2020
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiRerankRequestManager;
21+
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiResponseHandler;
2122
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiUnifiedChatCompletionResponseHandler;
2223
import org.elasticsearch.xpack.inference.services.googlevertexai.completion.GoogleVertexAiChatCompletionModel;
2324
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
2425
import org.elasticsearch.xpack.inference.services.googlevertexai.request.GoogleVertexAiUnifiedChatCompletionRequest;
2526
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
27+
import org.elasticsearch.xpack.inference.services.googlevertexai.response.GoogleVertexAiCompletionResponseEntity;
2628

2729
import java.util.Map;
2830
import java.util.Objects;
@@ -36,9 +38,13 @@ public class GoogleVertexAiActionCreator implements GoogleVertexAiActionVisitor
3638

3739
private final ServiceComponents serviceComponents;
3840

39-
static final ResponseHandler COMPLETION_HANDLER = new GoogleVertexAiUnifiedChatCompletionResponseHandler(
40-
"Google VertexAI chat completion"
41+
static final ResponseHandler CHAT_COMPLETION_HANDLER = new GoogleVertexAiResponseHandler(
42+
"Google VertexAI completion",
43+
GoogleVertexAiCompletionResponseEntity::fromResponse,
44+
GoogleVertexAiUnifiedChatCompletionResponseHandler.GoogleVertexAiErrorResponse::fromResponse,
45+
true
4146
);
47+
4248
static final String USER_ROLE = "user";
4349

4450
public GoogleVertexAiActionCreator(Sender sender, ServiceComponents serviceComponents) {
@@ -67,12 +73,12 @@ public ExecutableAction create(GoogleVertexAiRerankModel model, Map<String, Obje
6773

6874
@Override
6975
public ExecutableAction create(GoogleVertexAiChatCompletionModel model, Map<String, Object> taskSettings) {
70-
7176
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX);
77+
7278
var manager = new GenericRequestManager<>(
7379
serviceComponents.threadPool(),
7480
model,
75-
COMPLETION_HANDLER,
81+
CHAT_COMPLETION_HANDLER,
7682
inputs -> new GoogleVertexAiUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
7783
ChatCompletionInput.class
7884
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiActionVisitor.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,5 @@ public interface GoogleVertexAiActionVisitor {
2121
ExecutableAction create(GoogleVertexAiRerankModel model, Map<String, Object> taskSettings);
2222

2323
ExecutableAction create(GoogleVertexAiChatCompletionModel model, Map<String, Object> taskSettings);
24+
2425
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/completion/GoogleVertexAiChatCompletionModel.java

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
import static org.elasticsearch.core.Strings.format;
3131

3232
public class GoogleVertexAiChatCompletionModel extends GoogleVertexAiModel {
33+
34+
private final URI streamingURI;
35+
3336
public GoogleVertexAiChatCompletionModel(
3437
String inferenceEntityId,
3538
TaskType taskType,
@@ -63,7 +66,8 @@ public GoogleVertexAiChatCompletionModel(
6366
serviceSettings
6467
);
6568
try {
66-
this.uri = buildUri(serviceSettings.location(), serviceSettings.projectId(), serviceSettings.modelId());
69+
this.streamingURI = buildUriStreaming(serviceSettings.location(), serviceSettings.projectId(), serviceSettings.modelId());
70+
this.nonStreamingUri = buildUriNonStreaming(serviceSettings.location(), serviceSettings.projectId(), serviceSettings.modelId());
6771
} catch (URISyntaxException e) {
6872
throw new RuntimeException(e);
6973
}
@@ -114,7 +118,28 @@ public GoogleVertexAiSecretSettings getSecretSettings() {
114118
return (GoogleVertexAiSecretSettings) super.getSecretSettings();
115119
}
116120

117-
public static URI buildUri(String location, String projectId, String model) throws URISyntaxException {
121+
public URI streamingURI() {
122+
return this.streamingURI;
123+
}
124+
125+
public static URI buildUriNonStreaming(String location, String projectId, String model) throws URISyntaxException {
126+
return new URIBuilder().setScheme("https")
127+
.setHost(format("%s%s", location, GoogleVertexAiUtils.GOOGLE_VERTEX_AI_HOST_SUFFIX))
128+
.setPathSegments(
129+
GoogleVertexAiUtils.V1,
130+
GoogleVertexAiUtils.PROJECTS,
131+
projectId,
132+
GoogleVertexAiUtils.LOCATIONS,
133+
GoogleVertexAiUtils.GLOBAL,
134+
GoogleVertexAiUtils.PUBLISHERS,
135+
GoogleVertexAiUtils.PUBLISHER_GOOGLE,
136+
GoogleVertexAiUtils.MODELS,
137+
format("%s:%s", model, GoogleVertexAiUtils.GENERATE_CONTENT)
138+
)
139+
.build();
140+
}
141+
142+
public static URI buildUriStreaming(String location, String projectId, String model) throws URISyntaxException {
118143
return new URIBuilder().setScheme("https")
119144
.setHost(format("%s%s", location, GoogleVertexAiUtils.GOOGLE_VERTEX_AI_HOST_SUFFIX))
120145
.setPathSegments(
@@ -131,4 +156,25 @@ public static URI buildUri(String location, String projectId, String model) thro
131156
.setCustomQuery(GoogleVertexAiUtils.QUERY_PARAM_ALT_SSE)
132157
.build();
133158
}
159+
160+
@Override
161+
public int rateLimitGroupingHash() {
162+
// In VertexAI rate limiting is scoped to the project, region, model and endpoint.
163+
// API Key does not affect the quota
164+
// https://ai.google.dev/gemini-api/docs/rate-limits
165+
// https://cloud.google.com/vertex-ai/docs/quotas
166+
var projectId = getServiceSettings().projectId();
167+
var location = getServiceSettings().location();
168+
var modelId = getServiceSettings().modelId();
169+
170+
// Since we don't beforehand know which API is going to be used, we take a conservative approach and
171+
// count both endpoint for the rate limit
172+
return Objects.hash(
173+
projectId,
174+
location,
175+
modelId,
176+
GoogleVertexAiUtils.GENERATE_CONTENT,
177+
GoogleVertexAiUtils.STREAM_GENERATE_CONTENT
178+
);
179+
}
134180
}

0 commit comments

Comments
 (0)