diff --git a/src/prime_rl/orchestrator/scheduler.py b/src/prime_rl/orchestrator/scheduler.py index a8427b69f7..0bb032cef6 100644 --- a/src/prime_rl/orchestrator/scheduler.py +++ b/src/prime_rl/orchestrator/scheduler.py @@ -22,6 +22,28 @@ wait_for_path, ) +FAILURE_REASON_AGENT_NONZERO_EXIT = "agent_nonzero_exit" +FAILURE_REASON_AGENT_POLL_FAILED = "agent_poll_failed" +FAILURE_REASON_AGENT_EMPTY_TRAJECTORY = "agent_empty_trajectory" +FAILURE_REASON_ROLLOUT_TIMEOUT = "rollout_timeout" +FAILURE_REASON_SANDBOX_OOM = "sandbox_oom" +FAILURE_REASON_SANDBOX_TIMEOUT = "sandbox_timeout" +FAILURE_REASON_SANDBOX_COMMAND_FAILED = "sandbox_command_failed" +FAILURE_REASON_SANDBOX_SETUP_FAILED = "sandbox_setup_failed" +FAILURE_REASON_TUNNEL_ERROR = "tunnel_error" +FAILURE_REASON_STREAM_INTERRUPTED = "stream_interrupted" +FAILURE_REASON_MODEL_ERROR = "model_error" +FAILURE_REASON_ENV_SERVER_ERROR = "env_server_error" +FAILURE_REASON_UNKNOWN = "unknown" + +FAILURE_ORIGIN_AGENT = "agent" +FAILURE_ORIGIN_SANDBOX = "sandbox" +FAILURE_ORIGIN_TUNNEL = "tunnel" +FAILURE_ORIGIN_MODEL = "model" +FAILURE_ORIGIN_ENV_SERVER = "env_server" +FAILURE_ORIGIN_ROLLOUT = "rollout" +FAILURE_ORIGIN_UNKNOWN = "unknown" + @dataclass class InflightRequest: @@ -113,6 +135,9 @@ def __init__( self.cancelled_rollouts_count = 0 self.empty_rollouts_by_env: dict[str, int] = defaultdict(int) self.errored_rollouts_by_env: dict[str, int] = defaultdict(int) + self.rejected_rollouts_by_env: dict[str, int] = defaultdict(int) + self.error_reasons_by_env: dict[str, Counter[str]] = defaultdict(Counter) + self.error_origins_by_env: dict[str, Counter[str]] = defaultdict(Counter) self.total_rollouts_by_env: dict[str, int] = defaultdict(int) self.last_batch_generation_time = 0.0 @@ -262,9 +287,93 @@ async def _schedule_next_request(self) -> bool: group_id = self.next_group_id self.next_group_id += 1 self.groups[group_id] = GroupState(example=example, rollouts_to_schedule=self.rollouts_per_example) + self._log_group_created(group_id=group_id, example=example) await self.schedule_rollout(group_id=group_id) return True + def _log_group_created(self, group_id: int, example: dict) -> None: + info = example.get("info") or {} + instance_id = info.get("instance_id") if isinstance(info, dict) else None + parts = [ + "Created rollout group", + f"step={self.step}", + f"group_id={group_id}", + f"example_id={example.get('example_id')}", + ] + if instance_id: + parts.append(f"instance_id={instance_id}") + parts.append(f"env_name={example.get('env_name')}") + self.logger.info(" | ".join(parts)) + + @staticmethod + def _legacy_error_text(rollout: vf.RolloutOutput) -> str: + error = rollout.get("error") + if not isinstance(error, dict): + return str(error or "") + return str(error.get("error_chain_str") or error.get("error_chain_repr") or error.get("error") or "") + + @classmethod + def _classify_legacy_rollout_error(cls, rollout: vf.RolloutOutput) -> tuple[str, str]: + text = cls._legacy_error_text(rollout) + if "AgentPollError" in text or "Agent polling failed" in text: + return FAILURE_REASON_AGENT_POLL_FAILED, FAILURE_ORIGIN_AGENT + if "AgentError" in text: + return FAILURE_REASON_AGENT_NONZERO_EXIT, FAILURE_ORIGIN_AGENT + if "StreamInterrupted" in text: + return FAILURE_REASON_STREAM_INTERRUPTED, FAILURE_ORIGIN_TUNNEL + if "TunnelError" in text: + return FAILURE_REASON_TUNNEL_ERROR, FAILURE_ORIGIN_TUNNEL + if "SandboxOOM" in text: + return FAILURE_REASON_SANDBOX_OOM, FAILURE_ORIGIN_SANDBOX + if "SandboxTimeout" in text: + return FAILURE_REASON_SANDBOX_TIMEOUT, FAILURE_ORIGIN_SANDBOX + if "SandboxSetupError" in text: + return FAILURE_REASON_SANDBOX_SETUP_FAILED, FAILURE_ORIGIN_SANDBOX + if "SandboxError" in text: + return FAILURE_REASON_SANDBOX_COMMAND_FAILED, FAILURE_ORIGIN_SANDBOX + if "ModelError" in text or "InvalidModelResponseError" in text or "EmptyModelResponseError" in text: + return FAILURE_REASON_MODEL_ERROR, FAILURE_ORIGIN_MODEL + if "timeout_reached" in text: + return FAILURE_REASON_ROLLOUT_TIMEOUT, FAILURE_ORIGIN_ROLLOUT + return FAILURE_REASON_UNKNOWN, FAILURE_ORIGIN_UNKNOWN + + @classmethod + def _classify_rejected_rollout(cls, rollout: vf.RolloutOutput) -> tuple[str, str]: + failure = rollout.get("failure") + if isinstance(failure, dict): + reason = failure.get("reason") + origin = failure.get("origin") + if isinstance(reason, str) and isinstance(origin, str): + return reason, origin + if rollout.get("error") is not None: + return cls._classify_legacy_rollout_error(rollout) + return FAILURE_REASON_AGENT_EMPTY_TRAJECTORY, FAILURE_ORIGIN_AGENT + + def _record_rejected_rollout( + self, + env_name: str, + reason: str, + origin: str, + *, + count: int = 1, + compatibility_error: bool = True, + ) -> None: + self.rejected_rollouts_by_env[env_name] += count + self.error_reasons_by_env[env_name][reason] += count + self.error_origins_by_env[env_name][origin] += count + if compatibility_error: + self.errored_rollouts_by_env[env_name] += count + + def _record_env_server_error(self, env_name: str, count: int = 1) -> None: + self.total_rollouts_by_env[env_name] += count + self._record_rejected_rollout( + env_name, + FAILURE_REASON_ENV_SERVER_ERROR, + FAILURE_ORIGIN_ENV_SERVER, + count=count, + compatibility_error=True, + ) + async def _fill_inflight_requests(self) -> None: while await self._schedule_next_request(): pass @@ -435,16 +544,25 @@ async def generate_batch(self, step: int) -> list[vf.RolloutOutput]: valid_rollouts = [] has_failures = False for rollout in rollouts: - if rollout["error"] is not None: - self.errored_rollouts_by_env[env_name] += 1 + if rollout.get("error") is not None: + reason, origin = self._classify_rejected_rollout(rollout) + self._record_rejected_rollout(env_name, reason, origin) has_failures = True + error_text = self._legacy_error_text(rollout) self.logger.warning( f"Rollout error in group {group_id} ({env_name}), re-scheduling " f"({len(group.completed_rollouts)}/{self.rollouts_per_example} complete): " - f"{rollout['error']['error_chain_repr']}" + f"{error_text}" ) - elif len(rollout["trajectory"]) == 0: + elif len(rollout.get("trajectory") or []) == 0: self.empty_rollouts_by_env[env_name] += 1 + reason, origin = self._classify_rejected_rollout(rollout) + self._record_rejected_rollout( + env_name, + reason, + origin, + compatibility_error=False, + ) has_failures = True self.logger.warning( f"Empty trajectory in group {group_id} ({env_name}), re-scheduling " @@ -472,7 +590,11 @@ async def generate_batch(self, step: int) -> list[vf.RolloutOutput]: await self.drop_group(group_id) continue except Exception as e: - self.logger.warning(f"Rollout failed: {e}") + self._record_env_server_error(env_name, rollout_info.rollout_count) + self.logger.warning( + f"Rollout failed in env server for group {group_id} ({env_name}), " + f"re-scheduling: {type(e).__name__}: {e}" + ) if group_id is not None: await self.drop_group(group_id) continue @@ -521,6 +643,7 @@ def async_level(self) -> int: def get_metrics(self) -> dict[str, float]: total_rollouts = sum(self.total_rollouts_by_env.values()) + total_rejected_rollouts = sum(self.rejected_rollouts_by_env.values()) metrics = { "time/wait_for_ckpt": self.wait_for_ckpt_time, "time/update_weights": self.update_weights_time, @@ -530,13 +653,31 @@ def get_metrics(self) -> dict[str, float]: "scheduler/cancelled_rollouts": self.cancelled_rollouts_count, "empty_rollouts/all": sum(self.empty_rollouts_by_env.values()) / max(total_rollouts, 1), "errored_rollouts/all": sum(self.errored_rollouts_by_env.values()) / max(total_rollouts, 1), + "error/all/mean": total_rejected_rollouts / max(total_rollouts, 1), "off_policy_level/all/max": self.max_off_policy_level, "off_policy_level/all/mean": self.mean_off_policy_level, } + all_reasons: Counter[str] = Counter() + all_origins: Counter[str] = Counter() + for env_name, reasons in self.error_reasons_by_env.items(): + all_reasons.update(reasons) + env_total = max(self.total_rollouts_by_env[env_name], 1) + for reason, count in reasons.items(): + metrics[f"error_reason/{env_name}/{reason}"] = count / env_total + for env_name, origins in self.error_origins_by_env.items(): + all_origins.update(origins) + env_total = max(self.total_rollouts_by_env[env_name], 1) + for origin, count in origins.items(): + metrics[f"error_origin/{env_name}/{origin}"] = count / env_total + for reason, count in all_reasons.items(): + metrics[f"error_reason/all/{reason}"] = count / max(total_rollouts, 1) + for origin, count in all_origins.items(): + metrics[f"error_origin/all/{origin}"] = count / max(total_rollouts, 1) for env_name in self.total_rollouts_by_env: env_total = max(self.total_rollouts_by_env[env_name], 1) metrics[f"empty_rollouts/{env_name}"] = self.empty_rollouts_by_env.get(env_name, 0) / env_total metrics[f"errored_rollouts/{env_name}"] = self.errored_rollouts_by_env.get(env_name, 0) / env_total + metrics[f"error/{env_name}/mean"] = self.rejected_rollouts_by_env.get(env_name, 0) / env_total by_env: dict[str, list[int]] = {} for info in self.inflight_requests.values(): by_env.setdefault(info.env_name, []).append(info.off_policy_steps) @@ -546,6 +687,9 @@ def get_metrics(self) -> dict[str, float]: self.cancelled_rollouts_count = 0 self.empty_rollouts_by_env.clear() self.errored_rollouts_by_env.clear() + self.rejected_rollouts_by_env.clear() + self.error_reasons_by_env.clear() + self.error_origins_by_env.clear() self.total_rollouts_by_env.clear() # Add inference pool metrics (e.g. elastic pool server counts) diff --git a/tests/unit/orchestrator/test_scheduler.py b/tests/unit/orchestrator/test_scheduler.py index 9e73b5207d..663ea7a5f7 100644 --- a/tests/unit/orchestrator/test_scheduler.py +++ b/tests/unit/orchestrator/test_scheduler.py @@ -1,4 +1,5 @@ import asyncio +from collections import Counter, defaultdict from pathlib import Path from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch @@ -21,10 +22,17 @@ def make_scheduler() -> Scheduler: scheduler.checkpoint_ready.set() scheduler.lora_name = None scheduler.model_name = "test-model" + scheduler.inference_pool = SimpleNamespace(get_metrics=lambda: {}) scheduler.update_weights_time = 0 scheduler.wait_for_ckpt_time = 0 scheduler.inflight_requests = {} scheduler.groups = {} + scheduler.empty_rollouts_by_env = defaultdict(int) + scheduler.errored_rollouts_by_env = defaultdict(int) + scheduler.rejected_rollouts_by_env = defaultdict(int) + scheduler.error_reasons_by_env = defaultdict(Counter) + scheduler.error_origins_by_env = defaultdict(Counter) + scheduler.total_rollouts_by_env = defaultdict(int) scheduler.max_off_policy_steps = 1 scheduler.cancelled_rollouts_count = 0 scheduler.policy_update_lock = asyncio.Lock() @@ -174,3 +182,102 @@ def test_client_identity_distinguishes_base_url_and_dp_rank(): ) assert Scheduler._client_identity(client_a) != Scheduler._client_identity(client_b) + + +def test_log_group_created_includes_example_mapping(): + scheduler = make_scheduler() + scheduler.step = 42 + + scheduler._log_group_created( + group_id=7, + example={ + "example_id": 2332, + "env_name": "opencode-swe", + "info": {"instance_id": "brazilian-utils__brutils-python-126"}, + }, + ) + + scheduler.logger.info.assert_called_once() + message = scheduler.logger.info.call_args.args[0] + assert "step=42" in message + assert "group_id=7" in message + assert "example_id=2332" in message + assert "instance_id=brazilian-utils__brutils-python-126" in message + assert "env_name=opencode-swe" in message + + +def test_rejected_rollout_structured_failure_metrics(): + scheduler = make_scheduler() + scheduler.total_rollouts_by_env["opencode-swe"] = 2 + rollout = { + "error": {"error_chain_repr": "AgentPollError('read failed')"}, + "failure": { + "reason": "agent_poll_failed", + "origin": "agent", + "error_type": "AgentPollError", + "root_error_type": "AgentPollError", + "message": "read failed", + "logs": {}, + }, + "trajectory": [], + } + reason, origin = scheduler._classify_rejected_rollout(rollout) + + scheduler._record_rejected_rollout("opencode-swe", reason, origin) + metrics = scheduler.get_metrics() + + assert metrics["errored_rollouts/opencode-swe"] == 0.5 + assert metrics["errored_rollouts/all"] == 0.5 + assert metrics["error/opencode-swe/mean"] == 0.5 + assert metrics["error/all/mean"] == 0.5 + assert metrics["error_reason/opencode-swe/agent_poll_failed"] == 0.5 + assert metrics["error_reason/all/agent_poll_failed"] == 0.5 + assert metrics["error_origin/opencode-swe/agent"] == 0.5 + assert metrics["error_origin/all/agent"] == 0.5 + + +def test_legacy_rollout_error_fallback_metrics(): + scheduler = make_scheduler() + scheduler.total_rollouts_by_env["legacy-env"] = 1 + rollout = { + "error": { + "error_chain_str": "SandboxSetupError -> CommandTimeoutError", + "error_chain_repr": "SandboxSetupError('setup failed')", + }, + "trajectory": [], + } + + reason, origin = scheduler._classify_rejected_rollout(rollout) + scheduler._record_rejected_rollout("legacy-env", reason, origin) + metrics = scheduler.get_metrics() + + assert metrics["error_reason/legacy-env/sandbox_setup_failed"] == 1.0 + assert metrics["error_origin/legacy-env/sandbox"] == 1.0 + + +def test_empty_trajectory_records_rejected_error_but_not_compat_error_metric(): + scheduler = make_scheduler() + scheduler.total_rollouts_by_env["opencode-swe"] = 1 + rollout = {"error": None, "trajectory": []} + + reason, origin = scheduler._classify_rejected_rollout(rollout) + scheduler.empty_rollouts_by_env["opencode-swe"] += 1 + scheduler._record_rejected_rollout("opencode-swe", reason, origin, compatibility_error=False) + metrics = scheduler.get_metrics() + + assert metrics["empty_rollouts/opencode-swe"] == 1.0 + assert metrics["errored_rollouts/opencode-swe"] == 0.0 + assert metrics["error/opencode-swe/mean"] == 1.0 + assert metrics["error_reason/opencode-swe/agent_empty_trajectory"] == 1.0 + + +def test_env_server_exception_records_env_server_error_metrics(): + scheduler = make_scheduler() + + scheduler._record_env_server_error("opencode-swe", count=2) + metrics = scheduler.get_metrics() + + assert metrics["errored_rollouts/opencode-swe"] == 1.0 + assert metrics["error/opencode-swe/mean"] == 1.0 + assert metrics["error_reason/opencode-swe/env_server_error"] == 1.0 + assert metrics["error_origin/opencode-swe/env_server"] == 1.0