Skip to content

Commit 2fddb40

Browse files
committedJun 20, 2019
Changes to load plugins and implement the SDK interfaces
1 parent b036771 commit 2fddb40

15 files changed

+331
-31
lines changed
 

‎.gitignore

+7
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,10 @@ ENV/
120120

121121
# Prop file in benchmark
122122
benchmarks/*.properties
123+
124+
# intellij files
125+
*.iml
126+
127+
# MMS files
128+
mms/frontend
129+
mms/plugins

‎frontend/gradle.properties

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ slf4j_api_version=1.7.25
55
slf4j_log4j12_version=1.7.25
66
gson_version=2.8.5
77
commons_cli_version=1.3.1
8-
testng_version=6.8.1
8+
testng_version=6.8.1
9+
mms_server_sdk_version=1.0.0

‎frontend/server/build.gradle

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ dependencies {
22
compile "io.netty:netty-all:${netty_version}"
33
compile project(":modelarchive")
44
compile "commons-cli:commons-cli:${commons_cli_version}"
5-
5+
compile "software.amazon.ai:mms-plugins-sdk:${mms_server_sdk_version}"
66
testCompile "org.testng:testng:${testng_version}"
77
}
88

‎frontend/server/src/main/java/com/amazonaws/ml/mms/ModelServer.java

+34-3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import com.amazonaws.ml.mms.archive.ModelArchive;
1616
import com.amazonaws.ml.mms.archive.ModelException;
1717
import com.amazonaws.ml.mms.metrics.MetricManager;
18+
import com.amazonaws.ml.mms.servingsdk_impl.PluginLoader;
1819
import com.amazonaws.ml.mms.util.ConfigManager;
1920
import com.amazonaws.ml.mms.util.Connector;
2021
import com.amazonaws.ml.mms.util.ServerGroups;
@@ -31,10 +32,13 @@
3132
import io.netty.util.internal.logging.Slf4JLoggerFactory;
3233
import java.io.File;
3334
import java.io.IOException;
35+
import java.lang.annotation.Annotation;
3436
import java.security.GeneralSecurityException;
3537
import java.util.ArrayList;
38+
import java.util.HashMap;
3639
import java.util.InvalidPropertiesFormatException;
3740
import java.util.List;
41+
import java.util.ServiceLoader;
3842
import java.util.Set;
3943
import java.util.concurrent.ExecutionException;
4044
import java.util.concurrent.atomic.AtomicBoolean;
@@ -45,6 +49,9 @@
4549
import org.apache.commons.cli.ParseException;
4650
import org.slf4j.Logger;
4751
import org.slf4j.LoggerFactory;
52+
import software.amazon.ai.mms.servingsdk.ModelServerEndpoint;
53+
import software.amazon.ai.mms.servingsdk.annotations.Endpoint;
54+
import software.amazon.ai.mms.servingsdk.annotations.helpers.EndpointTypes;
4855

4956
public class ModelServer {
5057

@@ -53,6 +60,8 @@ public class ModelServer {
5360
private ServerGroups serverGroups;
5461
private List<ChannelFuture> futures = new ArrayList<>(2);
5562
private AtomicBoolean stopped = new AtomicBoolean(false);
63+
private HashMap<String, ModelServerEndpoint> infEps;
64+
private HashMap<String, ModelServerEndpoint> mgmtEps;
5665

5766
private ConfigManager configManager;
5867

@@ -207,10 +216,8 @@ public ChannelFuture initializeServer(
207216
Connector connector, EventLoopGroup serverGroup, EventLoopGroup workerGroup)
208217
throws InterruptedException, IOException, GeneralSecurityException {
209218
final String purpose = connector.getPurpose();
210-
211219
Class<? extends ServerChannel> channelClass = connector.getServerChannel();
212220
logger.info("Initialize {} server with: {}.", purpose, channelClass.getSimpleName());
213-
214221
ServerBootstrap b = new ServerBootstrap();
215222
b.option(ChannelOption.SO_BACKLOG, 1024)
216223
.channel(channelClass)
@@ -223,7 +230,7 @@ public ChannelFuture initializeServer(
223230
if (connector.isSsl()) {
224231
sslCtx = configManager.getSslContext();
225232
}
226-
b.childHandler(new ServerInitializer(sslCtx, connector.isManagement()));
233+
b.childHandler(new ServerInitializer(sslCtx, connector.isManagement(), infEps, mgmtEps));
227234

228235
ChannelFuture future;
229236
try {
@@ -276,6 +283,9 @@ public List<ChannelFuture> start()
276283

277284
initModelStore();
278285

286+
infEps = PluginLoader.getInstance().getAllInferenceServingEndpoints();
287+
mgmtEps = PluginLoader.getInstance().getAllManagementServingEndpoints();
288+
279289
Connector inferenceConnector = configManager.getListener(false);
280290
Connector managementConnector = configManager.getListener(true);
281291
if (inferenceConnector.equals(managementConnector)) {
@@ -295,6 +305,27 @@ public List<ChannelFuture> start()
295305
return futures;
296306
}
297307

308+
private boolean validEndpoint(Annotation a, EndpointTypes type) {
309+
return a instanceof Endpoint
310+
&& !((Endpoint) a).urlPattern().isEmpty()
311+
&& ((Endpoint) a).endpointType().equals(type);
312+
}
313+
314+
private HashMap<String, ModelServerEndpoint> registerEndpoints(EndpointTypes type) {
315+
ServiceLoader<ModelServerEndpoint> loader = ServiceLoader.load(ModelServerEndpoint.class);
316+
HashMap<String, ModelServerEndpoint> ep = new HashMap<>();
317+
for (ModelServerEndpoint mep : loader) {
318+
Class<? extends ModelServerEndpoint> modelServerEndpointClassObj = mep.getClass();
319+
Annotation[] annotations = modelServerEndpointClassObj.getAnnotations();
320+
for (Annotation a : annotations) {
321+
if (validEndpoint(a, type)) {
322+
ep.put(((Endpoint) a).urlPattern(), mep);
323+
}
324+
}
325+
}
326+
return ep;
327+
}
328+
298329
public boolean isRunning() {
299330
return !stopped.get();
300331
}

‎frontend/server/src/main/java/com/amazonaws/ml/mms/ServerInitializer.java

+13-3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import io.netty.handler.codec.http.HttpObjectAggregator;
2222
import io.netty.handler.codec.http.HttpServerCodec;
2323
import io.netty.handler.ssl.SslContext;
24+
import java.util.Map;
25+
import software.amazon.ai.mms.servingsdk.ModelServerEndpoint;
2426

2527
/**
2628
* A special {@link io.netty.channel.ChannelInboundHandler} which offers an easy way to initialize a
@@ -31,16 +33,24 @@ public class ServerInitializer extends ChannelInitializer<Channel> {
3133

3234
private final boolean managementServer;
3335
private SslContext sslCtx;
36+
private final Map<String, ModelServerEndpoint> inferenceEndpoints;
37+
private final Map<String, ModelServerEndpoint> mgmtEndpoints;
3438

3539
/**
3640
* Creates a new {@code HttpRequestHandler} instance.
3741
*
3842
* @param sslCtx null if SSL is not enabled
3943
* @param managementServer true to initialize a management server instead of an API Server
4044
*/
41-
public ServerInitializer(SslContext sslCtx, boolean managementServer) {
45+
public ServerInitializer(
46+
SslContext sslCtx,
47+
boolean managementServer,
48+
Map<String, ModelServerEndpoint> infEp,
49+
Map<String, ModelServerEndpoint> mgmtEp) {
4250
this.sslCtx = sslCtx;
4351
this.managementServer = managementServer;
52+
inferenceEndpoints = infEp;
53+
mgmtEndpoints = mgmtEp;
4454
}
4555

4656
/** {@inheritDoc} */
@@ -54,9 +64,9 @@ public void initChannel(Channel ch) {
5464
pipeline.addLast("http", new HttpServerCodec());
5565
pipeline.addLast("aggregator", new HttpObjectAggregator(maxRequestSize));
5666
if (managementServer) {
57-
pipeline.addLast("handler", new ManagementRequestHandler());
67+
pipeline.addLast("handler", new ManagementRequestHandler(mgmtEndpoints));
5868
} else {
59-
pipeline.addLast("handler", new InferenceRequestHandler());
69+
pipeline.addLast("handler", new InferenceRequestHandler(inferenceEndpoints));
6070
}
6171
}
6272
}

‎frontend/server/src/main/java/com/amazonaws/ml/mms/http/InferenceRequestHandler.java

+50-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
import com.amazonaws.ml.mms.archive.ModelNotFoundException;
1616
import com.amazonaws.ml.mms.openapi.OpenApiUtils;
17+
import com.amazonaws.ml.mms.servingsdk_impl.ModelServerContext;
18+
import com.amazonaws.ml.mms.servingsdk_impl.ModelServerRequest;
19+
import com.amazonaws.ml.mms.servingsdk_impl.ModelServerResponse;
1720
import com.amazonaws.ml.mms.util.NettyUtils;
1821
import com.amazonaws.ml.mms.util.messages.InputParameter;
1922
import com.amazonaws.ml.mms.util.messages.RequestInput;
@@ -22,18 +25,25 @@
2225
import com.amazonaws.ml.mms.wlm.Model;
2326
import com.amazonaws.ml.mms.wlm.ModelManager;
2427
import io.netty.channel.ChannelHandlerContext;
28+
import io.netty.handler.codec.http.DefaultFullHttpResponse;
2529
import io.netty.handler.codec.http.FullHttpRequest;
30+
import io.netty.handler.codec.http.FullHttpResponse;
2631
import io.netty.handler.codec.http.HttpHeaderValues;
2732
import io.netty.handler.codec.http.HttpMethod;
33+
import io.netty.handler.codec.http.HttpResponseStatus;
2834
import io.netty.handler.codec.http.HttpUtil;
35+
import io.netty.handler.codec.http.HttpVersion;
2936
import io.netty.handler.codec.http.QueryStringDecoder;
3037
import io.netty.handler.codec.http.multipart.DefaultHttpDataFactory;
3138
import io.netty.handler.codec.http.multipart.HttpDataFactory;
3239
import io.netty.handler.codec.http.multipart.HttpPostRequestDecoder;
40+
import java.io.ByteArrayOutputStream;
41+
import java.io.PrintStream;
3342
import java.util.List;
3443
import java.util.Map;
3544
import org.slf4j.Logger;
3645
import org.slf4j.LoggerFactory;
46+
import software.amazon.ai.mms.servingsdk.ModelServerEndpoint;
3747

3848
/**
3949
* A class handling inbound HTTP requests to the management API.
@@ -43,9 +53,12 @@
4353
public class InferenceRequestHandler extends HttpRequestHandler {
4454

4555
private static final Logger logger = LoggerFactory.getLogger(InferenceRequestHandler.class);
56+
Map<String, ModelServerEndpoint> inferEndpointMap;
4657

4758
/** Creates a new {@code InferenceRequestHandler} instance. */
48-
public InferenceRequestHandler() {}
59+
public InferenceRequestHandler(Map<String, ModelServerEndpoint> endpointMap) {
60+
inferEndpointMap = endpointMap;
61+
}
4962

5063
@Override
5164
protected void handleRequest(
@@ -68,11 +81,46 @@ protected void handleRequest(
6881
handlePredictions(ctx, req, segments);
6982
break;
7083
default:
71-
handleLegacyPredict(ctx, req, decoder, segments);
84+
if (inferEndpointMap.getOrDefault(segments[1], null) != null) {
85+
handleCustomEndpoint(ctx, req, segments, decoder);
86+
} else {
87+
handleLegacyPredict(ctx, req, decoder, segments);
88+
}
7289
break;
7390
}
7491
}
7592

93+
private void handleCustomEndpoint(
94+
ChannelHandlerContext ctx,
95+
FullHttpRequest req,
96+
String[] segments,
97+
QueryStringDecoder decoder) {
98+
if (HttpMethod.GET.equals(req.method())) {
99+
ModelServerEndpoint endpoint = inferEndpointMap.get(segments[1]);
100+
Runnable r =
101+
() -> {
102+
FullHttpResponse rsp =
103+
new DefaultFullHttpResponse(
104+
HttpVersion.HTTP_1_1, HttpResponseStatus.OK, false);
105+
try {
106+
endpoint.doGet(
107+
new ModelServerRequest(req, decoder),
108+
new ModelServerResponse(rsp),
109+
new ModelServerContext());
110+
NettyUtils.sendHttpResponse(ctx, rsp, true);
111+
} catch (Exception e) {
112+
ByteArrayOutputStream ps = new ByteArrayOutputStream();
113+
e.printStackTrace(new PrintStream(ps));
114+
logger.error("Unknown exception", e);
115+
NettyUtils.sendError(ctx, HttpResponseStatus.NOT_IMPLEMENTED, e);
116+
}
117+
};
118+
ModelManager.getInstance().submitTask(r);
119+
} else {
120+
throw new ServiceUnavailableException("Unknown HTTP method called.");
121+
}
122+
}
123+
76124
@Override
77125
protected void handleApiDescription(ChannelHandlerContext ctx) {
78126
NettyUtils.sendJsonResponse(ctx, OpenApiUtils.listInferenceApis());

‎frontend/server/src/main/java/com/amazonaws/ml/mms/http/ManagementRequestHandler.java

+7-2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import java.util.Map;
3535
import java.util.concurrent.CompletableFuture;
3636
import java.util.function.Function;
37+
import software.amazon.ai.mms.servingsdk.ModelServerEndpoint;
3738

3839
/**
3940
* A class handling inbound HTTP requests to the management API.
@@ -42,8 +43,11 @@
4243
*/
4344
public class ManagementRequestHandler extends HttpRequestHandler {
4445

46+
Map<String, ModelServerEndpoint> mgmtEndpointMap;
4547
/** Creates a new {@code ManagementRequestHandler} instance. */
46-
public ManagementRequestHandler() {}
48+
public ManagementRequestHandler(Map<String, ModelServerEndpoint> endpointMap) {
49+
mgmtEndpointMap = endpointMap;
50+
}
4751

4852
@Override
4953
protected void handleRequest(
@@ -52,7 +56,8 @@ protected void handleRequest(
5256
QueryStringDecoder decoder,
5357
String[] segments)
5458
throws ModelException {
55-
if (!"models".equals(segments[1])) {
59+
if (!"models".equals(segments[1])
60+
&& mgmtEndpointMap.getOrDefault(segments[1], null) == null) {
5661
throw new ResourceNotFoundException();
5762
}
5863

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package com.amazonaws.ml.mms.servingsdk_impl;
2+
3+
import com.amazonaws.ml.mms.util.ConfigManager;
4+
import java.util.Properties;
5+
import software.amazon.ai.mms.servingsdk.Context;
6+
7+
public class ModelServerContext implements Context {
8+
@Override
9+
public Properties getConfig() {
10+
return ConfigManager.getInstance().getConfiguration();
11+
}
12+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package com.amazonaws.ml.mms.servingsdk_impl;
2+
3+
import io.netty.handler.codec.http.FullHttpRequest;
4+
import io.netty.handler.codec.http.HttpUtil;
5+
import io.netty.handler.codec.http.QueryStringDecoder;
6+
import java.io.ByteArrayInputStream;
7+
import java.util.ArrayList;
8+
import java.util.List;
9+
import java.util.Map;
10+
import software.amazon.ai.mms.servingsdk.http.Request;
11+
12+
public class ModelServerRequest implements Request {
13+
private FullHttpRequest req;
14+
private QueryStringDecoder decoder;
15+
16+
public ModelServerRequest(FullHttpRequest r, QueryStringDecoder d) {
17+
req = r;
18+
decoder = d;
19+
}
20+
21+
@Override
22+
public List<String> getHeaderNames() {
23+
return new ArrayList<>(req.headers().names());
24+
}
25+
26+
@Override
27+
public String getRequestURI() {
28+
return req.uri();
29+
}
30+
31+
@Override
32+
public Map<String, List<String>> getParameterMap() {
33+
return decoder.parameters();
34+
}
35+
36+
@Override
37+
public List<String> getParameter(String k) {
38+
return decoder.parameters().get(k);
39+
}
40+
41+
@Override
42+
public String getContentType() {
43+
return HttpUtil.getMimeType(req).toString();
44+
}
45+
46+
@Override
47+
public ByteArrayInputStream getInputStream() {
48+
return new ByteArrayInputStream(req.content().array());
49+
}
50+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package com.amazonaws.ml.mms.servingsdk_impl;
2+
3+
import io.netty.buffer.ByteBufOutputStream;
4+
import io.netty.handler.codec.http.FullHttpResponse;
5+
import io.netty.handler.codec.http.HttpHeaderNames;
6+
import io.netty.handler.codec.http.HttpResponseStatus;
7+
import java.io.OutputStream;
8+
import software.amazon.ai.mms.servingsdk.http.Response;
9+
10+
public class ModelServerResponse implements Response {
11+
12+
private FullHttpResponse response;
13+
14+
public ModelServerResponse(FullHttpResponse rsp) {
15+
response = rsp;
16+
}
17+
18+
@Override
19+
public void setStatus(int i) {
20+
response.setStatus(HttpResponseStatus.valueOf(i));
21+
}
22+
23+
@Override
24+
public void setStatus(int i, String s) {
25+
response.setStatus(HttpResponseStatus.valueOf(i, s));
26+
}
27+
28+
@Override
29+
public void setHeader(String k, String v) {
30+
response.headers().set(k, v);
31+
}
32+
33+
@Override
34+
public void addHeader(String k, String v) {
35+
response.headers().add(k, v);
36+
}
37+
38+
@Override
39+
public void setContentType(String contentType) {
40+
response.headers().set(HttpHeaderNames.CONTENT_TYPE, contentType);
41+
}
42+
43+
@Override
44+
public OutputStream getOutputStream() {
45+
return new ByteBufOutputStream(response.content());
46+
}
47+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package com.amazonaws.ml.mms.servingsdk_impl;
2+
3+
import java.lang.annotation.Annotation;
4+
import java.util.HashMap;
5+
import java.util.ServiceLoader;
6+
import software.amazon.ai.mms.servingsdk.ModelServerEndpoint;
7+
import software.amazon.ai.mms.servingsdk.annotations.Endpoint;
8+
import software.amazon.ai.mms.servingsdk.annotations.helpers.EndpointTypes;
9+
10+
public final class PluginLoader {
11+
12+
private static PluginLoader instance = new PluginLoader();
13+
14+
private PluginLoader() {}
15+
16+
public static PluginLoader getInstance() {
17+
return instance;
18+
}
19+
20+
private boolean validEndpoint(Annotation a, EndpointTypes type) {
21+
return a instanceof Endpoint
22+
&& !((Endpoint) a).urlPattern().isEmpty()
23+
&& ((Endpoint) a).endpointType().equals(type);
24+
}
25+
26+
private HashMap<String, ModelServerEndpoint> getEndpoints(EndpointTypes type) {
27+
ServiceLoader<ModelServerEndpoint> loader = ServiceLoader.load(ModelServerEndpoint.class);
28+
HashMap<String, ModelServerEndpoint> ep = new HashMap<>();
29+
for (ModelServerEndpoint mep : loader) {
30+
Class<? extends ModelServerEndpoint> modelServerEndpointClassObj = mep.getClass();
31+
Annotation[] annotations = modelServerEndpointClassObj.getAnnotations();
32+
for (Annotation a : annotations) {
33+
if (validEndpoint(a, type)) {
34+
ep.put(((Endpoint) a).urlPattern(), mep);
35+
}
36+
}
37+
}
38+
return ep;
39+
}
40+
41+
public HashMap<String, ModelServerEndpoint> getAllInferenceServingEndpoints() {
42+
return getEndpoints(EndpointTypes.INFERENCE);
43+
}
44+
45+
public HashMap<String, ModelServerEndpoint> getAllManagementServingEndpoints() {
46+
return getEndpoints(EndpointTypes.MANAGEMENT);
47+
}
48+
}

‎frontend/server/src/main/java/com/amazonaws/ml/mms/util/ConfigManager.java

+10-1
Original file line numberDiff line numberDiff line change
@@ -242,18 +242,27 @@ public String getMmsDefaultServiceHandler() {
242242
return getProperty(MMS_DEFAULT_SERVICE_HANDLER, null);
243243
}
244244

245+
public Properties getConfiguration() {
246+
return new Properties(prop);
247+
}
248+
245249
public int getDefaultWorkers() {
246250
if (isDebug()) {
247251
return 1;
248252
}
249-
250253
int workers = getIntProperty(MMS_DEFAULT_WORKERS_PER_MODEL, 0);
254+
255+
if ((workers == 0) && (prop.getProperty("NUM_WORKERS", null) != null)) {
256+
workers = getIntProperty("NUM_WORKERS", 0);
257+
}
258+
251259
if (workers == 0) {
252260
workers = getNumberOfGpu();
253261
}
254262
if (workers == 0) {
255263
workers = Runtime.getRuntime().availableProcessors();
256264
}
265+
setProperty("NUM_WORKERS", Integer.toString(workers));
257266
return workers;
258267
}
259268

‎frontend/server/src/test/resources/config.properties

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ load_models=noop-v0.1,noop-v1.0
99
# netty_client_threads=0
1010
# default_workers_per_model=0
1111
# job_queue_size=100
12+
# plugins_path=/tmp/plugins
1213
async_logging=true
1314
default_response_timeout=120
1415
# number_of_gpu=1

‎mms/model_server.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ def start():
7979
else:
8080
mms_conf_file = "config.properties"
8181

82+
class_path = \
83+
".:{}".format(os.path.join(mms_home, "mms/frontend/*"))
84+
8285
if os.path.isfile(mms_conf_file):
8386
props = load_properties(mms_conf_file)
8487
vm_args = props.get("vmargs")
@@ -89,9 +92,14 @@ def start():
8992
if word.startswith("-Dlog4j.configuration="):
9093
arg_list.remove(word)
9194
cmd.extend(arg_list)
95+
plugins = props.get("plugins_path", None)
96+
if plugins:
97+
class_path += ":" + plugins + "/*" if "*" not in plugins else ":" + plugins
98+
99+
cmd.append("-cp")
100+
cmd.append(class_path)
92101

93-
cmd.append("-jar")
94-
cmd.append("{}/mms/frontend/model-server.jar".format(mms_home))
102+
cmd.append("com.amazonaws.ml.mms.ModelServer")
95103

96104
# model-server.jar command line parameters
97105
cmd.append("--python")

‎setup.py

+39-16
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,14 @@
3333
import subprocess
3434
import sys
3535
from datetime import date
36-
from shutil import copyfile, rmtree
36+
from shutil import copy2, rmtree
3737

3838
import setuptools.command.build_py
3939
from setuptools import setup, find_packages, Command
4040

4141
import mms
4242

4343
pkgs = find_packages()
44-
source_server_file = os.path.abspath('frontend/server/build/libs/server-1.0.jar')
45-
dest_file_name = os.path.abspath('mms/frontend/model-server.jar')
46-
4744

4845
def pypi_description():
4946
"""
@@ -62,18 +59,13 @@ def detect_model_server_version():
6259
return mms.__version__.strip() + 'b' + str(date.today()).replace('-', '')
6360

6461

65-
class BuildFrontEnd(Command):
62+
class BuildFrontEnd(setuptools.command.build_py.build_py):
6663
"""
6764
Class defined to run custom commands.
6865
"""
6966
description = 'Build Model Server Frontend'
70-
user_options = []
71-
72-
def initialize_options(self):
73-
pass
74-
75-
def finalize_options(self):
76-
pass
67+
source_server_file = os.path.abspath('frontend/server/build/libs/server-1.0.jar')
68+
dest_file_name = os.path.abspath('mms/frontend/model-server.jar')
7769

7870
# noinspection PyMethodMayBeStatic
7971
def run(self):
@@ -90,18 +82,18 @@ def run(self):
9082
else:
9183
raise
9284

93-
if os.path.exists(source_server_file):
94-
os.remove(source_server_file)
85+
if os.path.exists(self.source_server_file):
86+
os.remove(self.source_server_file)
9587

9688
# Remove build/lib directory.
9789
if os.path.exists('build/lib/'):
9890
rmtree('build/lib/')
9991

10092
try:
101-
subprocess.check_call('frontend/gradlew -p frontend build', shell=True)
93+
subprocess.check_call('frontend/gradlew -p frontend clean build', shell=True)
10294
except OSError:
10395
assert 0, "build failed"
104-
copyfile(source_server_file, dest_file_name)
96+
copy2(self.source_server_file, self.dest_file_name)
10597

10698

10799
class BuildPy(setuptools.command.build_py.build_py):
@@ -115,6 +107,36 @@ def run(self):
115107
setuptools.command.build_py.build_py.run(self)
116108

117109

110+
class BuildPlugins(Command):
111+
description = 'Build Model Server Plugins'
112+
user_options = [('plugins=', 'p', 'Plugins installed')]
113+
source_plugin_dir = \
114+
os.path.abspath('plugins/build/plugins')
115+
116+
def initialize_options(self):
117+
self.plugins = None
118+
119+
def finalize_options(self):
120+
if self.plugins is None:
121+
print("No plugin option provided. Defaulting to 'default'")
122+
self.plugins = "default"
123+
124+
# noinspection PyMethodMayBeStatic
125+
def run(self):
126+
if os.path.isdir(self.source_plugin_dir):
127+
rmtree(self.source_plugin_dir)
128+
129+
try:
130+
if self.plugins == "sagemaker":
131+
subprocess.check_call('plugins/gradlew -p plugins clean bS', shell=True)
132+
else:
133+
raise OSError("No such rule exists")
134+
except OSError:
135+
assert 0, "build failed"
136+
137+
self.run_command('build_py')
138+
139+
118140
if __name__ == '__main__':
119141
version = detect_model_server_version()
120142

@@ -132,6 +154,7 @@ def run(self):
132154
packages=pkgs,
133155
cmdclass={
134156
'build_frontend': BuildFrontEnd,
157+
'build_plugins': BuildPlugins,
135158
'build_py': BuildPy,
136159
},
137160
install_requires=requirements,

0 commit comments

Comments
 (0)
Please sign in to comment.