Skip to content

Commit eebdf3a

Browse files
authored
add runner script for slurm (#263)
Summary: - add a script to lauch replicas using titan on slurm - add a script to randomly kill replicas to test fault tolerance
1 parent 855bcad commit eebdf3a

File tree

4 files changed

+339
-1
lines changed

4 files changed

+339
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ dev = [
2727
"parameterized",
2828
"expecttest",
2929
"numpy",
30-
"torchx",
30+
"torchx-nightly",
3131
"lintrunner",
3232
"lintrunner-adapters",
3333
]

torchft/examples/slurm/README.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
## Launch lighthouse
2+
3+
Run this command to launch the lighthouse somewhere and make sure other slurm nodes have access to this node
4+
5+
6+
```
7+
RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 10000
8+
```
9+
10+
## Launch training
11+
12+
First, go to your local torchtitan folder and run
13+
14+
```bash
15+
$ pip install -r requirements.txt
16+
$ pip install .
17+
```
18+
19+
Run the following command to launch torchft lighthouse and replicas using torchtitan on slurm
20+
21+
```bash
22+
$ pip install torchx-nightly
23+
$ # Set the address of the lighthouse server e.g.
24+
$ export TORCHFT_LIGHTHOUSE=http://slurm-head-node-0:29510
25+
$ python runner.py --workspace-dir=/path/to/torchtitan/folder --nodes=1 --nproc-per-node=8 --replica-count=2
26+
```
27+
28+
## Test fault tolerance
29+
30+
To inject some failures, you can use the following command
31+
32+
```bash
33+
$ python punisher.py kill_loop --num-failures=10 --mtbf-secs=300
34+
```

torchft/examples/slurm/punisher.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import argparse
2+
import logging
3+
import random
4+
import time
5+
6+
from torchx import specs
7+
from torchx.runner import Runner, get_runner
8+
9+
logging.basicConfig(level=logging.INFO)
10+
logger: logging.Logger = logging.getLogger(__name__)
11+
12+
_SCHEDULER = "slurm"
13+
14+
15+
def kill_all(runner: Runner) -> None:
16+
jobs = runner.list(_SCHEDULER)
17+
jobs = [job for job in jobs if job.state == specs.AppState.RUNNING]
18+
for job in jobs:
19+
if "ft_" not in job.name:
20+
continue
21+
print(f"killing {job.app_handle}")
22+
runner.cancel(job.app_handle)
23+
24+
25+
def kill_one(runner: Runner) -> None:
26+
jobs = runner.list(_SCHEDULER)
27+
jobs = [job for job in jobs if job.state == specs.AppState.RUNNING]
28+
candidates = []
29+
for job in jobs:
30+
if "ft_" not in job.name:
31+
continue
32+
if "ft_0" in job.name:
33+
continue
34+
candidates.append(job.app_handle)
35+
choice = random.choice(candidates)
36+
print(f"killing {choice=} {candidates=}")
37+
runner.cancel(choice)
38+
39+
40+
def kill_loop(runner: Runner, args: argparse.Namespace) -> None:
41+
for _ in range(args.num_failures):
42+
kill_one(runner)
43+
dur = random.random() * (2 * args.mtbf_secs)
44+
print(f"sleeping for {dur=} {args.mtbf_secs=}")
45+
time.sleep(args.mtbf_secs)
46+
47+
48+
def main() -> None:
49+
parser = argparse.ArgumentParser(description="CLI tool to inject failures on slurm")
50+
subparsers = parser.add_subparsers(dest="command", help="Available commands")
51+
52+
# kill_loop subcommand
53+
kill_loop_parser = subparsers.add_parser("kill_loop", help="Kill jobs in a loop")
54+
kill_loop_parser.add_argument(
55+
"--mtbf-secs",
56+
type=float,
57+
default=5,
58+
help="Mean time between failures",
59+
)
60+
kill_loop_parser.add_argument(
61+
"--num-failures",
62+
type=int,
63+
default=1,
64+
help="Number of failures to inject",
65+
)
66+
67+
# kill_one subcommand
68+
subparsers.add_parser("kill_one", help="Kill a single job")
69+
70+
# kill_all subcommand
71+
subparsers.add_parser("kill_all", help="Kill all jobs")
72+
73+
args = parser.parse_args()
74+
75+
if args.command is None:
76+
parser.print_help()
77+
return
78+
79+
with get_runner() as runner:
80+
if args.command == "kill_loop":
81+
kill_loop(runner, args)
82+
elif args.command == "kill_one":
83+
kill_one(runner)
84+
elif args.command == "kill_all":
85+
kill_all(runner)
86+
87+
88+
if __name__ == "__main__":
89+
main()

torchft/examples/slurm/runner.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
import argparse
2+
import logging
3+
import os
4+
import time
5+
6+
from torchx import specs
7+
from torchx.components.dist import ddp
8+
from torchx.runner import Runner, get_runner
9+
10+
logging.basicConfig(level=logging.INFO)
11+
logger: logging.Logger = logging.getLogger(__name__)
12+
13+
_SCHEDULER = "slurm"
14+
15+
16+
def _make_app(replica_id: int, cli_args: argparse.Namespace) -> specs.AppDef:
17+
args = [
18+
"--comm.trace_buf_size=0",
19+
"--comm.train_timeout_seconds=60",
20+
"--metrics.log_freq=1",
21+
"--profiling.enable_profiling",
22+
"--experimental.custom_args_module=torchtitan.components.ft.config",
23+
"--job.config_file=./torchtitan/models/llama3/train_configs/llama3_8b.toml",
24+
"--model.name=llama3_ft",
25+
"--training.dataset=c4",
26+
"--training.steps=10000",
27+
"--training.local_batch_size=2",
28+
f"--parallelism.data_parallel_shard_degree={cli_args.nodes * cli_args.nproc_per_node}",
29+
"--fault_tolerance.enable",
30+
f"--fault_tolerance.replica_id={replica_id}",
31+
f"--fault_tolerance.group_size={cli_args.replica_count}",
32+
f"--fault_tolerance.process_group={cli_args.process_group}",
33+
f"--fault_tolerance.process_group_timeout_ms={600 * 1000}",
34+
]
35+
36+
if cli_args.enable_semi_sync:
37+
args += [
38+
f"--fault_tolerance.semi_sync_method={cli_args.semi_sync_method}",
39+
]
40+
41+
if cli_args.semi_sync_method == "diloco":
42+
args += [
43+
"--fault_tolerance.sync_steps=20",
44+
"--fault_tolerance.fragment_sync_delay=1",
45+
f"--fault_tolerance.num_fragments={cli_args.num_fragments}",
46+
]
47+
48+
if replica_id == 0:
49+
args += [
50+
"--metrics.enable-wandb",
51+
"--checkpoint.interval=100",
52+
]
53+
54+
env = {}
55+
56+
# use agent store in torchelastic to avoid TCPStore init race condition
57+
env["TORCH_SHARE_RDZV_TCP_STORE"] = "1"
58+
env["TORCH_CPP_LOG_LEVEL"] = "INFO"
59+
60+
env["TORCH_CUDA_SANITIZER=1"] = "1"
61+
62+
# NCCL envs for debugging
63+
env["NCCL_DEBUG"] = "INFO"
64+
env["NCCL_DEBUG_SUBSYS"] = "ALL"
65+
env["NCCL_PROTO"] = "Simple"
66+
67+
# gloo
68+
if os.environ.get("GLOO_SOCKET_IFNAME") is not None:
69+
env["GLOO_SOCKET_IFNAME"] = os.environ.get("GLOO_SOCKET_IFNAME")
70+
71+
# application log levels
72+
env["LOGLEVEL"] = "INFO"
73+
env["RUST_LOGS"] = "INFO"
74+
env["TORCH_CPP_LOG_LEVEL"] = "INFO"
75+
76+
# application timeouts
77+
env["TORCHFT_QUORUM_TIMEOUT_SEC"] = "900"
78+
env["TORCHFT_TIMEOUT_SEC"] = "600"
79+
env["TORCHFT_QUORUM_RETRIES"] = "0"
80+
81+
env["TORCHFT_LIGHTHOUSE"] = os.environ.get(
82+
"TORCHFT_LIGHTHOUSE", "http://slurm-head-node-0:29510"
83+
)
84+
85+
env["WANDB_PROJECT"] = "torchft"
86+
87+
app = ddp(
88+
*args,
89+
name=f"ft_{replica_id}",
90+
env=env,
91+
script="./torchtitan/train.py",
92+
gpu=cli_args.nproc_per_node,
93+
j=f"{cli_args.nodes}x{cli_args.nproc_per_node}",
94+
)
95+
app.roles[0].name = app.name
96+
return app
97+
98+
99+
def start_replica(
100+
runner: Runner, replica_id: int, args: argparse.Namespace
101+
) -> specs.AppHandle:
102+
app = _make_app(replica_id, args)
103+
104+
app_handle = runner.run(
105+
app,
106+
scheduler=_SCHEDULER,
107+
)
108+
109+
return app_handle
110+
111+
112+
def monitor(runner: Runner, args: argparse.Namespace) -> None:
113+
jobs = runner.list(_SCHEDULER)
114+
jobs = [job for job in jobs if job.state == specs.AppState.RUNNING]
115+
116+
active_replicas = {}
117+
118+
for job in jobs:
119+
if "ft_" not in job.name:
120+
continue
121+
name, _, _ = job.name.partition("-")
122+
_, _, replica_id_str = name.partition("_")
123+
replica_id = int(replica_id_str)
124+
active_replicas[replica_id] = job
125+
126+
to_launch = set()
127+
for replica_id in range(args.replica_count):
128+
alive = replica_id in active_replicas
129+
130+
if alive:
131+
job = active_replicas[replica_id]
132+
print(f" - {replica_id=:2d}: ALIVE {job.app_handle}")
133+
else:
134+
print(f" - {replica_id=:2d}: DEAD")
135+
to_launch.add(replica_id)
136+
137+
for replica_id in to_launch:
138+
app_handle = start_replica(
139+
runner,
140+
replica_id,
141+
args,
142+
)
143+
print(f"launched {replica_id=}: {app_handle=}")
144+
145+
146+
def main() -> None:
147+
parser = argparse.ArgumentParser(
148+
description="CLI tool lauch data parallel replicas on slurm"
149+
)
150+
151+
parser.add_argument(
152+
"--workspace-dir", type=str, help="Location of torchtitan folder"
153+
)
154+
155+
parser.add_argument(
156+
"--nodes",
157+
type=int,
158+
default=10,
159+
help="Number of nodes per replica",
160+
)
161+
162+
parser.add_argument(
163+
"--nproc-per-node",
164+
type=int,
165+
default=10,
166+
help="Number of ranks per node",
167+
)
168+
169+
parser.add_argument(
170+
"--replica-count",
171+
type=int,
172+
default=10,
173+
help="Number of data parallel replicas",
174+
)
175+
176+
parser.add_argument(
177+
"--process-group",
178+
type=str,
179+
default="gloo",
180+
help="The process group to use for data parallel",
181+
)
182+
183+
parser.add_argument(
184+
"--enable-semi-sync",
185+
type=bool,
186+
default=True,
187+
help="Whether to enable semi-sync method for data parallel",
188+
)
189+
190+
parser.add_argument(
191+
"--semi-sync-method",
192+
type=str,
193+
default="diloco",
194+
help="The semi-sync method to use for data parallel. Options: diloco, local_sgd",
195+
)
196+
197+
parser.add_argument(
198+
"--num-fragments",
199+
type=int,
200+
default=2,
201+
help="The number of fragments to use for data parallel. Only used for diloco semi-sync method",
202+
)
203+
204+
args = parser.parse_args()
205+
206+
os.chdir(args.workspace_dir)
207+
208+
with get_runner() as runner:
209+
while True:
210+
monitor(runner, args)
211+
time.sleep(10)
212+
213+
214+
if __name__ == "__main__":
215+
main()

0 commit comments

Comments
 (0)