From b43629f81ab96e6682d03f8830de39488a87ef0b Mon Sep 17 00:00:00 2001 From: Mika Senghaas Date: Tue, 10 Mar 2026 22:12:53 +0000 Subject: [PATCH] fix: re-schedule empty trajectories instead of filtering after group completion Empty trajectories were filtered after sampling from the buffer, which could yield incomplete groups and break the advantage computation's assumption that each group has exactly `rollouts_per_example` rollouts. Now empty trajectories are detected per-rollout as they complete. When one is found, `rollouts_to_schedule` is incremented so the group naturally re-fills via `_fill_inflight_requests`, and the group is only yielded once it has the full count of non-empty rollouts. Co-Authored-By: Claude Opus 4.6 --- src/prime_rl/orchestrator/scheduler.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/prime_rl/orchestrator/scheduler.py b/src/prime_rl/orchestrator/scheduler.py index fbaefb1adc..a9a1657849 100644 --- a/src/prime_rl/orchestrator/scheduler.py +++ b/src/prime_rl/orchestrator/scheduler.py @@ -370,7 +370,16 @@ async def generate_batch(self, step: int) -> list[vf.RolloutOutput]: group = self.groups.get(group_id) if group is None: continue - group.completed_rollouts.append(finished_task.result()) + rollout = finished_task.result() + if len(rollout["trajectory"]) == 0: + self.empty_rollouts_count += 1 + group.rollouts_to_schedule += 1 + self.logger.warning( + f"Empty trajectory in group {group_id}, re-scheduling " + f"({len(group.completed_rollouts)}/{self.rollouts_per_example} complete)" + ) + continue + group.completed_rollouts.append(rollout) if len(group.completed_rollouts) < self.rollouts_per_example: continue completed_rollouts = self.groups.pop(group_id).completed_rollouts @@ -388,14 +397,6 @@ async def generate_batch(self, step: int) -> list[vf.RolloutOutput]: self.buffer.update(completed_rollouts) accepted_rollouts = self.buffer.sample_rollouts(n=self.rollouts_per_example) - empty_count = sum(1 for r in accepted_rollouts if len(r["trajectory"]) == 0) - if empty_count > 0: - self.logger.warning( - f"Filtered {empty_count}/{len(accepted_rollouts)} rollouts with empty trajectories" - ) - accepted_rollouts = [r for r in accepted_rollouts if len(r["trajectory"]) > 0] - self.empty_rollouts_count += empty_count - batch_rollouts.extend(accepted_rollouts) progress_increment = self.get_batch_progress_increment(accepted_rollouts) batch_progress += progress_increment