diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py index d741fca8c2..8b6a80929d 100644 --- a/src/prime_rl/trainer/sft/train.py +++ b/src/prime_rl/trainer/sft/train.py @@ -172,6 +172,7 @@ def train(config: SFTConfig): cp_enabled = parallel_dims.cp_enabled cp_rank = parallel_dims.world_mesh["cp"].get_local_rank() if cp_enabled else 0 cp_group = parallel_dims.world_mesh["cp"].get_group() if cp_enabled else None + dp_cp_group = parallel_dims.get_mesh("dp_cp").get_group() cp_size = parallel_dims.cp ce_loss = None @@ -185,7 +186,8 @@ def train(config: SFTConfig): case _: raise ValueError(f"Invalid loss implementation: {config.loss_impl}") - def compute_loss(micro_batch: dict) -> torch.Tensor: + def compute_loss(micro_batch: dict) -> tuple[torch.Tensor, torch.Tensor]: + """Forward pass returning (loss_sum, token_count) over unmasked tokens.""" input_ids = micro_batch["input_ids"].to("cuda") position_ids = micro_batch["position_ids"].to("cuda") target_ids = micro_batch["target_ids"].to("cuda") @@ -196,69 +198,48 @@ def compute_loss(micro_batch: dict) -> torch.Tensor: target_ids = shard_for_cp(target_ids, cp_rank=cp_rank, cp_world_size=cp_size) loss_mask = shard_for_cp(loss_mask, cp_rank=cp_rank, cp_world_size=cp_size) + token_count = loss_mask.sum(dtype=torch.int64) + with maybe_activation_offloading(config.model.ac_offloading): if config.loss_impl == "liger_fused": masked_target_ids = target_ids.clone() masked_target_ids[~loss_mask] = FusedCrossEntropyOutputLinear.IGNORE_INDEX out = forward(model, input_ids, position_ids, labels=masked_target_ids) - loss = out["loss"] + loss_sum = out["loss"] * token_count else: out = forward(model, input_ids, position_ids) logits = out["logits"] B, L, V = logits.shape - loss = ce_loss(logits.view(-1, V), target_ids.view(-1)).view(B, L) - loss = loss[loss_mask].mean() + token_loss = ce_loss(logits.view(-1, V), target_ids.view(-1)).view(B, L) + loss_sum = token_loss[loss_mask].sum() del logits del out - return loss + return loss_sum, token_count maybe_record_function = nullcontext - def run_forward_loop(data_iter, num_steps=None, *, backward=True): - total_loss = torch.tensor(0.0, device="cuda") + def run_eval_loop(data_iter): + """Validation forward loop. Returns token-weighted global mean loss.""" + total_loss_sum = torch.tensor(0.0, device="cuda") + total_token_count = torch.tensor(0, dtype=torch.int64, device="cuda") nan_count = torch.tensor(0, device="cuda") - valid_steps = torch.tensor(0, device="cuda") - max_vio_total = torch.tensor(0.0, device="cuda") - divisor = num_steps or 1 - - ctx = nullcontext() if backward else torch.no_grad() - with ctx: - for step, micro_batch in enumerate(data_iter): - if backward and config.log.log_data: - input_ids_log = micro_batch["input_ids"].flatten().tolist() - loss_mask_log = micro_batch["loss_mask"].flatten().tolist() - print_sample(input_ids_log, loss_mask_log, tokenizer) - - with maybe_record_function("forward"): - loss = compute_loss(micro_batch) - - if not torch.isnan(loss.detach()): - total_loss += loss.detach() - valid_steps += 1 + + with torch.no_grad(): + for micro_batch in data_iter: + loss_sum, token_count = compute_loss(micro_batch) + if not torch.isnan(loss_sum.detach()): + total_loss_sum += loss_sum.detach() + total_token_count += token_count else: nan_count += 1 - if backward: - with maybe_record_function("backward"): - (loss / divisor).backward() - - if is_tt_moe_model(model): - max_vio = get_load_balance_stats(model)["max_vio"] - if max_vio is not None: - max_vio = max_vio.mean() - dist.all_reduce(max_vio, op=dist.ReduceOp.MAX) - max_vio_total += max_vio / divisor - - if num_steps is not None and step + 1 >= num_steps: - break - - dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) - dist.all_reduce(valid_steps, op=dist.ReduceOp.SUM) + dist.all_reduce(total_loss_sum, op=dist.ReduceOp.SUM, group=dp_cp_group) + dist.all_reduce(total_token_count, op=dist.ReduceOp.SUM, group=dp_cp_group) dist.all_reduce(nan_count, op=dist.ReduceOp.SUM) - mean_loss = (total_loss / valid_steps).item() if valid_steps.item() > 0 else float("nan") - return mean_loss, nan_count.item(), max_vio_total + mean_loss = (total_loss_sum / total_token_count).item() if total_token_count.item() > 0 else float("nan") + return mean_loss, nan_count.item() def run_validation(step: int) -> None: val_dataset = setup_dataset( @@ -267,7 +248,7 @@ def run_validation(step: int) -> None: val_dataloader = setup_dataloader(val_dataset, config.val.data) # No train/eval switch: no dropout in these models, and toggling would trigger torch.compile recompilation - mean_loss, nan_count, _ = run_forward_loop(val_dataloader, backward=False) + mean_loss, nan_count = run_eval_loop(val_dataloader) if nan_count > 0: logger.warning(f"Validation at step {step}: {nan_count} batches had NaN loss") if mean_loss != mean_loss: @@ -323,9 +304,56 @@ def run_validation(step: int) -> None: step_start_time = time.perf_counter() forward_backward_start_time = time.perf_counter() - batch_loss, nan_loss_count, batch_max_vio = run_forward_loop(dataiter, grad_accum_steps, backward=True) + step_loss_sum = torch.tensor(0.0, device="cuda") + step_local_token_count = torch.tensor(0, dtype=torch.int64, device="cuda") + nan_loss_count = torch.tensor(0, device="cuda") + batch_max_vio = torch.tensor(0.0, device="cuda") + for micro_step in range(grad_accum_steps): + micro_batch = next(dataiter) + + if config.log.log_data: + print_sample( + micro_batch["input_ids"].flatten().tolist(), micro_batch["loss_mask"].flatten().tolist(), tokenizer + ) + + with maybe_record_function("forward"): + local_loss_sum, local_token_count = compute_loss(micro_batch) + + step_local_token_count += local_token_count + + if torch.isnan(local_loss_sum.detach()): + nan_loss_count += 1 + logger.warning("Local loss is nan, excluding this micro step from backward") + scaled_loss = torch.nan_to_num(local_loss_sum, nan=0.0) / grad_accum_steps + else: + step_loss_sum += local_loss_sum.detach() + scaled_loss = local_loss_sum / grad_accum_steps + + with maybe_record_function("backward"): + scaled_loss.backward() + + if is_tt_moe_model(model): + max_vio = get_load_balance_stats(model)["max_vio"] + if max_vio is not None: + max_vio = max_vio.mean() + dist.all_reduce(max_vio, op=dist.ReduceOp.MAX) + batch_max_vio += max_vio / grad_accum_steps + forward_backward_time = time.perf_counter() - forward_backward_start_time + # All-reduce token counts and rescale gradients to get a global token-weighted mean. + # FSDP already divided grads by fsdp_gradient_divide_factor, so we undo that and + # divide by the true global token count instead. + global_step_token_count = step_local_token_count.clone() + dist.all_reduce(global_step_token_count, op=dist.ReduceOp.SUM, group=dp_cp_group) + global_token_count_val = global_step_token_count.item() + + if global_token_count_val > 0: + grad_scale = parallel_dims.fsdp_gradient_divide_factor * grad_accum_steps / global_token_count_val + for param in model.parameters(): + if param.grad is not None: + param.grad.mul_(grad_scale) + # Run validation after forward-backward (so torch.compile sees training graph first) but before # optimizer step (so eval_on_start evaluates untrained weights) if config.val is not None and ( @@ -334,6 +362,15 @@ def run_validation(step: int) -> None: ): run_validation(progress.step) + # Compute the global mean loss for logging. + dist.all_reduce(step_loss_sum, op=dist.ReduceOp.SUM, group=dp_cp_group) + dist.all_reduce(nan_loss_count, op=dist.ReduceOp.SUM) + if global_token_count_val > 0: + batch_loss = (step_loss_sum / global_token_count_val).item() + else: + batch_loss = 0.0 + nan_loss_count = nan_loss_count.item() + logger.debug(f"Clipping gradients with max norm {config.optim.max_norm}") grad_norm = clip_grad_norm_( model.parameters(), max_norm=config.optim.max_norm, ep_enabled=parallel_dims.ep_enabled