Skip to content

Refactor response handlers to improve error handling and streamline mid-stream error processing #128923

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;

import java.util.Locale;
import java.util.Objects;
import java.util.function.Function;

Expand All @@ -34,17 +36,22 @@ public abstract class BaseResponseHandler implements ResponseHandler {
public static final String SERVER_ERROR_OBJECT = "Received an error response";
public static final String BAD_REQUEST = "Received a bad request status code";
public static final String METHOD_NOT_ALLOWED = "Received a method not allowed status code";
protected static final String STREAM_ERROR = "stream_error";

protected final String requestType;
protected final ResponseParser parseFunction;
private final Function<HttpResult, ErrorResponse> errorParseFunction;
private final boolean canHandleStreamingResponses;

public BaseResponseHandler(String requestType, ResponseParser parseFunction, Function<HttpResult, ErrorResponse> errorParseFunction) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This constructor is not used anywhere but children classes, making it protected for that reason.

protected BaseResponseHandler(
String requestType,
ResponseParser parseFunction,
Function<HttpResult, ErrorResponse> errorParseFunction
) {
this(requestType, parseFunction, errorParseFunction, false);
}

public BaseResponseHandler(
protected BaseResponseHandler(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This constructor is not used anywhere but children classes, making it protected for that reason.

String requestType,
ResponseParser parseFunction,
Function<HttpResult, ErrorResponse> errorParseFunction,
Expand Down Expand Up @@ -109,19 +116,136 @@ private void checkForErrorObject(Request request, HttpResult result) {
}

protected Exception buildError(String message, Request request, HttpResult result) {
var errorEntityMsg = errorParseFunction.apply(result);
return buildError(message, request, result, errorEntityMsg);
var errorResponse = errorParseFunction.apply(result);
return buildError(message, request, result, errorResponse);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

errorParseFunction creates ErrorResponse instance, not a message. errorEntityMsg variable naming doesn't really make sense here. Fixed.

}

protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
var responseStatusCode = result.response().getStatusLine().getStatusCode();
return new ElasticsearchStatusException(
errorMessage(message, request, result, errorResponse, responseStatusCode),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed unused parameter

extractErrorMessage(message, request, errorResponse, responseStatusCode),
toRestStatus(responseStatusCode)
);
}

protected String errorMessage(String message, Request request, HttpResult result, ErrorResponse errorResponse, int statusCode) {
/**
* Builds an error for a streaming request with a custom error type.
* This method is used when an error response is received from the external service.
* Only streaming requests support this format, and it should be used when the error response.
*
* @param message the error message to include in the exception
* @param request the request that caused the error
* @param result the HTTP result containing the error response
* @param errorResponse the parsed error response from the HTTP result
* @return an instance of {@link UnifiedChatCompletionException} with details from the error response
*/
protected UnifiedChatCompletionException buildChatCompletionError(
String message,
Request request,
HttpResult result,
ErrorResponse errorResponse
) {
assert request.isStreaming() : "Only streaming requests support this format";
var statusCode = result.response().getStatusLine().getStatusCode();
var errorMessage = extractErrorMessage(message, request, errorResponse, statusCode);
var restStatus = toRestStatus(statusCode);

if (errorResponse.errorStructureFound()
&& errorResponse instanceof UnifiedChatCompletionExceptionConvertible chatCompletionExceptionConvertible) {
return chatCompletionExceptionConvertible.toUnifiedChatCompletionException(errorMessage, restStatus);
} else {
return buildDefaultChatCompletionError(errorResponse, errorMessage, restStatus);
}
}

/**
* Builds a default {@link UnifiedChatCompletionException} for a streaming request.
* This method is used when an error response is received but no specific error handling is implemented.
* Only streaming requests should use this method.
*
* @param errorResponse the error response parsed from the HTTP result
* @param errorMessage the error message to include in the exception
* @param restStatus the REST status code of the response
* @return an instance of {@link UnifiedChatCompletionException} with details from the error response
*/
protected static UnifiedChatCompletionException buildDefaultChatCompletionError(
ErrorResponse errorResponse,
String errorMessage,
RestStatus restStatus
) {
return new UnifiedChatCompletionException(
restStatus,
errorMessage,
createErrorType(errorResponse),
restStatus.name().toLowerCase(Locale.ROOT)
);
}

/**
* Builds a mid-stream error for a streaming request with a custom error type.
* This method is used when an error occurs while processing a streaming response and allows for custom error handling.
* Only streaming requests should use this method.
*
* @param inferenceEntityId the ID of the inference entity
* @param message the error message
* @param e the exception that caused the error, can be null
* @param midStreamErrorExtractor a function that extracts the mid-stream error response from the message
* @return a {@link UnifiedChatCompletionException} representing the mid-stream error
*/
protected UnifiedChatCompletionException buildMidStreamChatCompletionError(
String inferenceEntityId,
String message,
Exception e,
Function<String, ErrorResponse> midStreamErrorExtractor
) {
// Extract the error response from the message using the provided method
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if commenting here is excessive.

var error = midStreamErrorExtractor.apply(message);
// Check if the error response matches the expected type
if (error.errorStructureFound() && error instanceof MidStreamUnifiedChatCompletionExceptionConvertible midStreamError) {
// If it matches, we can build a custom mid-stream error exception
return midStreamError.toUnifiedChatCompletionException(inferenceEntityId);
} else if (e != null) {
// If the error response does not match, we can still return an exception based on the original throwable
return UnifiedChatCompletionException.fromThrowable(e);
} else {
// If no specific error response is found, we return a default mid-stream error
return buildDefaultMidStreamChatCompletionError(inferenceEntityId, error);
}
}

/**
* Builds a default mid-stream error for a streaming request.
* This method is used when no specific error response is found in the message.
* Only streaming requests should use this method.
*
* @param inferenceEntityId the ID of the inference entity
* @param errorResponse the error response extracted from the message
* @return a {@link UnifiedChatCompletionException} representing the default mid-stream error
*/
protected static UnifiedChatCompletionException buildDefaultMidStreamChatCompletionError(
String inferenceEntityId,
ErrorResponse errorResponse
) {
return new UnifiedChatCompletionException(
RestStatus.INTERNAL_SERVER_ERROR,
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, inferenceEntityId),
createErrorType(errorResponse),
STREAM_ERROR
);
}

/**
* Creates a string representation of the error type based on the provided ErrorResponse.
* This method is used to generate a human-readable error type for logging or exception messages.
*
* @param errorResponse the ErrorResponse object
* @return a string representing the error type
*/
private static String createErrorType(ErrorResponse errorResponse) {
return errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown";
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this from several handlers so it can be used directly.

}

protected static String extractErrorMessage(String message, Request request, ErrorResponse errorResponse, int statusCode) {
return (errorResponse == null
|| errorResponse.errorStructureFound() == false
|| Strings.isNullOrEmpty(errorResponse.getErrorMessage()))
Expand All @@ -135,7 +259,7 @@ protected String errorMessage(String message, Request request, HttpResult result
);
}

public static RestStatus toRestStatus(int statusCode) {
protected static RestStatus toRestStatus(int statusCode) {
RestStatus code = null;
if (statusCode < 500) {
code = RestStatus.fromCode(statusCode);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.retry;

import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;

public interface MidStreamUnifiedChatCompletionExceptionConvertible {

UnifiedChatCompletionException toUnifiedChatCompletionException(String inferenceEntityId);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.retry;

import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;

public interface UnifiedChatCompletionExceptionConvertible {

UnifiedChatCompletionException toUnifiedChatCompletionException(String errorMessage, RestStatus restStatus);

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,26 @@
package org.elasticsearch.xpack.inference.external.response.streaming;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.http.retry.MidStreamUnifiedChatCompletionExceptionConvertible;
import org.elasticsearch.xpack.inference.external.http.retry.UnifiedChatCompletionExceptionConvertible;
import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity;

import java.util.Objects;
import java.util.Optional;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler.SERVER_ERROR_OBJECT;

/**
* Represents an error response from a streaming inference service.
* This class extends {@link ErrorResponse} and provides additional fields
Expand All @@ -38,17 +45,21 @@
* </code></pre>
* TODO: {@link ErrorMessageResponseEntity} is nearly identical to this, but doesn't parse as many fields. We must remove the duplication.
*/
public class StreamingErrorResponse extends ErrorResponse {
public class OpenAiStreamingChatCompletionErrorResponse extends ErrorResponse
implements
UnifiedChatCompletionExceptionConvertible,
MidStreamUnifiedChatCompletionExceptionConvertible {
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
"streaming_error",
true,
args -> Optional.ofNullable((StreamingErrorResponse) args[0])
);
private static final ConstructingObjectParser<StreamingErrorResponse, Void> ERROR_BODY_PARSER = new ConstructingObjectParser<>(
"streaming_error",
true,
args -> new StreamingErrorResponse((String) args[0], (String) args[1], (String) args[2], (String) args[3])
args -> Optional.ofNullable((OpenAiStreamingChatCompletionErrorResponse) args[0])
);
private static final ConstructingObjectParser<OpenAiStreamingChatCompletionErrorResponse, Void> ERROR_BODY_PARSER =
new ConstructingObjectParser<>(
"streaming_error",
true,
args -> new OpenAiStreamingChatCompletionErrorResponse((String) args[0], (String) args[1], (String) args[2], (String) args[3])
);

static {
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("message"));
Expand Down Expand Up @@ -105,13 +116,34 @@ public static ErrorResponse fromString(String response) {
private final String param;
private final String type;

StreamingErrorResponse(String errorMessage, @Nullable String code, @Nullable String param, String type) {
OpenAiStreamingChatCompletionErrorResponse(String errorMessage, @Nullable String code, @Nullable String param, String type) {
super(errorMessage);
this.code = code;
this.param = param;
this.type = Objects.requireNonNull(type);
}

@Override
public UnifiedChatCompletionException toUnifiedChatCompletionException(String errorMessage, RestStatus restStatus) {
return new UnifiedChatCompletionException(restStatus, errorMessage, this.type(), this.code(), this.param());
}

@Override
public UnifiedChatCompletionException toUnifiedChatCompletionException(String inferenceEntityId) {
return new UnifiedChatCompletionException(
RestStatus.INTERNAL_SERVER_ERROR,
format(
"%s for request from inference entity id [%s]. Error message: [%s]",
SERVER_ERROR_OBJECT,
inferenceEntityId,
this.getErrorMessage()
),
this.type(),
this.code(),
this.param()
);
}

@Nullable
public String code() {
return code;
Expand Down
Loading