diff --git a/docs/proposals/horovod-on-tony.md b/docs/proposals/horovod-on-tony.md new file mode 100644 index 00000000..e173829c --- /dev/null +++ b/docs/proposals/horovod-on-tony.md @@ -0,0 +1,235 @@ +- [Motivation](#motivation) +- [Goals](#goals) +- [API](#api) +- [Design](#design) +- [Usage](#usage) + +## Motivation +The purpose of Horovod is to make it easy to take a single-GPU training script and successfully scale it to train across many GPUs in parallel. This has two aspects: +1. Easy to use +2. Run faster in distributed mode + +However, because of the limitation of SSH mechanism (hard to support SSH on Yarn), TonY don't support horovod with **MPI controller**. With the help of **Gloo controller**, we are expecting to support horovod. + +## Goals +1. Support Horovod on TonY (limited static topology support) +2. Elastic Horovod will be supported later. + +## API +There are no API changes to other machine learning frameworks. + +## Design +Attention: It's gloo controller that makes TonY support horovod. + +From the perspective of compatibility and maintainability, it's better to directly use the [gloo_runner](https://github.com/horovod/horovod/blob/master/horovod/runner/gloo_run.py) on TonY. But after reading through Horovod's code, I find it difficult to reuse gloo_runner's code on TonY, because the existence of some unrelated codes will lead driver to start the worker through the SSH command, which is hard to be supportted on Yarn. + +After having a deep understanding of Horovod code and [communicating with developers](https://github.com/horovod/horovod/discussions/2785), i know that the Gloo controller uses a rendezvous server to assign each worker role, and provides HTTP API for workers to obtain cluster information. So each worker can build a training cluster and start training at the same time. + +Horovod is served as two roles, worker and driver. Driver is responsible for starting the rendezvous server and will not participate in training (no GPU required, lightweight). Before starting, driver need to know all workers' hostnames in advance. The worker is only responsible for training. According to TonY's architecture (**Application master** and **task executor**), the design can be as follows. + +### Horovod Driver +__How to start rendezvous server__ +Reusing Horovod rendezvous server code, we introduce tony-horovod driver launcher to offer a python script +```python +# Init the horovod rendezous server +global_rendezv = RendezvousServer(verbose=1) +# Output server port, which will be used horovod worker to connect server. +global_rendezv_port = global_rendezv.start() +print("Rendezvous server started, port: " + str(global_rendezv_port)) + +hosts = parse_hosts(worker_list) +# Output the host plan, it will output local_rank, rank and so on. +host_alloc_plan = get_host_assignments(hosts, 1) + +# Start the server. +global_rendezv.init(host_alloc_plan) +``` + +__When to start driver__ +After all workers' resource have be assigned and TonY's Application master could get all workers' registry info. + +__Where to start driver__ +Two options +1. On TonY application master. +This will save resources(no extra resources to start driver), and the amount of code changes will be small. But by injecting relevant Horovod's driver code into AM, it is not elegant. +2. On TonY task executor. +Additional customization of the driver configuration is required and the startup of driver will be covered on TonY automatically. And it is necessary to coordinate the startup sequence between the driver and other workers, because driver should start before worker. + +Second option will be adopted in this PR. +In order to unify different machine framework startup, we supposed to create `FrameworkRuntime` interface, it will expose methods as follows +```java + /** For AM, getting cluster spec and return to task exectuor **/ + String constructClusterSpec(String taskId) throws IOException; + + /** For AM, when app finished, it need to call it to release resource **/ + void destroy(); + + /** For AM, init the tony session **/ + void setTonySession(final TonySession session); + + /** For AM, it ensures that each task executor start sequence. like Horovod driver should start before workers **/ + boolean canStartTask(TonyConfigurationKeys.DistributedMode distributedMode, String taskId); + + /** For AM, it will pre-check tony conf and inject some params. like horovod runtime will inject driver config into it. **/ + boolean validateAndUpdateConfig(Configuration tonyConf); + + /** + * For AM, it will receive some callback info from task executor. + * This method will be called when Application Master accepting task executors' callback info. + * This method is suitable for the task executors that have a dependency of startup sequence, + * and the start of downstream tasks needs to rely on the info after the start of the upstream task. + */ + boolean receiveTaskCallbackInfo(String taskId, String callbackInfo); + + /** For TaskExecutor, execute task process **/ + int run(TaskExecutor executor) throws Exception; +``` + +So, we need to create `HorovodRuntime` to support it. Besides, TF/PyTorch/MXNet will also be supported in independent runtime, like `TFRuntime`. + +As stated in the design above, Horovod driver should be started on one task executor and before other workers. So in `HorovodRuntime`, we can use `canStartTask` method to coordinate task executor startup sequence. + +Besides, how to start Horovod driver? I think we can create `HorovodDriver` class to do it. Its methods as follows. +```java +public class HorovodDriver { + public final Process taskProcess; + public final int port; + public final List slotInfoList; + + // For TaskExecutor to start horovod driver, it will start rendezvous server + public synchronized static HorovodDriver create(String workerList) throws Exception { + return startRendezvousServer(workerList); + } + + private static HorovodDriver startRendezvousServer(String workerlist) throws Exception { + ... + } + + public void close() { + if (taskProcess != null) { + killProcess(taskProcess); + } + } + + // For TaskExecutor to wait process finish, it will hang until python Process exit. + public int waitFor(long timeout) throws InterruptedException { + this.taskProcess.waitFor(timeout, TimeUnit.MICROSECONDS); + return this.taskProcess.exitValue(); + } +} +``` + +How to extend `HorovodRuntime`. pseudo code as follows +```java +public class HorovodRuntime implements MLFrameworkRuntime { + private volatile boolean isDriverReady = false; + + private List workerSlotMetaInfo; + private String rendezvServerPort; + private String rendezvServerHost; + + @Override + public String constructClusterSpec(String taskId) throws IOException { + // when task is Driver, it will return worker list to driver, and make it start rendezvous server + + // when task is worker, it will return rendezouvs server's slot info to worker + } + + @Override + public boolean receiveTaskCallbackInfo(String taskId, String callbackInfo) { + // when role is driver, AM will accept driver's callback info, which is slot info including horovod + // host plan. It will be recorded in runtime and give to workers. + } + + @Override + public boolean canStartTask(TonyConfigurationKeys.DistributedMode distributedMode, String taskId) { + // coordinate startup sequence + } + + @Override + public boolean preCheck(Configuration tonyConf) { + // inject driver conf and make it untracked. + tonyConf.set("tony.driver.instances", "1"); + tonyConf.set("tony.driver.vcores", "1"); + tonyConf.set("tony.application.untracked.jobtypes", "driver"); + return true; + } + + // ===================For task executor======================= + + public void buildTaskEnv(TaskExecutor executor) throws Exception { + // set env for worker, like HOROVOD_CONTROLLER, HOST + } + + @Override + public int run(TaskExecutor executor) throws Exception { + buildTaskEnv(executor); + // if it is driver, it will launcher horovod driver and register info to AM + if (DRIVER.equals(executor.getJobName())) { + HorovodDriver driver = HorovodDriver.create(executor.getClusterSpec()); + String callBackInfo = driver.getCallbackInfo(); + log.info("Horovod driver call back to AM: \n" + callBackInfo); + executor.registerCallbackInfo(callBackInfo); + int exitCode = driver.waitFor(); + return exitCode; + } + + // if it is worker, it will execute training script directly. + return this.executorPythonShell(executor); + } +} +``` + +### Horovod Worker +__where to start worker__ +Only on TonY's task executor. + +__How to start worker__ +Just like start tensorflow task. +But some envs should be injected before starting worker. + +``` +HOROVOD_CONTROLLER=gloo +HOROVOD_CPU_OPERATIONS=gloo +HOROVOD_GLOO_TIMEOUT_SECONDS=2000 +HOROVOD_GLOO_RENDEZVOUS_PORT=9999 +HOROVOD_GLOO_RENDEZVOUS_ADDR=localhost + +HOROVOD_CROSS_RANK=0 +HOROVOD_CROSS_SIZE=1 +HOROVOD_LOCAL_RANK=0 +HOROVOD_LOCAL_SIZE=1 +HOROVOD_SIZE=1 +HOROVOD_RANK=0 +HOROVOD_HOSTNAME=0.0.0.0 +``` +__How to get these horovod params?__ +Acutally, these params are from **host_alloc_plan**(mentioned in previous python code). The python script should output these params and AM will get them and assign to task executor. + +## Usage +tony-test.xml is as follows, more details are shown on tony-examples module. + +``` + + + tony.worker.instances + 4 + + + tony.worker.memory + 3g + + + tony.docker.enabled + true + + + tony.docker.containers.image + YOUR_DOCKER_IMAGE_ADDRESS + + + tony.application.framework + horovod + + +``` \ No newline at end of file diff --git a/tony-core/src/main/java/com/linkedin/tony/ApplicationMaster.java b/tony-core/src/main/java/com/linkedin/tony/ApplicationMaster.java index 72cd208d..e3472cc3 100644 --- a/tony-core/src/main/java/com/linkedin/tony/ApplicationMaster.java +++ b/tony-core/src/main/java/com/linkedin/tony/ApplicationMaster.java @@ -725,7 +725,6 @@ private void stop() { } frameworkRuntime.destroy(); - nmClientAsync.stop(); amRMClient.stop(); // Poll until TonyClient signals we should exit @@ -868,12 +867,6 @@ public Set getTaskInfos() { return Collections.emptySet(); } - @Override - public String getClusterSpec() throws IOException { - ObjectMapper objectMapper = new ObjectMapper(); - return objectMapper.writeValueAsString(session.getClusterSpec()); - } - @Override public void taskExecutorHeartbeat(String taskId) { TonyTask task = session.getTask(taskId); @@ -885,6 +878,12 @@ public void taskExecutorHeartbeat(String taskId) { } } + @Override + public String getClusterSpec() throws IOException { + ObjectMapper objectMapper = new ObjectMapper(); + return objectMapper.writeValueAsString(session.getClusterSpec()); + } + @Override public String registerWorkerSpec(String taskId, String spec) throws IOException { TonyTask task = session.getTask(taskId); @@ -901,13 +900,9 @@ public String registerWorkerSpec(String taskId, String spec) throws IOException killChiefWorkerIfTesting(taskId); } - // two distributed mode (default is GANG) cases: - // 1. In FCFS mode, task will be allowed to run when AM accept worker registered spec, - // 2. In GANG mode, it will start until all tasks have registered. if (frameworkRuntime.canStartTask(distributedMode, taskId)) { return frameworkRuntime.constructClusterSpec(taskId); } - return null; } diff --git a/tony-core/src/main/java/com/linkedin/tony/FrameworkRuntime.java b/tony-core/src/main/java/com/linkedin/tony/FrameworkRuntime.java index 0bcfe0b9..6e6e43d6 100644 --- a/tony-core/src/main/java/com/linkedin/tony/FrameworkRuntime.java +++ b/tony-core/src/main/java/com/linkedin/tony/FrameworkRuntime.java @@ -19,6 +19,7 @@ import org.apache.hadoop.conf.Configuration; +import com.linkedin.tony.runtime.HorovodRuntime; import com.linkedin.tony.runtime.MXNetRuntime; import com.linkedin.tony.runtime.PyTorchRuntime; import com.linkedin.tony.runtime.StandaloneRuntime; @@ -35,6 +36,8 @@ static FrameworkRuntime get(TonyConfigurationKeys.FrameworkType frameworkType) { return new PyTorchRuntime(); case MXNET: return new MXNetRuntime(); + case HOROVOD: + return new HorovodRuntime(); case STANDALONE: return new StandaloneRuntime(); default: diff --git a/tony-core/src/main/java/com/linkedin/tony/TaskExecutor.java b/tony-core/src/main/java/com/linkedin/tony/TaskExecutor.java index b75fe52c..c5d9750e 100644 --- a/tony-core/src/main/java/com/linkedin/tony/TaskExecutor.java +++ b/tony-core/src/main/java/com/linkedin/tony/TaskExecutor.java @@ -168,6 +168,7 @@ private static TaskExecutor createExecutor() throws Exception { TimeUnit.MILLISECONDS); executor.setupPorts(); + executor.clusterSpec = executor.registerAndGetClusterSpec(); if (executor.clusterSpec == null) { diff --git a/tony-core/src/main/java/com/linkedin/tony/TonyConfigurationKeys.java b/tony-core/src/main/java/com/linkedin/tony/TonyConfigurationKeys.java index be761bbb..a55565dc 100644 --- a/tony-core/src/main/java/com/linkedin/tony/TonyConfigurationKeys.java +++ b/tony-core/src/main/java/com/linkedin/tony/TonyConfigurationKeys.java @@ -299,4 +299,10 @@ public static String getContainerDockerMountKey() { // Configurations that can take multiple values. public static final List MULTI_VALUE_CONF = Collections.unmodifiableList( Arrays.asList(CONTAINER_LAUNCH_ENV, EXECUTION_ENV, getContainerResourcesKey())); + + // Local testing horovod driver + public static final String TEST_HOROVOD_FAIL_ENABLE_KEY = TONY_APPLICATION_PREFIX + "test.horovod-driver-fail-enable"; + public static final boolean DEFAULT_TEST_HOROVOD_FAIL = false; + public static final String IN_TEST_HOROVOD_MODE = TONY_APPLICATION_PREFIX + "test.horovod-test-mode-enable"; + public static final boolean DEFAULT_IN_TEST_HOROVOD_MODE = false; } diff --git a/tony-core/src/main/java/com/linkedin/tony/horovod/DriverCallbackInfo.java b/tony-core/src/main/java/com/linkedin/tony/horovod/DriverCallbackInfo.java new file mode 100644 index 00000000..0084ef0b --- /dev/null +++ b/tony-core/src/main/java/com/linkedin/tony/horovod/DriverCallbackInfo.java @@ -0,0 +1,47 @@ +/** + * Copyright 2021 LinkedIn Corporation. All rights reserved. Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.tony.horovod; + +import java.util.List; + +public class DriverCallbackInfo { + private String port; + private String host; + private List slotInfos; + + public DriverCallbackInfo() { + // ignore + } + + public DriverCallbackInfo(String port, String host, List slotInfos) { + this.port = port; + this.host = host; + this.slotInfos = slotInfos; + } + + public String getPort() { + return port; + } + + public void setPort(String port) { + this.port = port; + } + + public String getHost() { + return host; + } + + public void setHost(String host) { + this.host = host; + } + + public List getSlotInfos() { + return slotInfos; + } + + public void setSlotInfos(List slotInfos) { + this.slotInfos = slotInfos; + } +} diff --git a/tony-core/src/main/java/com/linkedin/tony/horovod/HorovodClusterSpec.java b/tony-core/src/main/java/com/linkedin/tony/horovod/HorovodClusterSpec.java new file mode 100644 index 00000000..53fae716 --- /dev/null +++ b/tony-core/src/main/java/com/linkedin/tony/horovod/HorovodClusterSpec.java @@ -0,0 +1,57 @@ +/** + * Copyright 2021 LinkedIn Corporation. All rights reserved. Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.tony.horovod; + +import java.util.List; + +public class HorovodClusterSpec { + private List slotInfos; + private String port; + private String amHost; + private List sameHostTaskIndexList; + + public HorovodClusterSpec() { + // ignore + } + + public HorovodClusterSpec(List slotInfos, String port, String amHost, List sameHostTaskIndexList) { + this.slotInfos = slotInfos; + this.port = port; + this.amHost = amHost; + this.sameHostTaskIndexList = sameHostTaskIndexList; + } + + public List getSlotInfos() { + return slotInfos; + } + + public void setSlotInfos(List slotInfos) { + this.slotInfos = slotInfos; + } + + public String getPort() { + return port; + } + + public void setPort(String port) { + this.port = port; + } + + public String getAmHost() { + return amHost; + } + + public void setAmHost(String amHost) { + this.amHost = amHost; + } + + public List getSameHostTaskIndexList() { + return sameHostTaskIndexList; + } + + public void setSameHostTaskIndexList(List sameHostTaskIndexList) { + this.sameHostTaskIndexList = sameHostTaskIndexList; + } +} diff --git a/tony-core/src/main/java/com/linkedin/tony/horovod/HorovodDriver.java b/tony-core/src/main/java/com/linkedin/tony/horovod/HorovodDriver.java new file mode 100644 index 00000000..5c14f4a7 --- /dev/null +++ b/tony-core/src/main/java/com/linkedin/tony/horovod/HorovodDriver.java @@ -0,0 +1,283 @@ +/** + * Copyright 2021 LinkedIn Corporation. All rights reserved. Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.tony.horovod; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.time.Duration; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import org.apache.commons.io.FileUtils; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.codehaus.jackson.map.ObjectMapper; + +import com.google.common.annotations.VisibleForTesting; +import com.google.gson.Gson; +import com.google.gson.reflect.TypeToken; +import com.linkedin.tony.util.Utils; + +import static java.util.Objects.requireNonNull; + +/** + * Introduce HorovodDriver class to be a Horovod driver, which is similar to the role of Horovod launcher(gloo_run). + * The responsibilities of the HorovodDriver are as follows + * 1. Start the rendezvous server by using the built-in python script (horovod_driver.py in resource folder). + * 2. Get the python script's process output and parse it, and pass it to Horovod runtime + * 3. Monitor rendezvous server's python process + */ +public class HorovodDriver { + private static final Log LOG = LogFactory.getLog(HorovodDriver.class); + private static final Path DRIVER_SCRIPT_PATH = requireNonNull(createDriverScripPath()); + private static final String FAKE_SERVER_PORT = "9999"; + private static final String DRIVER_PYTHON_SCRIPT_NAME = "horovod_driver.py"; + private static final String DRIVER_TMP_FOLDER_NAME = "horovod_driver"; + public static final String PORT_FILE_NAME_SUFFIX = "____HOROVOD_RENDEZVOUS_SERVER____"; + private static boolean inTestMode = false; + private static boolean failInTestMode = false; + + // TODO: 4/10/21 Monitor task process exit. Once exit, it should throw exception. + private final Process taskProcess; + private final int port; + private final List slotInfoList; + + private HorovodDriver(Process taskProcess, int port, List slotInfos) { + this.taskProcess = taskProcess; + this.port = port; + this.slotInfoList = slotInfos; + } + + public List getSlotInfoList() { + return slotInfoList; + } + + public int getPort() { + return port; + } + + @VisibleForTesting + protected static Path createDriverScripPath() { + ClassLoader loader = Thread.currentThread().getContextClassLoader(); + final String driverScript = DRIVER_PYTHON_SCRIPT_NAME; + try { + Path tempDir = Files.createTempDirectory(DRIVER_TMP_FOLDER_NAME); + tempDir.toFile().deleteOnExit(); + try (InputStream stream = loader.getResourceAsStream(driverScript)) { + Files.copy(stream, Paths.get(tempDir.toAbsolutePath().toString(), driverScript)); + } + return Paths.get(tempDir.toAbsolutePath().toString(), driverScript); + } catch (Exception e) { + LOG.info(e); + return null; + } + } + + public synchronized static HorovodDriver create(String workerList) throws Exception { + reset(); + return startRendezvousServer(workerList); + } + + @VisibleForTesting + protected static void reset() throws IOException { + // remove existed port files. + assert DRIVER_SCRIPT_PATH != null; + Path parentPath = DRIVER_SCRIPT_PATH.getParent(); + assert parentPath != null; + + if (!existPortFile(parentPath)) { + return; + } + + File[] files = parentPath.toFile().listFiles(); + if (files == null) { + return; + } + + Arrays.stream(files).filter(file -> file.getName().endsWith(PORT_FILE_NAME_SUFFIX)) + .forEach(file -> file.delete()); + } + + /** + * @return Pair, left is Rendezvous server port, right is SlotInfo. + * @throws IOException + * @param taskProcess + */ + private static Pair> waitTillServerStarted(final Process taskProcess) throws Exception { + assert DRIVER_SCRIPT_PATH != null; + Path parentPath = DRIVER_SCRIPT_PATH.getParent(); + assert parentPath != null; + + int checkCount = 0; + int maxCheckCount = 5; + Duration checkInterval = Duration.ofSeconds(2); + + while (!existPortFile(parentPath) && (checkCount++) < maxCheckCount) { + if (taskProcess != null && !taskProcess.isAlive()) { + throw new Exception("Horovod Driver python process has finished, exit code: " + taskProcess.exitValue()); + } + + try { + LOG.info("Rendezvous server don't start, sleep for " + checkInterval.getSeconds() + " secs"); + Thread.sleep(checkInterval.toMillis()); + } catch (Exception e) { + LOG.warn(e); + } + } + + if (checkCount > maxCheckCount) { + LOG.error("Timeout of starting horovod driver."); + throw new Exception("Errors on starting horovod driver within the fixed time."); + } + + if (taskProcess != null && !taskProcess.isAlive()) { + String msg = "Driver python process has ended abnormally, exit code: " + taskProcess.exitValue(); + LOG.error(msg); + throw new Exception(msg); + } + return getServerInfo(parentPath); + } + + @VisibleForTesting + protected static Pair> getServerInfo(Path parentPath) throws IOException { + int port = -1; + File parentFile = parentPath.toFile(); + requireNonNull(parentFile); + File[] files = parentFile.listFiles(); + if (files == null) { + return Pair.of(port, null); + } + + for (File file : files) { + String fileName = file.getName(); + if (fileName.endsWith(PORT_FILE_NAME_SUFFIX)) { + int tempIndex = fileName.indexOf(PORT_FILE_NAME_SUFFIX); + port = Integer.parseInt(fileName.substring(0, tempIndex)); + String fileContent = FileUtils.readFileToString(file); + // TODO: 4/10/21 fast fail when file content is empty. + LOG.info("Horovod rendezvous server slot info: \n" + fileContent); + List slotInfoList = new Gson().fromJson(fileContent, + new TypeToken>() { }.getType()); + return Pair.of(port, slotInfoList); + } + } + LOG.info("Still no starting horovod rendezvous server."); + return Pair.of(port, null); + } + + private static boolean existPortFile(Path parentPath) throws IOException { + return getServerInfo(parentPath).getLeft() != -1 ? true : false; + } + + private static HorovodDriver startRendezvousServer(String workerlist) throws Exception { + // todo: Precheck python version >= 3.6 (required by Horovod) + String driverProcessCommand = String.format("python %s -w %s", DRIVER_SCRIPT_PATH, workerlist); + if (inTestMode) { + driverProcessCommand += " -t " + " -p " + FAKE_SERVER_PORT; + + if (failInTestMode) { + driverProcessCommand += " -f"; + } + } + + ProcessBuilder taskProcessBuilder = new ProcessBuilder("bash", "-c", driverProcessCommand); + taskProcessBuilder.redirectError(ProcessBuilder.Redirect.INHERIT); + taskProcessBuilder.redirectOutput(ProcessBuilder.Redirect.INHERIT); + + LOG.info("Starting python's Horovod driver cmd: " + driverProcessCommand); + Process taskProcess = taskProcessBuilder.start(); + Pair> serverInfo = waitTillServerStarted(taskProcess); + return new HorovodDriver(taskProcess, serverInfo.getLeft(), serverInfo.getRight()); + } + + public void close() { + if (taskProcess != null) { + killProcess(taskProcess); + } + + try { + reset(); + } catch (IOException e) { + LOG.error("Errors on cleaning up driver tmp files.", e); + } + } + + public int waitFor(long timeout) throws InterruptedException { + if (timeout <= 0) { + this.taskProcess.waitFor(); + } else { + this.taskProcess.waitFor(timeout, TimeUnit.MICROSECONDS); + } + + return this.taskProcess.exitValue(); + } + + public int waitFor() throws InterruptedException { + return waitFor(-1); + } + + private static void killProcess(Process taskProcess) { + if (!taskProcess.isAlive()) { + return; + } + + LOG.info("Killing the Horovod driver python process.."); + taskProcess.destroy(); + + int checkCount = 0; + int maxCheckCount = 10; + while (taskProcess.isAlive() && (checkCount++) < maxCheckCount) { + try { + Thread.sleep(Duration.ofSeconds(1).toMillis()); + } catch (InterruptedException e) { + LOG.info(e); + } + } + + if (taskProcess.isAlive()) { + LOG.info("Killing the Horovod driver python process forcibly..."); + taskProcess.destroyForcibly(); + } + + LOG.info("Successfully killed the Horovod driver python process"); + } + + public static void setInTest() { + HorovodDriver.inTestMode = true; + } + + public static String getFakeServerPort() { + return FAKE_SERVER_PORT; + } + + public static void setTaskFailInTestMode() { + HorovodDriver.failInTestMode = true; + } + + public static void removeTaskFailInTestMode() { + HorovodDriver.failInTestMode = false; + } + + public String getCallbackInfo() throws IOException { + DriverCallbackInfo callbackInfo = new DriverCallbackInfo(String.valueOf(port), + Utils.getCurrentHostName(), slotInfoList); + return new ObjectMapper().writeValueAsString(callbackInfo); + } + + public int getExitCode() { + if (!taskProcess.isAlive()) { + return taskProcess.exitValue(); + } + + LOG.error("Task process is still alive."); + return -1; + } +} diff --git a/tony-core/src/main/java/com/linkedin/tony/horovod/SlotInfo.java b/tony-core/src/main/java/com/linkedin/tony/horovod/SlotInfo.java new file mode 100644 index 00000000..2e85e2f5 --- /dev/null +++ b/tony-core/src/main/java/com/linkedin/tony/horovod/SlotInfo.java @@ -0,0 +1,98 @@ +/** + * Copyright 2021 LinkedIn Corporation. All rights reserved. Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.tony.horovod; + +/** + * Introduce SlotInfo class to wrap consensus info, which is needed by Horovod workers. + * SlotInfo is provided by Horovod Driver, which collect built-in python script process's output. + * SlotInfo including info is like as follows, more details can be found on horovod-on-tony proposal + * + * hostname localhost:98 + * rank 0 + * localRank 0 + * crossRank 0 + * size 4 + * localSize 2 + * crossSize 2 + * + */ +public class SlotInfo { + private String hostname; + private int rank; + private int localRank; + private int crossRank; + private int size; + private int localSize; + private int crossSize; + + public SlotInfo() { + } + + public SlotInfo(String hostname, int rank, int localRank, int crossRank, int size, int localSize, int crossSize) { + this.hostname = hostname; + this.rank = rank; + this.localRank = localRank; + this.crossRank = crossRank; + this.size = size; + this.localSize = localSize; + this.crossSize = crossSize; + } + + public String getHostname() { + return hostname; + } + + public void setHostname(String hostname) { + this.hostname = hostname; + } + + public int getRank() { + return rank; + } + + public void setRank(int rank) { + this.rank = rank; + } + + public int getLocalRank() { + return localRank; + } + + public void setLocalRank(int localRank) { + this.localRank = localRank; + } + + public int getCrossRank() { + return crossRank; + } + + public void setCrossRank(int crossRank) { + this.crossRank = crossRank; + } + + public int getSize() { + return size; + } + + public void setSize(int size) { + this.size = size; + } + + public int getLocalSize() { + return localSize; + } + + public void setLocalSize(int localSize) { + this.localSize = localSize; + } + + public int getCrossSize() { + return crossSize; + } + + public void setCrossSize(int crossSize) { + this.crossSize = crossSize; + } +} diff --git a/tony-core/src/main/java/com/linkedin/tony/rpc/RegisterCallbackInfoRequest.java b/tony-core/src/main/java/com/linkedin/tony/rpc/RegisterCallbackInfoRequest.java index bbf962b1..70bf9d93 100644 --- a/tony-core/src/main/java/com/linkedin/tony/rpc/RegisterCallbackInfoRequest.java +++ b/tony-core/src/main/java/com/linkedin/tony/rpc/RegisterCallbackInfoRequest.java @@ -20,4 +20,4 @@ public interface RegisterCallbackInfoRequest { void setTaskId(String taskId); String getCallbackInfo(); void setCallbackInfo(String callbackInfo); -} \ No newline at end of file +} diff --git a/tony-core/src/main/java/com/linkedin/tony/rpc/TensorFlowCluster.java b/tony-core/src/main/java/com/linkedin/tony/rpc/TensorFlowCluster.java index f64bb4da..9805b32d 100644 --- a/tony-core/src/main/java/com/linkedin/tony/rpc/TensorFlowCluster.java +++ b/tony-core/src/main/java/com/linkedin/tony/rpc/TensorFlowCluster.java @@ -34,5 +34,4 @@ RegisterTensorBoardUrlResponse registerTensorBoardUrl(RegisterTensorBoardUrlRequ HeartbeatResponse taskExecutorHeartbeat(HeartbeatRequest request) throws YarnException, IOException; Empty registerCallbackInfo(RegisterCallbackInfoRequest request) throws YarnException, IOException; - } diff --git a/tony-core/src/main/java/com/linkedin/tony/runtime/HorovodRuntime.java b/tony-core/src/main/java/com/linkedin/tony/runtime/HorovodRuntime.java new file mode 100644 index 00000000..a0159be6 --- /dev/null +++ b/tony-core/src/main/java/com/linkedin/tony/runtime/HorovodRuntime.java @@ -0,0 +1,306 @@ +/* + * Copyright 2021 LinkedIn Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.linkedin.tony.runtime; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.apache.commons.lang.StringUtils; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.api.records.FinalApplicationStatus; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.annotations.VisibleForTesting; +import com.google.gson.Gson; +import com.linkedin.tony.Constants; +import com.linkedin.tony.TaskExecutor; +import com.linkedin.tony.TonyConfigurationKeys; +import com.linkedin.tony.horovod.DriverCallbackInfo; +import com.linkedin.tony.horovod.HorovodClusterSpec; +import com.linkedin.tony.horovod.HorovodDriver; +import com.linkedin.tony.horovod.SlotInfo; +import com.linkedin.tony.tensorflow.TonySession; +import com.linkedin.tony.util.Utils; + +import static com.linkedin.tony.TonyConfigurationKeys.DEFAULT_IN_TEST_HOROVOD_MODE; +import static com.linkedin.tony.TonyConfigurationKeys.DEFAULT_TEST_HOROVOD_FAIL; +import static com.linkedin.tony.TonyConfigurationKeys.DistributedMode.GANG; +import static com.linkedin.tony.TonyConfigurationKeys.IN_TEST_HOROVOD_MODE; +import static com.linkedin.tony.TonyConfigurationKeys.TEST_HOROVOD_FAIL_ENABLE_KEY; + +public class HorovodRuntime extends MLGenericRuntime { + private static final String DRIVER = "driver"; + private static final String WORKER = "worker"; + + private volatile boolean isDriverReady = false; + + private List workerSlotMetaInfo; + private String rendezvServerPort; + private String rendezvServerHost; + + private boolean isTestMode = false; + + @Override + public String constructClusterSpec(String taskId) throws IOException { + assert session != null; + + TonySession.TonyTask tonyTask = session.getTask(taskId); + String taskHost = tonyTask.getHost(); + + List sameHostIndexCollection = new ArrayList<>(); + String workerList = buildWorkerList(session, taskHost, sameHostIndexCollection); + + log.info("Horovod Worker host list: " + workerList); + Collections.sort(sameHostIndexCollection); + log.info("Same host name task index collection: " + sameHostIndexCollection); + + if (isDriverRole(taskId)) { + log.info("starting Horovod Driver, worker list: " + workerList); + return workerList; + } + + if (!isDriverReady) { + log.error("Horovod driver is not ready, it shouldn't return cluster spec to worker."); + return null; + } + + log.info("Starting Horovod worker, task id: " + taskId); + // when task role is worker, it will return horovod cluster spec. + HorovodClusterSpec spec = new HorovodClusterSpec( + workerSlotMetaInfo, + rendezvServerPort, + rendezvServerHost, + sameHostIndexCollection + ); + ObjectMapper objectMapper = new ObjectMapper(); + return objectMapper.writeValueAsString(spec); + } + + private boolean isDriverRole(String taskId) { + assert session != null; + + if (DRIVER.equals(session.getTask(taskId).getJobName())) { + return true; + } + + return false; + } + + @VisibleForTesting + public String buildWorkerList(TonySession session, String taskHost, List sameHostIndexCollection) { + assert session != null; + + Map hostNumProcMap = new HashMap<>(); + session.getTonyTasks().values().stream() + .flatMap(tasks -> Arrays.stream(tasks)) + .filter(task -> task != null && !DRIVER.equals(task.getJobName())) + .forEach(task -> { + int numProc = hostNumProcMap.getOrDefault(task.getHost(), 0); + hostNumProcMap.put(task.getHost(), ++numProc); + + if (task.getHost().equals(taskHost)) { + sameHostIndexCollection.add(Integer.parseInt(task.getTaskIndex())); + } + }); + + String workerList = StringUtils.join( + hostNumProcMap.entrySet() + .stream() + .map(entry -> entry.getKey() + ":" + entry.getValue()) + .collect(Collectors.toList()), + "," + ); + return workerList; + } + + + @Override + public boolean receiveTaskCallbackInfo(String taskId, String callbackInfo) { + assert session != null; + log.info("Receiving horovod driver call back info..."); + + if (!isDriverRole(taskId)) { + log.error("Accept call back info from not driver task executor."); + return false; + } + + DriverCallbackInfo driverCallbackInfo = new Gson().fromJson(callbackInfo, DriverCallbackInfo.class); + this.workerSlotMetaInfo = driverCallbackInfo.getSlotInfos(); + this.rendezvServerPort = driverCallbackInfo.getPort(); + this.rendezvServerHost = driverCallbackInfo.getHost(); + + this.isDriverReady = true; + + return true; + } + + @Override + public boolean canStartTask(TonyConfigurationKeys.DistributedMode distributedMode, String taskId) { + assert session != null; + + if (GANG != distributedMode) { + setAppFailed("Horovod don't support " + distributedMode + " distributed mode."); + return false; + } + + int numExpectedTasks = session.getNumExpectedTasks(); + + if (session.getNumRegisteredTasks() != numExpectedTasks) { + printTasksPeriodically(); + return false; + } + + if (isDriverRole(taskId)) { + return true; + } + + // check driver is ready? + if (!isDriverReady) { + log.info("Horovod driver is not ready."); + return false; + } + + return true; + } + + @Override + public boolean validateAndUpdateConfig(Configuration tonyConf) { + // inject driver conf and make it untracked. + tonyConf.set("tony.driver.instances", "1"); + tonyConf.set("tony.driver.vcores", "1"); + tonyConf.set("tony.application.untracked.jobtypes", "driver"); + return true; + } + + private void setHorovodRunEnv(TaskExecutor executor, HorovodClusterSpec horovodClusterSpec, + int taskIndex, String currentHostName) { + if (isTestMode) { + currentHostName = "localhost"; + } + String rendezvPort = horovodClusterSpec.getPort(); + String rendezvHost = horovodClusterSpec.getAmHost(); + log.info("Horovod rendezvous server host: " + rendezvHost + ", port: " + rendezvPort); + + executor.getShellEnv().put("HOROVOD_CONTROLLER", "gloo"); + executor.getShellEnv().put("HOROVOD_CPU_OPERATIONS", "gloo"); + executor.getShellEnv().put("HOROVOD_GLOO_TIMEOUT_SECONDS", "2000"); + executor.getShellEnv().put("HOROVOD_GLOO_RENDEZVOUS_PORT", String.valueOf(rendezvPort)); + executor.getShellEnv().put("HOROVOD_GLOO_RENDEZVOUS_ADDR", rendezvHost); + + List localRankSortList = new ArrayList<>(); + for (SlotInfo slotInfo : horovodClusterSpec.getSlotInfos()) { + String hostName = slotInfo.getHostname(); + if (!hostName.equals(currentHostName)) { + continue; + } + + localRankSortList.add(slotInfo); + Collections.sort(localRankSortList, Comparator.comparingInt(SlotInfo::getLocalRank)); + } + + int seqIndex = horovodClusterSpec.getSameHostTaskIndexList().indexOf(taskIndex); + SlotInfo assignSlotInfo = localRankSortList.get(seqIndex); + + log.info("TaskIndex: " + taskIndex + ", host: " + currentHostName + ", horovod local rank: " + + assignSlotInfo.getLocalRank()); + + log.info("Setting Horovod runtime env..."); + executor.getShellEnv().put("HOROVOD_CROSS_RANK", String.valueOf(assignSlotInfo.getCrossRank())); + executor.getShellEnv().put("HOROVOD_CROSS_SIZE", String.valueOf(assignSlotInfo.getCrossSize())); + executor.getShellEnv().put("HOROVOD_LOCAL_RANK", String.valueOf(assignSlotInfo.getLocalRank())); + executor.getShellEnv().put("HOROVOD_LOCAL_SIZE", String.valueOf(assignSlotInfo.getLocalSize())); + executor.getShellEnv().put("HOROVOD_SIZE", String.valueOf(assignSlotInfo.getSize())); + executor.getShellEnv().put("HOROVOD_RANK", String.valueOf(assignSlotInfo.getRank())); + + executor.getShellEnv().put("HOROVOD_HOSTNAME", assignSlotInfo.getHostname()); + } + + private void setAppFailed(String errorMsg) { + session.setFinalStatus(FinalApplicationStatus.FAILED, errorMsg); + session.setTrainingFinished(); + } + + // ===================For task executor======================= + + protected void buildTaskEnv(TaskExecutor executor) throws Exception { + assert session == null; + + log.info("Setting TonY task executor basic env..."); + Map executorShellEnv = executor.getShellEnv(); + executorShellEnv.put(Constants.JOB_NAME, String.valueOf(executor.getJobName())); + executorShellEnv.put(Constants.TASK_INDEX, String.valueOf(executor.getTaskIndex())); + executorShellEnv.put(Constants.TASK_NUM, String.valueOf(executor.getNumTasks())); + executorShellEnv.put(Constants.DISTRUBUTED_MODE_NAME, executor.getDistributedMode().name()); + + if (DRIVER.equals(executor.getJobName())) { + log.info("Task is Horovod driver, no need to set extra env."); + return; + } + log.info("Setting up Horovod worker..."); + + // cluster spec like: h1:1,h2:2,h3:1 + HorovodClusterSpec horovodClusterSpec = + Utils.parseClusterSpecForHorovod(executor.getClusterSpec()); + setHorovodRunEnv(executor, horovodClusterSpec, executor.getTaskIndex(), + Utils.getCurrentHostName()); + } + + @Override + public int run(TaskExecutor executor) throws Exception { + assert session == null; + + setInTestMode(executor); + buildTaskEnv(executor); + if (DRIVER.equals(executor.getJobName())) { + HorovodDriver driver = HorovodDriver.create(executor.getClusterSpec()); + String callBackInfo = driver.getCallbackInfo(); + log.info("Horovod driver call back to AM: \n" + callBackInfo); + String taskId = executor.getJobName() + ":" + executor.getTaskIndex(); + executor.callbackInfoToAM(taskId, callBackInfo); + + log.info("Horovod driver has started. It will end when all workers finished."); + int exitCode = driver.waitFor(); + return exitCode; + } + + return this.executorPythonShell(executor); + } + + private void setInTestMode(TaskExecutor executor) { + assert session == null; + + Configuration tonyConf = executor.getTonyConf(); + boolean isInTestMode = tonyConf.getBoolean(IN_TEST_HOROVOD_MODE, DEFAULT_IN_TEST_HOROVOD_MODE); + if (isInTestMode) { + HorovodDriver.setInTest(); + isTestMode = true; + } + + boolean setFailedInTest = tonyConf.getBoolean(TEST_HOROVOD_FAIL_ENABLE_KEY, DEFAULT_TEST_HOROVOD_FAIL); + if (setFailedInTest) { + HorovodDriver.setInTest(); + HorovodDriver.setTaskFailInTestMode(); + isTestMode = true; + } + } +} \ No newline at end of file diff --git a/tony-core/src/main/java/com/linkedin/tony/runtime/PyTorchRuntime.java b/tony-core/src/main/java/com/linkedin/tony/runtime/PyTorchRuntime.java index f9f7b7fc..f007bbb0 100644 --- a/tony-core/src/main/java/com/linkedin/tony/runtime/PyTorchRuntime.java +++ b/tony-core/src/main/java/com/linkedin/tony/runtime/PyTorchRuntime.java @@ -31,7 +31,6 @@ public void buildTaskEnv(TaskExecutor executor) throws Exception { throw new RuntimeException("Errors on getting init method."); } log.info("Init method is: " + initMethod); - Map executorShellEnv = executor.getShellEnv(); executorShellEnv.put(Constants.INIT_METHOD, initMethod); executorShellEnv.put(Constants.RANK, String.valueOf(executor.getTaskIndex())); diff --git a/tony-core/src/main/java/com/linkedin/tony/tensorflow/TonySession.java b/tony-core/src/main/java/com/linkedin/tony/tensorflow/TonySession.java index 9b605155..56769c47 100644 --- a/tony-core/src/main/java/com/linkedin/tony/tensorflow/TonySession.java +++ b/tony-core/src/main/java/com/linkedin/tony/tensorflow/TonySession.java @@ -4,6 +4,7 @@ */ package com.linkedin.tony.tensorflow; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.linkedin.tony.Constants; import com.linkedin.tony.TonyClient; @@ -284,6 +285,10 @@ public void onTaskCompleted(String jobName, String jobIndex, int exitCode, Strin } } + public void setTrainingFinished() { + this.trainingFinished = true; + } + /** * Update the status of a session and set exit code if a session is completed. */ @@ -418,6 +423,13 @@ public Builder setTonyConf(Configuration tonyConf) { } } + @VisibleForTesting + public TonyTask buildTonyTask(String jobName, String taskIndex, String host) { + TonyTask task = new TonyTask(jobName, taskIndex, -1, -1); + task.setHostPort(host + ":9999"); + return task; + } + /** * A TonyTask represents a task job executed in the workers. */ @@ -573,6 +585,10 @@ public TonyTask getTask(String taskId) { } } + public Configuration getTonyConf() { + return tonyConf; + } + public void addRegisteredTask(String taskId) { registeredTasks.add(taskId); } diff --git a/tony-core/src/main/java/com/linkedin/tony/util/Utils.java b/tony-core/src/main/java/com/linkedin/tony/util/Utils.java index b7c5efbb..f0c50d71 100644 --- a/tony-core/src/main/java/com/linkedin/tony/util/Utils.java +++ b/tony-core/src/main/java/com/linkedin/tony/util/Utils.java @@ -69,6 +69,7 @@ import com.linkedin.tony.LocalizableResource; import com.linkedin.tony.TFConfig; import com.linkedin.tony.TonyConfigurationKeys; +import com.linkedin.tony.horovod.HorovodClusterSpec; import com.linkedin.tony.rpc.TaskInfo; import com.linkedin.tony.tensorflow.JobContainerRequest; @@ -822,5 +823,12 @@ public static Map linksToBeDisplayedOnPage(String jobId) { return titleAndLinks; } + public static HorovodClusterSpec parseClusterSpecForHorovod(String clusterSpec) throws IOException { + ObjectMapper objectMapper = new ObjectMapper(); + HorovodClusterSpec spec = + objectMapper.readValue(clusterSpec, new TypeReference() { }); + return spec; + } + private Utils() { } } diff --git a/tony-core/src/main/resources/horovod_driver.py b/tony-core/src/main/resources/horovod_driver.py new file mode 100644 index 00000000..6aa33dd6 --- /dev/null +++ b/tony-core/src/main/resources/horovod_driver.py @@ -0,0 +1,188 @@ +# +# Copyright 2021 LinkedIn Corporation. All rights reserved. Licensed under the +# BSD-2 Clause license. See LICENSE in the project root for license information. +# +import os +import logging +import time + +from optparse import OptionParser +import sys +import signal +import json + +try: + import horovod.tensorflow as hvd + from horovod.runner import gloo_run + from horovod.runner.http.http_server import RendezvousServer + from horovod.runner.common.util.hosts import get_host_assignments, parse_hosts + from horovod.runner.elastic import discovery + from horovod.runner.elastic.rendezvous import create_rendezvous_handler + from horovod.runner.elastic.driver import ElasticDriver +except Exception as e: + logging.error("Horovod is not installed. See README for instructions to install it") + pass + +PORT_FILE_NAME_SUFFIX = "____HOROVOD_RENDEZVOUS_SERVER____" + +def elastic_driver_fn(): + pass + + +def static_driver_fn(): + global_rendezv = RendezvousServer(verbose=1) + global_rendezv_port = global_rendezv.start() + print("Rendezvous server started, port: " + str(global_rendezv_port)) + + # worker_list = "localhost:1" + hosts = parse_hosts(worker_list) + host_alloc_plan = get_host_assignments(hosts, 1) + + global_rendezv.init(host_alloc_plan) + return (global_rendezv_port, host_alloc_plan) + +def _build_fake_host_plan(): + return [ + { + "hostname": "localhost", + "rank": "0", + "localRank": "0", + "crossRank": "0", + "size": "2", + "localSize": "2", + "crossSize": "1" + }, + { + "hostname": "localhost", + "rank": "1", + "localRank": "1", + "crossRank": "1", + "size": "2", + "localSize": "2", + "crossSize": "1" + } + ] + +def _get_host_plan_json(host_alloc_plan): + if host_alloc_plan == None: + return json.dumps(_build_fake_host_plan()) + + hosts = [] + for plan in host_alloc_plan: + hosts.append({ + "hostname": plan.hostname, + "rank": plan.rank, + "localRank": plan.local_rank, + "crossRank": plan.cross_rank, + "size": plan.size, + "localSize": plan.local_size, + "crossSize": plan.cross_size + }) + print("Host alloc plan: \n" + json.dumps(hosts)) + return json.dumps(hosts) + + +def set_option(): + parser = OptionParser() + parser.add_option( + "-a", "--num_proc", dest="num_process", type="str", help="number process of training", default="1") + parser.add_option( + "-w", "--worker_list", dest="worker_list", type="str", help="worker list" + ) + parser.add_option( + "-e", action="store_true", help="enable elastic training.", dest="enable_elastic", default=False + ) + parser.add_option( + "-t", action="store_true", help="is in test mode", dest="is_in_test_mode", default=False + ) + parser.add_option( + "-p", "--fake_port", dest="fake_port", type="str", help="fake server port for TonY unit test" + ) + parser.add_option( + "-f", action="store_true", help="fast fail in test mode for TonY unit test", dest="is_fast_fail", default=False + ) + (options, args) = parser.parse_args(sys.argv) + + global worker_list + worker_list = options.worker_list + + global enable_elastic + enable_elastic = options.enable_elastic + print("Enable elastic: " + str(enable_elastic)) + + global is_in_test_mode + is_in_test_mode = options.is_in_test_mode + global fake_server_port + global is_fast_fail + is_fast_fail = False + if is_in_test_mode: + fake_server_port = options.fake_port + is_fast_fail = options.is_fast_fail + + +def __port_file_path(port): + path_dir = os.path.dirname(os.path.abspath(__file__)) + port_file_path = os.path.join(path_dir, str(port) + PORT_FILE_NAME_SUFFIX) + return port_file_path + + +def create_port_file(port, host_alloc_plan): + port_file = __port_file_path(port) + logging.info("Creating port file %s", port_file) + with open(__port_file_path(port), 'w') as fo: + fo.write(_get_host_plan_json(host_alloc_plan)) + logging.info("Port file for %s created", port_file) + pass + + +def delete_port_file(port): + port_file = __port_file_path(port) + logging.info("Deleting port file %s", port_file) + try: + os.remove(__port_file_path(port)) + logging.info("Port file %s deleted", port_file) + except OSError: + pass + + +def handle_exit(*args): + try: + logging.info("Closing rendezvous server...") + # todo: Close rendezvous server. + logging.info("Closed rendezvous server") + + delete_port_file(port) + except: + logging.exception("Failed to close rendezvous server") + + sys.exit(0) + + +if __name__ == '__main__': + set_option() + + # Just for Unit Test + if is_fast_fail: + sys.exit(1) + + try: + global port + if enable_elastic: + elastic_driver_fn() + else: + if is_in_test_mode: + print("In unit test mode. fake port: " + fake_server_port) + (port, host_alloc_plan) = (fake_server_port, None) + else: + (port, host_alloc_plan) = static_driver_fn() + create_port_file(port, host_alloc_plan) + except: + logging.exception("Errors on starting horovod rendezvous server.") + handle_exit() + + signal.signal(signal.SIGTERM, handle_exit) + signal.signal(signal.SIGINT, handle_exit) + signal.signal(signal.SIGILL, handle_exit) + while True: + time.sleep(10) + diff --git a/tony-core/src/main/resources/tony-default.xml b/tony-core/src/main/resources/tony-default.xml index f6db69bd..bedc1689 100644 --- a/tony-core/src/main/resources/tony-default.xml +++ b/tony-core/src/main/resources/tony-default.xml @@ -392,4 +392,14 @@ tony.secret.key changeme + + + tony.application.test.horovod-driver-fail-enable + false + + + + tony.application.test.horovod-test-mode-enable + false + diff --git a/tony-core/src/test/java/com/linkedin/tony/TestTonyE2E.java b/tony-core/src/test/java/com/linkedin/tony/TestTonyE2E.java index 382f0602..7594ef82 100644 --- a/tony-core/src/test/java/com/linkedin/tony/TestTonyE2E.java +++ b/tony-core/src/test/java/com/linkedin/tony/TestTonyE2E.java @@ -480,6 +480,42 @@ public void testTonyPSCrashShouldFailAndStopAM() throws IOException, ParseExcept Assert.assertNotNull(handler.getAppId()); } + @Test + public void testTonyHorovodDriverCrashShouldFailAndStopAM() throws ParseException, IOException { + client.init(new String[]{ + "--src_dir", "tony-core/src/test/resources/scripts", + "--hdfs_classpath", libPath, + "--container_env", Constants.SKIP_HADOOP_PATH + "=true", + "--python_venv", "tony-core/src/test/resources/test.zip", + "--conf", "tony.worker.instances=1", + "--conf", "tony.worker.command=python sleep_30.py", + "--conf", "tony.application.test.horovod-driver-fail-enable=true", + "--conf", "tony.application.framework=horovod" + }); + client.addListener(handler); + int exitCode = client.start(); + Assert.assertEquals(exitCode, -1); + client.removeListener(handler); + } + + @Test + public void testTonyHorovodShouldPass() throws ParseException, IOException { + client.init(new String[]{ + "--src_dir", "tony-core/src/test/resources/scripts", + "--hdfs_classpath", libPath, + "--container_env", Constants.SKIP_HADOOP_PATH + "=true", + "--python_venv", "tony-core/src/test/resources/test.zip", + "--executes", "python check_horovod_env.py", + "--conf", "tony.worker.instances=2", + "--conf", "tony.application.test.horovod-test-mode-enable=true", + "--conf", "tony.application.framework=horovod" + }); + client.addListener(handler); + int exitCode = client.start(); + Assert.assertEquals(exitCode, 0); + client.removeListener(handler); + } + /** * Since we are switching from passing arguments to ApplicationMaster & TaskExecutor * to passing tony configuration file. It is critical to make sure all fields in diff --git a/tony-core/src/test/java/com/linkedin/tony/horovod/TestHorovodDriver.java b/tony-core/src/test/java/com/linkedin/tony/horovod/TestHorovodDriver.java new file mode 100644 index 00000000..8460cfc3 --- /dev/null +++ b/tony-core/src/test/java/com/linkedin/tony/horovod/TestHorovodDriver.java @@ -0,0 +1,142 @@ +/** + * Copyright 2021 LinkedIn Corporation. All rights reserved. Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.tony.horovod; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.commons.io.FileUtils; +import org.apache.commons.lang3.tuple.Pair; +import org.codehaus.jackson.map.ObjectMapper; +import org.testng.Assert; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +public class TestHorovodDriver { + + @BeforeClass + public void beforeTest() { + HorovodDriver.setInTest(); + } + + /** + * It should start horovod driver successfully. + * @throws Exception + */ + @Test + public void testHorovodDriver() throws Exception { + // Fake worker list is useless, when in test mode, + // python code will return the "localhost:2" host plan + String fakeWorkerList = "localhost:2"; + + HorovodDriver driver = HorovodDriver.create(fakeWorkerList); + Assert.assertNotEquals(HorovodDriver.getFakeServerPort(), driver.getPort()); + + List slotInfoList = driver.getSlotInfoList(); + Assert.assertNotNull(slotInfoList); + Assert.assertEquals(2, slotInfoList.size()); + Assert.assertEquals(0, slotInfoList.get(0).getCrossRank()); + Assert.assertEquals(1, slotInfoList.get(0).getCrossSize()); + Assert.assertEquals(0, slotInfoList.get(0).getLocalRank()); + Assert.assertEquals(2, slotInfoList.get(0).getLocalSize()); + Assert.assertEquals(0, slotInfoList.get(0).getRank()); + Assert.assertEquals(2, slotInfoList.get(0).getSize()); + + driver.close(); + } + + /** + * When python process exit, it will throw exception. + */ + @Test + public void testHorovodDriverWhenFailed() { + try { + HorovodDriver.setTaskFailInTestMode(); + String fakeWorkerList = "localhost:2"; + HorovodDriver.create(fakeWorkerList); + Assert.fail("Should throw exception on starting driver."); + } catch (Exception e) { + // ignore. + } finally { + HorovodDriver.removeTaskFailInTestMode(); + } + } + + @Test + public void testCreateDriverScripPath() { + Path driverPath = HorovodDriver.createDriverScripPath(); + File driverFile = driverPath.toFile(); + Assert.assertNotNull(driverFile); + Assert.assertTrue(driverFile.isFile()); + + cleanupTmpFile(driverFile.getParentFile()); + } + + /** + * Test get rendezvous server info by reading specific file. + */ + @Test + public void testGetServerInfo() throws IOException { + Path driverPath = HorovodDriver.createDriverScripPath(); + Path parentPath = driverPath.getParent(); + + Pair> infoPair = HorovodDriver.getServerInfo(parentPath); + Assert.assertNotNull(infoPair); + Assert.assertEquals(-1, infoPair.getLeft().intValue()); + Assert.assertNull(infoPair.getRight()); + + // inject server info into files. + int port = 10000; + List slotInfos = buildFakeSlotInfo(); + ObjectMapper mapper = new ObjectMapper(); + String slotJson = mapper.writeValueAsString(slotInfos); + // create tmp port file. + createPortTmpFile(parentPath, port, slotJson); + + infoPair = HorovodDriver.getServerInfo(parentPath); + Assert.assertNotNull(infoPair); + Assert.assertEquals(port, infoPair.getLeft().intValue()); + List metaSlotInfos = infoPair.getRight(); + Assert.assertNotNull(metaSlotInfos); + + String metaInfoJson = mapper.writeValueAsString(metaSlotInfos); + Assert.assertNotNull(metaInfoJson); + Assert.assertEquals(slotJson, metaInfoJson); + + // clear up all folders. + cleanupTmpFile(parentPath.toFile()); + } + + private List buildFakeSlotInfo() { + List slotInfos = new ArrayList<>(); + slotInfos.add( + new SlotInfo("localhost", 0, 0, 0, 2, 2, 1) + ); + slotInfos.add( + new SlotInfo("localhost", 1, 1, 1, 2, 2, 1) + ); + return slotInfos; + } + + private void createPortTmpFile(Path parentPath, int port, String slotJson) throws IOException { + String portFileName = String.format("%d%s", port, HorovodDriver.PORT_FILE_NAME_SUFFIX); + File tmpFile = Paths.get(parentPath.toAbsolutePath().toString(), portFileName).toFile(); + FileUtils.writeStringToFile(tmpFile, slotJson); + } + + private void cleanupTmpFile(File file) { + if (file.isDirectory()) { + File[] childfiles = file.listFiles(); + Arrays.stream(childfiles).forEach(this::cleanupTmpFile); + } + boolean ok = file.delete(); + System.out.println(ok); + } +} diff --git a/tony-core/src/test/java/com/linkedin/tony/runtime/TestHorovodRuntime.java b/tony-core/src/test/java/com/linkedin/tony/runtime/TestHorovodRuntime.java new file mode 100644 index 00000000..c61b753a --- /dev/null +++ b/tony-core/src/test/java/com/linkedin/tony/runtime/TestHorovodRuntime.java @@ -0,0 +1,55 @@ +/** + * Copyright 2021 LinkedIn Corporation. All rights reserved. Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.tony.runtime; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import org.testng.Assert; +import org.testng.annotations.BeforeTest; +import org.testng.annotations.Test; + +import com.linkedin.tony.tensorflow.TonySession; + +public class TestHorovodRuntime { + private TonySession session = new TonySession(); + private HorovodRuntime runtime; + + @BeforeTest + public void before() { + Map taskMaps = session.getTonyTasks(); + + List taskList = Arrays.asList( + session.buildTonyTask("worker", "0", "localhost1"), + session.buildTonyTask("worker", "1", "localhost1"), + session.buildTonyTask("worker", "2", "localhost2"), + session.buildTonyTask("worker", "3", "localhost3"), + session.buildTonyTask("driver", "0", "localhost4") + ); + + taskMaps.put("worker", taskList.toArray(new TonySession.TonyTask[0])); + + runtime = new HorovodRuntime(); + } + + @Test + public void testBuildWorkerList() { + List sameHostIndexCollection = new ArrayList<>(); + String currenthost = "localhost1"; + String workerList = runtime.buildWorkerList(session, currenthost, sameHostIndexCollection); + Assert.assertEquals("localhost3:1,localhost2:1,localhost1:2", workerList); + Assert.assertEquals(2, sameHostIndexCollection.size()); + Assert.assertEquals("[0, 1]", sameHostIndexCollection.toString()); + + sameHostIndexCollection = new ArrayList<>(); + currenthost = "localhost2"; + workerList = runtime.buildWorkerList(session, currenthost, sameHostIndexCollection); + Assert.assertEquals("localhost3:1,localhost2:1,localhost1:2", workerList); + Assert.assertEquals(1, sameHostIndexCollection.size()); + Assert.assertEquals("[2]", sameHostIndexCollection.toString()); + } +} diff --git a/tony-core/src/test/resources/scripts/check_horovod_env.py b/tony-core/src/test/resources/scripts/check_horovod_env.py new file mode 100644 index 00000000..d76d3e9c --- /dev/null +++ b/tony-core/src/test/resources/scripts/check_horovod_env.py @@ -0,0 +1,40 @@ +# +# Copyright 2021 LinkedIn Corporation. All rights reserved. Licensed under the BSD-2 Clause license. +# See LICENSE in the project root for license information. +# +import os + +controller = os.environ['HOROVOD_CONTROLLER'] +operators = os.environ['HOROVOD_CPU_OPERATIONS'] +timeout = os.environ['HOROVOD_GLOO_TIMEOUT_SECONDS'] +rendez_port = os.environ['HOROVOD_GLOO_RENDEZVOUS_PORT'] +rendez_addr = os.environ['HOROVOD_GLOO_RENDEZVOUS_ADDR'] +cross_rank = os.environ['HOROVOD_CROSS_RANK'] +cross_size = os.environ['HOROVOD_CROSS_SIZE'] +local_rank = os.environ['HOROVOD_LOCAL_RANK'] +local_size = os.environ['HOROVOD_LOCAL_SIZE'] +size = os.environ['HOROVOD_SIZE'] +rank = os.environ['HOROVOD_RANK'] +hostname = os.environ['HOROVOD_HOSTNAME'] + +job_name = os.environ['JOB_NAME'] + +print('JOB_NAME is ' + job_name) + +print('Horovod envs are as follows:') +print('controller: ' + controller) +print('operators: ' + operators) +print('timeout: ' + timeout) +print('rendez_port: ' + rendez_port) +print('rendez_addr: ' + rendez_addr) +print('cross_rank: ' + cross_rank) +print('cross_size: ' + cross_size) +print('local_rank: ' + local_rank) +print('local_size: ' + local_size) +print('size: ' + size) +print('rank: ' + rank) +print('hostname: ' + hostname) + + +if not (controller and job_name and operators and timeout and rendez_addr and rendez_port and cross_rank and cross_size and local_rank and local_size and size and rank): + raise ValueError \ No newline at end of file diff --git a/tony-examples/horovod-on-tony/README.md b/tony-examples/horovod-on-tony/README.md new file mode 100644 index 00000000..1b5c46cd --- /dev/null +++ b/tony-examples/horovod-on-tony/README.md @@ -0,0 +1,70 @@ +### Running Examples +This example shows how to run a simple Horovod program on TonY. +Requirements: +1. Build a Docker runtime container(required Hadoop configurations) with TF2.x installed and Horovod 0.21.3+ +2. Install Hadoop 3.1.1+ + +If you don't have security enabled, you'll also need to provide a custom config file with security turned off. + +### Build a Docker runtime container +1. Prepare Dockerfile +``` +FROM ${YOUR_BASIC_HADOOP_CONTAINER_IMAGE} + +RUN pip3 install tensorflow==2.4.1 \ + && HOROVOD_WITH_GLOO=1 HOROVOD_WITH_TENSORFLOW=1 pip3 install horovod[tensorflow] +``` +2. Build image +``` +docker build -t docker.io/bigdata/horovod-test-1:v1 . +``` +3. Push to docker registry +``` +docker push docker.io/bigdata/horovod-test-1:v1 +``` + +For the instructions below, we assume this docker image has been pushed to docker registry which can be access by Hadoop nodemanager, and this image is named __docker.io/bigdata/horovod-test-1:v1__ + +### Install Hadoop 3.1.1+ +TonY only requires YARN, not HDFS. Please see the open-source documentation on how to set YARN up. + +### Config TonY job for Horovod +If your Hadoop cluster is not running with security enabled (e.g.: for local testing), you need to disable the security check. Here is a sample of the config: +``` + + + tony.worker.instances + 4 + + + tony.worker.memory + 3g + + + tony.docker.enabled + true + + + tony.docker.containers.image + docker.io/bigdata/horovod-test-1:v1 + + + tony.application.framework + horovod + + +``` + +For the instructions below, we assume this file is named __tony-test.xml__ + +### Running an example +``` +gradlew :tony-cli:build + +java -cp `hadoop classpath`:/path/to/TonY/tony-cli/build/libs/tony-cli-x.x.x-all.jar com.linkedin.tony.cli.ClusterSubmitter \ +--src_dir=/path/to/TonY/tony-examples/horovod-on-tony \ +--executes=tensorflow2_mnist.py \ +--conf_file=/path/to/tony-test.xml \ +--python_binary_path=python3 +``` + diff --git a/tony-examples/horovod-on-tony/tensorflow2_keras_mnist.py b/tony-examples/horovod-on-tony/tensorflow2_keras_mnist.py new file mode 100644 index 00000000..cb052aea --- /dev/null +++ b/tony-examples/horovod-on-tony/tensorflow2_keras_mnist.py @@ -0,0 +1,93 @@ + +# Copyright 2019 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import tensorflow as tf +import horovod.tensorflow.keras as hvd + +# Horovod: initialize Horovod. +hvd.init() + +# Horovod: pin GPU to be used to process local rank (one GPU per process) +gpus = tf.config.experimental.list_physical_devices('GPU') +for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) +if gpus: + tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU') + +(mnist_images, mnist_labels), _ = \ + tf.keras.datasets.mnist.load_data(path='mnist-%d.npz' % hvd.rank()) + +dataset = tf.data.Dataset.from_tensor_slices( + (tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32), + tf.cast(mnist_labels, tf.int64)) +) +dataset = dataset.repeat().shuffle(10000).batch(128) + +mnist_model = tf.keras.Sequential([ + tf.keras.layers.Conv2D(32, [3, 3], activation='relu'), + tf.keras.layers.Conv2D(64, [3, 3], activation='relu'), + tf.keras.layers.MaxPooling2D(pool_size=(2, 2)), + tf.keras.layers.Dropout(0.25), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(128, activation='relu'), + tf.keras.layers.Dropout(0.5), + tf.keras.layers.Dense(10, activation='softmax') +]) + +# Horovod: adjust learning rate based on number of GPUs. +scaled_lr = 0.001 * hvd.size() +opt = tf.optimizers.Adam(scaled_lr) + +# Horovod: add Horovod DistributedOptimizer. +opt = hvd.DistributedOptimizer( + opt, backward_passes_per_step=1, average_aggregated_gradients=True) + +# Horovod: Specify `experimental_run_tf_function=False` to ensure TensorFlow +# uses hvd.DistributedOptimizer() to compute gradients. +mnist_model.compile(loss=tf.losses.SparseCategoricalCrossentropy(), + optimizer=opt, + metrics=['accuracy'], + experimental_run_tf_function=False) + +callbacks = [ + # Horovod: broadcast initial variable states from rank 0 to all other processes. + # This is necessary to ensure consistent initialization of all workers when + # training is started with random weights or restored from a checkpoint. + hvd.callbacks.BroadcastGlobalVariablesCallback(0), + + # Horovod: average metrics among workers at the end of every epoch. + # + # Note: This callback must be in the list before the ReduceLROnPlateau, + # TensorBoard or other metrics-based callbacks. + hvd.callbacks.MetricAverageCallback(), + + # Horovod: using `lr = 1.0 * hvd.size()` from the very beginning leads to worse final + # accuracy. Scale the learning rate `lr = 1.0` ---> `lr = 1.0 * hvd.size()` during + # the first three epochs. See https://arxiv.org/abs/1706.02677 for details. + hvd.callbacks.LearningRateWarmupCallback(initial_lr=scaled_lr, warmup_epochs=3, verbose=1), +] + +# Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them. +if hvd.rank() == 0: + callbacks.append(tf.keras.callbacks.ModelCheckpoint('./checkpoint-{epoch}.h5')) + +# Horovod: write logs on worker 0. +verbose = 1 if hvd.rank() == 0 else 0 + +# Train the model. +# Horovod: adjust number of steps based on number of GPUs. +mnist_model.fit(dataset, steps_per_epoch=500 // hvd.size(), callbacks=callbacks, epochs=24, verbose=verbose) + diff --git a/tony-examples/horovod-on-tony/tensorflow2_mnist.py b/tony-examples/horovod-on-tony/tensorflow2_mnist.py new file mode 100644 index 00000000..f9e2f12d --- /dev/null +++ b/tony-examples/horovod-on-tony/tensorflow2_mnist.py @@ -0,0 +1,102 @@ +# Copyright 2019 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +"""A deep MNIST classifier with Tensorflow 2.x + +This example was adapted from +https://github.com/horovod/horovod/blob/master/examples/tensorflow2/tensorflow2_mnist.py + +""" + + +import tensorflow as tf +import horovod.tensorflow as hvd + +# Horovod: initialize Horovod. +hvd.init() + +# Horovod: pin GPU to be used to process local rank (one GPU per process) +gpus = tf.config.experimental.list_physical_devices('GPU') +for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) +if gpus: + tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU') + +(mnist_images, mnist_labels), _ = \ + tf.keras.datasets.mnist.load_data(path='mnist-%d.npz' % hvd.rank()) + +dataset = tf.data.Dataset.from_tensor_slices( + (tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32), + tf.cast(mnist_labels, tf.int64)) +) +dataset = dataset.repeat().shuffle(10000).batch(128) + +mnist_model = tf.keras.Sequential([ + tf.keras.layers.Conv2D(32, [3, 3], activation='relu'), + tf.keras.layers.Conv2D(64, [3, 3], activation='relu'), + tf.keras.layers.MaxPooling2D(pool_size=(2, 2)), + tf.keras.layers.Dropout(0.25), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(128, activation='relu'), + tf.keras.layers.Dropout(0.5), + tf.keras.layers.Dense(10, activation='softmax') +]) +loss = tf.losses.SparseCategoricalCrossentropy() + +# Horovod: adjust learning rate based on number of GPUs. +opt = tf.optimizers.Adam(0.001 * hvd.size()) + +checkpoint_dir = './checkpoints' +checkpoint = tf.train.Checkpoint(model=mnist_model, optimizer=opt) + + +@tf.function +def training_step(images, labels, first_batch): + with tf.GradientTape() as tape: + probs = mnist_model(images, training=True) + loss_value = loss(labels, probs) + + # Horovod: add Horovod Distributed GradientTape. + tape = hvd.DistributedGradientTape(tape) + + grads = tape.gradient(loss_value, mnist_model.trainable_variables) + opt.apply_gradients(zip(grads, mnist_model.trainable_variables)) + + # Horovod: broadcast initial variable states from rank 0 to all other processes. + # This is necessary to ensure consistent initialization of all workers when + # training is started with random weights or restored from a checkpoint. + # + # Note: broadcast should be done after the first gradient step to ensure optimizer + # initialization. + if first_batch: + hvd.broadcast_variables(mnist_model.variables, root_rank=0) + hvd.broadcast_variables(opt.variables(), root_rank=0) + + return loss_value + + +# Horovod: adjust number of steps based on number of GPUs. +for batch, (images, labels) in enumerate(dataset.take(10000 // hvd.size())): + loss_value = training_step(images, labels, batch == 0) + + if batch % 10 == 0 and hvd.local_rank() == 0: + print('Step #%d\tLoss: %.6f' % (batch, loss_value)) + +# Horovod: save checkpoints only on worker 0 to prevent other workers from +# corrupting it. +if hvd.rank() == 0: + checkpoint.save(checkpoint_dir) +