diff --git a/torchtitan/utils.py b/torchtitan/utils.py index 88663c00..c9dcf2fa 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -34,14 +34,20 @@ def get_device_info(): device_type, device_module = get_device_info() -def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float: - tensor = torch.tensor(x).to(device_type) - return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.MAX.name, group=mesh).item() +def dist_reduce(x: torch.Tensor, reduceOp: str, mesh: DeviceMesh) -> float: + if isinstance(x, DTensor): + # functional collectives do not support DTensor inputs + x = x.full_tensor() + assert x.numel() == 1 # required by `.item()` + return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item() -def dist_mean(x: Union[int, float], mesh: DeviceMesh) -> float: - tensor = torch.tensor(x).to(device_type) - return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh).item() +def dist_max(x: torch.Tensor, mesh: DeviceMesh) -> float: + return dist_reduce(x, reduceOp=c10d.ReduceOp.MAX.name, mesh=mesh) + + +def dist_mean(x: torch.Tensor, mesh: DeviceMesh) -> float: + return dist_reduce(x, reduceOp=c10d.ReduceOp.AVG.name, mesh=mesh) def _warn_overwrite_env(env, val): diff --git a/train.py b/train.py index 21dd9f8b..d79f51b5 100644 --- a/train.py +++ b/train.py @@ -231,7 +231,6 @@ def loss_fn(pred, labels): ) # variables used to keep info for metrics logging - losses_since_last_log = [] ntokens_since_last_log = 0 data_loading_times = [] time_last_log = time.perf_counter() @@ -295,10 +294,11 @@ def loss_fn(pred, labels): pp_schedule.step() # accumulate losses across pipeline microbatches + # TODO: PP+FSDP unexpectedly puts the loss back to the CPU loss = ( - torch.mean(torch.stack(losses)) + torch.mean(torch.stack(losses)).to(device) if is_last_stage - else torch.Tensor([-1.0]) + else torch.tensor([-1.0], device=device) ) else: # Non-PP forward / backward @@ -330,26 +330,23 @@ def loss_fn(pred, labels): # it issues a single all-reduce for all parameters at once for better performance float8_handler.precompute_float8_dynamic_scale_for_fsdp(model_parts) - losses_since_last_log.append(loss) - # log metrics if ( train_state.step == 1 or train_state.step % job_config.metrics.log_freq == 0 ): - losses = [loss.item() for loss in losses_since_last_log] - avg_loss, max_loss = sum(losses) / len(losses), max(losses) if ( parallel_dims.dp_replicate_enabled or parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled ): + loss = loss.detach() global_avg_loss, global_max_loss = ( - utils.dist_mean(avg_loss, world_mesh["dp_cp"]), - utils.dist_max(max_loss, world_mesh["dp_cp"]), + utils.dist_mean(loss, world_mesh["dp_cp"]), + utils.dist_max(loss, world_mesh["dp_cp"]), ) else: - global_avg_loss, global_max_loss = avg_loss, max_loss + global_avg_loss = global_max_loss = loss.item() # update train state train_state.log_steps.append(train_state.step) @@ -399,7 +396,6 @@ def loss_fn(pred, labels): f"{color.magenta}mfu: {mfu:.2f}%{color.reset}" ) - losses_since_last_log.clear() ntokens_since_last_log = 0 data_loading_times.clear() time_last_log = time.perf_counter()