Skip to content

[ML] Sync Inference with Trained Model stats #130544

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/130544.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 130544
summary: Sync Inference with Trained Model stats
area: Machine Learning
type: bug
issues:
- 130339
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ public void start(Model model, TimeValue timeout, ActionListener<Boolean> finalL
}
}).<Boolean>andThen((l2, modelDidPut) -> {
var startRequest = esModel.getStartTrainedModelDeploymentActionRequest(timeout);
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, l2);
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(esModel, l2);
client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener);
});
subscribableListener.addTimeout(timeout, threadPool, inferenceExecutor);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
Expand Down Expand Up @@ -43,7 +42,7 @@ protected String modelNotFoundErrorMessage(String modelId) {

@Override
public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
Model model,
ElasticsearchInternalModel esModel,
ActionListener<Boolean> listener
) {
throw new IllegalStateException("cannot start model that uses an existing deployment");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
Expand All @@ -33,7 +32,7 @@ public ElasticRerankerServiceSettings getServiceSettings() {

@Override
public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
Model model,
ElasticsearchInternalModel esModel,
ActionListener<Boolean> listener
) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import static org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus.State.STARTED;
Expand Down Expand Up @@ -85,20 +87,21 @@ public StartTrainedModelDeploymentAction.Request getStartTrainedModelDeploymentA
}

public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
Model model,
ElasticsearchInternalModel esModel,
ActionListener<Boolean> listener
) {
return new ActionListener<>() {
@Override
public void onResponse(CreateTrainedModelAssignmentAction.Response response) {
esModel.updateServiceSettings(response.getTrainedModelAssignment());
listener.onResponse(Boolean.TRUE);
}

@Override
public void onFailure(Exception e) {
var cause = ExceptionsHelper.unwrapCause(e);
if (cause instanceof ResourceNotFoundException) {
listener.onFailure(new ResourceNotFoundException(modelNotFoundErrorMessage(internalServiceSettings.modelId())));
listener.onFailure(new ResourceNotFoundException(modelNotFoundErrorMessage(esModel.internalServiceSettings.modelId())));
return;
} else if (cause instanceof ElasticsearchStatusException statusException) {
if (statusException.status() == RestStatus.CONFLICT
Expand Down Expand Up @@ -128,8 +131,18 @@ public ElasticsearchInternalServiceSettings getServiceSettings() {
return (ElasticsearchInternalServiceSettings) super.getServiceSettings();
}

public void updateNumAllocations(Integer numAllocations) {
this.internalServiceSettings.setNumAllocations(numAllocations);
public void updateServiceSettings(AssignmentStats assignmentStats) {
this.internalServiceSettings.setAllocations(
assignmentStats.getNumberOfAllocations(),
assignmentStats.getAdaptiveAllocationsSettings()
);
}

private void updateServiceSettings(TrainedModelAssignment trainedModelAssignment) {
this.internalServiceSettings.setAllocations(
this.internalServiceSettings.getNumAllocations(),
trainedModelAssignment.getAdaptiveAllocationsSettings()
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,7 @@ public void updateModelsWithDynamicFields(List<Model> models, ActionListener<Lis
ActionListener.wrap(stats -> {
for (var deploymentStats : stats.getStats().results()) {
var modelsForDeploymentId = modelsByDeploymentIds.get(deploymentStats.getDeploymentId());
modelsForDeploymentId.forEach(model -> model.updateNumAllocations(deploymentStats.getNumberOfAllocations()));
modelsForDeploymentId.forEach(model -> model.updateServiceSettings(deploymentStats));
}
var updatedModels = new ArrayList<Model>();
modelsByDeploymentIds.values().forEach(updatedModels::addAll);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public class ElasticsearchInternalServiceSettings implements ServiceSettings {
private Integer numAllocations;
private final int numThreads;
private final String modelId;
private final AdaptiveAllocationsSettings adaptiveAllocationsSettings;
private AdaptiveAllocationsSettings adaptiveAllocationsSettings;
private final String deploymentId;

public static ElasticsearchInternalServiceSettings fromPersistedMap(Map<String, Object> map) {
Expand Down Expand Up @@ -158,8 +158,9 @@ public ElasticsearchInternalServiceSettings(StreamInput in) throws IOException {
this.deploymentId = in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readOptionalString() : null;
}

public void setNumAllocations(Integer numAllocations) {
public void setAllocations(Integer numAllocations, @Nullable AdaptiveAllocationsSettings adaptiveAllocationsSettings) {
this.numAllocations = numAllocations;
this.adaptiveAllocationsSettings = adaptiveAllocationsSettings;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,12 @@
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.NAME;
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.assertArg;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.doAnswer;
Expand Down Expand Up @@ -1767,7 +1769,9 @@ private void testUpdateModelsWithDynamicFields(Map<String, List<Model>> modelsBy
modelsByDeploymentId.forEach((deploymentId, models) -> {
var expectedNumberOfAllocations = updatedNumberOfAllocations.get(deploymentId);
models.forEach(model -> {
verify((ElasticsearchInternalModel) model).updateNumAllocations(expectedNumberOfAllocations);
verify((ElasticsearchInternalModel) model).updateServiceSettings(assertArg(assignmentStats -> {
assertThat(assignmentStats.getNumberOfAllocations(), equalTo(expectedNumberOfAllocations));
}));
verify((ElasticsearchInternalModel) model).mlNodeDeploymentId();
verifyNoMoreInteractions(model);
});
Expand Down Expand Up @@ -1858,7 +1862,9 @@ public void testUpdateWithMlEnabled() throws IOException, InterruptedException {
var latch = new CountDownLatch(1);
service.updateModelsWithDynamicFields(models, ActionTestUtils.assertNoFailureListener(r -> latch.countDown()));
assertTrue(latch.await(30, TimeUnit.SECONDS));
verify(model).updateNumAllocations(3);
verify(model).updateServiceSettings(
assertArg(assignmentStats -> { assertThat(assignmentStats.getNumberOfAllocations(), equalTo(3)); })
);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,17 @@

package org.elasticsearch.xpack.inference.services.elasticsearch;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentTests;

import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class ElserInternalModelTests extends ESTestCase {
public void testUpdateNumAllocation() {
Expand All @@ -21,10 +30,22 @@ public void testUpdateNumAllocation() {
null
);

model.updateNumAllocations(1);
assertEquals(1, model.getServiceSettings().getNumAllocations().intValue());
AssignmentStats assignmentStats = mock();
when(assignmentStats.getNumberOfAllocations()).thenReturn(1);
model.updateServiceSettings(assignmentStats);

assertThat(model.getServiceSettings().getNumAllocations(), equalTo(1));
assertNull(model.getServiceSettings().getAdaptiveAllocationsSettings());

model.updateNumAllocations(null);
assertNull(model.getServiceSettings().getNumAllocations());
TrainedModelAssignment trainedModelAssignment = TrainedModelAssignmentTests.randomInstance();
CreateTrainedModelAssignmentAction.Response response = mock();
when(response.getTrainedModelAssignment()).thenReturn(trainedModelAssignment);
model.getCreateTrainedModelAssignmentActionListener(model, ActionListener.noop()).onResponse(response);

assertThat(model.getServiceSettings().getNumAllocations(), equalTo(1));
assertThat(
model.getServiceSettings().getAdaptiveAllocationsSettings(),
equalTo(trainedModelAssignment.getAdaptiveAllocationsSettings())
);
}
}
Loading