15
15
import com .amazonaws .ml .mms .archive .ModelArchive ;
16
16
import com .amazonaws .ml .mms .archive .ModelException ;
17
17
import com .amazonaws .ml .mms .metrics .MetricManager ;
18
+ import com .amazonaws .ml .mms .servingsdk_impl .PluginLoader ;
18
19
import com .amazonaws .ml .mms .util .ConfigManager ;
19
20
import com .amazonaws .ml .mms .util .Connector ;
20
21
import com .amazonaws .ml .mms .util .ServerGroups ;
31
32
import io .netty .util .internal .logging .Slf4JLoggerFactory ;
32
33
import java .io .File ;
33
34
import java .io .IOException ;
35
+ import java .lang .annotation .Annotation ;
34
36
import java .security .GeneralSecurityException ;
35
37
import java .util .ArrayList ;
38
+ import java .util .HashMap ;
36
39
import java .util .InvalidPropertiesFormatException ;
37
40
import java .util .List ;
41
+ import java .util .ServiceLoader ;
38
42
import java .util .Set ;
39
43
import java .util .concurrent .ExecutionException ;
40
44
import java .util .concurrent .atomic .AtomicBoolean ;
45
49
import org .apache .commons .cli .ParseException ;
46
50
import org .slf4j .Logger ;
47
51
import org .slf4j .LoggerFactory ;
52
+ import software .amazon .ai .mms .servingsdk .ModelServerEndpoint ;
53
+ import software .amazon .ai .mms .servingsdk .annotations .Endpoint ;
54
+ import software .amazon .ai .mms .servingsdk .annotations .helpers .EndpointTypes ;
48
55
49
56
public class ModelServer {
50
57
@@ -53,6 +60,8 @@ public class ModelServer {
53
60
private ServerGroups serverGroups ;
54
61
private List <ChannelFuture > futures = new ArrayList <>(2 );
55
62
private AtomicBoolean stopped = new AtomicBoolean (false );
63
+ private HashMap <String , ModelServerEndpoint > infEps ;
64
+ private HashMap <String , ModelServerEndpoint > mgmtEps ;
56
65
57
66
private ConfigManager configManager ;
58
67
@@ -207,10 +216,8 @@ public ChannelFuture initializeServer(
207
216
Connector connector , EventLoopGroup serverGroup , EventLoopGroup workerGroup )
208
217
throws InterruptedException , IOException , GeneralSecurityException {
209
218
final String purpose = connector .getPurpose ();
210
-
211
219
Class <? extends ServerChannel > channelClass = connector .getServerChannel ();
212
220
logger .info ("Initialize {} server with: {}." , purpose , channelClass .getSimpleName ());
213
-
214
221
ServerBootstrap b = new ServerBootstrap ();
215
222
b .option (ChannelOption .SO_BACKLOG , 1024 )
216
223
.channel (channelClass )
@@ -223,7 +230,7 @@ public ChannelFuture initializeServer(
223
230
if (connector .isSsl ()) {
224
231
sslCtx = configManager .getSslContext ();
225
232
}
226
- b .childHandler (new ServerInitializer (sslCtx , connector .isManagement ()));
233
+ b .childHandler (new ServerInitializer (sslCtx , connector .isManagement (), infEps , mgmtEps ));
227
234
228
235
ChannelFuture future ;
229
236
try {
@@ -276,6 +283,9 @@ public List<ChannelFuture> start()
276
283
277
284
initModelStore ();
278
285
286
+ infEps = PluginLoader .getInstance ().getAllInferenceServingEndpoints ();
287
+ mgmtEps = PluginLoader .getInstance ().getAllManagementServingEndpoints ();
288
+
279
289
Connector inferenceConnector = configManager .getListener (false );
280
290
Connector managementConnector = configManager .getListener (true );
281
291
if (inferenceConnector .equals (managementConnector )) {
@@ -295,6 +305,27 @@ public List<ChannelFuture> start()
295
305
return futures ;
296
306
}
297
307
308
+ private boolean validEndpoint (Annotation a , EndpointTypes type ) {
309
+ return a instanceof Endpoint
310
+ && !((Endpoint ) a ).urlPattern ().isEmpty ()
311
+ && ((Endpoint ) a ).endpointType ().equals (type );
312
+ }
313
+
314
+ private HashMap <String , ModelServerEndpoint > registerEndpoints (EndpointTypes type ) {
315
+ ServiceLoader <ModelServerEndpoint > loader = ServiceLoader .load (ModelServerEndpoint .class );
316
+ HashMap <String , ModelServerEndpoint > ep = new HashMap <>();
317
+ for (ModelServerEndpoint mep : loader ) {
318
+ Class <? extends ModelServerEndpoint > modelServerEndpointClassObj = mep .getClass ();
319
+ Annotation [] annotations = modelServerEndpointClassObj .getAnnotations ();
320
+ for (Annotation a : annotations ) {
321
+ if (validEndpoint (a , type )) {
322
+ ep .put (((Endpoint ) a ).urlPattern (), mep );
323
+ }
324
+ }
325
+ }
326
+ return ep ;
327
+ }
328
+
298
329
public boolean isRunning () {
299
330
return !stopped .get ();
300
331
}
0 commit comments