Skip to content

Commit 97fef50

Browse files
committed
add runner script for slurm
1 parent 855bcad commit 97fef50

File tree

3 files changed

+320
-0
lines changed

3 files changed

+320
-0
lines changed

torchft/examples/slurm/README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
## Launch training
2+
3+
Run the following command to launch torchft lighthouse and replicas using torchtitan on slurm
4+
5+
```bash
6+
$ python runner.py
7+
```
8+
9+
## Test fault tolerance
10+
11+
To inject some failures, you can use the following command
12+
13+
```bash
14+
$ python punisher.py kill_loop
15+
```

torchft/examples/slurm/punisher.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import argparse
2+
import logging
3+
import random
4+
import time
5+
6+
from torchx import specs
7+
8+
from torchx.runner import get_runner
9+
10+
logging.basicConfig(level=logging.INFO)
11+
logger = logging.getLogger(__name__)
12+
13+
_SCHEDULER = "slurm"
14+
15+
16+
def kill_all(runner):
17+
jobs = runner.list(_SCHEDULER)
18+
jobs = [job for job in jobs if job.state == specs.AppState.RUNNING]
19+
for job in jobs:
20+
if "ft_" not in job.name:
21+
continue
22+
print(f"killing {job.app_handle}")
23+
runner.cancel(job.app_handle)
24+
25+
26+
def kill_one(runner):
27+
jobs = runner.list(_SCHEDULER)
28+
jobs = [job for job in jobs if job.state == specs.AppState.RUNNING]
29+
candidates = []
30+
for job in jobs:
31+
if "ft_" not in job.name:
32+
continue
33+
if "ft_0" in job.name:
34+
continue
35+
candidates.append(job.app_handle)
36+
choice = random.choice(candidates)
37+
print(f"killing {choice=} {candidates=}")
38+
runner.cancel(choice)
39+
40+
41+
def kill_loop(runner, args: argparse.Namespace):
42+
for _ in range(args.num_failures):
43+
kill_one(runner)
44+
dur = random.random() * (2 * args.mtbf_secs)
45+
print(f"sleeping for {dur=} {args.mtbf_secs=}")
46+
time.sleep(args.mtbf_secs)
47+
48+
49+
def main():
50+
parser = argparse.ArgumentParser(description="CLI tool to inject failures on slurm")
51+
subparsers = parser.add_subparsers(dest="command", help="Available commands")
52+
53+
# kill_loop subcommand
54+
kill_loop_parser = subparsers.add_parser("kill_loop", help="Kill jobs in a loop")
55+
kill_loop_parser.add_argument(
56+
"--mtbf-secs",
57+
type=float,
58+
default=5,
59+
help="Mean time between failures",
60+
)
61+
kill_loop_parser.add_argument(
62+
"--num-failures",
63+
type=int,
64+
default=1,
65+
help="Number of failures to inject",
66+
)
67+
68+
# kill_one subcommand
69+
subparsers.add_parser("kill_one", help="Kill a single job")
70+
71+
# kill_all subcommand
72+
subparsers.add_parser("kill_all", help="Kill all jobs")
73+
74+
args = parser.parse_args()
75+
76+
if args.command is None:
77+
parser.print_help()
78+
return
79+
80+
with get_runner() as runner:
81+
if args.command == "kill_loop":
82+
kill_loop(runner, args)
83+
elif args.command == "kill_one":
84+
kill_one(runner)
85+
elif args.command == "kill_all":
86+
kill_all(runner)
87+
88+
if __name__ == "__main__":
89+
main()

torchft/examples/slurm/runner.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
import logging
2+
import os
3+
import argparse
4+
5+
from torchx import specs
6+
from torchx.components.dist import ddp
7+
from torchx.runner import get_runner
8+
9+
logging.basicConfig(level=logging.INFO)
10+
logger = logging.getLogger(__name__)
11+
12+
_SCHEDULER = "slurm"
13+
14+
def _make_app(replica_id: int, cli_args: argparse.Namespace):
15+
args = [
16+
"--comm.trace_buf_size=0",
17+
"--comm.train_timeout_seconds=60",
18+
19+
"--metrics.log_freq=1",
20+
"--profiling.enable_profiling",
21+
22+
"--experimental.custom_args_module=torchtitan.components.ft.config",
23+
"--job.config_file=./torchtitan/models/llama3/train_configs/llama3_8b.toml",
24+
"--model.name=llama3_ft",
25+
26+
"--training.dataset=c4",
27+
"--training.steps=10000",
28+
"--training.local_batch_size=2",
29+
30+
f"--parallelism.data_parallel_shard_degree={cli_args.nodes * cli_args.nproc_per_node}",
31+
32+
"--fault_tolerance.enable",
33+
f"--fault_tolerance.replica_id={replica_id}",
34+
f"--fault_tolerance.group_size={cli_args.groups}",
35+
f"--fault_tolerance.process_group={cli_args.process_group}",
36+
f"--fault_tolerance.process_group_timeout_ms={600 * 1000}",
37+
]
38+
39+
if cli_args.enable_semi_sync:
40+
args += [
41+
f"--fault_tolerance.semi_sync_method={cli_args.semi_sync_method}",
42+
]
43+
44+
if cli_args.semi_sync_method == "diloco":
45+
args += [
46+
"--fault_tolerance.sync_steps=20",
47+
"--fault_tolerance.fragment_sync_delay=1",
48+
f"--fault_tolerance.num_fragments={cli_args.num_fragments}",
49+
]
50+
51+
if cli_args.replica_id == 0:
52+
args += [
53+
"--metrics.enable-wandb",
54+
"--checkpoint.interval=100",
55+
]
56+
57+
env = {}
58+
59+
# use agent store in torchelastic to avoid TCPStore init race condition
60+
env["TORCH_SHARE_RDZV_TCP_STORE"] = "1"
61+
env["TORCH_CPP_LOG_LEVEL"] = "INFO"
62+
63+
env["TORCH_CUDA_SANITIZER=1"] = "1"
64+
65+
# NCCL envs for debugging
66+
env["NCCL_DEBUG"] = "INFO"
67+
env["NCCL_DEBUG_SUBSYS"] = "ALL"
68+
env["NCCL_PROTO"] = "Simple"
69+
70+
# gloo
71+
env["GLOO_SOCKET_IFNAME"] = os.environ.get("GLOO_SOCKET_IFNAME", "ens3")
72+
73+
# application log levels
74+
env["LOGLEVEL"] = "INFO"
75+
env["RUST_LOGS"] = "INFO"
76+
env["TORCH_CPP_LOG_LEVEL"] = "INFO"
77+
78+
# application timeouts
79+
env["TORCHFT_QUORUM_TIMEOUT_SEC"] = "900"
80+
env["TORCHFT_TIMEOUT_SEC"] = "600"
81+
env["TORCHFT_QUORUM_RETRIES"] = "0"
82+
83+
env["TORCHFT_LIGHTHOUSE"] = os.environ.get("TORCHFT_LIGHTHOUSE", "http://slurm-head-node-0:29510")
84+
85+
env["WANDB_PROJECT"] = "torchft"
86+
87+
88+
app = ddp(
89+
*args,
90+
name=f"ft_{replica_id}",
91+
env=env,
92+
script="./torchtitan/train.py",
93+
gpu=cli_args.nproc_per_node,
94+
j=f"{cli_args.nodes}x{cli_args.nproc_per_node}",
95+
)
96+
app.roles[0].name = app.name
97+
return app
98+
99+
100+
101+
def start_replica(runner, replica_id: int, args: argparse.Namespace):
102+
app = _make_app(replica_id, args)
103+
104+
app_handle = runner.run(
105+
app,
106+
scheduler=_SCHEDULER,
107+
cfg={"partition": "batch"},
108+
)
109+
110+
return app_handle
111+
112+
113+
def monitor(runner, args: argparse.Namespace):
114+
jobs = runner.list(_SCHEDULER)
115+
jobs = [job for job in jobs if job.state == specs.AppState.RUNNING]
116+
117+
active_replicas = {}
118+
119+
for job in jobs:
120+
if "ft_" not in job.name:
121+
continue
122+
name, _, _ = job.name.partition("-")
123+
_, _, replica_id_str = name.partition("_")
124+
replica_id = int(replica_id_str)
125+
active_replicas[replica_id] = job
126+
127+
to_launch = set()
128+
for replica_id in range(args.replica_count):
129+
alive = replica_id in active_replicas
130+
131+
if alive:
132+
job = active_replicas[replica_id]
133+
print(f" - {replica_id=:2d}: ALIVE {job.app_handle}")
134+
else:
135+
print(f" - {replica_id=:2d}: DEAD")
136+
to_launch.add(replica_id)
137+
138+
for replica_id in to_launch:
139+
app_handle = start_replica(
140+
runner,
141+
replica_id,
142+
args,
143+
)
144+
print(f"launched {replica_id=}: {app_handle=}")
145+
146+
147+
def main():
148+
parser = argparse.ArgumentParser(
149+
description="CLI tool lauch data parallel replicas on slurm"
150+
)
151+
152+
parser.add_argument(
153+
"--workspace-dir",
154+
type=str,
155+
help="Location of torchtitan folder"
156+
)
157+
158+
parser.add_argument(
159+
"--nodes",
160+
type=int,
161+
default=10,
162+
help="Number of nodes per replica",
163+
)
164+
165+
parser.add_argument(
166+
"--nproc-per-node",
167+
type=int,
168+
default=10,
169+
help="Number of ranks per node",
170+
)
171+
172+
parser.add_argument(
173+
"--replica-count",
174+
type=int,
175+
default=10,
176+
help="Number of data parallel replicas",
177+
)
178+
179+
parser.add_argument(
180+
"--process-group",
181+
type=str,
182+
default="gloo",
183+
help="The process group to use for data parallel",
184+
)
185+
186+
parser.add_argument(
187+
"--enable-semi-sync-method",
188+
type=bool,
189+
default=True,
190+
help="Whether to enable semi-sync method for data parallel",
191+
)
192+
193+
194+
parser.add_argument(
195+
"--semi-sync-method",
196+
type=str,
197+
default="diloco",
198+
help="The semi-sync method to use for data parallel. Options: diloco, local_sgd",
199+
)
200+
201+
parser.add_argument(
202+
"--num-fragments",
203+
type=int,
204+
default=2,
205+
help="The number of fragments to use for data parallel. Only used for diloco semi-sync method",
206+
)
207+
208+
args = parser.parse_args()
209+
210+
os.chdir(args.workspace_dir)
211+
212+
with get_runner() as runner:
213+
monitor(runner, args)
214+
215+
if __name__ == "__main__":
216+
main()

0 commit comments

Comments
 (0)