Skip to content

Commit

Permalink
Various cleanups around distributed setup
Browse files Browse the repository at this point in the history
ghstack-source-id: 60e9e1a1282fcae3b334ce3ffa8db23c1848e098
Pull Request resolved: pytorch#645
  • Loading branch information
awgu committed Oct 22, 2024
1 parent e10cb94 commit e211236
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 13 deletions.
6 changes: 0 additions & 6 deletions torchtitan/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,6 @@ def trace_handler(prof):
logger.info(
f"Finished dumping traces in {time.monotonic() - begin:.2f} seconds"
)
# Profiling is a heavy operation which could cost very different amount of time
# across all ranks. Insert a barrier to make sure all ranks have finished profiling
# before moving on.
# TODO: Can we find a cleaner way?
torch.distributed.barrier()

logger.info(f"Profiling active. Traces will be saved at {trace_dir}")

Expand Down Expand Up @@ -119,7 +114,6 @@ def step(self, exit_ctx: bool = False):
logger.info(
f"Finished dumping memory snapshot in {time.monotonic() - begin:.2f} seconds"
)
torch.distributed.barrier()

logger.info(f"Memory profiler active. Snapshot will be saved at {snapshot_dir}")
profiler = MemoryProfiler(global_step, config.profiling.profile_freq)
Expand Down
14 changes: 7 additions & 7 deletions torchtitan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ def set_pg_timeouts(timeout, world_mesh):
"""
Sets the timeout for all PGs in the provided mesh, and the default (world) group.
Note: synchronizes via a barrier, before changing the timeouts. This is important, becuase
Note: synchronizes via a barrier, before changing the timeouts. This is important, because
otherwise you may face a race where the slow rank has not reached the timeout reduction point
yet due to slow operations permitted under the old timeout value, but other faster ranks may
start issueing collectives under the new shorter timeout and then immediately timeout.
start issuing collectives under the new shorter timeout and then immediately timeout.
"""
logger.info(
f"Synchronizing and adjusting timeout for all ProcessGroups to {timeout}"
Expand All @@ -70,7 +70,7 @@ def set_pg_timeouts(timeout, world_mesh):
# otherwise, some ranks may issue collectives with the new/shorter timeout and
# those may time out, before other ranks have finished with initialization done
# under the old/slow timeout.
torch.distributed.barrier()
torch.distributed.barrier(device_ids=[torch.cuda.current_device()])
torch.cuda.synchronize()

groups = [world_mesh.get_group(mesh_dim) for mesh_dim in range(world_mesh.ndim)]
Expand Down Expand Up @@ -117,15 +117,15 @@ def init_distributed(job_config):
os.makedirs(dump_dir, exist_ok=True)
_warn_overwrite_env(TRACE_FILE, f"{dump_dir}/rank_")

torch.distributed.init_process_group(
"nccl", timeout=timedelta(seconds=job_config.comm.init_timeout_seconds)
)

# to mitigate the memory issue that collectives using
# async_op=True hold memory longer than they should
# such as those in tensor parallelism
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"

torch.distributed.init_process_group(
"nccl", timeout=timedelta(seconds=job_config.comm.init_timeout_seconds)
)


def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> int:
num_params = sum(p.numel() for p in model.parameters())
Expand Down

0 comments on commit e211236

Please sign in to comment.