Skip to content
304 changes: 214 additions & 90 deletions src/prime_rl/utils/monitor/prime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math
import os
import time
from collections import deque
from datetime import datetime, timezone
from pathlib import Path
from threading import Thread
Expand All @@ -14,13 +15,49 @@
import pyarrow.parquet as pq
import verifiers as vf
from prime_cli.core.config import Config as PrimeConfig
from tenacity import (
AsyncRetrying,
retry_if_exception,
stop_after_attempt,
stop_after_delay,
wait_exponential,
wait_random,
)
from transformers.tokenization_utils import PreTrainedTokenizer

from prime_rl.configs.shared import PrimeMonitorConfig
from prime_rl.utils.config import BaseConfig
from prime_rl.utils.logger import get_logger
from prime_rl.utils.monitor.base import Monitor, sample_items_for_logging

_RETRYABLE_HTTPX_EXC = (httpx.TransportError, httpx.RemoteProtocolError)


def _is_retryable_upload_error(exc: BaseException) -> bool:
if isinstance(exc, _RETRYABLE_HTTPX_EXC):
return True
if isinstance(exc, httpx.HTTPStatusError):
code = exc.response.status_code
return code in (408, 429) or 500 <= code < 600
return False


# Per-HTTP-call retry budget. Wall-clock cap dominates so a single call cannot
# stall the upload pipeline indefinitely.
_UPLOAD_MAX_ATTEMPTS = 6
_UPLOAD_MAX_DELAY_S = 120.0
_UPLOAD_BACKOFF_MAX_S = 30.0

# Bounded backlog of (step, parquet_bytes) waiting to be uploaded. Sized to
# survive a few intervals of R2 unavailability without unbounded memory growth.
_MAX_PENDING_SAMPLE_UPLOADS = 5

# After a retryable upload failure (R2/API outage), suppress further drain
# attempts for this long so coroutines queued during the failed drain don't
# each consume another full tenacity retry budget back-to-back. The next
# log_samples tick after this expires picks up the backlog.
_RETRYABLE_COOLDOWN_S = 60.0


def _json(val: Any) -> str:
"""JSON-serialize dicts/lists, pass strings through, default to empty string for None."""
Expand Down Expand Up @@ -287,7 +324,12 @@ def log(self, metrics: dict[str, Any], step: int) -> None:
)

def log_samples(self, rollouts: list[vf.RolloutOutput], step: int) -> None:
"""Logs rollouts to Prime Intellect API using presigned URLs for direct R2 upload."""
"""Logs rollouts to Prime Intellect API using presigned URLs for direct R2 upload.

Adds the new step to a bounded backlog and schedules a drain on the background
event loop. Failed uploads stay in the backlog and are retried on the next
log_samples call, so a transient R2 outage does not silently drop data.
"""
if not self.is_master:
return
if not self.enabled:
Expand Down Expand Up @@ -321,12 +363,11 @@ def log_samples(self, rollouts: list[vf.RolloutOutput], step: int) -> None:
return

self._pending_sample_steps.add(step)

# Use presigned URL flow for uploading samples
self._upload_samples_via_presigned_url(parquet_bytes, step)
self._enqueue_and_drain_samples(step, parquet_bytes)

self.logger.debug(
f"Initiated samples upload at step {step} to Prime Intellect API in {time.perf_counter() - start_time:.2f}s"
f"Queued samples upload for step {step} in {time.perf_counter() - start_time:.2f}s "
f"(backlog={len(self._sample_upload_queue) + 1})"
)

def _rollouts_to_parquet_bytes(self, rollouts: list[vf.RolloutOutput], step: int) -> bytes | None:
Expand Down Expand Up @@ -392,105 +433,161 @@ def _rollouts_to_parquet_bytes(self, rollouts: list[vf.RolloutOutput], step: int
pq.write_table(table, buf, compression="snappy", use_dictionary=True, write_statistics=True)
return buf.getvalue()

def _upload_samples_via_presigned_url(self, parquet_bytes: bytes, step: int) -> None:
"""Upload Parquet samples using presigned URL flow (fire-and-forget)."""
def _enqueue_and_drain_samples(self, step: int, parquet_bytes: bytes) -> None:
"""Append a step's parquet to the backlog and trigger a drain on the bg loop."""
future = asyncio.run_coroutine_threadsafe(
self._upload_samples_via_presigned_url_async(parquet_bytes, step),
self._enqueue_and_drain_samples_async(step, parquet_bytes),
self._loop,
)
self._pending_futures.append(future)
# Clean up completed futures to avoid memory growth
self._pending_futures = [f for f in self._pending_futures if not f.done()]

async def _upload_samples_via_presigned_url_async(
self, parquet_bytes: bytes, step: int, max_retries: int = 3
) -> None:
"""Upload Parquet bytes via presigned URL flow."""
try:
presign_data = await self._request_presigned_url(step)
if not presign_data:
self.logger.warning(f"Failed to get presigned URL for samples at step {step}")
return

presigned_url = presign_data["presigned_url"]
s3_key = presign_data["s3_key"]
async def _enqueue_and_drain_samples_async(self, step: int, parquet_bytes: bytes) -> None:
"""Backlog-aware upload: enqueue, then drain oldest-first under a lock."""
if self._sample_upload_lock is None:
self._sample_upload_lock = asyncio.Lock()

upload_success = await self._upload_to_r2(
presigned_url, parquet_bytes, content_type="application/parquet", max_retries=max_retries
self._sample_upload_queue.append((step, parquet_bytes))
while len(self._sample_upload_queue) > _MAX_PENDING_SAMPLE_UPLOADS:
dropped_step, _ = self._sample_upload_queue.popleft()
Comment thread
JannikSt marked this conversation as resolved.
self._pending_sample_steps.discard(dropped_step)
self.logger.warning(
f"Sample upload backlog exceeded {_MAX_PENDING_SAMPLE_UPLOADS}, "
f"dropping oldest queued step {dropped_step}"
)
if not upload_success:
self.logger.warning(f"Failed to upload samples to R2 at step {step}")
return

confirm_success = await self._confirm_samples_upload(step, s3_key)
if not confirm_success:
self.logger.warning(f"Failed to confirm samples upload at step {step}")
return
await self._drain_sample_backlog()

async def _drain_sample_backlog(self, *, ignore_cooldown: bool = False) -> None:
"""Drain the sample backlog under the upload lock. Used by the regular
log_samples path and by close()/_flush() for a final shutdown attempt.
With ignore_cooldown=True, retries the head even if the retryable cooldown
is currently armed — last-chance attempt before the loop stops.
"""
if self._sample_upload_lock is None:
self._sample_upload_lock = asyncio.Lock()
if ignore_cooldown:
self._retryable_cooldown_until = None

async with self._sample_upload_lock:
while self._sample_upload_queue:
# If a previous drain just hit a retryable failure, every coroutine
# queued on this lock would otherwise pop the same head and burn
# another full tenacity budget. Bail out and let a later tick (after
# the cooldown elapses) pick up the backlog.
if (
not ignore_cooldown
and self._retryable_cooldown_until is not None
and time.monotonic() < self._retryable_cooldown_until
):
return

# Pop the in-flight item out of the queue entirely before awaiting,
# so concurrent eviction (which runs without the lock) cannot remove
# the head we are uploading. On retryable failure we appendleft to
# preserve ordering; on success or permanent failure the item simply
# stays gone.
pending_step, pending_bytes = self._sample_upload_queue.popleft()
start = time.perf_counter()
try:
await self._upload_one_sample_step(pending_step, pending_bytes)
except Exception as e:
if _is_retryable_upload_error(e):
# Transient — keep at the head and arm the cooldown so other
# waiters don't immediately retry the same head.
self._sample_upload_queue.appendleft((pending_step, pending_bytes))
Comment thread
JannikSt marked this conversation as resolved.
self._retryable_cooldown_until = time.monotonic() + _RETRYABLE_COOLDOWN_S
self.logger.opt(exception=True).warning(
f"Sample upload for step {pending_step} failed after retries; "
f"keeping in backlog (size={len(self._sample_upload_queue)}, "
f"cooldown={_RETRYABLE_COOLDOWN_S:.0f}s): "
f"{type(e).__name__}: {e}"
)
return
Comment thread
JannikSt marked this conversation as resolved.

# Permanent (e.g. 4xx that's not 408/429): drop this step so it
# doesn't block the queue head, and continue draining the rest.
self._pending_sample_steps.discard(pending_step)
self.logger.opt(exception=True).warning(
f"Sample upload for step {pending_step} failed with non-retryable error; "
f"dropping step (queue size={len(self._sample_upload_queue)}): "
f"{type(e).__name__}: {e}"
)
continue

self.last_log_samples_step = step
self.logger.debug(f"Successfully completed samples upload at step {step}")
# Success: bookkeeping (item is already popped). Clear any stale
# cooldown — R2 is healthy.
self._retryable_cooldown_until = None
self._pending_sample_steps.discard(pending_step)
self.last_log_samples_step = pending_step
self.logger.debug(f"Uploaded samples for step {pending_step} in {time.perf_counter() - start:.2f}s")

except Exception as e:
self.logger.warning(f"Failed to upload samples via presigned URL at step {step}: {type(e).__name__}: {e}")
finally:
self._pending_sample_steps.discard(step)
async def _upload_one_sample_step(self, step: int, parquet_bytes: bytes) -> None:
"""Run presign → R2 PUT → confirm for a single step. Raises on final failure."""
presign_data = await self._request_presigned_url(step)
presigned_url = presign_data["presigned_url"]
s3_key = presign_data["s3_key"]

async def _request_presigned_url(self, step: int) -> dict[str, Any] | None:
"""Request a presigned URL from the backend."""
try:
response = await self._client.post(
f"{self.base_url}/samples/presign",
headers=self._headers,
json={"run_id": self.run_id, "step": step},
await self._upload_to_r2(presigned_url, parquet_bytes, step, content_type="application/parquet")
await self._confirm_samples_upload(step, s3_key)

def _retry_policy(self, op_name: str, step: int) -> AsyncRetrying:
"""Shared tenacity policy for sample-upload HTTP calls."""

def _log_retry(retry_state) -> None:
exc = retry_state.outcome.exception() if retry_state.outcome else None
self.logger.warning(
f"Retrying {op_name} for step {step} "
f"(attempt {retry_state.attempt_number}/{_UPLOAD_MAX_ATTEMPTS}): "
f"{type(exc).__name__ if exc else 'unknown'}: {exc}"
)
response.raise_for_status()
response_data = response.json()["data"]
return {
"presigned_url": response_data["presignedUrl"],
"s3_key": response_data["s3Key"],
}
except Exception as e:
self.logger.warning(f"Failed to request presigned URL: {type(e).__name__}: {e}")
return None

return AsyncRetrying(
retry=retry_if_exception(_is_retryable_upload_error),
stop=stop_after_attempt(_UPLOAD_MAX_ATTEMPTS) | stop_after_delay(_UPLOAD_MAX_DELAY_S),
wait=wait_exponential(multiplier=1, min=1, max=_UPLOAD_BACKOFF_MAX_S) + wait_random(0, 1),
before_sleep=_log_retry,
reraise=True,
)

async def _request_presigned_url(self, step: int) -> dict[str, Any]:
"""Request a presigned URL from the backend. Raises on final failure."""
async for attempt in self._retry_policy("samples/presign", step):
with attempt:
response = await self._client.post(
f"{self.base_url}/samples/presign",
headers=self._headers,
json={"run_id": self.run_id, "step": step},
)
response.raise_for_status()
response_data = response.json()["data"]
return {
"presigned_url": response_data["presignedUrl"],
"s3_key": response_data["s3Key"],
}
raise RuntimeError("retry loop exited without returning")

async def _upload_to_r2(
self, presigned_url: str, data: bytes, content_type: str = "application/json", max_retries: int = 3
) -> bool:
"""Upload data to R2 using presigned URL."""
for attempt in range(max_retries):
try:
self, presigned_url: str, data: bytes, step: int, content_type: str = "application/json"
) -> None:
"""Upload data to R2 via presigned URL. Raises on final failure."""
async for attempt in self._retry_policy("R2 PUT", step):
with attempt:
response = await self._client.put(presigned_url, content=data, headers={"Content-Type": content_type})
response.raise_for_status()
return True
except Exception as e:
if attempt == max_retries - 1:
self.logger.warning(f"Failed to upload to R2 after {max_retries} attempts: {type(e).__name__}: {e}")
return False
delay = 2**attempt
self.logger.debug(f"Retrying R2 upload in {delay}s (attempt {attempt + 1}/{max_retries})")
await asyncio.sleep(delay)

async def _confirm_samples_upload(self, step: int, s3_key: str, max_retries: int = 3) -> bool:
"""Confirm samples upload with the backend. Returns True on success."""
for attempt in range(max_retries):
try:
return

async def _confirm_samples_upload(self, step: int, s3_key: str) -> None:
"""Confirm samples upload with the backend. Raises on final failure."""
async for attempt in self._retry_policy("samples/confirm", step):
with attempt:
response = await self._client.post(
f"{self.base_url}/samples/confirm",
headers=self._headers,
json={"run_id": self.run_id, "step": step, "s3_key": s3_key},
)
response.raise_for_status()
return True
except Exception as e:
if attempt == max_retries - 1:
self.logger.warning(
f"Failed to confirm samples upload after {max_retries} attempts: {type(e).__name__}: {e}"
)
return False
delay = 2**attempt
self.logger.debug(f"Retrying samples confirm in {delay}s (attempt {attempt + 1}/{max_retries})")
await asyncio.sleep(delay)
return False
return

def log_eval_samples(self, rollouts: list[vf.RolloutOutput], env_name: str, step: int) -> None:
pass
Expand Down Expand Up @@ -615,6 +712,13 @@ def _init_async_client(self) -> None:
self._thread.start()
self._client = httpx.AsyncClient(timeout=30)
self._pending_futures: list[asyncio.Future] = []
# Sample-upload backlog. Lock is constructed lazily on the bg loop to bind
# to the right asyncio loop after fork.
self._sample_upload_queue: deque[tuple[int, bytes]] = deque()
self._sample_upload_lock: asyncio.Lock | None = None
# Set to a future monotonic timestamp after a retryable failure so queued
# drains don't all reattempt the same failing head back-to-back.
self._retryable_cooldown_until: float | None = None
if hasattr(self, "_pending_sample_steps") and self._pending_sample_steps:
self._pending_sample_steps.clear()

Expand All @@ -628,21 +732,41 @@ def _run_event_loop(self) -> None:
self._loop.run_forever()

def _flush(self, timeout: float = 30.0) -> None:
"""Wait for all pending async requests to complete."""
"""Wait for all pending async requests to complete and make a last-chance
attempt to drain any sample uploads still queued in the backlog. Without
this final drain, a retryable failure on the last log_samples tick would
leave its parquet bytes parked in the queue and lost on close.
"""
if not self.enabled or not hasattr(self, "_loop"):
return

if not self._pending_futures:
return

self.logger.debug(f"Flushing {len(self._pending_futures)} pending request(s)")
for future in self._pending_futures:
if self._pending_futures:
self.logger.debug(f"Flushing {len(self._pending_futures)} pending request(s)")
for future in self._pending_futures:
try:
future.result(timeout=timeout)
except Exception as e:
self.logger.debug(f"Pending request completed with error: {e}")
self._pending_futures.clear()

# Final shutdown drain of the sample backlog. Bypass the cooldown — this
# is our last chance before the loop stops.
if hasattr(self, "_sample_upload_queue") and self._sample_upload_queue:
backlog_size = len(self._sample_upload_queue)
self.logger.info(f"Final sample-upload drain on close ({backlog_size} step(s) queued)")
try:
future = asyncio.run_coroutine_threadsafe(
self._drain_sample_backlog(ignore_cooldown=True),
self._loop,
)
future.result(timeout=timeout)
except Exception as e:
self.logger.debug(f"Pending request completed with error: {e}")

self._pending_futures.clear()
self.logger.warning(f"Final sample-upload drain failed: {type(e).__name__}: {e}")
remaining = len(self._sample_upload_queue)
if remaining:
self.logger.warning(
f"{remaining} sample upload(s) still in backlog at close — these samples will be lost"
)

async def _make_request_async(self, endpoint: str, data: dict[str, Any], max_retries: int = 3) -> None:
"""Make an async POST request to the Prime Intellect API with retries."""
Expand Down
Loading