Skip to content

Commit

Permalink
[ML] Fix IndexOutOfBoundsException during inference (#109567)
Browse files Browse the repository at this point in the history
backport of #109533
  • Loading branch information
davidkyle authored Jun 12, 2024
1 parent 7321380 commit 83967a1
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 29 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/109533.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 109533
summary: Fix IndexOutOfBoundsException during inference
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -636,8 +636,6 @@ public static class Builder {
private InferenceConfig inferenceConfig;
private TrainedModelLocation location;
private ModelPackageConfig modelPackageConfig;
private Long perDeploymentMemoryBytes;
private Long perAllocationMemoryBytes;
private String platformArchitecture;
private TrainedModelPrefixStrings prefixStrings;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,8 @@ public void putModel(Model model, ActionListener<Boolean> listener) {
return;
} else if (model instanceof MultilingualE5SmallModel e5Model) {
String modelId = e5Model.getServiceSettings().getModelId();
var fieldNames = List.<String>of();
var input = new TrainedModelInput(fieldNames);
var config = TrainedModelConfig.builder().setInput(input).setModelId(modelId).build();
var input = new TrainedModelInput(List.<String>of("text_field")); // by convention text_field is used
var config = TrainedModelConfig.builder().setInput(input).setModelId(modelId).validate(true).build();
PutTrainedModelAction.Request putRequest = new PutTrainedModelAction.Request(config, false, true);
executeAsyncWithOrigin(
client,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,8 @@ public void putModel(Model model, ActionListener<Boolean> listener) {
return;
} else {
String modelId = ((ElserInternalModel) model).getServiceSettings().getModelId();
var fieldNames = List.<String>of();
var input = new TrainedModelInput(fieldNames);
var config = TrainedModelConfig.builder().setInput(input).setModelId(modelId).build();
var input = new TrainedModelInput(List.<String>of("text_field")); // by convention text_field is used
var config = TrainedModelConfig.builder().setInput(input).setModelId(modelId).validate(true).build();
PutTrainedModelAction.Request putRequest = new PutTrainedModelAction.Request(config, false, true);
executeAsyncWithOrigin(
client,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,17 @@
import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResultsTests;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate;
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.settings.InternalServiceSettings;
import org.junit.After;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;

import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -41,6 +47,7 @@
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
Expand All @@ -58,6 +65,17 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {

TaskType taskType = TaskType.TEXT_EMBEDDING;
String randomInferenceEntityId = randomAlphaOfLength(10);
private static ThreadPool threadPool;

@Before
public void setUpThreadPool() {
threadPool = new TestThreadPool("test");
}

@After
public void shutdownThreadPool() {
TestThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS);
}

public void testParseRequestConfig() {

Expand Down Expand Up @@ -480,6 +498,44 @@ public void testChunkInferSetsTokenization() {
}
}

@SuppressWarnings("unchecked")
public void testPutModel() {
var client = mock(Client.class);
ArgumentCaptor<PutTrainedModelAction.Request> argument = ArgumentCaptor.forClass(PutTrainedModelAction.Request.class);

doAnswer(invocation -> {
var listener = (ActionListener<PutTrainedModelAction.Response>) invocation.getArguments()[2];
listener.onResponse(new PutTrainedModelAction.Response(mock(TrainedModelConfig.class)));
return null;
}).when(client).execute(Mockito.same(PutTrainedModelAction.INSTANCE), argument.capture(), any());

when(client.threadPool()).thenReturn(threadPool);

var service = createService(client);

var model = new MultilingualE5SmallModel(
"my-e5",
TaskType.TEXT_EMBEDDING,
"e5",
new MultilingualE5SmallInternalServiceSettings(1, 1, ".multilingual-e5-small")
);

service.putModel(model, new ActionListener<>() {
@Override
public void onResponse(Boolean success) {
assertTrue(success);
}

@Override
public void onFailure(Exception e) {
fail(e);
}
});

var putConfig = argument.getValue().getTrainedModelConfig();
assertEquals("text_field", putConfig.getInput().getFieldNames().get(0));
}

private ElasticsearchInternalService createService(Client client) {
var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client);
return new ElasticsearchInternalService(context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,23 @@
import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResultsTests;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate;
import org.junit.After;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
Expand All @@ -51,6 +58,18 @@

public class ElserInternalServiceTests extends ESTestCase {

private static ThreadPool threadPool;

@Before
public void setUpThreadPool() {
threadPool = new TestThreadPool("test");
}

@After
public void shutdownThreadPool() {
TestThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS);
}

public static Model randomModelConfig(String inferenceEntityId, TaskType taskType) {
return switch (taskType) {
case SPARSE_EMBEDDING -> new ElserInternalModel(
Expand Down Expand Up @@ -460,6 +479,45 @@ public void testChunkInferSetsTokenization() {
}
}

@SuppressWarnings("unchecked")
public void testPutModel() {
var client = mock(Client.class);
ArgumentCaptor<PutTrainedModelAction.Request> argument = ArgumentCaptor.forClass(PutTrainedModelAction.Request.class);

doAnswer(invocation -> {
var listener = (ActionListener<PutTrainedModelAction.Response>) invocation.getArguments()[2];
listener.onResponse(new PutTrainedModelAction.Response(mock(TrainedModelConfig.class)));
return null;
}).when(client).execute(Mockito.same(PutTrainedModelAction.INSTANCE), argument.capture(), any());

when(client.threadPool()).thenReturn(threadPool);

var service = createService(client);

var model = new ElserInternalModel(
"my-elser",
TaskType.SPARSE_EMBEDDING,
"elser",
new ElserInternalServiceSettings(1, 1, ".elser_model_2"),
ElserMlNodeTaskSettings.DEFAULT
);

service.putModel(model, new ActionListener<>() {
@Override
public void onResponse(Boolean success) {
assertTrue(success);
}

@Override
public void onFailure(Exception e) {
fail(e);
}
});

var putConfig = argument.getValue().getTrainedModelConfig();
assertEquals("text_field", putConfig.getInput().getFieldNames().get(0));
}

private ElserInternalService createService(Client client) {
var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client);
return new ElserInternalService(context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,12 +323,12 @@ protected void masterOperation(
}
}, finalResponseListener::onFailure);

checkForExistingTask(
checkForExistingModelDownloadTask(
client,
trainedModelConfig.getModelId(),
request.isWaitForCompletion(),
finalResponseListener,
handlePackageAndTagsListener,
() -> handlePackageAndTagsListener.onResponse(null),
request.timeout()
);
}
Expand Down Expand Up @@ -371,14 +371,26 @@ void callVerifyMlNodesAndModelArchitectures(
}

/**
* This method is package private for testing
* Check if the model is being downloaded.
* If the download is in progress then the response will be on
* the {@code isBeingDownloadedListener} otherwise {@code createModelAction}
* is called to trigger the next step in the model install.
* Should only be called for Elasticsearch hosted models.
*
* @param client Client
* @param modelId Model Id
* @param isWaitForCompletion Wait for the download to complete
* @param isBeingDownloadedListener The listener called if the download is in progress
* @param createModelAction If no download is in progress this is called to continue
* the model install process.
* @param timeout Model download timeout
*/
static void checkForExistingTask(
static void checkForExistingModelDownloadTask(
Client client,
String modelId,
boolean isWaitForCompletion,
ActionListener<Response> sendResponseListener,
ActionListener<Void> storeModelListener,
ActionListener<Response> isBeingDownloadedListener,
Runnable createModelAction,
TimeValue timeout
) {
TaskRetriever.getDownloadTaskInfo(
Expand All @@ -389,12 +401,12 @@ static void checkForExistingTask(
() -> "Timed out waiting for model download to complete",
ActionListener.wrap(taskInfo -> {
if (taskInfo != null) {
getModelInformation(client, modelId, sendResponseListener);
getModelInformation(client, modelId, isBeingDownloadedListener);
} else {
// no task exists so proceed with creating the model
storeModelListener.onResponse(null);
createModelAction.run();
}
}, sendResponseListener::onFailure)
}, isBeingDownloadedListener::onFailure)
);
}

Expand Down Expand Up @@ -554,5 +566,4 @@ static InferenceConfig parseInferenceConfigFromModelPackage(Map<String, Object>
return inferenceConfig;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;

import static org.elasticsearch.xpack.ml.utils.TaskRetrieverTests.getTaskInfoListOfOne;
import static org.elasticsearch.xpack.ml.utils.TaskRetrieverTests.mockClientWithTasksResponse;
import static org.elasticsearch.xpack.ml.utils.TaskRetrieverTests.mockListTasksClient;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.doAnswer;
Expand Down Expand Up @@ -161,12 +161,12 @@ public void testCheckForExistingTaskCallsOnFailureForAnError() {

var responseListener = new PlainActionFuture<PutTrainedModelAction.Response>();

TransportPutTrainedModelAction.checkForExistingTask(
TransportPutTrainedModelAction.checkForExistingModelDownloadTask(
client,
"inferenceEntityId",
true,
responseListener,
new PlainActionFuture<Void>(),
() -> {},
TIMEOUT
);

Expand All @@ -178,31 +178,31 @@ public void testCheckForExistingTaskCallsOnFailureForAnError() {
public void testCheckForExistingTaskCallsStoreModelListenerWhenNoTasksExist() {
var client = mockClientWithTasksResponse(Collections.emptyList(), threadPool);

var storeListener = new PlainActionFuture<Void>();
var createModelCalled = new AtomicBoolean();

TransportPutTrainedModelAction.checkForExistingTask(
TransportPutTrainedModelAction.checkForExistingModelDownloadTask(
client,
"inferenceEntityId",
true,
new PlainActionFuture<>(),
storeListener,
() -> createModelCalled.set(Boolean.TRUE),
TIMEOUT
);

assertThat(storeListener.actionGet(TIMEOUT), nullValue());
assertTrue(createModelCalled.get());
}

public void testCheckForExistingTaskThrowsNoModelFoundError() {
var client = mockClientWithTasksResponse(getTaskInfoListOfOne(), threadPool);
prepareGetTrainedModelResponse(client, Collections.emptyList());

var respListener = new PlainActionFuture<PutTrainedModelAction.Response>();
TransportPutTrainedModelAction.checkForExistingTask(
TransportPutTrainedModelAction.checkForExistingModelDownloadTask(
client,
"inferenceEntityId",
true,
respListener,
new PlainActionFuture<>(),
() -> {},
TIMEOUT
);

Expand All @@ -224,12 +224,12 @@ public void testCheckForExistingTaskReturnsTask() {
prepareGetTrainedModelResponse(client, List.of(trainedModel));

var respListener = new PlainActionFuture<PutTrainedModelAction.Response>();
TransportPutTrainedModelAction.checkForExistingTask(
TransportPutTrainedModelAction.checkForExistingModelDownloadTask(
client,
"inferenceEntityId",
true,
respListener,
new PlainActionFuture<>(),
() -> {},
TIMEOUT
);

Expand Down

0 comments on commit 83967a1

Please sign in to comment.