Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,16 @@ $$

This constant is recommended to be the maximum completion length. To use this formulation, set `loss_type="dr_grpo"` in the [`GRPOConfig`].

### CISPO: Truncated importance-sampling REINFORCE

The ScaleRL paper[^scalerl] introduces CISPO, a variant of truncated importance-sampling REINFORCE that keeps the prompt-level normalization from DAPO while replacing the PPO-style min operator with a stop-gradient truncation of the importance ratios:

$$
\mathcal{L}_{\text{CISPO}}(\theta) = - \frac{1}{T_G} \sum_{i=1}^{G} \sum_{t=1}^{|o_i|} \operatorname{sg}\!\left(\min(\rho_{i,t}, \epsilon_{\max})\right) \, \hat{A}_i \log \pi_\theta(o_{i,t} \mid q, o_{i, < t}) \,,
$$

where \( \rho_{i,t} = \tfrac{\pi_\theta(o_{i,t} \mid q, o_{i,<t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,<t})} \), \(T_G = \sum_i |o_i|\), and `sg` denotes the stop-gradient operator. Setting `loss_type="cispo"` in [`GRPOConfig`] enables this objective. The truncation threshold can be controlled through `cispo_clip_max` (default `5.0`), which corresponds to \( \epsilon_{\max} \) in the equation above.

## Logged metrics

While training and evaluating, we record the following reward metrics:
Expand Down Expand Up @@ -174,6 +184,11 @@ While training and evaluating, we record the following reward metrics:
- `clip_ratio/low_min`: The minimum ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\)
- `clip_ratio/high_mean`: The average ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\)
- `clip_ratio/high_max`: The maximum ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\).
- `cispo/importance_ratio/mean`: (Only when `loss_type="cispo"`.) Average importance ratio \( \rho_{i,t} \) before truncation.
- `cispo/importance_ratio/truncated_mean`: (Only when `loss_type="cispo"`.) Average truncated ratio \( \min(\rho_{i,t}, \epsilon_{\max}) \).
- `cispo/importance_ratio/max`: (Only when `loss_type="cispo"`.) Maximum observed importance ratio \( \rho_{i,t} \) in the batch.
- `cispo/importance_ratio/max_truncated`: (Only when `loss_type="cispo"`.) Maximum truncated ratio after applying \( \epsilon_{\max} \).
- `cispo/clip_fraction`: (Only when `loss_type="cispo"`.) Fraction of tokens whose importance ratio exceeded \( \epsilon_{\max} \).

## Customization

Expand All @@ -185,6 +200,8 @@ Generation is often the main bottleneck when training with online methods. To ac
pip install trl[vllm]
```

[^scalerl]: Yao et al., *ScaleRL: Scaling RL Compute Effectively and Predictably*, 2025.

We support two ways of using vLLM during training: **server mode** and **colocate mode**.

> [!TIP]
Expand Down
24 changes: 10 additions & 14 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,22 +538,18 @@ class GRPOConfig(TrainingArguments):
loss_type: str = field(
default="dapo",
metadata={
"help": "Specifies the loss formulation to use. Supported values are 'grpo', 'dapo', 'bnpo', and "
"'dr_grpo'. "
"'grpo': Aggregates token-level losses by normalizing over sequence length. Not recommended due to length "
"bias—this approach tends to prefer shorter completions with positive advantages and longer ones with "
"negative advantages. "
"'dapo' (default): Aggregates token-level losses by normalizing with the number of active token in the "
"global accumulated batch. This method was introduced in the DAPO paper to eliminate length bias. "
"'dr_grpo': Aggregates token-level losses by normalizing with a global constant. This method was "
"introduced in the Dr. GRPO paper to eliminate length bias. The value of the constant corresponds to "
"`max_completion_length`. "
"'bnpo': Aggregates token-level losses by normalizing with the number of active token in the local batch. "
"Note that normalization is performed over the local batch only, so results may slightly vary depending "
"on the local batch size, despite a constant effective batch size. When using "
"`per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss."
"help": "Specifies the loss formulation to use. Supported values are 'grpo', 'dapo', 'bnpo', 'dr_grpo', and 'cispo'. 'grpo': Aggregates token-level losses by normalizing over sequence length. Not recommended due to length bias-this approach tends to prefer shorter completions with positive advantages and longer ones with negative advantages. 'dapo' (default): Aggregates token-level losses by normalizing with the number of active token in the global accumulated batch. This method was introduced in the DAPO paper to eliminate length bias. 'dr_grpo': Aggregates token-level losses by normalizing with a global constant. This method was introduced in the Dr. GRPO paper to eliminate length bias. The value of the constant corresponds to `max_completion_length`. 'bnpo': Aggregates token-level losses by normalizing with the number of active token in the local batch. Note that normalization is performed over the local batch only, so results may slightly vary depending on the local batch size, despite a constant effective batch size. When using `per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss. 'cispo': Uses the truncated importance-sampling REINFORCE loss introduced in the ScaleRL paper (Eq. 4), truncating importance ratios at `cispo_clip_max` with gradients stopped through the truncation.",
"choices": ["grpo", "dapo", "bnpo", "dr_grpo", "cispo"],
},
)

cispo_clip_max: float = field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend re-using the epsilon-high config instead of creating a new one here.

default=5.0,
metadata={
"help": "Upper truncation epsilon_max applied to the importance sampling ratio for the CISPO loss. Weights are set to min(rho, epsilon_max) with gradients stopped through the truncation, following ScaleRL Eq. 4.",
},
)

mask_truncated_completions: bool = field(
default=False,
metadata={
Expand Down
72 changes: 62 additions & 10 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,9 @@ def __init__(
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
self.epsilon_low = args.epsilon
self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon
self.cispo_clip_max = args.cispo_clip_max
if self.loss_type == "cispo" and self.cispo_clip_max <= 0:
raise ValueError("`cispo_clip_max` must be a positive float when using the CISPO loss.")
# Tracks the number of iterations (forward + backward passes), including those within a grad accum cycle
self._step = 0
# Buffer the batch to reuse generated outputs across multiple updates. For more details, see
Expand Down Expand Up @@ -445,6 +448,8 @@ def __init__(

# Liger loss
if self.use_liger_loss:
if self.loss_type == "cispo":
raise NotImplementedError("Liger kernels do not currently support the CISPO loss.")
if not is_liger_kernel_available():
raise ImportError(
"Liger is required to use `liger_loss` as the GRPO loss. Run `pip install liger-kernel`."
Expand Down Expand Up @@ -1705,19 +1710,28 @@ def _compute_loss(self, model, inputs):
f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' "
"and 'sequence'."
)
# From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
# importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)

# From here, log_importance_weights (and all subsequent tensors) shape depends on the importance sampling
# level: "token" level -> (B, T); "sequence" level -> (B, 1)
coef_1 = torch.exp(log_importance_weights)
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
cispo_truncated_weights = None
cispo_clipped_mask = None

if self.loss_type == "cispo":
cispo_cap = torch.full_like(coef_1, self.cispo_clip_max)
cispo_truncated_weights = torch.minimum(coef_1, cispo_cap)
cispo_clipped_mask = coef_1 > cispo_cap
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be simplified by:

cispo_weights = torch.where(coef_1 < self.epsilon_high, coef_1, self.epsilon_high).detach()

cispo_weights = cispo_truncated_weights.detach()
per_token_loss = -cispo_weights * advantages.unsqueeze(1) * per_token_logps
else:
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)

# Two-sided clipping
if self.args.delta is not None:
coef_1 = torch.clamp(coef_1, max=self.args.delta)
# Two-sided clipping
if self.args.delta is not None:
coef_1 = torch.clamp(coef_1, max=self.args.delta)

per_token_loss1 = coef_1 * advantages.unsqueeze(1)
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
if entropy_mask is not None:
per_token_loss = per_token_loss * entropy_mask

Expand All @@ -1739,6 +1753,9 @@ def _compute_loss(self, model, inputs):
elif self.loss_type == "dapo":
normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes
loss = (per_token_loss * completion_mask).sum() / normalizer
elif self.loss_type == "cispo":
normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes
loss = (per_token_loss * completion_mask).sum() / normalizer
else:
raise ValueError(f"Unknown loss type: {self.loss_type}")

Expand All @@ -1760,6 +1777,41 @@ def masked_batch_mean(x):
mean_entropy = masked_batch_mean(entropies)
self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item())

if self.loss_type == "cispo":
truncated_mean = masked_batch_mean(cispo_truncated_weights)
ratio_mean = masked_batch_mean(coef_1)
clip_fraction = masked_batch_mean(cispo_clipped_mask.float())

gathered_truncated_mean = self.accelerator.gather(truncated_mean)
gathered_ratio_mean = self.accelerator.gather(ratio_mean)
gathered_clip_fraction = self.accelerator.gather(clip_fraction)

self._metrics[mode]["cispo/importance_ratio/truncated_mean"].append(
gathered_truncated_mean.nanmean().item()
)
self._metrics[mode]["cispo/importance_ratio/mean"].append(gathered_ratio_mean.nanmean().item())
self._metrics[mode]["cispo/clip_fraction"].append(gathered_clip_fraction.nanmean().item())

if cispo_truncated_weights.shape[1] == 1:
flat_original = coef_1.squeeze(1)
flat_truncated = cispo_truncated_weights.squeeze(1)
else:
mask = completion_mask.bool()
flat_original = coef_1.masked_select(mask)
flat_truncated = cispo_truncated_weights.masked_select(mask)

max_ratio = flat_original.max() if flat_original.numel() > 0 else torch.tensor(0.0, device=coef_1.device)
max_truncated = (
flat_truncated.max() if flat_truncated.numel() > 0 else torch.tensor(0.0, device=coef_1.device)
)
self._metrics[mode]["cispo/importance_ratio/max"].append(
nanmax(self.accelerator.gather(max_ratio)).item()
)
self._metrics[mode]["cispo/importance_ratio/max_truncated"].append(
nanmax(self.accelerator.gather(max_truncated)).item()
)
return loss

# Compute the clipped probability ratios
is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)
is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)
Expand Down