Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 62 additions & 33 deletions src/prime_rl/orchestrator/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def __init__(
self.checkpoint_ready.set()
self.update_weights_time, self.wait_for_ckpt_time = 0, 0
self.update_policy_task: asyncio.Task | None = None
self.inflight_policy_update_task: asyncio.Task | None = None
self.policy_update_lock = asyncio.Lock()
self.cancelled_rollouts_count = 0
self.last_batch_generation_time = 0.0

Expand Down Expand Up @@ -229,47 +231,71 @@ async def update_policy_loop(self):
await self.maybe_update_policy()
await asyncio.sleep(1)

async def maybe_update_policy(self):
"""Updates the policy to the latest available checkpoint. Aborts rollout requests that are older than the max retention steps."""
def _compute_next_ckpt_step(self) -> int:
latest_ckpt_step = get_latest_ckpt_step(get_broadcast_dir(self.config.output_dir)) or 0
async_away_ckpt_step = max(self.step - self.max_async_level, 0)
next_ckpt_step = (
async_away_ckpt_step if self.strict_async_level else max(async_away_ckpt_step, latest_ckpt_step)
)
if self.strict_async_level:
return async_away_ckpt_step
return max(async_away_ckpt_step, latest_ckpt_step)

if next_ckpt_step > self.ckpt_step:
if next_ckpt_step == async_away_ckpt_step:
self.logger.info(
f"Orchestrator paused: waiting for trainer process to complete checkpoint {next_ckpt_step} "
f"(>{self.max_async_level} step(s) ahead). Training is progressing normally."
)
self.checkpoint_ready.clear()
wait_for_ckpt_start_time = time.perf_counter()
await wait_for_path(get_step_path(get_broadcast_dir(self.config.output_dir), next_ckpt_step) / "STABLE")
self.wait_for_ckpt_time = time.perf_counter() - wait_for_ckpt_start_time
self.logger.info(
f"Orchestrator resumed: checkpoint {next_ckpt_step} ready (after {self.wait_for_ckpt_time:.2f}s)"
)

self.logger.debug(
f"Got new policy with step {next_ckpt_step}. Updating weights and cancelling old rollout requests."
async def _apply_policy_update(self, next_ckpt_step: int) -> None:
async_away_ckpt_step = max(self.step - self.max_async_level, 0)
if next_ckpt_step == async_away_ckpt_step:
Comment thread
samsja marked this conversation as resolved.
self.logger.info(
f"Orchestrator paused: waiting for trainer process to complete checkpoint {next_ckpt_step} "
f"(>{self.max_async_level} step(s) ahead). Training is progressing normally."
)
self.checkpoint_ready.clear()
wait_for_ckpt_start_time = time.perf_counter()
await wait_for_path(get_step_path(get_broadcast_dir(self.config.output_dir), next_ckpt_step) / "STABLE")
self.wait_for_ckpt_time = time.perf_counter() - wait_for_ckpt_start_time
self.logger.info(
f"Orchestrator resumed: checkpoint {next_ckpt_step} ready (after {self.wait_for_ckpt_time:.2f}s)"
)

self.logger.debug(
f"Got new policy with step {next_ckpt_step}. Updating weights and cancelling old rollout requests."
)

update_weights_start_time = time.perf_counter()
weights_path = get_step_path(get_broadcast_dir(self.config.output_dir), next_ckpt_step)
await self.inference_pool.update_weights(weights_path, lora_name=self.lora_name, step=next_ckpt_step)
self.update_weights_time = time.perf_counter() - update_weights_start_time
self.logger.debug(f"Updated weights to step {next_ckpt_step} in {self.update_weights_time:.2f}s")

# Update weights on inference servers
update_weights_start_time = time.perf_counter()
weights_path = get_step_path(get_broadcast_dir(self.config.output_dir), next_ckpt_step)
await self.inference_pool.update_weights(weights_path, lora_name=self.lora_name, step=next_ckpt_step)
self.update_weights_time = time.perf_counter() - update_weights_start_time
self.logger.debug(f"Updated weights to step {next_ckpt_step} in {self.update_weights_time:.2f}s")
self.ckpt_step = next_ckpt_step
if self.lora_name is not None:
self.model_name = self.lora_name
self.inference_pool.update_model_name(self.model_name)

if self.lora_name is not None:
self.model_name = self.lora_name
self.inference_pool.update_model_name(self.model_name)
self.checkpoint_ready.set()
await self._update_off_policy()

async def _get_or_start_policy_update_task(self, next_ckpt_step: int) -> asyncio.Task:
async with self.policy_update_lock:
task = self.inflight_policy_update_task
if task is not None and not task.done():
return task

task = asyncio.create_task(self._apply_policy_update(next_ckpt_step))
self.inflight_policy_update_task = task

def _clear_inflight_policy_update(done_task: asyncio.Task) -> None:
if self.inflight_policy_update_task is done_task:
self.inflight_policy_update_task = None

self.checkpoint_ready.set()
task.add_done_callback(_clear_inflight_policy_update)
return task

async def maybe_update_policy(self):
"""Updates the policy to the latest available checkpoint. Aborts rollout requests that are older than the max retention steps."""
while True:
next_ckpt_step = self._compute_next_ckpt_step()
if next_ckpt_step <= self.ckpt_step:
return

await self._update_off_policy()
self.ckpt_step = next_ckpt_step
task = await self._get_or_start_policy_update_task(next_ckpt_step)
await asyncio.shield(task)

async def _update_off_policy(self) -> None:
stale_group_ids = {
Expand Down Expand Up @@ -392,6 +418,9 @@ async def stop(self) -> None:
if self.update_policy_task is not None:
await safe_cancel(self.update_policy_task)
self.update_policy_task = None
if self.inflight_policy_update_task is not None:
await safe_cancel(self.inflight_policy_update_task)
self.inflight_policy_update_task = None

@property
def max_off_policy_level(self) -> int:
Expand Down
101 changes: 100 additions & 1 deletion tests/unit/orchestrator/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,34 @@
import asyncio
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock, patch

from prime_rl.orchestrator.scheduler import InflightRolloutInfo, Scheduler
from prime_rl.utils.async_utils import safe_cancel


def make_scheduler() -> Scheduler:
scheduler = Scheduler.__new__(Scheduler)
scheduler.max_async_level = 1
scheduler.strict_async_level = False
scheduler.step = 9
scheduler.ckpt_step = 7
scheduler.config = SimpleNamespace(output_dir=Path("/tmp/prime-rl-test"))
scheduler.logger = MagicMock()
scheduler.checkpoint_ready = asyncio.Event()
scheduler.checkpoint_ready.set()
scheduler.lora_name = None
scheduler.model_name = "test-model"
scheduler.update_weights_time = 0
scheduler.wait_for_ckpt_time = 0
scheduler.inflight_requests = {}
scheduler.groups = {}
scheduler.max_off_policy_steps = 1
scheduler.cancelled_rollouts_count = 0
scheduler.policy_update_lock = asyncio.Lock()
scheduler.inflight_policy_update_task = None
scheduler.update_policy_task = None
return scheduler


def test_update_off_policy_does_not_increment_interleaved_on_policy_tasks():
Expand Down Expand Up @@ -58,3 +84,76 @@ async def drop_group(group_id: int) -> int:
await asyncio.sleep(0)

asyncio.run(run())


def test_maybe_update_policy_reuses_inflight_update_after_cancellation():
async def run() -> None:
scheduler = make_scheduler()
started = asyncio.Event()
release = asyncio.Event()
applied_steps: list[int] = []

async def update_weights(weight_dir, lora_name=None, step=0) -> None:
applied_steps.append(step)
started.set()
await release.wait()

scheduler.inference_pool = SimpleNamespace(
update_weights=update_weights,
update_model_name=MagicMock(),
)
scheduler._update_off_policy = AsyncMock()

with (
patch("prime_rl.orchestrator.scheduler.get_latest_ckpt_step", return_value=8),
patch("prime_rl.orchestrator.scheduler.wait_for_path", new=AsyncMock()),
):
first = asyncio.create_task(scheduler.maybe_update_policy())
await started.wait()
await safe_cancel(first)

second = asyncio.create_task(scheduler.maybe_update_policy())
await asyncio.sleep(0)
assert applied_steps == [8]

release.set()
await second

assert applied_steps == [8]
assert scheduler.ckpt_step == 8

asyncio.run(run())


def test_stop_cancels_inflight_policy_update_task():
async def run() -> None:
scheduler = make_scheduler()
started = asyncio.Event()
cancelled = asyncio.Event()

async def update_weights(weight_dir, lora_name=None, step=0) -> None:
started.set()
try:
await asyncio.Future()
finally:
cancelled.set()

scheduler.inference_pool = SimpleNamespace(
update_weights=update_weights,
update_model_name=MagicMock(),
)
scheduler._update_off_policy = AsyncMock()

with (
patch("prime_rl.orchestrator.scheduler.get_latest_ckpt_step", return_value=8),
patch("prime_rl.orchestrator.scheduler.wait_for_path", new=AsyncMock()),
):
scheduler.update_policy_task = asyncio.create_task(scheduler.maybe_update_policy())
await started.wait()
await asyncio.wait_for(scheduler.stop(), timeout=0.2)

assert cancelled.is_set()
assert scheduler.update_policy_task is None
assert scheduler.inflight_policy_update_task is None

asyncio.run(run())