From 504050b0345d6b6f87c4c05b5869c508fbd72359 Mon Sep 17 00:00:00 2001 From: Jeffrey Date: Tue, 2 Jul 2024 00:58:27 -0700 Subject: [PATCH 01/11] add param --- open_lm/distributed.py | 2 ++ open_lm/params.py | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/open_lm/distributed.py b/open_lm/distributed.py index 8c07d663..34cf899a 100644 --- a/open_lm/distributed.py +++ b/open_lm/distributed.py @@ -3,6 +3,7 @@ import logging import torch import torch.distributed as dist +import datetime def is_global_master(args): @@ -79,6 +80,7 @@ def init_distributed_device(args): init_method=args.dist_url, world_size=args.world_size, rank=args.rank, + timeout=datetime.timedelta(seconds=args.backend_timeout) ) else: # DDP via torchrun, torch.distributed.launch diff --git a/open_lm/params.py b/open_lm/params.py index 0a7a3f64..3e15d7cb 100644 --- a/open_lm/params.py +++ b/open_lm/params.py @@ -787,6 +787,12 @@ def parse_args(args): default=0, help="This is the maximum number of failed checkpoints (due to not having seen enough tokens) that are allowed", ) + parser.add_argument( + "--backend-timeout", + type=int, + default=None, + help="This the number of seconds passed into the timeout arg for torch.distributed.init_process_group." + ) add_model_args(parser) From 4368279574f3494b96043e4cd32bc6c39fd60cba Mon Sep 17 00:00:00 2001 From: Jeffrey Date: Tue, 2 Jul 2024 02:02:12 -0700 Subject: [PATCH 02/11] add timeout to both --- open_lm/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_lm/distributed.py b/open_lm/distributed.py index 34cf899a..3a510384 100644 --- a/open_lm/distributed.py +++ b/open_lm/distributed.py @@ -87,7 +87,7 @@ def init_distributed_device(args): # Note that this currently assumes that the world size is all gpus in a node. assert args.preset_world_size is None, "--preset_world_size with torchrun is not currently supported." args.local_rank, _, _ = world_info_from_env() - torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, timeout=args.backend_timeout) args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() args.distributed = True From 14d7210a6d231947baee41e3c18a68379fc78b99 Mon Sep 17 00:00:00 2001 From: Jeffrey Date: Tue, 2 Jul 2024 16:40:56 -0700 Subject: [PATCH 03/11] fix type to datetime.timedelta --- open_lm/distributed.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/open_lm/distributed.py b/open_lm/distributed.py index 3a510384..bd3c6667 100644 --- a/open_lm/distributed.py +++ b/open_lm/distributed.py @@ -60,6 +60,9 @@ def init_distributed_device(args): args.local_rank = 0 # For testing, allow forcing distributed mode to test distributed code path even on one gpu. if is_using_distributed() or args.force_distributed: + + timeout = datetime.timedelta(seconds=args.timeout) if args.timeout else None + if "SLURM_PROCID" in os.environ: # DDP via SLURM args.local_rank, args.rank, env_world_size = world_info_from_env() @@ -80,14 +83,14 @@ def init_distributed_device(args): init_method=args.dist_url, world_size=args.world_size, rank=args.rank, - timeout=datetime.timedelta(seconds=args.backend_timeout) + timeout=timeout, ) else: # DDP via torchrun, torch.distributed.launch # Note that this currently assumes that the world size is all gpus in a node. assert args.preset_world_size is None, "--preset_world_size with torchrun is not currently supported." args.local_rank, _, _ = world_info_from_env() - torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, timeout=args.backend_timeout) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, timeout=timeout) args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() args.distributed = True From 22e9968d7d3c14dbacc37b467f13735cf871208a Mon Sep 17 00:00:00 2001 From: Jeffrey Date: Tue, 2 Jul 2024 16:59:58 -0700 Subject: [PATCH 04/11] fix typo --- open_lm/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_lm/distributed.py b/open_lm/distributed.py index bd3c6667..2dccb71d 100644 --- a/open_lm/distributed.py +++ b/open_lm/distributed.py @@ -61,7 +61,7 @@ def init_distributed_device(args): # For testing, allow forcing distributed mode to test distributed code path even on one gpu. if is_using_distributed() or args.force_distributed: - timeout = datetime.timedelta(seconds=args.timeout) if args.timeout else None + timeout = datetime.timedelta(seconds=args.backend_timeout) if args.timeout else None if "SLURM_PROCID" in os.environ: # DDP via SLURM From 7ef1afe28c904cbb11ca0aeb4d6ebd0b8149847f Mon Sep 17 00:00:00 2001 From: Jeffrey Date: Tue, 2 Jul 2024 17:20:34 -0700 Subject: [PATCH 05/11] fix typo in 2nd location --- open_lm/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_lm/distributed.py b/open_lm/distributed.py index 2dccb71d..f49e564f 100644 --- a/open_lm/distributed.py +++ b/open_lm/distributed.py @@ -61,7 +61,7 @@ def init_distributed_device(args): # For testing, allow forcing distributed mode to test distributed code path even on one gpu. if is_using_distributed() or args.force_distributed: - timeout = datetime.timedelta(seconds=args.backend_timeout) if args.timeout else None + timeout = datetime.timedelta(seconds=args.backend_timeout) if args.backend_timeout else None if "SLURM_PROCID" in os.environ: # DDP via SLURM From d0f05b122b069988020abe7c58087c0cee9ac13d Mon Sep 17 00:00:00 2001 From: Jeffrey Date: Wed, 3 Jul 2024 15:23:15 -0700 Subject: [PATCH 06/11] do not sleep before first attempt at remote sync --- open_lm/file_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/open_lm/file_utils.py b/open_lm/file_utils.py index f91919b2..80709ce5 100644 --- a/open_lm/file_utils.py +++ b/open_lm/file_utils.py @@ -72,11 +72,10 @@ def remote_sync(local_dir, remote_dir, protocol): def remote_sync_with_expon_backoff(sync_every, local_dir, remote_dir, protocol, max_retries=6): for i in range(max_retries): - time.sleep(sync_every * 2**i) success = remote_sync(local_dir, remote_dir, protocol) if success: return True - + time.sleep(sync_every * 2**i) return False From 933a8adbf60b72ae40e34d33b6b237e3f13dcc91 Mon Sep 17 00:00:00 2001 From: Jeffrey Date: Thu, 4 Jul 2024 00:43:04 -0700 Subject: [PATCH 07/11] patch args.log_local --- open_lm/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/open_lm/main.py b/open_lm/main.py index 7c80f558..5843a18c 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -572,7 +572,7 @@ def main(args): if args.resume is not None and averagers is not None: load_avg_models(args, averagers) - if is_master(args): + if is_master(args, local=args.log_local): logging.info(f"Model (has {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters):") logging.info(f"{str(model)}") logging.info("Params:") @@ -717,7 +717,7 @@ def main(args): raise ValueError(f"Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const.") # determine if this worker should save logs and checkpoints. only do so if it is rank == 0 - args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args) + args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args, local=args.log_local) writer = None if args.save_logs and args.tensorboard: assert tensorboard is not None, "Please install tensorboard." From 39e8aa8928f9a69565032df25c77ee6ebedf90ca Mon Sep 17 00:00:00 2001 From: Jeffrey Date: Thu, 4 Jul 2024 00:53:17 -0700 Subject: [PATCH 08/11] change rank0_only behavior when log_local --- open_lm/main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/open_lm/main.py b/open_lm/main.py index 5843a18c..97cda7b5 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -211,7 +211,8 @@ def save_checkpoint( ): cpu_state, optim_state = None, None if args.logs and args.logs.lower() != "none" and args.fsdp: - save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + rank0_only = not args.log_local + save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=rank0_only) with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy): cpu_state = model.state_dict() optim_state = FSDP.optim_state_dict(model, optimizer) From d8ca03cb7ade352484783a2867511760169c0cdc Mon Sep 17 00:00:00 2001 From: Jeffrey Date: Thu, 4 Jul 2024 02:34:17 -0700 Subject: [PATCH 09/11] fix sleep behavior only for last iteration, also address the parent dir issue --- open_lm/file_utils.py | 2 +- open_lm/main.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/open_lm/file_utils.py b/open_lm/file_utils.py index 80709ce5..12b6ea88 100644 --- a/open_lm/file_utils.py +++ b/open_lm/file_utils.py @@ -72,10 +72,10 @@ def remote_sync(local_dir, remote_dir, protocol): def remote_sync_with_expon_backoff(sync_every, local_dir, remote_dir, protocol, max_retries=6): for i in range(max_retries): + time.sleep(sync_every * 2**i) success = remote_sync(local_dir, remote_dir, protocol) if success: return True - time.sleep(sync_every * 2**i) return False diff --git a/open_lm/main.py b/open_lm/main.py index 97cda7b5..530114a1 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -381,7 +381,7 @@ def main(args): args.tensorboard = "tensorboard" in args.report_to or "all" in args.report_to args.checkpoint_path = os.path.join(log_base_path, "checkpoints") args.failed_checkpoint_path = os.path.join(log_base_path, "checkpoints_failed") - if is_master(args): + if is_master(args, local=args.log_local): args.tensorboard_path = os.path.join(log_base_path, "tensorboard") if args.tensorboard else "" for dirname in [args.tensorboard_path, args.checkpoint_path, args.failed_checkpoint_path]: if dirname: @@ -932,8 +932,9 @@ def main(args): if remote_sync_process is not None: logging.info("Final remote sync.") terminate_sync_process(remote_sync_process) + # Can just pass in sync_every=0 for last sync, otherwise will unecessarily sleep. result = remote_sync_with_expon_backoff( - args.remote_sync_frequency, + 0, os.path.join(args.logs, args.name), os.path.join(args.remote_sync, args.name), args.remote_sync_protocol, From a857a41f43566d4a2f841bfe35fdaa07ed5a2a51 Mon Sep 17 00:00:00 2001 From: Jeffrey Date: Thu, 4 Jul 2024 09:36:14 -0700 Subject: [PATCH 10/11] 0 sleep for initial sync --- open_lm/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/open_lm/main.py b/open_lm/main.py index 530114a1..a3f77ce6 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -425,9 +425,9 @@ def main(args): # start the sync proces if remote-sync is not None remote_sync_process = None if is_master(args) and args.remote_sync is not None: - # first make sure it works + # first make sure it works: here, remote_sync_frequency is set to 0 for this initial test result = remote_sync_with_expon_backoff( - args.remote_sync_frequency, + 0, os.path.join(args.logs, args.name), os.path.join(args.remote_sync, args.name), args.remote_sync_protocol, From 6f8fbab40acebf9078aaefbe608e4bca314f6bc3 Mon Sep 17 00:00:00 2001 From: Jeffrey Date: Thu, 4 Jul 2024 12:31:51 -0700 Subject: [PATCH 11/11] linting --- open_lm/params.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_lm/params.py b/open_lm/params.py index 3e15d7cb..4c66c8ee 100644 --- a/open_lm/params.py +++ b/open_lm/params.py @@ -791,7 +791,7 @@ def parse_args(args): "--backend-timeout", type=int, default=None, - help="This the number of seconds passed into the timeout arg for torch.distributed.init_process_group." + help="This the number of seconds passed into the timeout arg for torch.distributed.init_process_group.", ) add_model_args(parser)