|
| 1 | +import logging |
| 2 | +import os |
| 3 | +import argparse |
| 4 | + |
| 5 | +from torchx import specs |
| 6 | +from torchx.components.dist import ddp |
| 7 | +from torchx.runner import get_runner |
| 8 | + |
| 9 | +logging.basicConfig(level=logging.INFO) |
| 10 | +logger = logging.getLogger(__name__) |
| 11 | + |
| 12 | +_SCHEDULER = "slurm" |
| 13 | + |
| 14 | +def _make_app(replica_id: int, cli_args: argparse.Namespace): |
| 15 | + args = [ |
| 16 | + "--comm.trace_buf_size=0", |
| 17 | + "--comm.train_timeout_seconds=60", |
| 18 | + |
| 19 | + "--metrics.log_freq=1", |
| 20 | + "--profiling.enable_profiling", |
| 21 | + |
| 22 | + "--experimental.custom_args_module=torchtitan.components.ft.config", |
| 23 | + "--job.config_file=./torchtitan/models/llama3/train_configs/llama3_8b.toml", |
| 24 | + "--model.name=llama3_ft", |
| 25 | + |
| 26 | + "--training.dataset=c4", |
| 27 | + "--training.steps=10000", |
| 28 | + "--training.local_batch_size=2", |
| 29 | + |
| 30 | + f"--parallelism.data_parallel_shard_degree={cli_args.nodes * cli_args.nproc_per_node}", |
| 31 | + |
| 32 | + "--fault_tolerance.enable", |
| 33 | + f"--fault_tolerance.replica_id={replica_id}", |
| 34 | + f"--fault_tolerance.group_size={cli_args.groups}", |
| 35 | + f"--fault_tolerance.process_group={cli_args.process_group}", |
| 36 | + f"--fault_tolerance.process_group_timeout_ms={600 * 1000}", |
| 37 | + ] |
| 38 | + |
| 39 | + if cli_args.enable_semi_sync: |
| 40 | + args += [ |
| 41 | + f"--fault_tolerance.semi_sync_method={cli_args.semi_sync_method}", |
| 42 | + ] |
| 43 | + |
| 44 | + if cli_args.semi_sync_method == "diloco": |
| 45 | + args += [ |
| 46 | + "--fault_tolerance.sync_steps=20", |
| 47 | + "--fault_tolerance.fragment_sync_delay=1", |
| 48 | + f"--fault_tolerance.num_fragments={cli_args.num_fragments}", |
| 49 | + ] |
| 50 | + |
| 51 | + if cli_args.replica_id == 0: |
| 52 | + args += [ |
| 53 | + "--metrics.enable-wandb", |
| 54 | + "--checkpoint.interval=100", |
| 55 | + ] |
| 56 | + |
| 57 | + env = {} |
| 58 | + |
| 59 | + # use agent store in torchelastic to avoid TCPStore init race condition |
| 60 | + env["TORCH_SHARE_RDZV_TCP_STORE"] = "1" |
| 61 | + env["TORCH_CPP_LOG_LEVEL"] = "INFO" |
| 62 | + |
| 63 | + env["TORCH_CUDA_SANITIZER=1"] = "1" |
| 64 | + |
| 65 | + # NCCL envs for debugging |
| 66 | + env["NCCL_DEBUG"] = "INFO" |
| 67 | + env["NCCL_DEBUG_SUBSYS"] = "ALL" |
| 68 | + env["NCCL_PROTO"] = "Simple" |
| 69 | + |
| 70 | + # gloo |
| 71 | + env["GLOO_SOCKET_IFNAME"] = os.environ.get("GLOO_SOCKET_IFNAME", "ens3") |
| 72 | + |
| 73 | + # application log levels |
| 74 | + env["LOGLEVEL"] = "INFO" |
| 75 | + env["RUST_LOGS"] = "INFO" |
| 76 | + env["TORCH_CPP_LOG_LEVEL"] = "INFO" |
| 77 | + |
| 78 | + # application timeouts |
| 79 | + env["TORCHFT_QUORUM_TIMEOUT_SEC"] = "900" |
| 80 | + env["TORCHFT_TIMEOUT_SEC"] = "600" |
| 81 | + env["TORCHFT_QUORUM_RETRIES"] = "0" |
| 82 | + |
| 83 | + env["TORCHFT_LIGHTHOUSE"] = os.environ.get("TORCHFT_LIGHTHOUSE", "http://slurm-head-node-0:29510") |
| 84 | + |
| 85 | + env["WANDB_PROJECT"] = "torchft" |
| 86 | + |
| 87 | + |
| 88 | + app = ddp( |
| 89 | + *args, |
| 90 | + name=f"ft_{replica_id}", |
| 91 | + env=env, |
| 92 | + script="./torchtitan/train.py", |
| 93 | + gpu=cli_args.nproc_per_node, |
| 94 | + j=f"{cli_args.nodes}x{cli_args.nproc_per_node}", |
| 95 | + ) |
| 96 | + app.roles[0].name = app.name |
| 97 | + return app |
| 98 | + |
| 99 | + |
| 100 | + |
| 101 | +def start_replica(runner, replica_id: int, args: argparse.Namespace): |
| 102 | + app = _make_app(replica_id, args) |
| 103 | + |
| 104 | + app_handle = runner.run( |
| 105 | + app, |
| 106 | + scheduler=_SCHEDULER, |
| 107 | + cfg={"partition": "batch"}, |
| 108 | + ) |
| 109 | + |
| 110 | + return app_handle |
| 111 | + |
| 112 | + |
| 113 | +def monitor(runner, args: argparse.Namespace): |
| 114 | + jobs = runner.list(_SCHEDULER) |
| 115 | + jobs = [job for job in jobs if job.state == specs.AppState.RUNNING] |
| 116 | + |
| 117 | + active_replicas = {} |
| 118 | + |
| 119 | + for job in jobs: |
| 120 | + if "ft_" not in job.name: |
| 121 | + continue |
| 122 | + name, _, _ = job.name.partition("-") |
| 123 | + _, _, replica_id_str = name.partition("_") |
| 124 | + replica_id = int(replica_id_str) |
| 125 | + active_replicas[replica_id] = job |
| 126 | + |
| 127 | + to_launch = set() |
| 128 | + for replica_id in range(args.replica_count): |
| 129 | + alive = replica_id in active_replicas |
| 130 | + |
| 131 | + if alive: |
| 132 | + job = active_replicas[replica_id] |
| 133 | + print(f" - {replica_id=:2d}: ALIVE {job.app_handle}") |
| 134 | + else: |
| 135 | + print(f" - {replica_id=:2d}: DEAD") |
| 136 | + to_launch.add(replica_id) |
| 137 | + |
| 138 | + for replica_id in to_launch: |
| 139 | + app_handle = start_replica( |
| 140 | + runner, |
| 141 | + replica_id, |
| 142 | + args, |
| 143 | + ) |
| 144 | + print(f"launched {replica_id=}: {app_handle=}") |
| 145 | + |
| 146 | + |
| 147 | +def main(): |
| 148 | + parser = argparse.ArgumentParser( |
| 149 | + description="CLI tool lauch data parallel replicas on slurm" |
| 150 | + ) |
| 151 | + |
| 152 | + parser.add_argument( |
| 153 | + "--workspace-dir", |
| 154 | + type=str, |
| 155 | + help="Location of torchtitan folder" |
| 156 | + ) |
| 157 | + |
| 158 | + parser.add_argument( |
| 159 | + "--nodes", |
| 160 | + type=int, |
| 161 | + default=10, |
| 162 | + help="Number of nodes per replica", |
| 163 | + ) |
| 164 | + |
| 165 | + parser.add_argument( |
| 166 | + "--nproc-per-node", |
| 167 | + type=int, |
| 168 | + default=10, |
| 169 | + help="Number of ranks per node", |
| 170 | + ) |
| 171 | + |
| 172 | + parser.add_argument( |
| 173 | + "--replica-count", |
| 174 | + type=int, |
| 175 | + default=10, |
| 176 | + help="Number of data parallel replicas", |
| 177 | + ) |
| 178 | + |
| 179 | + parser.add_argument( |
| 180 | + "--process-group", |
| 181 | + type=str, |
| 182 | + default="gloo", |
| 183 | + help="The process group to use for data parallel", |
| 184 | + ) |
| 185 | + |
| 186 | + parser.add_argument( |
| 187 | + "--enable-semi-sync-method", |
| 188 | + type=bool, |
| 189 | + default=True, |
| 190 | + help="Whether to enable semi-sync method for data parallel", |
| 191 | + ) |
| 192 | + |
| 193 | + |
| 194 | + parser.add_argument( |
| 195 | + "--semi-sync-method", |
| 196 | + type=str, |
| 197 | + default="diloco", |
| 198 | + help="The semi-sync method to use for data parallel. Options: diloco, local_sgd", |
| 199 | + ) |
| 200 | + |
| 201 | + parser.add_argument( |
| 202 | + "--num-fragments", |
| 203 | + type=int, |
| 204 | + default=2, |
| 205 | + help="The number of fragments to use for data parallel. Only used for diloco semi-sync method", |
| 206 | + ) |
| 207 | + |
| 208 | + args = parser.parse_args() |
| 209 | + |
| 210 | + os.chdir(args.workspace_dir) |
| 211 | + |
| 212 | + with get_runner() as runner: |
| 213 | + monitor(runner, args) |
| 214 | + |
| 215 | +if __name__ == "__main__": |
| 216 | + main() |
0 commit comments