Skip to content

Commit 7a1bdaf

Browse files
committed
rest stream predict api
Signed-off-by: Jing Zhang <[email protected]>
1 parent 26fc493 commit 7a1bdaf

File tree

10 files changed

+378
-8
lines changed

10 files changed

+378
-8
lines changed

common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java

+7
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
import org.opensearch.ml.common.AccessMode;
3737
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
3838

39+
import com.google.gson.JsonObject;
40+
import com.google.gson.JsonParser;
41+
3942
import lombok.Builder;
4043
import lombok.EqualsAndHashCode;
4144
import lombok.NoArgsConstructor;
@@ -346,6 +349,10 @@ public <T> T createPayload(String action, Map<String, String> parameters) {
346349

347350
if (!isJson(payload)) {
348351
throw new IllegalArgumentException("Invalid payload: " + payload);
352+
} else if (parameters.containsKey("stream")) {
353+
JsonObject jsonObject = JsonParser.parseString(payload).getAsJsonObject();
354+
jsonObject.addProperty("stream", true);
355+
payload = jsonObject.toString();
349356
}
350357
return (T) payload;
351358
}

ml-algorithms/build.gradle

+11
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,22 @@ plugins {
1414
}
1515

1616
repositories {
17+
mavenLocal()
1718
mavenCentral()
1819
}
1920

2021
dependencies {
2122
implementation project(path: ":${rootProject.name}-spi", configuration: 'shadow')
2223
implementation project(path: ":${rootProject.name}-common", configuration: 'shadow')
2324
implementation project(':opensearch-ml-memory')
25+
26+
implementation "org.opensearch:opensearch-arrow-spi:${opensearch_version}"
27+
implementation "org.apache.arrow:arrow-vector:${versions.arrow}"
28+
implementation "org.apache.arrow:arrow-format:${versions.arrow}"
29+
implementation "org.apache.arrow:arrow-memory-core:${versions.arrow}"
30+
runtimeOnly "org.apache.arrow:arrow-memory-netty:${versions.arrow}"
31+
runtimeOnly "org.apache.arrow:arrow-memory-netty-buffer-patch:${versions.arrow}"
32+
2433
compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
2534
implementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}"
2635
testImplementation "org.opensearch.test:framework:${opensearch_version}"
@@ -88,6 +97,8 @@ dependencies {
8897
testImplementation("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}")
8998
testImplementation("com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}")
9099
testImplementation group: 'com.networknt' , name: 'json-schema-validator', version: '1.4.0'
100+
api group: 'com.squareup.okhttp3', name: 'okhttp', version: '4.12.0'
101+
implementation group: 'com.squareup.okhttp3', name: 'okhttp-sse', version: '4.12.0'
91102
}
92103

93104
lombok {

ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java

+7-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
import java.nio.file.Path;
1212
import java.util.Locale;
1313
import java.util.Map;
14+
import java.util.function.Supplier;
1415

16+
import org.opensearch.arrow.spi.StreamManager;
1517
import org.opensearch.core.action.ActionListener;
1618
import org.opensearch.ml.common.FunctionName;
1719
import org.opensearch.ml.common.MLModel;
@@ -49,11 +51,14 @@ public class MLEngine {
4951

5052
private Encryptor encryptor;
5153

52-
public MLEngine(Path opensearchDataFolder, Encryptor encryptor) {
54+
private Supplier<StreamManager> streamManager;
55+
56+
public MLEngine(Path opensearchDataFolder, Encryptor encryptor, Supplier<StreamManager> streamManager) {
5357
this.mlCachePath = opensearchDataFolder.resolve("ml_cache");
5458
this.mlModelsCachePath = mlCachePath.resolve("models_cache");
5559
this.mlConfigPath = mlCachePath.resolve("config");
5660
this.encryptor = encryptor;
61+
this.streamManager = streamManager;
5762
}
5863

5964
public String getPrebuiltModelMetaListPath() {
@@ -141,7 +146,7 @@ public Map<String, String> getConnectorCredential(Connector connector) {
141146

142147
public Predictable deploy(MLModel mlModel, Map<String, Object> params) {
143148
Predictable predictable = MLEngineClassLoader.initInstance(mlModel.getAlgorithm(), null, MLAlgoParams.class);
144-
predictable.initModel(mlModel, params, encryptor);
149+
predictable.initModel(mlModel, params, encryptor, streamManager);
145150
return predictable;
146151
}
147152

ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java

+9-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
package org.opensearch.ml.engine;
77

88
import java.util.Map;
9+
import java.util.function.Supplier;
910

11+
import org.opensearch.arrow.spi.StreamManager;
1012
import org.opensearch.core.action.ActionListener;
1113
import org.opensearch.ml.common.MLModel;
1214
import org.opensearch.ml.common.input.MLInput;
@@ -47,7 +49,13 @@ default void asyncPredict(MLInput mlInput, ActionListener<MLTaskResponse> action
4749
* @param params other parameters
4850
* @param encryptor encryptor
4951
*/
50-
void initModel(MLModel model, Map<String, Object> params, Encryptor encryptor);
52+
default void initModel(MLModel model, Map<String, Object> params, Encryptor encryptor) {
53+
54+
};
55+
56+
default void initModel(MLModel model, Map<String, Object> params, Encryptor encryptor, Supplier<StreamManager> streamManager) {
57+
58+
};
5159

5260
/**
5361
* Close resources like deployed model.

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java

+19
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717
import java.util.Map;
1818
import java.util.Optional;
1919
import java.util.concurrent.atomic.AtomicBoolean;
20+
import java.util.function.Supplier;
2021

2122
import org.apache.logging.log4j.Logger;
2223
import org.opensearch.ExceptionsHelper;
2324
import org.opensearch.OpenSearchStatusException;
2425
import org.opensearch.action.bulk.BackoffPolicy;
2526
import org.opensearch.action.support.GroupedActionListener;
2627
import org.opensearch.action.support.RetryableAction;
28+
import org.opensearch.arrow.spi.StreamManager;
2729
import org.opensearch.cluster.service.ClusterService;
2830
import org.opensearch.common.collect.Tuple;
2931
import org.opensearch.common.unit.TimeValue;
@@ -221,6 +223,8 @@ && getUserRateLimiterMap().get(user.getName()) != null
221223
}
222224
if (getConnectorClientConfig().getMaxRetryTimes() != 0) {
223225
invokeRemoteServiceWithRetry(action, mlInput, parameters, payload, executionContext, actionListener);
226+
} else if (parameters.containsKey("stream")) {
227+
invokeRemoteServiceStream(action, mlInput, parameters, payload, executionContext, actionListener);
224228
} else {
225229
invokeRemoteService(action, mlInput, parameters, payload, executionContext, actionListener);
226230
}
@@ -337,4 +341,19 @@ class RetryableActionExtensionArgs {
337341
private final ExecutionContext executionContext;
338342
private final String payload;
339343
}
344+
345+
void invokeRemoteServiceStream(
346+
String action,
347+
MLInput mlInput,
348+
Map<String, String> parameters,
349+
String payload,
350+
ExecutionContext executionContext,
351+
ActionListener<Tuple<Integer, ModelTensors>> actionListener
352+
);
353+
354+
default void setStreamManager(Supplier<StreamManager> streamManager) {}
355+
356+
default Supplier<StreamManager> getStreamManager() {
357+
return null;
358+
};
340359
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99

1010
import java.util.Map;
1111
import java.util.concurrent.atomic.AtomicBoolean;
12+
import java.util.function.Supplier;
1213

14+
import org.opensearch.arrow.spi.StreamManager;
1315
import org.opensearch.cluster.service.ClusterService;
1416
import org.opensearch.common.util.TokenBucket;
1517
import org.opensearch.core.action.ActionListener;
@@ -98,7 +100,7 @@ public boolean isModelReady() {
98100
}
99101

100102
@Override
101-
public void initModel(MLModel model, Map<String, Object> params, Encryptor encryptor) {
103+
public void initModel(MLModel model, Map<String, Object> params, Encryptor encryptor, Supplier<StreamManager> streamManager) {
102104
try {
103105
Connector connector = model.getConnector().cloneConnector();
104106
connector
@@ -112,6 +114,7 @@ public void initModel(MLModel model, Map<String, Object> params, Encryptor encry
112114
this.connectorExecutor.setUserRateLimiterMap((Map<String, TokenBucket>) params.get(USER_RATE_LIMITER_MAP));
113115
this.connectorExecutor.setMlGuard((MLGuard) params.get(GUARDRAILS));
114116
this.connectorExecutor.setConnectorPrivateIpEnabled((AtomicBoolean) params.get(CONNECTOR_PRIVATE_IP_ENABLED));
117+
this.connectorExecutor.setStreamManager(streamManager);
115118
} catch (RuntimeException e) {
116119
log.error("Failed to init remote model.", e);
117120
throw e;

plugin/build.gradle

+23
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,19 @@ dependencies {
5757
implementation project(':opensearch-ml-algorithms')
5858
implementation project(':opensearch-ml-search-processors')
5959
implementation project(':opensearch-ml-memory')
60+
61+
implementation "org.opensearch:opensearch-arrow-spi:${opensearchVersion}"
62+
implementation "org.apache.arrow:arrow-vector:${versions.arrow}"
63+
implementation "org.apache.arrow:arrow-format:${versions.arrow}"
64+
implementation "org.apache.arrow:arrow-memory-core:${versions.arrow}"
65+
runtimeOnly "org.apache.arrow:arrow-memory-netty:${versions.arrow}"
66+
runtimeOnly "org.apache.arrow:arrow-memory-netty-buffer-patch:${versions.arrow}"
67+
68+
implementation "io.netty:netty-buffer:${versions.netty}"
69+
implementation "io.netty:netty-common:${versions.netty}"
70+
compileOnly 'org.checkerframework:checker-qual:3.44.0'
71+
72+
6073
compileOnly "com.google.guava:guava:32.1.3-jre"
6174

6275
implementation group: 'software.amazon.awssdk', name: 'aws-core', version: "2.30.18"
@@ -424,6 +437,16 @@ configurations.all {
424437
resolutionStrategy.force "jakarta.json:jakarta.json-api:2.1.3"
425438
resolutionStrategy.force "org.opensearch:opensearch:${opensearch_version}"
426439
resolutionStrategy.force "org.bouncycastle:bcprov-jdk18on:1.78.1"
440+
resolutionStrategy.force "org.jetbrains.kotlin:kotlin-stdlib-jdk8:1.9.10"
441+
resolutionStrategy.force "org.jetbrains.kotlin:kotlin-stdlib:1.9.10"
442+
resolutionStrategy.force "org.jetbrains.kotlin:kotlin-stdlib-jdk7:1.9.10"
443+
resolutionStrategy.force "org.jetbrains.kotlin:kotlin-stdlib-common:1.9.10"
444+
resolutionStrategy.force "org.checkerframework:checker-qual:3.44.0"
445+
resolutionStrategy.force 'io.netty:netty-buffer:4.1.118.Final'
446+
resolutionStrategy.force 'io.netty:netty-common:4.1.118.Final'
447+
resolutionStrategy.force 'com.fasterxml.jackson.core:jackson-annotations:2.18.2'
448+
resolutionStrategy.force 'com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.18.2'
449+
resolutionStrategy.force 'com.google.flatbuffers:flatbuffers-java:24.3.25'
427450
}
428451

429452
apply plugin: 'com.netflix.nebula.ospackage'

plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

+38-4
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import java.util.function.Supplier;
4141

4242
import org.opensearch.action.ActionRequest;
43+
import org.opensearch.arrow.spi.StreamManager;
4344
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
4445
import org.opensearch.cluster.node.DiscoveryNodes;
4546
import org.opensearch.cluster.service.ClusterService;
@@ -312,6 +313,7 @@
312313
import org.opensearch.plugins.Plugin;
313314
import org.opensearch.plugins.SearchPipelinePlugin;
314315
import org.opensearch.plugins.SearchPlugin;
316+
import org.opensearch.plugins.StreamManagerPlugin;
315317
import org.opensearch.plugins.SystemIndexPlugin;
316318
import org.opensearch.remote.metadata.client.SdkClient;
317319
import org.opensearch.remote.metadata.client.impl.SdkClientFactory;
@@ -343,7 +345,8 @@ public class MachineLearningPlugin extends Plugin
343345
SearchPipelinePlugin,
344346
ExtensiblePlugin,
345347
IngestPlugin,
346-
SystemIndexPlugin {
348+
SystemIndexPlugin,
349+
StreamManagerPlugin {
347350
public static final String ML_THREAD_POOL_PREFIX = "thread_pool.ml_commons.";
348351
public static final String GENERAL_THREAD_POOL = "opensearch_ml_general";
349352
public static final String SDK_CLIENT_THREAD_POOL = "opensearch_ml_sdkclient";
@@ -354,6 +357,7 @@ public class MachineLearningPlugin extends Plugin
354357
public static final String INGEST_THREAD_POOL = "opensearch_ml_ingest";
355358
public static final String REGISTER_THREAD_POOL = "opensearch_ml_register";
356359
public static final String DEPLOY_THREAD_POOL = "opensearch_ml_deploy";
360+
public static final String STREAM_PREDICT_THREAD_POOL = "opensearch_ml_predict_stream";
357361
public static final String ML_BASE_URI = "/_plugins/_ml";
358362

359363
private MLStats mlStats;
@@ -398,6 +402,14 @@ public class MachineLearningPlugin extends Plugin
398402
private ScriptService scriptService;
399403
private Encryptor encryptor;
400404

405+
private StreamManager streamManager;
406+
407+
private StreamManager getStreamManagerRef() {
408+
return this.streamManager;
409+
}
410+
411+
private Supplier<StreamManager> streamManagerSupplier = () -> { return getStreamManagerRef(); };
412+
401413
public MachineLearningPlugin(Settings settings) {
402414
// Handle this here as this feature is tied to Search/Query API, not to a ml-common API
403415
// and as such, it can't be lazy-loaded when a ml-commons API is invoked.
@@ -523,7 +535,7 @@ public Collection<Object> createComponents(
523535

524536
encryptor = new EncryptorImpl(clusterService, client, sdkClient, mlIndicesHandler);
525537

526-
mlEngine = new MLEngine(dataPath, encryptor);
538+
mlEngine = new MLEngine(dataPath, encryptor, streamManagerSupplier);
527539
nodeHelper = new DiscoveryNodeHelper(clusterService, settings);
528540
modelCacheHelper = new MLModelCacheHelper(clusterService, settings);
529541
cmHandler = new OpenSearchConversationalMemoryHandler(client, clusterService);
@@ -753,7 +765,8 @@ public Collection<Object> createComponents(
753765
mlCircuitBreakerService,
754766
mlModelAutoRedeployer,
755767
cmHandler,
756-
sdkClient
768+
sdkClient,
769+
streamManagerSupplier
757770
);
758771
}
759772

@@ -771,6 +784,11 @@ public List<RestHandler> getRestHandlers(
771784
RestMLTrainingAction restMLTrainingAction = new RestMLTrainingAction();
772785
RestMLTrainAndPredictAction restMLTrainAndPredictAction = new RestMLTrainAndPredictAction();
773786
RestMLPredictionAction restMLPredictionAction = new RestMLPredictionAction(mlModelManager, mlFeatureEnabledSetting);
787+
RestMLPredictionStreamingAction restMLPredictionStreamingAction = new RestMLPredictionStreamingAction(
788+
mlModelManager,
789+
mlFeatureEnabledSetting,
790+
streamManagerSupplier
791+
);
774792
RestMLExecuteAction restMLExecuteAction = new RestMLExecuteAction(mlFeatureEnabledSetting);
775793
RestMLGetModelAction restMLGetModelAction = new RestMLGetModelAction(mlFeatureEnabledSetting);
776794
RestMLDeleteModelAction restMLDeleteModelAction = new RestMLDeleteModelAction(mlFeatureEnabledSetting);
@@ -835,6 +853,7 @@ public List<RestHandler> getRestHandlers(
835853
restMLStatsAction,
836854
restMLTrainingAction,
837855
restMLPredictionAction,
856+
restMLPredictionStreamingAction,
838857
restMLExecuteAction,
839858
restMLTrainAndPredictAction,
840859
restMLGetModelAction,
@@ -964,6 +983,14 @@ public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
964983
ML_THREAD_POOL_PREFIX + INGEST_THREAD_POOL,
965984
false
966985
);
986+
FixedExecutorBuilder streamPredictThreadPool = new FixedExecutorBuilder(
987+
settings,
988+
STREAM_PREDICT_THREAD_POOL,
989+
OpenSearchExecutors.allocatedProcessors(settings) * 2,
990+
10000,
991+
ML_THREAD_POOL_PREFIX + STREAM_PREDICT_THREAD_POOL,
992+
false
993+
);
967994

968995
return ImmutableList
969996
.of(
@@ -975,7 +1002,8 @@ public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
9751002
predictThreadPool,
9761003
remotePredictThreadPool,
9771004
batchIngestThreadPool,
978-
sdkClientThreadPool
1005+
sdkClientThreadPool,
1006+
streamPredictThreadPool
9791007
);
9801008
}
9811009

@@ -1174,4 +1202,10 @@ public Collection<SystemIndexDescriptor> getSystemIndexDescriptors(Settings sett
11741202
systemIndexDescriptors.add(new SystemIndexDescriptor(ML_STOP_WORDS_INDEX, "ML Commons Stop Words Index"));
11751203
return systemIndexDescriptors;
11761204
}
1205+
1206+
@Override
1207+
public void onStreamManagerInitialized(Supplier<StreamManager> streamManager) {
1208+
this.streamManager = streamManager.get();
1209+
}
1210+
11771211
}

0 commit comments

Comments
 (0)