Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

Documenting **breaking** configuration changes — renamed, removed, or moved fields that require users to update existing configs.

- **`trainer.loss` (default loss)**: Dropped the Kimi-K2.5 KL term and removed the advantage-conditioned mask from the default DPPO-Binary TV loss. Tokens are now dropped symmetrically when `π_train - π_infer` falls outside `[-dppo_diff_low, dppo_diff_high]` regardless of advantage sign. Removed `kl_tau`. Renamed `dppo_mask_low` → `dppo_diff_low` (default `0.2`) and `dppo_mask_high` → `dppo_diff_high` (default `0.2`). `adv_tau` and `teacher_tau` are unchanged. (2026-05-06)
- **`orchestrator.advantage.length_penalty` → discriminated sub-config**: The scalar `length_penalty: Literal["tokens","turns"] | None` is replaced by a `LengthPenaltyConfig | None` discriminated on `type`. Token shaping now takes weighted completion + tool-response token costs. Migration: `length_penalty = "tokens"` becomes `[orchestrator.advantage.length_penalty]\ntype = "tokens"` (default weights `completion_weight = 1.0`, `tool_response_weight = 1.0` — total context). `length_penalty = "turns"` becomes `[orchestrator.advantage.length_penalty]\ntype = "turns"`. (2026-05-06)
- **`orchestrator.advantage.length_shaping` → `orchestrator.advantage.length_penalty`**: The boolean `length_shaping` flag has been replaced by `length_penalty: Literal["tokens", "turns"] | None` (default: `None`). `length_shaping = true` becomes `length_penalty = "tokens"`; `length_shaping = false` becomes `length_penalty = None`. The new `"turns"` option applies the same correctness-gated efficiency shaping using trajectory turn count instead of completion-token count. (2026-05-01)
- **`AdvantageInputs` API**: Replaced the `rewards`/`completion_lengths`/`num_turns` tensor fields with a single `rollouts: list[list[vf.RolloutOutput]]` (grouped by problem). Custom advantage functions can now access any rollout metadata. Existing custom advantages must update their signatures and extract per-rollout fields directly (e.g. `torch.tensor([[r["reward"] for r in g] for g in inputs.rollouts])`). (2026-05-01)
Expand Down
12 changes: 12 additions & 0 deletions packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,14 @@ class OrchestratorConfig(BaseConfig):
),
] = 8

max_error_reschedule_attempts: Annotated[
int | None,
Field(
ge=1,
description="Maximum number of times the scheduler will reschedule a group whose rollouts errored or returned empty trajectories. After this many consecutive failed attempts, the group is dropped from the current step's batch (the trainer proceeds with the rollouts from other groups). `None` means retry indefinitely (legacy behavior). Useful for unblocking single-example hangs in agent envs.",
),
] = None

max_async_level: Annotated[
int,
Field(
Expand Down Expand Up @@ -1261,6 +1269,10 @@ def validate_renderer_args(self):
renderer_args_set.append(f"renderer.reasoning_parser={self.renderer.reasoning_parser!r}")
if self.renderer.pool_size is not None:
renderer_args_set.append(f"renderer.pool_size={self.renderer.pool_size!r}")
if self.renderer.preserve_all_thinking:
renderer_args_set.append("renderer.preserve_all_thinking=True")
if self.renderer.preserve_thinking_between_tool_calls:
renderer_args_set.append("renderer.preserve_thinking_between_tool_calls=True")

if renderer_args_set:
raise ValueError(
Expand Down
28 changes: 28 additions & 0 deletions packages/prime-rl-configs/src/prime_rl/configs/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,34 @@ class RendererConfig(BaseConfig):
),
] = None

preserve_all_thinking: Annotated[
bool,
Field(
description=(
"If True, the renderer keeps `<think>...</think>` blocks for ALL "
"assistant turns when re-rendering the prompt for the next request. "
"By default the GLM/Qwen renderers strip thinking from older turns "
"(only the most recent assistant keeps it). Stripping breaks the "
"trajectory-step prefix property at every turn-rebuild — re-rendered "
"tokens no longer extend the streamed tokens, so the splitter opens "
"extra training samples (e.g. 1 + 2*compactions instead of 1 + "
"compactions). Turn this on for RL with multi-turn agents that "
"compact context."
),
),
] = False

preserve_thinking_between_tool_calls: Annotated[
bool,
Field(
description=(
"Narrower variant of preserve_all_thinking: keep thinking only "
"between consecutive tool-call assistant turns. See "
"renderers.Renderer.render for exact semantics."
),
),
] = False


class ElasticConfig(BaseConfig):
"""Configures elastic inference pool with DNS-based service discovery.
Expand Down
13 changes: 9 additions & 4 deletions packages/prime-rl-configs/src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,15 +682,20 @@ class CheckpointConfig(BaseConfig):


class DefaultLossConfig(BaseModel):
"""Config for the default loss."""
"""Config for the default loss (DPPO-Binary TV / arXiv:2602.04879)."""

type: Literal["default"] = "default"

dppo_mask_low: Annotated[float, Field(ge=0, description="The low threshold for masking tokens.")] = 0.2
dppo_mask_high: Annotated[float, Field(ge=0, description="The high threshold for masking tokens.")] = 0.2
dppo_diff_low: Annotated[
float,
Field(ge=0, description="Lower bound on (π_train - π_infer). Tokens below -dppo_diff_low are dropped."),
] = 0.2
dppo_diff_high: Annotated[
float,
Field(ge=0, description="Upper bound on (π_train - π_infer). Tokens above dppo_diff_high are dropped."),
] = 0.2
adv_tau: Annotated[float, Field(ge=0, description="The tau for advantages.")] = 1.0
teacher_tau: Annotated[float, Field(ge=0, description="The tau for teacher logprobs.")] = 0.0
kl_tau: Annotated[float, Field(ge=0, description="The tau for KL divergence.")] = 1e-3


class SFTLossConfig(BaseModel):
Expand Down
4 changes: 4 additions & 0 deletions src/prime_rl/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,8 @@ async def setup_rollout_inference_pool(
renderer=config.renderer.name,
tool_parser=config.renderer.tool_parser,
reasoning_parser=config.renderer.reasoning_parser,
preserve_all_thinking=config.renderer.preserve_all_thinking,
preserve_thinking_between_tool_calls=config.renderer.preserve_thinking_between_tool_calls,
)
logger.info(f"Initialized {type(renderer).__name__} for {config.model.name}")
inference_pool = await setup_inference_pool(
Expand All @@ -937,6 +939,8 @@ async def setup_rollout_inference_pool(
tool_parser=config.renderer.tool_parser,
reasoning_parser=config.renderer.reasoning_parser,
renderer_pool_size=config.renderer.pool_size,
preserve_all_thinking=config.renderer.preserve_all_thinking,
preserve_thinking_between_tool_calls=config.renderer.preserve_thinking_between_tool_calls,
)
logger.info("Using direct renderer rollout client")
return renderer, inference_pool
Expand Down
23 changes: 23 additions & 0 deletions src/prime_rl/orchestrator/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ class GroupState:
rollouts_to_schedule: int
completed_rollouts: list[vf.RolloutOutput] = field(default_factory=list)
pinned_client: vf.ClientConfig | None = None
# Number of rollout attempts in this group that returned errored or empty
# trajectories. Compared against config.max_error_reschedule_attempts to
# decide when to drop a permanently-stuck group.
failed_attempts: int = 0


class Scheduler:
Expand Down Expand Up @@ -114,6 +118,7 @@ def __init__(
self.empty_rollouts_by_env: dict[str, int] = defaultdict(int)
self.errored_rollouts_by_env: dict[str, int] = defaultdict(int)
self.total_rollouts_by_env: dict[str, int] = defaultdict(int)
self.dropped_groups_by_env: dict[str, int] = defaultdict(int)
self.last_batch_generation_time = 0.0

@property
Expand Down Expand Up @@ -454,6 +459,24 @@ async def generate_batch(self, step: int) -> list[vf.RolloutOutput]:
rollout["env_name"] = env_name
valid_rollouts.append(rollout)

if has_failures:
group.failed_attempts += 1
max_attempts = self.config.max_error_reschedule_attempts
if max_attempts is not None and group.failed_attempts >= max_attempts:
# Permanently-stuck group: drop it from this step and let the
# rest of the batch proceed. Avoids a single bad example (e.g.
# an agent rollout whose sandbox poll keeps timing out)
# blocking step progress forever.
self.dropped_groups_by_env[env_name] += 1
self.logger.warning(
f"Dropping group {group_id} ({env_name}) after {group.failed_attempts} "
f"failed attempts ({len(group.completed_rollouts)}/{self.rollouts_per_example} "
f"complete). Set orchestrator.max_error_reschedule_attempts higher (or to None) "
f"to retry more aggressively."
)
await self.drop_group(group_id)
continue

if has_failures and env.requires_group_scoring:
# Group scoring requires all rollouts — discard partial results, reschedule full group
group.completed_rollouts.clear()
Expand Down
40 changes: 23 additions & 17 deletions src/prime_rl/trainer/models/kernels/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,17 +139,18 @@ def _unpack_grouped_rows_kernel(
actual_m = tl.load(actual_ms_ptr + pid_g)
row_offsets = (pid_blk - block_start) * GROUP_BLOCK_M + pid_sub * BLOCK_M + tl.arange(0, BLOCK_M)
col_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
src_rows_i64 = (pid_blk * GROUP_BLOCK_M + pid_sub * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64)
dst_rows_i64 = (dst_start + row_offsets).to(tl.int64)
col_offsets_i64 = col_offsets.to(tl.int64)
valid_rows = row_offsets < actual_m
valid_cols = col_offsets < cols
x = tl.load(
x_ptr
+ (pid_blk * GROUP_BLOCK_M + pid_sub * BLOCK_M + tl.arange(0, BLOCK_M))[:, None] * stride_xm
+ col_offsets[None, :] * stride_xn,
x_ptr + src_rows_i64[:, None] * stride_xm + col_offsets_i64[None, :] * stride_xn,
mask=valid_rows[:, None] & valid_cols[None, :],
other=0.0,
)
tl.store(
out_ptr + (dst_start + row_offsets)[:, None] * stride_ym + col_offsets[None, :] * stride_yn,
out_ptr + dst_rows_i64[:, None] * stride_ym + col_offsets_i64[None, :] * stride_yn,
x,
mask=valid_rows[:, None] & valid_cols[None, :],
)
Expand All @@ -176,9 +177,11 @@ def _per_token_fp8_kernel(
pid_k = tl.program_id(axis=1)
row_offsets = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
col_offsets = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
row_offsets_i64 = row_offsets.to(tl.int64)
col_offsets_i64 = col_offsets.to(tl.int64)
mask = (row_offsets[:, None] < rows) & (col_offsets[None, :] < cols)
x = tl.load(
x_ptr + row_offsets[:, None] * stride_xm + col_offsets[None, :] * stride_xn,
x_ptr + row_offsets_i64[:, None] * stride_xm + col_offsets_i64[None, :] * stride_xn,
mask=mask,
other=0.0,
).to(tl.float32)
Expand All @@ -188,11 +191,11 @@ def _per_token_fp8_kernel(
scale = tl.exp2(tl.ceil(tl.log2(scale)))
y = x / scale[:, None]
tl.store(
out_ptr + row_offsets[:, None] * stride_ym + col_offsets[None, :] * stride_yn,
out_ptr + row_offsets_i64[:, None] * stride_ym + col_offsets_i64[None, :] * stride_yn,
y.to(tl.float8e4nv),
mask=mask,
)
tl.store(sf_ptr + row_offsets * stride_sm + pid_k * stride_sk, scale, mask=row_offsets < rows)
tl.store(sf_ptr + row_offsets_i64 * stride_sm + pid_k * stride_sk, scale, mask=row_offsets < rows)


@triton.jit
Expand Down Expand Up @@ -226,12 +229,13 @@ def _grouped_per_token_fp8_kernel(
local_block = pid_blk - block_start
row_offsets = local_block * GROUP_BLOCK_M + pid_sub * BLOCK_M + tl.arange(0, BLOCK_M)
col_offsets = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
src_rows_i64 = (src_start + row_offsets).to(tl.int64)
dst_rows_i64 = (pid_blk * GROUP_BLOCK_M + pid_sub * BLOCK_M + tl.arange(0, BLOCK_M)).to(tl.int64)
col_offsets_i64 = col_offsets.to(tl.int64)
valid_rows = row_offsets < actual_m
valid_cols = col_offsets < cols
src_rows = src_start + row_offsets
dst_rows = pid_blk * GROUP_BLOCK_M + pid_sub * BLOCK_M + tl.arange(0, BLOCK_M)
x = tl.load(
x_ptr + src_rows[:, None] * stride_xm + col_offsets[None, :] * stride_xn,
x_ptr + src_rows_i64[:, None] * stride_xm + col_offsets_i64[None, :] * stride_xn,
mask=valid_rows[:, None] & valid_cols[None, :],
other=0.0,
).to(tl.float32)
Expand All @@ -241,11 +245,11 @@ def _grouped_per_token_fp8_kernel(
scale = tl.exp2(tl.ceil(tl.log2(scale)))
y = x / scale[:, None]
tl.store(
out_ptr + dst_rows[:, None] * stride_ym + col_offsets[None, :] * stride_yn,
out_ptr + dst_rows_i64[:, None] * stride_ym + col_offsets_i64[None, :] * stride_yn,
y.to(tl.float8e4nv),
mask=valid_rows[:, None] & valid_cols[None, :],
)
tl.store(sf_ptr + dst_rows * stride_sm + pid_k * stride_sk, scale, mask=valid_rows)
tl.store(sf_ptr + dst_rows_i64 * stride_sm + pid_k * stride_sk, scale, mask=valid_rows)


@triton.jit
Expand Down Expand Up @@ -277,11 +281,13 @@ def _grouped_per_channel_fp8_sm90_kmajor_kernel(
local_block = pid_blk - block_start
row_offsets = local_block * BLOCK_K + tl.arange(0, BLOCK_K)
col_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
src_rows_i64 = (src_start + row_offsets).to(tl.int64)
row_offsets_i64 = row_offsets.to(tl.int64)
col_offsets_i64 = col_offsets.to(tl.int64)
valid_rows = row_offsets < actual_m
valid_cols = col_offsets < cols
src_rows = src_start + row_offsets
x = tl.load(
x_ptr + src_rows[:, None] * stride_xm + col_offsets[None, :] * stride_xn,
x_ptr + src_rows_i64[:, None] * stride_xm + col_offsets_i64[None, :] * stride_xn,
mask=valid_rows[:, None] & valid_cols[None, :],
other=0.0,
).to(tl.float32)
Expand All @@ -290,11 +296,11 @@ def _grouped_per_channel_fp8_sm90_kmajor_kernel(
if USE_UE8M0:
scale = tl.exp2(tl.ceil(tl.log2(scale)))
y = x / scale[None, :]
flat_base = block_start * BLOCK_K * cols
out_ptrs = out_ptr + flat_base + col_offsets[:, None] * aligned_m + row_offsets[None, :]
flat_base = block_start.to(tl.int64) * BLOCK_K * cols
out_ptrs = out_ptr + flat_base + col_offsets_i64[:, None] * aligned_m + row_offsets_i64[None, :]
tl.store(out_ptrs, tl.trans(y).to(tl.float8e4nv), mask=valid_cols[:, None] & (row_offsets[None, :] < aligned_m))
tl.store(
sf_ptr + pid_blk * stride_sf0 + col_offsets * stride_sf1,
sf_ptr + pid_blk * stride_sf0 + col_offsets_i64 * stride_sf1,
scale,
mask=valid_cols,
)
Expand Down
37 changes: 14 additions & 23 deletions src/prime_rl/trainer/rl/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,38 +106,30 @@ def _safe_mean(values: Tensor, mask: Tensor) -> Tensor:

def default_loss_fn(inputs: LossInputs, loss_config: DefaultLossConfig) -> LossOutputs:
"""
DPPO+KL loss, combining:
- DPPO-Binary TV Loss (https://arxiv.org/pdf/2602.04879)
- Kimi-K2.5 KL Loss (https://arxiv.org/pdf/2602.02276)

The mask is conditioned on the advantage sign: for positive advantages,
we mask tokens whose probability increased too much (trust region violation
in the upweight direction); for negative advantages, we mask tokens whose
probability decreased too much (trust region violation in the downweight
direction).
DPPO-Binary TV loss (https://arxiv.org/pdf/2602.04879), symmetric variant.

Token-level masked importance sampling: tokens whose probability
difference π_train - π_infer falls outside [-dppo_diff_low, dppo_diff_high]
are dropped (gradient set to 0), not clipped. The mask is symmetric
(not advantage-conditioned). No KL penalty — the double-sided
difference mask is what keeps the update inside the trust region.
"""
trainer_logprobs = inputs.trainer_logprobs
inference_logprobs = inputs.inference_logprobs
teacher_logprobs = inputs.teacher_logprobs
advantages = inputs.advantages
loss_mask = inputs.loss_mask

trainer_probs = torch.exp(trainer_logprobs)
inference_probs = torch.exp(inference_logprobs)
probs_diff = trainer_probs - inference_probs
dppo_invalid_mask_high = probs_diff > loss_config.dppo_mask_high
dppo_invalid_mask_low = probs_diff < -loss_config.dppo_mask_low
dppo_invalid_mask = torch.where(advantages > 0, dppo_invalid_mask_high, dppo_invalid_mask_low)

is_masked = dppo_invalid_mask
is_masked_high = (advantages > 0) & dppo_invalid_mask_high
is_masked_low = (advantages < 0) & dppo_invalid_mask_low
keep_mask = loss_mask & ~is_masked

log_importance_ratio = trainer_logprobs - inference_logprobs
importance_ratio = torch.exp(log_importance_ratio)
mismatch_kl = importance_ratio - log_importance_ratio - 1

probs_diff = (torch.exp(trainer_logprobs) - torch.exp(inference_logprobs)).detach()
is_masked_low = probs_diff < -loss_config.dppo_diff_low
is_masked_high = probs_diff > loss_config.dppo_diff_high
is_masked = is_masked_low | is_masked_high
keep_mask = loss_mask & ~is_masked

advantages = loss_config.adv_tau * advantages
if teacher_logprobs is not None:
teacher_kl = teacher_logprobs - trainer_logprobs
Expand All @@ -146,8 +138,7 @@ def default_loss_fn(inputs: LossInputs, loss_config: DefaultLossConfig) -> LossO
teacher_kl = None

pg_loss = keep_mask * advantages * importance_ratio
kl_loss = loss_mask * log_importance_ratio**2
loss = (-pg_loss + loss_config.kl_tau * kl_loss).sum()
loss = (-pg_loss).sum()

metrics = {
"mismatch_kl": _safe_mean(mismatch_kl, loss_mask), # all trainable tokens
Expand Down
Loading
Loading