Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 149 additions & 5 deletions src/prime_rl/orchestrator/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,28 @@
wait_for_path,
)

FAILURE_REASON_AGENT_NONZERO_EXIT = "agent_nonzero_exit"
FAILURE_REASON_AGENT_POLL_FAILED = "agent_poll_failed"
FAILURE_REASON_AGENT_EMPTY_TRAJECTORY = "agent_empty_trajectory"
FAILURE_REASON_ROLLOUT_TIMEOUT = "rollout_timeout"
FAILURE_REASON_SANDBOX_OOM = "sandbox_oom"
FAILURE_REASON_SANDBOX_TIMEOUT = "sandbox_timeout"
FAILURE_REASON_SANDBOX_COMMAND_FAILED = "sandbox_command_failed"
FAILURE_REASON_SANDBOX_SETUP_FAILED = "sandbox_setup_failed"
FAILURE_REASON_TUNNEL_ERROR = "tunnel_error"
FAILURE_REASON_STREAM_INTERRUPTED = "stream_interrupted"
FAILURE_REASON_MODEL_ERROR = "model_error"
FAILURE_REASON_ENV_SERVER_ERROR = "env_server_error"
FAILURE_REASON_UNKNOWN = "unknown"

FAILURE_ORIGIN_AGENT = "agent"
FAILURE_ORIGIN_SANDBOX = "sandbox"
FAILURE_ORIGIN_TUNNEL = "tunnel"
FAILURE_ORIGIN_MODEL = "model"
FAILURE_ORIGIN_ENV_SERVER = "env_server"
FAILURE_ORIGIN_ROLLOUT = "rollout"
FAILURE_ORIGIN_UNKNOWN = "unknown"


@dataclass
class InflightRequest:
Expand Down Expand Up @@ -113,6 +135,9 @@ def __init__(
self.cancelled_rollouts_count = 0
self.empty_rollouts_by_env: dict[str, int] = defaultdict(int)
self.errored_rollouts_by_env: dict[str, int] = defaultdict(int)
self.rejected_rollouts_by_env: dict[str, int] = defaultdict(int)
self.error_reasons_by_env: dict[str, Counter[str]] = defaultdict(Counter)
self.error_origins_by_env: dict[str, Counter[str]] = defaultdict(Counter)
self.total_rollouts_by_env: dict[str, int] = defaultdict(int)
self.last_batch_generation_time = 0.0

Expand Down Expand Up @@ -262,9 +287,93 @@ async def _schedule_next_request(self) -> bool:
group_id = self.next_group_id
self.next_group_id += 1
self.groups[group_id] = GroupState(example=example, rollouts_to_schedule=self.rollouts_per_example)
self._log_group_created(group_id=group_id, example=example)
await self.schedule_rollout(group_id=group_id)
return True

def _log_group_created(self, group_id: int, example: dict) -> None:
info = example.get("info") or {}
instance_id = info.get("instance_id") if isinstance(info, dict) else None
parts = [
"Created rollout group",
f"step={self.step}",
f"group_id={group_id}",
f"example_id={example.get('example_id')}",
]
if instance_id:
parts.append(f"instance_id={instance_id}")
parts.append(f"env_name={example.get('env_name')}")
self.logger.info(" | ".join(parts))

@staticmethod
def _legacy_error_text(rollout: vf.RolloutOutput) -> str:
error = rollout.get("error")
if not isinstance(error, dict):
return str(error or "")
return str(error.get("error_chain_str") or error.get("error_chain_repr") or error.get("error") or "")

@classmethod
def _classify_legacy_rollout_error(cls, rollout: vf.RolloutOutput) -> tuple[str, str]:
text = cls._legacy_error_text(rollout)
if "AgentPollError" in text or "Agent polling failed" in text:
return FAILURE_REASON_AGENT_POLL_FAILED, FAILURE_ORIGIN_AGENT
if "AgentError" in text:
return FAILURE_REASON_AGENT_NONZERO_EXIT, FAILURE_ORIGIN_AGENT
if "StreamInterrupted" in text:
return FAILURE_REASON_STREAM_INTERRUPTED, FAILURE_ORIGIN_TUNNEL
if "TunnelError" in text:
return FAILURE_REASON_TUNNEL_ERROR, FAILURE_ORIGIN_TUNNEL
if "SandboxOOM" in text:
return FAILURE_REASON_SANDBOX_OOM, FAILURE_ORIGIN_SANDBOX
if "SandboxTimeout" in text:
return FAILURE_REASON_SANDBOX_TIMEOUT, FAILURE_ORIGIN_SANDBOX
if "SandboxSetupError" in text:
return FAILURE_REASON_SANDBOX_SETUP_FAILED, FAILURE_ORIGIN_SANDBOX
if "SandboxError" in text:
return FAILURE_REASON_SANDBOX_COMMAND_FAILED, FAILURE_ORIGIN_SANDBOX
if "ModelError" in text or "InvalidModelResponseError" in text or "EmptyModelResponseError" in text:
return FAILURE_REASON_MODEL_ERROR, FAILURE_ORIGIN_MODEL
if "timeout_reached" in text:
return FAILURE_REASON_ROLLOUT_TIMEOUT, FAILURE_ORIGIN_ROLLOUT
return FAILURE_REASON_UNKNOWN, FAILURE_ORIGIN_UNKNOWN

@classmethod
def _classify_rejected_rollout(cls, rollout: vf.RolloutOutput) -> tuple[str, str]:
failure = rollout.get("failure")
if isinstance(failure, dict):
reason = failure.get("reason")
origin = failure.get("origin")
if isinstance(reason, str) and isinstance(origin, str):
return reason, origin
if rollout.get("error") is not None:
return cls._classify_legacy_rollout_error(rollout)
return FAILURE_REASON_AGENT_EMPTY_TRAJECTORY, FAILURE_ORIGIN_AGENT

def _record_rejected_rollout(
self,
env_name: str,
reason: str,
origin: str,
*,
count: int = 1,
compatibility_error: bool = True,
) -> None:
self.rejected_rollouts_by_env[env_name] += count
self.error_reasons_by_env[env_name][reason] += count
self.error_origins_by_env[env_name][origin] += count
if compatibility_error:
self.errored_rollouts_by_env[env_name] += count

def _record_env_server_error(self, env_name: str, count: int = 1) -> None:
self.total_rollouts_by_env[env_name] += count
self._record_rejected_rollout(
env_name,
FAILURE_REASON_ENV_SERVER_ERROR,
FAILURE_ORIGIN_ENV_SERVER,
count=count,
compatibility_error=True,
)

async def _fill_inflight_requests(self) -> None:
while await self._schedule_next_request():
pass
Expand Down Expand Up @@ -435,16 +544,25 @@ async def generate_batch(self, step: int) -> list[vf.RolloutOutput]:
valid_rollouts = []
has_failures = False
for rollout in rollouts:
if rollout["error"] is not None:
self.errored_rollouts_by_env[env_name] += 1
if rollout.get("error") is not None:
reason, origin = self._classify_rejected_rollout(rollout)
self._record_rejected_rollout(env_name, reason, origin)
has_failures = True
error_text = self._legacy_error_text(rollout)
self.logger.warning(
f"Rollout error in group {group_id} ({env_name}), re-scheduling "
f"({len(group.completed_rollouts)}/{self.rollouts_per_example} complete): "
f"{rollout['error']['error_chain_repr']}"
f"{error_text}"
)
elif len(rollout["trajectory"]) == 0:
elif len(rollout.get("trajectory") or []) == 0:
self.empty_rollouts_by_env[env_name] += 1
reason, origin = self._classify_rejected_rollout(rollout)
self._record_rejected_rollout(
env_name,
reason,
origin,
compatibility_error=False,
)
has_failures = True
self.logger.warning(
f"Empty trajectory in group {group_id} ({env_name}), re-scheduling "
Expand Down Expand Up @@ -472,7 +590,11 @@ async def generate_batch(self, step: int) -> list[vf.RolloutOutput]:
await self.drop_group(group_id)
continue
except Exception as e:
self.logger.warning(f"Rollout failed: {e}")
self._record_env_server_error(env_name, rollout_info.rollout_count)
self.logger.warning(
f"Rollout failed in env server for group {group_id} ({env_name}), "
f"re-scheduling: {type(e).__name__}: {e}"
)
if group_id is not None:
await self.drop_group(group_id)
continue
Expand Down Expand Up @@ -521,6 +643,7 @@ def async_level(self) -> int:

def get_metrics(self) -> dict[str, float]:
total_rollouts = sum(self.total_rollouts_by_env.values())
total_rejected_rollouts = sum(self.rejected_rollouts_by_env.values())
metrics = {
"time/wait_for_ckpt": self.wait_for_ckpt_time,
"time/update_weights": self.update_weights_time,
Expand All @@ -530,13 +653,31 @@ def get_metrics(self) -> dict[str, float]:
"scheduler/cancelled_rollouts": self.cancelled_rollouts_count,
"empty_rollouts/all": sum(self.empty_rollouts_by_env.values()) / max(total_rollouts, 1),
"errored_rollouts/all": sum(self.errored_rollouts_by_env.values()) / max(total_rollouts, 1),
"error/all/mean": total_rejected_rollouts / max(total_rollouts, 1),
"off_policy_level/all/max": self.max_off_policy_level,
"off_policy_level/all/mean": self.mean_off_policy_level,
}
all_reasons: Counter[str] = Counter()
all_origins: Counter[str] = Counter()
for env_name, reasons in self.error_reasons_by_env.items():
all_reasons.update(reasons)
env_total = max(self.total_rollouts_by_env[env_name], 1)
for reason, count in reasons.items():
metrics[f"error_reason/{env_name}/{reason}"] = count / env_total
for env_name, origins in self.error_origins_by_env.items():
all_origins.update(origins)
env_total = max(self.total_rollouts_by_env[env_name], 1)
for origin, count in origins.items():
metrics[f"error_origin/{env_name}/{origin}"] = count / env_total
for reason, count in all_reasons.items():
metrics[f"error_reason/all/{reason}"] = count / max(total_rollouts, 1)
for origin, count in all_origins.items():
metrics[f"error_origin/all/{origin}"] = count / max(total_rollouts, 1)
for env_name in self.total_rollouts_by_env:
env_total = max(self.total_rollouts_by_env[env_name], 1)
metrics[f"empty_rollouts/{env_name}"] = self.empty_rollouts_by_env.get(env_name, 0) / env_total
metrics[f"errored_rollouts/{env_name}"] = self.errored_rollouts_by_env.get(env_name, 0) / env_total
metrics[f"error/{env_name}/mean"] = self.rejected_rollouts_by_env.get(env_name, 0) / env_total
by_env: dict[str, list[int]] = {}
for info in self.inflight_requests.values():
by_env.setdefault(info.env_name, []).append(info.off_policy_steps)
Expand All @@ -546,6 +687,9 @@ def get_metrics(self) -> dict[str, float]:
self.cancelled_rollouts_count = 0
self.empty_rollouts_by_env.clear()
self.errored_rollouts_by_env.clear()
self.rejected_rollouts_by_env.clear()
self.error_reasons_by_env.clear()
self.error_origins_by_env.clear()
self.total_rollouts_by_env.clear()

# Add inference pool metrics (e.g. elastic pool server counts)
Expand Down
107 changes: 107 additions & 0 deletions tests/unit/orchestrator/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from collections import Counter, defaultdict
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
Expand All @@ -21,10 +22,17 @@ def make_scheduler() -> Scheduler:
scheduler.checkpoint_ready.set()
scheduler.lora_name = None
scheduler.model_name = "test-model"
scheduler.inference_pool = SimpleNamespace(get_metrics=lambda: {})
scheduler.update_weights_time = 0
scheduler.wait_for_ckpt_time = 0
scheduler.inflight_requests = {}
scheduler.groups = {}
scheduler.empty_rollouts_by_env = defaultdict(int)
scheduler.errored_rollouts_by_env = defaultdict(int)
scheduler.rejected_rollouts_by_env = defaultdict(int)
scheduler.error_reasons_by_env = defaultdict(Counter)
scheduler.error_origins_by_env = defaultdict(Counter)
scheduler.total_rollouts_by_env = defaultdict(int)
scheduler.max_off_policy_steps = 1
scheduler.cancelled_rollouts_count = 0
scheduler.policy_update_lock = asyncio.Lock()
Expand Down Expand Up @@ -174,3 +182,102 @@ def test_client_identity_distinguishes_base_url_and_dp_rank():
)

assert Scheduler._client_identity(client_a) != Scheduler._client_identity(client_b)


def test_log_group_created_includes_example_mapping():
scheduler = make_scheduler()
scheduler.step = 42

scheduler._log_group_created(
group_id=7,
example={
"example_id": 2332,
"env_name": "opencode-swe",
"info": {"instance_id": "brazilian-utils__brutils-python-126"},
},
)

scheduler.logger.info.assert_called_once()
message = scheduler.logger.info.call_args.args[0]
assert "step=42" in message
assert "group_id=7" in message
assert "example_id=2332" in message
assert "instance_id=brazilian-utils__brutils-python-126" in message
assert "env_name=opencode-swe" in message


def test_rejected_rollout_structured_failure_metrics():
scheduler = make_scheduler()
scheduler.total_rollouts_by_env["opencode-swe"] = 2
rollout = {
"error": {"error_chain_repr": "AgentPollError('read failed')"},
"failure": {
"reason": "agent_poll_failed",
"origin": "agent",
"error_type": "AgentPollError",
"root_error_type": "AgentPollError",
"message": "read failed",
"logs": {},
},
"trajectory": [],
}
reason, origin = scheduler._classify_rejected_rollout(rollout)

scheduler._record_rejected_rollout("opencode-swe", reason, origin)
metrics = scheduler.get_metrics()

assert metrics["errored_rollouts/opencode-swe"] == 0.5
assert metrics["errored_rollouts/all"] == 0.5
assert metrics["error/opencode-swe/mean"] == 0.5
assert metrics["error/all/mean"] == 0.5
assert metrics["error_reason/opencode-swe/agent_poll_failed"] == 0.5
assert metrics["error_reason/all/agent_poll_failed"] == 0.5
assert metrics["error_origin/opencode-swe/agent"] == 0.5
assert metrics["error_origin/all/agent"] == 0.5


def test_legacy_rollout_error_fallback_metrics():
scheduler = make_scheduler()
scheduler.total_rollouts_by_env["legacy-env"] = 1
rollout = {
"error": {
"error_chain_str": "SandboxSetupError -> CommandTimeoutError",
"error_chain_repr": "SandboxSetupError('setup failed')",
},
"trajectory": [],
}

reason, origin = scheduler._classify_rejected_rollout(rollout)
scheduler._record_rejected_rollout("legacy-env", reason, origin)
metrics = scheduler.get_metrics()

assert metrics["error_reason/legacy-env/sandbox_setup_failed"] == 1.0
assert metrics["error_origin/legacy-env/sandbox"] == 1.0


def test_empty_trajectory_records_rejected_error_but_not_compat_error_metric():
scheduler = make_scheduler()
scheduler.total_rollouts_by_env["opencode-swe"] = 1
rollout = {"error": None, "trajectory": []}

reason, origin = scheduler._classify_rejected_rollout(rollout)
scheduler.empty_rollouts_by_env["opencode-swe"] += 1
scheduler._record_rejected_rollout("opencode-swe", reason, origin, compatibility_error=False)
metrics = scheduler.get_metrics()

assert metrics["empty_rollouts/opencode-swe"] == 1.0
assert metrics["errored_rollouts/opencode-swe"] == 0.0
assert metrics["error/opencode-swe/mean"] == 1.0
assert metrics["error_reason/opencode-swe/agent_empty_trajectory"] == 1.0


def test_env_server_exception_records_env_server_error_metrics():
scheduler = make_scheduler()

scheduler._record_env_server_error("opencode-swe", count=2)
metrics = scheduler.get_metrics()

assert metrics["errored_rollouts/opencode-swe"] == 1.0
assert metrics["error/opencode-swe/mean"] == 1.0
assert metrics["error_reason/opencode-swe/env_server_error"] == 1.0
assert metrics["error_origin/opencode-swe/env_server"] == 1.0
Loading