Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SlurmRay launcher and transform API for launchers #159

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/nemo_run/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@
from nemo_run.core.execution.base import (
Executor,
ExecutorMacros,
FaultTolerance,
Torchrun,
import_executor,
)
from nemo_run.core.execution.dgxcloud import DGXCloudExecutor
from nemo_run.core.execution.docker import DockerExecutor
from nemo_run.core.execution.launcher import FaultTolerance, SlurmRay, Torchrun
from nemo_run.core.execution.local import LocalExecutor
from nemo_run.core.execution.skypilot import SkypilotExecutor
from nemo_run.core.execution.slurm import SlurmExecutor
Expand Down Expand Up @@ -69,6 +68,7 @@
"SlurmExecutor",
"SSHTunnel",
"Torchrun",
"SlurmRay",
]

try:
Expand Down
57 changes: 5 additions & 52 deletions src/nemo_run/core/execution/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,67 +18,17 @@
import os
from dataclasses import asdict, dataclass, field
from string import Template
from typing import Optional, Protocol, Type, Union, runtime_checkable
from typing import Optional, Protocol, Union, runtime_checkable

import fiddle as fdl
from torchx.specs import Role
from typing_extensions import Self

from nemo_run.config import NEMORUN_HOME, ConfigurableMixin
from nemo_run.core.execution.launcher import LAUNCHER_MAP, Launcher
from nemo_run.core.packaging.base import Packager


@dataclass(kw_only=True)
class Launcher(ConfigurableMixin):
nsys_profile: bool = False
nsys_folder: str = "nsys_profile"
nsys_trace: list[str] = field(default_factory=lambda: ["nvtx", "cuda"])

def get_nsys_prefix(self, profile_dir: str) -> Optional[list[str]]:
"""Make a command prefix for nsys profiling"""
if self.nsys_profile:
profile_out_path = os.path.join(profile_dir, self.nsys_folder)
args = [
"profile",
"-s",
"none",
"-t",
",".join(self.nsys_trace),
"-o",
f"{profile_out_path}/profile_%p",
"--force-overwrite",
"true",
"--capture-range=cudaProfilerApi",
"--capture-range-end=stop",
"--cuda-graph-trace=node",
]
return args


@dataclass(kw_only=True)
class Torchrun(Launcher):
rdzv_backend: str = "c10d"
rdzv_port: int = 29500


@dataclass(kw_only=True)
class FaultTolerance(Launcher):
cfg_path: str = ""
finished_flag_file: str = ""
job_results_file: str = ""
rdzv_backend: str = "c10d"
rdzv_port: int = 29500
workload_check_interval: Optional[float] = None
initial_rank_heartbeat_timeout: Optional[float] = None
rank_heartbeat_timeout: Optional[float] = None
rank_termination_signal: Optional[str] = None
log_level: Optional[str] = None
max_restarts: Optional[int] = None


LAUNCHER_MAP: dict[str, Type[Launcher]] = {"torchrun": Torchrun, "ft": FaultTolerance}


@dataclass(kw_only=True)
class ExecutorMacros(ConfigurableMixin):
"""
Expand Down Expand Up @@ -215,6 +165,9 @@ def get_launcher_prefix(self) -> Optional[list[str]]:
os.makedirs(os.path.join(self.job_dir, launcher.nsys_folder), exist_ok=True)
return launcher.get_nsys_prefix(profile_dir=self.job_dir)

def supports_launcher_transform(self) -> bool:
return False

def package_configs(self, *cfgs: tuple[str, str]) -> list[str]:
filenames = []
basepath = os.path.join(self.job_dir, "configs")
Expand Down
128 changes: 128 additions & 0 deletions src/nemo_run/core/execution/launcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import os
from dataclasses import dataclass, field
from typing import Optional, Type

from nemo_run.config import ConfigurableMixin, Script


@dataclass(kw_only=True)
class Launcher(ConfigurableMixin):
nsys_profile: bool = False
nsys_folder: str = "nsys_profile"
nsys_trace: list[str] = field(default_factory=lambda: ["nvtx", "cuda"])

def get_nsys_prefix(self, profile_dir: str) -> Optional[list[str]]:

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note

Mixing implicit and explicit returns may indicate an error as implicit returns always return None.
"""Make a command prefix for nsys profiling"""
if self.nsys_profile:
profile_out_path = os.path.join(profile_dir, self.nsys_folder)
args = [
"profile",
"-s",
"none",
"-t",
",".join(self.nsys_trace),
"-o",
f"{profile_out_path}/profile_%p",
"--force-overwrite",
"true",
"--capture-range=cudaProfilerApi",
"--capture-range-end=stop",
"--cuda-graph-trace=node",
]
return args

def transform(self, cmd: list[str]) -> Optional[Script]: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.


@dataclass(kw_only=True)
class Torchrun(Launcher):
rdzv_backend: str = "c10d"
rdzv_port: int = 29500


@dataclass(kw_only=True)
class FaultTolerance(Launcher):
cfg_path: str = ""
finished_flag_file: str = ""
job_results_file: str = ""
rdzv_backend: str = "c10d"
rdzv_port: int = 29500
workload_check_interval: Optional[float] = None
initial_rank_heartbeat_timeout: Optional[float] = None
rank_heartbeat_timeout: Optional[float] = None
rank_termination_signal: Optional[str] = None
log_level: Optional[str] = None
max_restarts: Optional[int] = None


@dataclass(kw_only=True)
class SlurmRay(Launcher):
"""
Transforms a provided cmd into a Ray launcher bash script for SlurmExecutor.
The Ray launcher script sets up a Ray cluster on Slurm nodes, with the head node starting Ray head
and executing the provided command. Worker nodes start Ray and wait.
"""

port: int = 6379

def transform(self, cmd: list[str]) -> Optional[Script]:
"""
Transforms the provided cmd into a Ray launcher bash script for SlurmExecutor.
"""
cmd_to_run = " ".join(cmd)
# Build the Ray launcher bash script. Braces in shell variables are escaped as {{ and }}
ray_script = f"""
# Check that a command was provided.
if [ "$#" -lt 1 ]; then
echo "Usage: $0 <command>"
exit 1
fi

# Function to start the Ray head node.
start_head() {{
echo "Starting Ray head node on ${{HEAD_IP}}"
ray start --head --node-ip-address=${{HEAD_IP}} --port={self.port}
export RAY_ADDRESS="${{HEAD_IP}}:{self.port}"
}}

# Function to start a Ray worker node.
start_worker() {{
# Obtain the head node's hostname from the SLURM_NODELIST.
echo "Starting Ray worker node. Connecting to head ${{HEAD_IP}}"
ray start --address=${{HEAD_IP}}:{self.port}
}}

# If this is the head node, start the head; otherwise, start a worker.
if [ -z "$SLURM_NODEID" ] || [ "$SLURM_NODEID" == "0" ]; then
start_head
else
start_worker
fi

# Only the head node executes the command.
if [ -z "$SLURM_NODEID" ] || [ "$SLURM_NODEID" == "0" ]; then
echo "Running command: {cmd_to_run}"
# Use eval so the given command is executed with its arguments.
eval "{cmd_to_run}"
echo "Command finished. Shutting down Ray on head node."
ray stop
# Optionally, you could touch a file to signal the worker nodes to shut down.
fi

# For worker nodes, simply wait so that Ray stays active.
if [ -n "$SLURM_NODEID" ] && [ "$SLURM_NODEID" != "0" ]; then
echo "Worker node running. Waiting for the Ray head to finish."
while true; do
sleep 15
done
fi
"""
# Return a new Script object with the inline content
return Script(inline=ray_script)


LAUNCHER_MAP: dict[str, Type[Launcher]] = {
"torchrun": Torchrun,
"ft": FaultTolerance,
"slurm_ray": SlurmRay,
}
3 changes: 1 addition & 2 deletions src/nemo_run/core/execution/skypilot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@
from nemo_run.core.execution.base import (
Executor,
ExecutorMacros,
FaultTolerance,
Torchrun,
)
from nemo_run.core.execution.launcher import FaultTolerance, Torchrun
from nemo_run.core.packaging.base import Packager
from nemo_run.core.packaging.git import GitArchivePackager

Expand Down
11 changes: 7 additions & 4 deletions src/nemo_run/core/execution/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,8 @@
from nemo_run.core.execution.base import (
Executor,
ExecutorMacros,
FaultTolerance,
Launcher,
Torchrun,
)
from nemo_run.core.execution.launcher import FaultTolerance, Launcher, SlurmRay, Torchrun
from nemo_run.core.execution.utils import fill_template
from nemo_run.core.frontend.console.api import CONSOLE
from nemo_run.core.packaging.base import Packager
Expand Down Expand Up @@ -544,6 +542,9 @@ def get_launcher_prefix(self) -> Optional[list[str]]:
if launcher.nsys_profile:
return launcher.get_nsys_prefix(profile_dir=f"/{RUNDIR_NAME}")

def supports_launcher_transform(self) -> bool:
return True if isinstance(self.get_launcher(), SlurmRay) else False

def package_configs(self, *cfgs: tuple[str, str]) -> list[str]:
filenames = []
basepath = os.path.join(self.job_dir, "configs")
Expand Down Expand Up @@ -825,7 +826,9 @@ def materialize(self) -> str:

sbatch_flags = []
if self.slurm_config.heterogeneous:
assert len(self.jobs) == len(self.slurm_config.resource_group)
assert (
len(self.jobs) == len(self.slurm_config.resource_group)
), f"Number of jobs {len(self.jobs)} must match number of resource group requests {len(self.slurm_config.resource_group)}.\nIf you are just submitting a single job, make sure that heterogeneous=False in the executor."
final_group_index = len(self.slurm_config.resource_group) - 1
if self.slurm_config.het_group_indices:
final_group_index = self.slurm_config.het_group_indices.index(
Expand Down
64 changes: 44 additions & 20 deletions src/nemo_run/run/torchx_backend/packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
from torchx import specs

from nemo_run.config import SCRIPTS_DIR, Partial, Script
from nemo_run.core.execution.base import Executor, FaultTolerance, Torchrun
from nemo_run.core.execution.base import Executor
from nemo_run.core.execution.dgxcloud import DGXCloudExecutor
from nemo_run.core.execution.launcher import FaultTolerance, Torchrun
from nemo_run.core.execution.local import LocalExecutor
from nemo_run.core.serialization.yaml import YamlSerializer
from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer
Expand Down Expand Up @@ -51,7 +52,7 @@ def package(
env = env | executor.env_vars
mounts = mounts or []

if isinstance(fn_or_script, Partial):
def _get_details_from_partial(fn_or_script: Partial):
args = [
"-n",
name,
Expand Down Expand Up @@ -90,24 +91,30 @@ def package(
no_python = False
script = None
entrypoint = "python"
else:
try:
yaml_cfgs = [
(
f"{name}_executor.yaml",
_serialize(executor.to_config(), serializer_cls=YamlSerializer),
),
(
f"{name}_config.yaml",
_serialize(
fdl_dc.convert_dataclasses_to_configs(fn_or_script, allow_post_init=True),
serializer_cls=YamlSerializer,

return role_args, args, m, no_python, script, entrypoint

def _get_details_from_script(fn_or_script: Script, serialize_configs: bool):
if serialize_configs:
try:
yaml_cfgs = [
(
f"{name}_executor.yaml",
_serialize(executor.to_config(), serializer_cls=YamlSerializer),
),
),
]
executor.package_configs(*yaml_cfgs)
except Exception as e:
log.warning(f"Failed saving yaml configs due to: {e}")
(
f"{name}_config.yaml",
_serialize(
fdl_dc.convert_dataclasses_to_configs(
fn_or_script, allow_post_init=True
),
serializer_cls=YamlSerializer,
),
),
]
executor.package_configs(*yaml_cfgs)
except Exception as e:
log.warning(f"Failed saving yaml configs due to: {e}")

args = fn_or_script.args
role_args = fn_or_script.to_command(
Expand All @@ -117,10 +124,27 @@ def package(
m = fn_or_script.path if fn_or_script.m else None
no_python = fn_or_script.entrypoint != "python"
script = fn_or_script.path if not fn_or_script.m else None
env = env | fn_or_script.env
entrypoint = fn_or_script.entrypoint

return role_args, args, m, no_python, script, entrypoint

if isinstance(fn_or_script, Partial):
role_args, args, m, no_python, script, entrypoint = _get_details_from_partial(fn_or_script)
else:
role_args, args, m, no_python, script, entrypoint = _get_details_from_script(
fn_or_script, serialize_configs=True
)
env = env | fn_or_script.env

launcher = executor.get_launcher()
if executor.supports_launcher_transform():
cmd = [entrypoint] + role_args
transformed_script = launcher.transform(cmd)
if transformed_script:
role_args, args, m, no_python, script, entrypoint = _get_details_from_script(
transformed_script, serialize_configs=False
)

if launcher and isinstance(launcher, Torchrun):
app_def = torchrun.torchrun(
*args,
Expand Down
7 changes: 3 additions & 4 deletions test/core/execution/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,15 @@

import fiddle as fdl
import pytest
from torchx.specs import Role

from nemo_run.config import Config
from nemo_run.core.execution.base import (
Executor,
ExecutorMacros,
FaultTolerance,
Launcher,
Torchrun,
)
from nemo_run.core.execution.launcher import FaultTolerance, Launcher, Torchrun
from nemo_run.core.execution.slurm import SlurmExecutor
from torchx.specs import Role


class TestExecutorMacros:
Expand Down
Loading
Loading