Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 82 additions & 45 deletions src/prime_rl/trainer/sft/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def train(config: SFTConfig):
cp_enabled = parallel_dims.cp_enabled
cp_rank = parallel_dims.world_mesh["cp"].get_local_rank() if cp_enabled else 0
cp_group = parallel_dims.world_mesh["cp"].get_group() if cp_enabled else None
dp_cp_group = parallel_dims.get_mesh("dp_cp").get_group()
cp_size = parallel_dims.cp

ce_loss = None
Expand All @@ -185,7 +186,8 @@ def train(config: SFTConfig):
case _:
raise ValueError(f"Invalid loss implementation: {config.loss_impl}")

def compute_loss(micro_batch: dict) -> torch.Tensor:
def compute_loss(micro_batch: dict) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward pass returning (loss_sum, token_count) over unmasked tokens."""
input_ids = micro_batch["input_ids"].to("cuda")
position_ids = micro_batch["position_ids"].to("cuda")
target_ids = micro_batch["target_ids"].to("cuda")
Expand All @@ -196,69 +198,48 @@ def compute_loss(micro_batch: dict) -> torch.Tensor:
target_ids = shard_for_cp(target_ids, cp_rank=cp_rank, cp_world_size=cp_size)
loss_mask = shard_for_cp(loss_mask, cp_rank=cp_rank, cp_world_size=cp_size)

token_count = loss_mask.sum(dtype=torch.int64)

with maybe_activation_offloading(config.model.ac_offloading):
if config.loss_impl == "liger_fused":
masked_target_ids = target_ids.clone()
masked_target_ids[~loss_mask] = FusedCrossEntropyOutputLinear.IGNORE_INDEX
out = forward(model, input_ids, position_ids, labels=masked_target_ids)
loss = out["loss"]
loss_sum = out["loss"] * token_count
else:
out = forward(model, input_ids, position_ids)
logits = out["logits"]
B, L, V = logits.shape
loss = ce_loss(logits.view(-1, V), target_ids.view(-1)).view(B, L)
loss = loss[loss_mask].mean()
token_loss = ce_loss(logits.view(-1, V), target_ids.view(-1)).view(B, L)
loss_sum = token_loss[loss_mask].sum()
del logits

del out
return loss
return loss_sum, token_count

maybe_record_function = nullcontext

def run_forward_loop(data_iter, num_steps=None, *, backward=True):
total_loss = torch.tensor(0.0, device="cuda")
def run_eval_loop(data_iter):
"""Validation forward loop. Returns token-weighted global mean loss."""
total_loss_sum = torch.tensor(0.0, device="cuda")
total_token_count = torch.tensor(0, dtype=torch.int64, device="cuda")
nan_count = torch.tensor(0, device="cuda")
valid_steps = torch.tensor(0, device="cuda")
max_vio_total = torch.tensor(0.0, device="cuda")
divisor = num_steps or 1

ctx = nullcontext() if backward else torch.no_grad()
with ctx:
for step, micro_batch in enumerate(data_iter):
if backward and config.log.log_data:
input_ids_log = micro_batch["input_ids"].flatten().tolist()
loss_mask_log = micro_batch["loss_mask"].flatten().tolist()
print_sample(input_ids_log, loss_mask_log, tokenizer)

with maybe_record_function("forward"):
loss = compute_loss(micro_batch)

if not torch.isnan(loss.detach()):
total_loss += loss.detach()
valid_steps += 1

with torch.no_grad():
for micro_batch in data_iter:
loss_sum, token_count = compute_loss(micro_batch)
if not torch.isnan(loss_sum.detach()):
total_loss_sum += loss_sum.detach()
total_token_count += token_count
else:
nan_count += 1

if backward:
with maybe_record_function("backward"):
(loss / divisor).backward()

if is_tt_moe_model(model):
max_vio = get_load_balance_stats(model)["max_vio"]
if max_vio is not None:
max_vio = max_vio.mean()
dist.all_reduce(max_vio, op=dist.ReduceOp.MAX)
max_vio_total += max_vio / divisor

if num_steps is not None and step + 1 >= num_steps:
break

dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
dist.all_reduce(valid_steps, op=dist.ReduceOp.SUM)
dist.all_reduce(total_loss_sum, op=dist.ReduceOp.SUM, group=dp_cp_group)
dist.all_reduce(total_token_count, op=dist.ReduceOp.SUM, group=dp_cp_group)
dist.all_reduce(nan_count, op=dist.ReduceOp.SUM)

mean_loss = (total_loss / valid_steps).item() if valid_steps.item() > 0 else float("nan")
return mean_loss, nan_count.item(), max_vio_total
mean_loss = (total_loss_sum / total_token_count).item() if total_token_count.item() > 0 else float("nan")
return mean_loss, nan_count.item()

def run_validation(step: int) -> None:
val_dataset = setup_dataset(
Expand All @@ -267,7 +248,7 @@ def run_validation(step: int) -> None:
val_dataloader = setup_dataloader(val_dataset, config.val.data)

# No train/eval switch: no dropout in these models, and toggling would trigger torch.compile recompilation
mean_loss, nan_count, _ = run_forward_loop(val_dataloader, backward=False)
mean_loss, nan_count = run_eval_loop(val_dataloader)
if nan_count > 0:
logger.warning(f"Validation at step {step}: {nan_count} batches had NaN loss")
if mean_loss != mean_loss:
Expand Down Expand Up @@ -323,9 +304,56 @@ def run_validation(step: int) -> None:
step_start_time = time.perf_counter()
forward_backward_start_time = time.perf_counter()

batch_loss, nan_loss_count, batch_max_vio = run_forward_loop(dataiter, grad_accum_steps, backward=True)
step_loss_sum = torch.tensor(0.0, device="cuda")
step_local_token_count = torch.tensor(0, dtype=torch.int64, device="cuda")
nan_loss_count = torch.tensor(0, device="cuda")
batch_max_vio = torch.tensor(0.0, device="cuda")
for micro_step in range(grad_accum_steps):
micro_batch = next(dataiter)

if config.log.log_data:
print_sample(
micro_batch["input_ids"].flatten().tolist(), micro_batch["loss_mask"].flatten().tolist(), tokenizer
)

with maybe_record_function("forward"):
local_loss_sum, local_token_count = compute_loss(micro_batch)

step_local_token_count += local_token_count
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NaN micro-step tokens inflate gradient scaling denominator

Medium Severity

step_local_token_count += local_token_count runs unconditionally before the NaN check, so tokens from NaN micro-steps are included in global_step_token_count. Since NaN micro-steps contribute zero gradients (via nan_to_num), the inflated denominator in grad_scale under-scales valid gradients. The eval loop (run_eval_loop) correctly only adds token_count inside the non-NaN branch — the training loop is inconsistent with that.

Additional Locations (1)
Fix in Cursor Fix in Web


if torch.isnan(local_loss_sum.detach()):
nan_loss_count += 1
logger.warning("Local loss is nan, excluding this micro step from backward")
scaled_loss = torch.nan_to_num(local_loss_sum, nan=0.0) / grad_accum_steps
else:
step_loss_sum += local_loss_sum.detach()
scaled_loss = local_loss_sum / grad_accum_steps

with maybe_record_function("backward"):
scaled_loss.backward()

if is_tt_moe_model(model):
max_vio = get_load_balance_stats(model)["max_vio"]
if max_vio is not None:
max_vio = max_vio.mean()
dist.all_reduce(max_vio, op=dist.ReduceOp.MAX)
batch_max_vio += max_vio / grad_accum_steps

forward_backward_time = time.perf_counter() - forward_backward_start_time

# All-reduce token counts and rescale gradients to get a global token-weighted mean.
# FSDP already divided grads by fsdp_gradient_divide_factor, so we undo that and
# divide by the true global token count instead.
global_step_token_count = step_local_token_count.clone()
dist.all_reduce(global_step_token_count, op=dist.ReduceOp.SUM, group=dp_cp_group)
global_token_count_val = global_step_token_count.item()

if global_token_count_val > 0:
grad_scale = parallel_dims.fsdp_gradient_divide_factor * grad_accum_steps / global_token_count_val
for param in model.parameters():
if param.grad is not None:
param.grad.mul_(grad_scale)

# Run validation after forward-backward (so torch.compile sees training graph first) but before
# optimizer step (so eval_on_start evaluates untrained weights)
if config.val is not None and (
Expand All @@ -334,6 +362,15 @@ def run_validation(step: int) -> None:
):
run_validation(progress.step)

# Compute the global mean loss for logging.
dist.all_reduce(step_loss_sum, op=dist.ReduceOp.SUM, group=dp_cp_group)
dist.all_reduce(nan_loss_count, op=dist.ReduceOp.SUM)
if global_token_count_val > 0:
batch_loss = (step_loss_sum / global_token_count_val).item()
else:
batch_loss = 0.0
nan_loss_count = nan_loss_count.item()

logger.debug(f"Clipping gradients with max norm {config.optim.max_norm}")
grad_norm = clip_grad_norm_(
model.parameters(), max_norm=config.optim.max_norm, ep_enabled=parallel_dims.ep_enabled
Expand Down