diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index 7919764b0c..82fd426122 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -602,6 +602,7 @@ def process_rollout(rollout: vf.RolloutOutput, rollout_idx: int) -> list[Trainin "reward": [rollout["reward"] for rollout in train_rollouts], "is_truncated": [rollout["is_truncated"] for rollout in train_rollouts], "error": [rollout["error"] for rollout in train_rollouts], + "stop_condition": [rollout.get("stop_condition") for rollout in train_rollouts], "seq_len": [get_seq_len(rollout) for rollout in train_rollouts], "prefill_len": rollout_prefill_lens, "decode_len": rollout_decode_lens, @@ -666,6 +667,14 @@ def process_rollout(rollout: vf.RolloutOutput, rollout_idx: int) -> list[Trainin "is_truncated/mean": results_df.groupby("example_id").is_truncated.mean().mean(), "is_truncated/max": results_df.groupby("example_id").is_truncated.mean().max(), "is_truncated/min": results_df.groupby("example_id").is_truncated.mean().min(), + "stop_condition/generation_truncated": ( + results_df.is_truncated & (results_df.stop_condition != "prompt_too_long") + ).mean(), + # Log rate of each stop condition (e.g. max_turns, prompt_too_long, has_error) + **{ + f"stop_condition/{sc}": rate + for sc, rate in results_df.stop_condition.dropna().value_counts(normalize=True).items() + }, # Seqs per rollout metrics "samples_per_rollout/mean": results_df.groupby("example_id").samples_per_rollout.mean().mean(), "samples_per_rollout/max": results_df.groupby("example_id").samples_per_rollout.mean().max(),