Skip to content

Commit 363c564

Browse files
committedMay 19, 2020
Clean up resources after worker thread is terminated
1 parent 70e8f1f commit 363c564

File tree

5 files changed

+88
-34
lines changed

5 files changed

+88
-34
lines changed
 

‎ci/buildspec.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ phases:
1010
- pip install pip -U
1111
- pip install future
1212
- pip install Pillow
13-
- pip install pytest
13+
- pip install pytest==4.0.0
1414
- pip install wheel
1515
- pip install twine
1616
- pip install pytest-mock -U

‎frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/WorkLoadManager.java

+30-24
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ public CompletableFuture<HttpResponseStatus> modelChanged(Model model) {
113113
if (minWorker == 0) {
114114
threads = workers.remove(model.getModelName());
115115
if (threads == null) {
116+
if (maxWorker == 0) {
117+
return shutdownServerThread(model, future);
118+
}
116119
future.complete(HttpResponseStatus.OK);
117120
return future;
118121
}
@@ -129,37 +132,40 @@ public CompletableFuture<HttpResponseStatus> modelChanged(Model model) {
129132
thread.shutdown();
130133
}
131134
if (maxWorker == 0) {
132-
model.getServerThread().shutdown();
133-
WorkerLifeCycle lifecycle = model.getServerThread().getLifeCycle();
134-
Process workerProcess = lifecycle.getProcess();
135-
if (workerProcess.isAlive()) {
136-
boolean workerDestroyed = false;
137-
workerProcess.destroyForcibly();
138-
try {
139-
workerDestroyed =
140-
workerProcess.waitFor(
141-
configManager.getUnregisterModelTimeout(),
142-
TimeUnit.SECONDS);
143-
} catch (InterruptedException e) {
144-
logger.warn(
145-
"WorkerThread interrupted during waitFor, possible asynch resource cleanup.");
146-
future.complete(HttpResponseStatus.INTERNAL_SERVER_ERROR);
147-
return future;
148-
}
149-
if (!workerDestroyed) {
150-
logger.warn(
151-
"WorkerThread timed out while cleaning, please resend request.");
152-
future.complete(HttpResponseStatus.REQUEST_TIMEOUT);
153-
return future;
154-
}
155-
}
135+
return shutdownServerThread(model, future);
156136
}
157137
future.complete(HttpResponseStatus.OK);
158138
}
159139
return future;
160140
}
161141
}
162142

143+
private CompletableFuture<HttpResponseStatus> shutdownServerThread(
144+
Model model, CompletableFuture<HttpResponseStatus> future) {
145+
model.getServerThread().shutdown();
146+
WorkerLifeCycle lifecycle = model.getServerThread().getLifeCycle();
147+
Process workerProcess = lifecycle.getProcess();
148+
if (workerProcess.isAlive()) {
149+
boolean workerDestroyed = false;
150+
workerProcess.destroyForcibly();
151+
try {
152+
workerDestroyed =
153+
workerProcess.waitFor(
154+
configManager.getUnregisterModelTimeout(), TimeUnit.SECONDS);
155+
} catch (InterruptedException e) {
156+
logger.warn(
157+
"WorkerThread interrupted during waitFor, possible asynch resource cleanup.");
158+
future.complete(HttpResponseStatus.INTERNAL_SERVER_ERROR);
159+
}
160+
if (!workerDestroyed) {
161+
logger.warn("WorkerThread timed out while cleaning, please resend request.");
162+
future.complete(HttpResponseStatus.REQUEST_TIMEOUT);
163+
}
164+
}
165+
future.complete(HttpResponseStatus.OK);
166+
return future;
167+
}
168+
163169
public void addServerThread(Model model, CompletableFuture<HttpResponseStatus> future)
164170
throws InterruptedException, ExecutionException, TimeoutException {
165171
WorkerStateListener listener = new WorkerStateListener(future, 1);

‎frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/WorkerLifeCycle.java

+34-5
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ public class WorkerLifeCycle {
4141
private CountDownLatch latch;
4242
private boolean success;
4343
private Connector connector;
44+
private ReaderThread errReader;
45+
private ReaderThread outReader;
4446

4547
public WorkerLifeCycle(ConfigManager configManager, Model model) {
4648
this.configManager = configManager;
@@ -84,10 +86,24 @@ private String[] getEnvString(String cwd, String modelPath, String handler) {
8486
return envList.toArray(new String[0]); // NOPMD
8587
}
8688

87-
public void attachIOStreams(String threadName, InputStream outStream, InputStream errStream) {
89+
public synchronized void attachIOStreams(
90+
String threadName, InputStream outStream, InputStream errStream) {
8891
logger.warn("attachIOStreams() threadName={}", threadName);
89-
new ReaderThread(threadName, errStream, true, this).start();
90-
new ReaderThread(threadName, outStream, false, this).start();
92+
errReader = new ReaderThread(threadName, errStream, true, this);
93+
outReader = new ReaderThread(threadName, outStream, false, this);
94+
errReader.start();
95+
outReader.start();
96+
}
97+
98+
public synchronized void terminateIOStreams() {
99+
if (errReader != null) {
100+
logger.warn("terminateIOStreams() threadName={}", errReader.getName());
101+
errReader.terminate();
102+
}
103+
if (outReader != null) {
104+
logger.warn("terminateIOStreams() threadName={}", outReader.getName());
105+
outReader.terminate();
106+
}
91107
}
92108

93109
public void startBackendServer(int port)
@@ -164,6 +180,7 @@ public synchronized void exit() {
164180
if (process != null) {
165181
process.destroyForcibly();
166182
connector.clean();
183+
terminateIOStreams();
167184
}
168185
}
169186

@@ -200,6 +217,7 @@ private static final class ReaderThread extends Thread {
200217
private InputStream is;
201218
private boolean error;
202219
private WorkerLifeCycle lifeCycle;
220+
private boolean isRunning = true;
203221
static final org.apache.log4j.Logger loggerModelMetrics =
204222
org.apache.log4j.Logger.getLogger(ConfigManager.MODEL_METRICS_LOGGER);
205223

@@ -210,10 +228,14 @@ public ReaderThread(String name, InputStream is, boolean error, WorkerLifeCycle
210228
this.lifeCycle = lifeCycle;
211229
}
212230

231+
public void terminate() {
232+
isRunning = false;
233+
}
234+
213235
@Override
214236
public void run() {
215237
try (Scanner scanner = new Scanner(is, StandardCharsets.UTF_8.name())) {
216-
while (scanner.hasNext()) {
238+
while (isRunning && scanner.hasNext()) {
217239
String result = scanner.nextLine();
218240
if (result == null) {
219241
break;
@@ -234,9 +256,16 @@ public void run() {
234256
logger.info(result);
235257
}
236258
}
259+
} catch (Exception e) {
260+
logger.error("Couldn't create scanner - {}", getName(), e);
237261
} finally {
238-
logger.error("Couldn't create scanner - {}", getName());
262+
logger.info("Stopped Scanner - {}", getName());
239263
lifeCycle.setSuccess(false);
264+
try {
265+
is.close();
266+
} catch (IOException e) {
267+
logger.error("Failed to close stream for thread {}", this.getName(), e);
268+
}
240269
}
241270
}
242271
}

‎frontend/server/src/main/java/com/amazonaws/ml/mms/wlm/WorkerThread.java

+22-3
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ public class WorkerThread implements Runnable {
8787

8888
private WorkerLifeCycle lifeCycle;
8989
private boolean serverThread;
90+
private RandomAccessFile out;
91+
private RandomAccessFile err;
92+
private Connector connector;
9093

9194
public WorkerState getState() {
9295
return state;
@@ -161,11 +164,11 @@ private void runWorker()
161164
case LOAD:
162165
String message = reply.getMessage();
163166
String tmpdir = System.getProperty("java.io.tmpdir");
164-
RandomAccessFile out =
167+
out =
165168
new RandomAccessFile(
166169
tmpdir + '/' + backendChannel.id().asLongText() + "-stdout",
167170
"rw");
168-
RandomAccessFile err =
171+
err =
169172
new RandomAccessFile(
170173
tmpdir + '/' + backendChannel.id().asLongText() + "-stderr",
171174
"rw");
@@ -286,7 +289,7 @@ private void connect()
286289

287290
final int responseBufferSize = ConfigManager.getInstance().getMaxResponseSize();
288291
try {
289-
Connector connector = new Connector(port);
292+
connector = new Connector(port);
290293
Bootstrap b = new Bootstrap();
291294
b.group(backendEventGroup)
292295
.channel(connector.getClientChannel())
@@ -381,6 +384,22 @@ public void shutdown() {
381384
model.removeJobQueue(backendChannel.id().asLongText());
382385
backendChannel.close();
383386
}
387+
if (this.serverThread && this.connector != null) {
388+
logger.debug("Cleaning connector socket");
389+
this.connector.clean();
390+
}
391+
logger.debug("Terminating IOStreams for worker thread shutdown");
392+
lifeCycle.terminateIOStreams();
393+
try {
394+
if (out != null) {
395+
out.close();
396+
}
397+
if (err != null) {
398+
err.close();
399+
}
400+
} catch (IOException e) {
401+
logger.error("Failed to close IO file handles", e);
402+
}
384403
Thread thread = currentThread.getAndSet(null);
385404
if (thread != null) {
386405
thread.interrupt();

‎frontend/server/src/test/java/com/amazonaws/ml/mms/ModelServerTest.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -1312,7 +1312,7 @@ private void testLoggingUnload(Channel inferChannel, Channel mgmtChannel)
13121312
Scanner logscanner = new Scanner(logfile, "UTF-8");
13131313
while (logscanner.hasNextLine()) {
13141314
String line = logscanner.nextLine();
1315-
if (line.contains("LoggingService exit")) {
1315+
if (line.contains("Model logging unregistered")) {
13161316
count = count + 1;
13171317
}
13181318
}

0 commit comments

Comments
 (0)
Please sign in to comment.