40
40
import java .util .function .Supplier ;
41
41
42
42
import org .opensearch .action .ActionRequest ;
43
+ import org .opensearch .arrow .spi .StreamManager ;
43
44
import org .opensearch .cluster .metadata .IndexNameExpressionResolver ;
44
45
import org .opensearch .cluster .node .DiscoveryNodes ;
45
46
import org .opensearch .cluster .service .ClusterService ;
312
313
import org .opensearch .plugins .Plugin ;
313
314
import org .opensearch .plugins .SearchPipelinePlugin ;
314
315
import org .opensearch .plugins .SearchPlugin ;
316
+ import org .opensearch .plugins .StreamManagerPlugin ;
315
317
import org .opensearch .plugins .SystemIndexPlugin ;
316
318
import org .opensearch .remote .metadata .client .SdkClient ;
317
319
import org .opensearch .remote .metadata .client .impl .SdkClientFactory ;
@@ -343,7 +345,8 @@ public class MachineLearningPlugin extends Plugin
343
345
SearchPipelinePlugin ,
344
346
ExtensiblePlugin ,
345
347
IngestPlugin ,
346
- SystemIndexPlugin {
348
+ SystemIndexPlugin ,
349
+ StreamManagerPlugin {
347
350
public static final String ML_THREAD_POOL_PREFIX = "thread_pool.ml_commons." ;
348
351
public static final String GENERAL_THREAD_POOL = "opensearch_ml_general" ;
349
352
public static final String SDK_CLIENT_THREAD_POOL = "opensearch_ml_sdkclient" ;
@@ -354,6 +357,7 @@ public class MachineLearningPlugin extends Plugin
354
357
public static final String INGEST_THREAD_POOL = "opensearch_ml_ingest" ;
355
358
public static final String REGISTER_THREAD_POOL = "opensearch_ml_register" ;
356
359
public static final String DEPLOY_THREAD_POOL = "opensearch_ml_deploy" ;
360
+ public static final String STREAM_PREDICT_THREAD_POOL = "opensearch_ml_predict_stream" ;
357
361
public static final String ML_BASE_URI = "/_plugins/_ml" ;
358
362
359
363
private MLStats mlStats ;
@@ -398,6 +402,14 @@ public class MachineLearningPlugin extends Plugin
398
402
private ScriptService scriptService ;
399
403
private Encryptor encryptor ;
400
404
405
+ private StreamManager streamManager ;
406
+
407
+ private StreamManager getStreamManagerRef () {
408
+ return this .streamManager ;
409
+ }
410
+
411
+ private Supplier <StreamManager > streamManagerSupplier = () -> { return getStreamManagerRef (); };
412
+
401
413
public MachineLearningPlugin (Settings settings ) {
402
414
// Handle this here as this feature is tied to Search/Query API, not to a ml-common API
403
415
// and as such, it can't be lazy-loaded when a ml-commons API is invoked.
@@ -523,7 +535,7 @@ public Collection<Object> createComponents(
523
535
524
536
encryptor = new EncryptorImpl (clusterService , client , sdkClient , mlIndicesHandler );
525
537
526
- mlEngine = new MLEngine (dataPath , encryptor );
538
+ mlEngine = new MLEngine (dataPath , encryptor , streamManagerSupplier );
527
539
nodeHelper = new DiscoveryNodeHelper (clusterService , settings );
528
540
modelCacheHelper = new MLModelCacheHelper (clusterService , settings );
529
541
cmHandler = new OpenSearchConversationalMemoryHandler (client , clusterService );
@@ -753,7 +765,8 @@ public Collection<Object> createComponents(
753
765
mlCircuitBreakerService ,
754
766
mlModelAutoRedeployer ,
755
767
cmHandler ,
756
- sdkClient
768
+ sdkClient ,
769
+ streamManagerSupplier
757
770
);
758
771
}
759
772
@@ -771,6 +784,11 @@ public List<RestHandler> getRestHandlers(
771
784
RestMLTrainingAction restMLTrainingAction = new RestMLTrainingAction ();
772
785
RestMLTrainAndPredictAction restMLTrainAndPredictAction = new RestMLTrainAndPredictAction ();
773
786
RestMLPredictionAction restMLPredictionAction = new RestMLPredictionAction (mlModelManager , mlFeatureEnabledSetting );
787
+ RestMLPredictionStreamingAction restMLPredictionStreamingAction = new RestMLPredictionStreamingAction (
788
+ mlModelManager ,
789
+ mlFeatureEnabledSetting ,
790
+ streamManagerSupplier
791
+ );
774
792
RestMLExecuteAction restMLExecuteAction = new RestMLExecuteAction (mlFeatureEnabledSetting );
775
793
RestMLGetModelAction restMLGetModelAction = new RestMLGetModelAction (mlFeatureEnabledSetting );
776
794
RestMLDeleteModelAction restMLDeleteModelAction = new RestMLDeleteModelAction (mlFeatureEnabledSetting );
@@ -835,6 +853,7 @@ public List<RestHandler> getRestHandlers(
835
853
restMLStatsAction ,
836
854
restMLTrainingAction ,
837
855
restMLPredictionAction ,
856
+ restMLPredictionStreamingAction ,
838
857
restMLExecuteAction ,
839
858
restMLTrainAndPredictAction ,
840
859
restMLGetModelAction ,
@@ -964,6 +983,14 @@ public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
964
983
ML_THREAD_POOL_PREFIX + INGEST_THREAD_POOL ,
965
984
false
966
985
);
986
+ FixedExecutorBuilder streamPredictThreadPool = new FixedExecutorBuilder (
987
+ settings ,
988
+ STREAM_PREDICT_THREAD_POOL ,
989
+ OpenSearchExecutors .allocatedProcessors (settings ) * 2 ,
990
+ 10000 ,
991
+ ML_THREAD_POOL_PREFIX + STREAM_PREDICT_THREAD_POOL ,
992
+ false
993
+ );
967
994
968
995
return ImmutableList
969
996
.of (
@@ -975,7 +1002,8 @@ public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
975
1002
predictThreadPool ,
976
1003
remotePredictThreadPool ,
977
1004
batchIngestThreadPool ,
978
- sdkClientThreadPool
1005
+ sdkClientThreadPool ,
1006
+ streamPredictThreadPool
979
1007
);
980
1008
}
981
1009
@@ -1174,4 +1202,10 @@ public Collection<SystemIndexDescriptor> getSystemIndexDescriptors(Settings sett
1174
1202
systemIndexDescriptors .add (new SystemIndexDescriptor (ML_STOP_WORDS_INDEX , "ML Commons Stop Words Index" ));
1175
1203
return systemIndexDescriptors ;
1176
1204
}
1205
+
1206
+ @ Override
1207
+ public void onStreamManagerInitialized (Supplier <StreamManager > streamManager ) {
1208
+ this .streamManager = streamManager .get ();
1209
+ }
1210
+
1177
1211
}
0 commit comments