diff --git a/docs/changelog/130544.yaml b/docs/changelog/130544.yaml new file mode 100644 index 0000000000000..415357d929f8d --- /dev/null +++ b/docs/changelog/130544.yaml @@ -0,0 +1,6 @@ +pr: 130544 +summary: Sync Inference with Trained Model stats +area: Machine Learning +type: bug +issues: + - 130339 diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java index 53e859b7f7a4d..ee4221157388e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java @@ -114,7 +114,7 @@ public void start(Model model, TimeValue timeout, ActionListener finalL } }).andThen((l2, modelDidPut) -> { var startRequest = esModel.getStartTrainedModelDeploymentActionRequest(timeout); - var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, l2); + var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(esModel, l2); client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener); }); subscribableListener.addTimeout(timeout, threadPool, inferenceExecutor); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticDeployedModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticDeployedModel.java index ce6c6258d0393..5a81eb6b04bcd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticDeployedModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticDeployedModel.java @@ -10,7 +10,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkingSettings; -import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; @@ -43,7 +42,7 @@ protected String modelNotFoundErrorMessage(String modelId) { @Override public ActionListener getCreateTrainedModelAssignmentActionListener( - Model model, + ElasticsearchInternalModel esModel, ActionListener listener ) { throw new IllegalStateException("cannot start model that uses an existing deployment"); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerModel.java index 276bce6dbe8f8..2c8bf8270fabc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerModel.java @@ -9,7 +9,6 @@ import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -33,7 +32,7 @@ public ElasticRerankerServiceSettings getServiceSettings() { @Override public ActionListener getCreateTrainedModelAssignmentActionListener( - Model model, + ElasticsearchInternalModel esModel, ActionListener listener ) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java index f1011efd3b12c..6a553480e68cc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java @@ -21,6 +21,8 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats; +import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import static org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus.State.STARTED; @@ -85,12 +87,13 @@ public StartTrainedModelDeploymentAction.Request getStartTrainedModelDeploymentA } public ActionListener getCreateTrainedModelAssignmentActionListener( - Model model, + ElasticsearchInternalModel esModel, ActionListener listener ) { return new ActionListener<>() { @Override public void onResponse(CreateTrainedModelAssignmentAction.Response response) { + esModel.updateServiceSettings(response.getTrainedModelAssignment()); listener.onResponse(Boolean.TRUE); } @@ -98,7 +101,7 @@ public void onResponse(CreateTrainedModelAssignmentAction.Response response) { public void onFailure(Exception e) { var cause = ExceptionsHelper.unwrapCause(e); if (cause instanceof ResourceNotFoundException) { - listener.onFailure(new ResourceNotFoundException(modelNotFoundErrorMessage(internalServiceSettings.modelId()))); + listener.onFailure(new ResourceNotFoundException(modelNotFoundErrorMessage(esModel.internalServiceSettings.modelId()))); return; } else if (cause instanceof ElasticsearchStatusException statusException) { if (statusException.status() == RestStatus.CONFLICT @@ -128,8 +131,18 @@ public ElasticsearchInternalServiceSettings getServiceSettings() { return (ElasticsearchInternalServiceSettings) super.getServiceSettings(); } - public void updateNumAllocations(Integer numAllocations) { - this.internalServiceSettings.setNumAllocations(numAllocations); + public void updateServiceSettings(AssignmentStats assignmentStats) { + this.internalServiceSettings.setAllocations( + assignmentStats.getNumberOfAllocations(), + assignmentStats.getAdaptiveAllocationsSettings() + ); + } + + private void updateServiceSettings(TrainedModelAssignment trainedModelAssignment) { + this.internalServiceSettings.setAllocations( + this.internalServiceSettings.getNumAllocations(), + trainedModelAssignment.getAdaptiveAllocationsSettings() + ); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 4f2674179be67..b17392311629f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -890,7 +890,7 @@ public void updateModelsWithDynamicFields(List models, ActionListener { for (var deploymentStats : stats.getStats().results()) { var modelsForDeploymentId = modelsByDeploymentIds.get(deploymentStats.getDeploymentId()); - modelsForDeploymentId.forEach(model -> model.updateNumAllocations(deploymentStats.getNumberOfAllocations())); + modelsForDeploymentId.forEach(model -> model.updateServiceSettings(deploymentStats)); } var updatedModels = new ArrayList(); modelsByDeploymentIds.values().forEach(updatedModels::addAll); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java index 98730f33d10f9..6753f5f16dc8c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java @@ -43,7 +43,7 @@ public class ElasticsearchInternalServiceSettings implements ServiceSettings { private Integer numAllocations; private final int numThreads; private final String modelId; - private final AdaptiveAllocationsSettings adaptiveAllocationsSettings; + private AdaptiveAllocationsSettings adaptiveAllocationsSettings; private final String deploymentId; public static ElasticsearchInternalServiceSettings fromPersistedMap(Map map) { @@ -158,8 +158,9 @@ public ElasticsearchInternalServiceSettings(StreamInput in) throws IOException { this.deploymentId = in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readOptionalString() : null; } - public void setNumAllocations(Integer numAllocations) { + public void setAllocations(Integer numAllocations, @Nullable AdaptiveAllocationsSettings adaptiveAllocationsSettings) { this.numAllocations = numAllocations; + this.adaptiveAllocationsSettings = adaptiveAllocationsSettings; } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 27709c2067a26..d2c22cdcf6f57 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -108,10 +108,12 @@ import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.NAME; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.assertArg; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.doAnswer; @@ -1767,7 +1769,9 @@ private void testUpdateModelsWithDynamicFields(Map> modelsBy modelsByDeploymentId.forEach((deploymentId, models) -> { var expectedNumberOfAllocations = updatedNumberOfAllocations.get(deploymentId); models.forEach(model -> { - verify((ElasticsearchInternalModel) model).updateNumAllocations(expectedNumberOfAllocations); + verify((ElasticsearchInternalModel) model).updateServiceSettings(assertArg(assignmentStats -> { + assertThat(assignmentStats.getNumberOfAllocations(), equalTo(expectedNumberOfAllocations)); + })); verify((ElasticsearchInternalModel) model).mlNodeDeploymentId(); verifyNoMoreInteractions(model); }); @@ -1858,7 +1862,9 @@ public void testUpdateWithMlEnabled() throws IOException, InterruptedException { var latch = new CountDownLatch(1); service.updateModelsWithDynamicFields(models, ActionTestUtils.assertNoFailureListener(r -> latch.countDown())); assertTrue(latch.await(30, TimeUnit.SECONDS)); - verify(model).updateNumAllocations(3); + verify(model).updateServiceSettings( + assertArg(assignmentStats -> { assertThat(assignmentStats.getNumberOfAllocations(), equalTo(3)); }) + ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModelTests.java index 5b21717ac03e4..3fee80b1fbe5f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModelTests.java @@ -7,8 +7,17 @@ package org.elasticsearch.xpack.inference.services.elasticsearch; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; +import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats; +import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; +import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentTests; + +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class ElserInternalModelTests extends ESTestCase { public void testUpdateNumAllocation() { @@ -21,10 +30,22 @@ public void testUpdateNumAllocation() { null ); - model.updateNumAllocations(1); - assertEquals(1, model.getServiceSettings().getNumAllocations().intValue()); + AssignmentStats assignmentStats = mock(); + when(assignmentStats.getNumberOfAllocations()).thenReturn(1); + model.updateServiceSettings(assignmentStats); + + assertThat(model.getServiceSettings().getNumAllocations(), equalTo(1)); + assertNull(model.getServiceSettings().getAdaptiveAllocationsSettings()); - model.updateNumAllocations(null); - assertNull(model.getServiceSettings().getNumAllocations()); + TrainedModelAssignment trainedModelAssignment = TrainedModelAssignmentTests.randomInstance(); + CreateTrainedModelAssignmentAction.Response response = mock(); + when(response.getTrainedModelAssignment()).thenReturn(trainedModelAssignment); + model.getCreateTrainedModelAssignmentActionListener(model, ActionListener.noop()).onResponse(response); + + assertThat(model.getServiceSettings().getNumAllocations(), equalTo(1)); + assertThat( + model.getServiceSettings().getAdaptiveAllocationsSettings(), + equalTo(trainedModelAssignment.getAdaptiveAllocationsSettings()) + ); } }