diff --git a/src/prime_rl/utils/monitor/prime.py b/src/prime_rl/utils/monitor/prime.py index 7e63b6c2d7..35d19352f9 100644 --- a/src/prime_rl/utils/monitor/prime.py +++ b/src/prime_rl/utils/monitor/prime.py @@ -4,6 +4,7 @@ import math import os import time +from collections import deque from datetime import datetime, timezone from pathlib import Path from threading import Thread @@ -14,6 +15,14 @@ import pyarrow.parquet as pq import verifiers as vf from prime_cli.core.config import Config as PrimeConfig +from tenacity import ( + AsyncRetrying, + retry_if_exception, + stop_after_attempt, + stop_after_delay, + wait_exponential, + wait_random, +) from transformers.tokenization_utils import PreTrainedTokenizer from prime_rl.configs.shared import PrimeMonitorConfig @@ -21,6 +30,34 @@ from prime_rl.utils.logger import get_logger from prime_rl.utils.monitor.base import Monitor, sample_items_for_logging +_RETRYABLE_HTTPX_EXC = (httpx.TransportError, httpx.RemoteProtocolError) + + +def _is_retryable_upload_error(exc: BaseException) -> bool: + if isinstance(exc, _RETRYABLE_HTTPX_EXC): + return True + if isinstance(exc, httpx.HTTPStatusError): + code = exc.response.status_code + return code in (408, 429) or 500 <= code < 600 + return False + + +# Per-HTTP-call retry budget. Wall-clock cap dominates so a single call cannot +# stall the upload pipeline indefinitely. +_UPLOAD_MAX_ATTEMPTS = 6 +_UPLOAD_MAX_DELAY_S = 120.0 +_UPLOAD_BACKOFF_MAX_S = 30.0 + +# Bounded backlog of (step, parquet_bytes) waiting to be uploaded. Sized to +# survive a few intervals of R2 unavailability without unbounded memory growth. +_MAX_PENDING_SAMPLE_UPLOADS = 5 + +# After a retryable upload failure (R2/API outage), suppress further drain +# attempts for this long so coroutines queued during the failed drain don't +# each consume another full tenacity retry budget back-to-back. The next +# log_samples tick after this expires picks up the backlog. +_RETRYABLE_COOLDOWN_S = 60.0 + def _json(val: Any) -> str: """JSON-serialize dicts/lists, pass strings through, default to empty string for None.""" @@ -287,7 +324,12 @@ def log(self, metrics: dict[str, Any], step: int) -> None: ) def log_samples(self, rollouts: list[vf.RolloutOutput], step: int) -> None: - """Logs rollouts to Prime Intellect API using presigned URLs for direct R2 upload.""" + """Logs rollouts to Prime Intellect API using presigned URLs for direct R2 upload. + + Adds the new step to a bounded backlog and schedules a drain on the background + event loop. Failed uploads stay in the backlog and are retried on the next + log_samples call, so a transient R2 outage does not silently drop data. + """ if not self.is_master: return if not self.enabled: @@ -321,12 +363,11 @@ def log_samples(self, rollouts: list[vf.RolloutOutput], step: int) -> None: return self._pending_sample_steps.add(step) - - # Use presigned URL flow for uploading samples - self._upload_samples_via_presigned_url(parquet_bytes, step) + self._enqueue_and_drain_samples(step, parquet_bytes) self.logger.debug( - f"Initiated samples upload at step {step} to Prime Intellect API in {time.perf_counter() - start_time:.2f}s" + f"Queued samples upload for step {step} in {time.perf_counter() - start_time:.2f}s " + f"(backlog={len(self._sample_upload_queue) + 1})" ) def _rollouts_to_parquet_bytes(self, rollouts: list[vf.RolloutOutput], step: int) -> bytes | None: @@ -392,105 +433,161 @@ def _rollouts_to_parquet_bytes(self, rollouts: list[vf.RolloutOutput], step: int pq.write_table(table, buf, compression="snappy", use_dictionary=True, write_statistics=True) return buf.getvalue() - def _upload_samples_via_presigned_url(self, parquet_bytes: bytes, step: int) -> None: - """Upload Parquet samples using presigned URL flow (fire-and-forget).""" + def _enqueue_and_drain_samples(self, step: int, parquet_bytes: bytes) -> None: + """Append a step's parquet to the backlog and trigger a drain on the bg loop.""" future = asyncio.run_coroutine_threadsafe( - self._upload_samples_via_presigned_url_async(parquet_bytes, step), + self._enqueue_and_drain_samples_async(step, parquet_bytes), self._loop, ) self._pending_futures.append(future) - # Clean up completed futures to avoid memory growth self._pending_futures = [f for f in self._pending_futures if not f.done()] - async def _upload_samples_via_presigned_url_async( - self, parquet_bytes: bytes, step: int, max_retries: int = 3 - ) -> None: - """Upload Parquet bytes via presigned URL flow.""" - try: - presign_data = await self._request_presigned_url(step) - if not presign_data: - self.logger.warning(f"Failed to get presigned URL for samples at step {step}") - return - - presigned_url = presign_data["presigned_url"] - s3_key = presign_data["s3_key"] + async def _enqueue_and_drain_samples_async(self, step: int, parquet_bytes: bytes) -> None: + """Backlog-aware upload: enqueue, then drain oldest-first under a lock.""" + if self._sample_upload_lock is None: + self._sample_upload_lock = asyncio.Lock() - upload_success = await self._upload_to_r2( - presigned_url, parquet_bytes, content_type="application/parquet", max_retries=max_retries + self._sample_upload_queue.append((step, parquet_bytes)) + while len(self._sample_upload_queue) > _MAX_PENDING_SAMPLE_UPLOADS: + dropped_step, _ = self._sample_upload_queue.popleft() + self._pending_sample_steps.discard(dropped_step) + self.logger.warning( + f"Sample upload backlog exceeded {_MAX_PENDING_SAMPLE_UPLOADS}, " + f"dropping oldest queued step {dropped_step}" ) - if not upload_success: - self.logger.warning(f"Failed to upload samples to R2 at step {step}") - return - confirm_success = await self._confirm_samples_upload(step, s3_key) - if not confirm_success: - self.logger.warning(f"Failed to confirm samples upload at step {step}") - return + await self._drain_sample_backlog() + + async def _drain_sample_backlog(self, *, ignore_cooldown: bool = False) -> None: + """Drain the sample backlog under the upload lock. Used by the regular + log_samples path and by close()/_flush() for a final shutdown attempt. + With ignore_cooldown=True, retries the head even if the retryable cooldown + is currently armed — last-chance attempt before the loop stops. + """ + if self._sample_upload_lock is None: + self._sample_upload_lock = asyncio.Lock() + if ignore_cooldown: + self._retryable_cooldown_until = None + + async with self._sample_upload_lock: + while self._sample_upload_queue: + # If a previous drain just hit a retryable failure, every coroutine + # queued on this lock would otherwise pop the same head and burn + # another full tenacity budget. Bail out and let a later tick (after + # the cooldown elapses) pick up the backlog. + if ( + not ignore_cooldown + and self._retryable_cooldown_until is not None + and time.monotonic() < self._retryable_cooldown_until + ): + return + + # Pop the in-flight item out of the queue entirely before awaiting, + # so concurrent eviction (which runs without the lock) cannot remove + # the head we are uploading. On retryable failure we appendleft to + # preserve ordering; on success or permanent failure the item simply + # stays gone. + pending_step, pending_bytes = self._sample_upload_queue.popleft() + start = time.perf_counter() + try: + await self._upload_one_sample_step(pending_step, pending_bytes) + except Exception as e: + if _is_retryable_upload_error(e): + # Transient — keep at the head and arm the cooldown so other + # waiters don't immediately retry the same head. + self._sample_upload_queue.appendleft((pending_step, pending_bytes)) + self._retryable_cooldown_until = time.monotonic() + _RETRYABLE_COOLDOWN_S + self.logger.opt(exception=True).warning( + f"Sample upload for step {pending_step} failed after retries; " + f"keeping in backlog (size={len(self._sample_upload_queue)}, " + f"cooldown={_RETRYABLE_COOLDOWN_S:.0f}s): " + f"{type(e).__name__}: {e}" + ) + return + + # Permanent (e.g. 4xx that's not 408/429): drop this step so it + # doesn't block the queue head, and continue draining the rest. + self._pending_sample_steps.discard(pending_step) + self.logger.opt(exception=True).warning( + f"Sample upload for step {pending_step} failed with non-retryable error; " + f"dropping step (queue size={len(self._sample_upload_queue)}): " + f"{type(e).__name__}: {e}" + ) + continue - self.last_log_samples_step = step - self.logger.debug(f"Successfully completed samples upload at step {step}") + # Success: bookkeeping (item is already popped). Clear any stale + # cooldown — R2 is healthy. + self._retryable_cooldown_until = None + self._pending_sample_steps.discard(pending_step) + self.last_log_samples_step = pending_step + self.logger.debug(f"Uploaded samples for step {pending_step} in {time.perf_counter() - start:.2f}s") - except Exception as e: - self.logger.warning(f"Failed to upload samples via presigned URL at step {step}: {type(e).__name__}: {e}") - finally: - self._pending_sample_steps.discard(step) + async def _upload_one_sample_step(self, step: int, parquet_bytes: bytes) -> None: + """Run presign → R2 PUT → confirm for a single step. Raises on final failure.""" + presign_data = await self._request_presigned_url(step) + presigned_url = presign_data["presigned_url"] + s3_key = presign_data["s3_key"] - async def _request_presigned_url(self, step: int) -> dict[str, Any] | None: - """Request a presigned URL from the backend.""" - try: - response = await self._client.post( - f"{self.base_url}/samples/presign", - headers=self._headers, - json={"run_id": self.run_id, "step": step}, + await self._upload_to_r2(presigned_url, parquet_bytes, step, content_type="application/parquet") + await self._confirm_samples_upload(step, s3_key) + + def _retry_policy(self, op_name: str, step: int) -> AsyncRetrying: + """Shared tenacity policy for sample-upload HTTP calls.""" + + def _log_retry(retry_state) -> None: + exc = retry_state.outcome.exception() if retry_state.outcome else None + self.logger.warning( + f"Retrying {op_name} for step {step} " + f"(attempt {retry_state.attempt_number}/{_UPLOAD_MAX_ATTEMPTS}): " + f"{type(exc).__name__ if exc else 'unknown'}: {exc}" ) - response.raise_for_status() - response_data = response.json()["data"] - return { - "presigned_url": response_data["presignedUrl"], - "s3_key": response_data["s3Key"], - } - except Exception as e: - self.logger.warning(f"Failed to request presigned URL: {type(e).__name__}: {e}") - return None + + return AsyncRetrying( + retry=retry_if_exception(_is_retryable_upload_error), + stop=stop_after_attempt(_UPLOAD_MAX_ATTEMPTS) | stop_after_delay(_UPLOAD_MAX_DELAY_S), + wait=wait_exponential(multiplier=1, min=1, max=_UPLOAD_BACKOFF_MAX_S) + wait_random(0, 1), + before_sleep=_log_retry, + reraise=True, + ) + + async def _request_presigned_url(self, step: int) -> dict[str, Any]: + """Request a presigned URL from the backend. Raises on final failure.""" + async for attempt in self._retry_policy("samples/presign", step): + with attempt: + response = await self._client.post( + f"{self.base_url}/samples/presign", + headers=self._headers, + json={"run_id": self.run_id, "step": step}, + ) + response.raise_for_status() + response_data = response.json()["data"] + return { + "presigned_url": response_data["presignedUrl"], + "s3_key": response_data["s3Key"], + } + raise RuntimeError("retry loop exited without returning") async def _upload_to_r2( - self, presigned_url: str, data: bytes, content_type: str = "application/json", max_retries: int = 3 - ) -> bool: - """Upload data to R2 using presigned URL.""" - for attempt in range(max_retries): - try: + self, presigned_url: str, data: bytes, step: int, content_type: str = "application/json" + ) -> None: + """Upload data to R2 via presigned URL. Raises on final failure.""" + async for attempt in self._retry_policy("R2 PUT", step): + with attempt: response = await self._client.put(presigned_url, content=data, headers={"Content-Type": content_type}) response.raise_for_status() - return True - except Exception as e: - if attempt == max_retries - 1: - self.logger.warning(f"Failed to upload to R2 after {max_retries} attempts: {type(e).__name__}: {e}") - return False - delay = 2**attempt - self.logger.debug(f"Retrying R2 upload in {delay}s (attempt {attempt + 1}/{max_retries})") - await asyncio.sleep(delay) - - async def _confirm_samples_upload(self, step: int, s3_key: str, max_retries: int = 3) -> bool: - """Confirm samples upload with the backend. Returns True on success.""" - for attempt in range(max_retries): - try: + return + + async def _confirm_samples_upload(self, step: int, s3_key: str) -> None: + """Confirm samples upload with the backend. Raises on final failure.""" + async for attempt in self._retry_policy("samples/confirm", step): + with attempt: response = await self._client.post( f"{self.base_url}/samples/confirm", headers=self._headers, json={"run_id": self.run_id, "step": step, "s3_key": s3_key}, ) response.raise_for_status() - return True - except Exception as e: - if attempt == max_retries - 1: - self.logger.warning( - f"Failed to confirm samples upload after {max_retries} attempts: {type(e).__name__}: {e}" - ) - return False - delay = 2**attempt - self.logger.debug(f"Retrying samples confirm in {delay}s (attempt {attempt + 1}/{max_retries})") - await asyncio.sleep(delay) - return False + return def log_eval_samples(self, rollouts: list[vf.RolloutOutput], env_name: str, step: int) -> None: pass @@ -615,6 +712,13 @@ def _init_async_client(self) -> None: self._thread.start() self._client = httpx.AsyncClient(timeout=30) self._pending_futures: list[asyncio.Future] = [] + # Sample-upload backlog. Lock is constructed lazily on the bg loop to bind + # to the right asyncio loop after fork. + self._sample_upload_queue: deque[tuple[int, bytes]] = deque() + self._sample_upload_lock: asyncio.Lock | None = None + # Set to a future monotonic timestamp after a retryable failure so queued + # drains don't all reattempt the same failing head back-to-back. + self._retryable_cooldown_until: float | None = None if hasattr(self, "_pending_sample_steps") and self._pending_sample_steps: self._pending_sample_steps.clear() @@ -628,21 +732,41 @@ def _run_event_loop(self) -> None: self._loop.run_forever() def _flush(self, timeout: float = 30.0) -> None: - """Wait for all pending async requests to complete.""" + """Wait for all pending async requests to complete and make a last-chance + attempt to drain any sample uploads still queued in the backlog. Without + this final drain, a retryable failure on the last log_samples tick would + leave its parquet bytes parked in the queue and lost on close. + """ if not self.enabled or not hasattr(self, "_loop"): return - if not self._pending_futures: - return - - self.logger.debug(f"Flushing {len(self._pending_futures)} pending request(s)") - for future in self._pending_futures: + if self._pending_futures: + self.logger.debug(f"Flushing {len(self._pending_futures)} pending request(s)") + for future in self._pending_futures: + try: + future.result(timeout=timeout) + except Exception as e: + self.logger.debug(f"Pending request completed with error: {e}") + self._pending_futures.clear() + + # Final shutdown drain of the sample backlog. Bypass the cooldown — this + # is our last chance before the loop stops. + if hasattr(self, "_sample_upload_queue") and self._sample_upload_queue: + backlog_size = len(self._sample_upload_queue) + self.logger.info(f"Final sample-upload drain on close ({backlog_size} step(s) queued)") try: + future = asyncio.run_coroutine_threadsafe( + self._drain_sample_backlog(ignore_cooldown=True), + self._loop, + ) future.result(timeout=timeout) except Exception as e: - self.logger.debug(f"Pending request completed with error: {e}") - - self._pending_futures.clear() + self.logger.warning(f"Final sample-upload drain failed: {type(e).__name__}: {e}") + remaining = len(self._sample_upload_queue) + if remaining: + self.logger.warning( + f"{remaining} sample upload(s) still in backlog at close — these samples will be lost" + ) async def _make_request_async(self, endpoint: str, data: dict[str, Any], max_retries: int = 3) -> None: """Make an async POST request to the Prime Intellect API with retries."""