Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

Documenting **breaking** configuration changes — renamed, removed, or moved fields that require users to update existing configs.

- **`inference.kv_cache_offload.cpu_bytes` removed → discriminated `type` config**: The flat `[inference.kv_cache_offload]` block with a single `cpu_bytes` field is replaced by a backend-discriminated union with composable `cpu`/`disk` tiers. Migrate native CPU offload from `[inference.kv_cache_offload]\ncpu_bytes = N` to `[inference.kv_cache_offload]\ntype = "native"` plus `[inference.kv_cache_offload.cpu]\nnum_bytes = N`. A `type = "mooncake"` backend (per-node distributed store; multi-node/SLURM only) and an optional `[inference.kv_cache_offload.disk]\npath = "..."` tier (layered behind cpu) are also available. `extra="forbid"` rejects the old `cpu_bytes` key, so existing configs must migrate. (2026-06-02)
- **Orchestrator async-pipeline rewrite** (collection of removals/renames). The orchestrator was rewritten to overlap train/eval rollouts on a shared concurrency limiter; several config fields were removed or renamed.
- **`orchestrator.seed` removed**: was only consumed by the deleted buffer; no replacement.
- **`orchestrator.eval.eval_base_model` → `orchestrator.eval.skip_first_step`** (semantics inverted): `eval_base_model = true` becomes `skip_first_step = false` (the default — run the step-0 eval before any train rollouts). No alias; configs setting `eval_base_model` must rename.
Expand Down
32 changes: 27 additions & 5 deletions docs/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,16 +192,38 @@ X-Session-ID = "trajectory_id" # this is the default - each rollout has a unique

### KV Cache Offload

Maximizing KV-Cache space is crucial to support high-concurrency workloads. We allow you to offload the KV cache to CPU memory, which can increase the space 10-fold in some cases. You can configure the amount of CPU memory to use for the KV cache by setting `inference.deployment.kv_cache_offload.cpu_bytes`.
Maximizing KV-Cache space is crucial to support high-concurrency workloads. You can offload the KV cache to CPU memory (and, behind it, disk) by setting `inference.kv_cache_offload`. It is a discriminated config with two composable tiers, `cpu` and `disk`: a `cpu` tier is always required, and an optional `disk` tier is layered behind it (GPU → DRAM → disk). Disk-only is not supported.

The `type` field selects the backend:

- `native` — vLLM's built-in offloading. CPU-only uses `OffloadingConnector`; CPU+disk uses `TieringOffloadingSpec` (a CPU primary tier with a filesystem secondary tier). Fully self-contained — no extra processes.
- `mooncake` — a [Mooncake](https://github.com/kvcache-ai/Mooncake) **shared distributed store** (SLURM only). One `mooncake_master` + metadata server runs on the head inference node; every inference node runs a `mooncake_client` that contributes its DRAM (and, with `disk`, SSD) segment to that *single* pool. Because blocks are keyed by model + parallel rank + content hash (no instance id), a prefix cached by one node/replica is reusable by all of them over RDMA — pooling every node's CPU RAM into one KV cache. Use `native` for local/single-process runs.

```toml
# Native CPU offload (reserves 128GB of CPU KV cache for this instance)
[inference.kv_cache_offload]
cpu_bytes = 128_000_000_000 # 128GB
```
type = "native"
[inference.kv_cache_offload.cpu]
num_bytes = 128_000_000_000 # 128GB

This will reserve 128GB of CPU memory per worker. If you use dp=8, this will reserve 1TB of CPU memory per node.
# Native CPU + disk tiering (self-contained)
[inference.kv_cache_offload]
type = "native"
[inference.kv_cache_offload.cpu]
num_bytes = 128_000_000_000
[inference.kv_cache_offload.disk]
path = "/scratch/kv" # disk capacity is bounded by the filesystem

# Mooncake CPU + disk (per-node distributed store, RDMA)
[inference.kv_cache_offload]
type = "mooncake"
[inference.kv_cache_offload.cpu]
num_bytes = 128_000_000_000
[inference.kv_cache_offload.disk]
path = "/scratch/kv"
```

We aim to support more offloading options in the future, such as multi-tier offloading to also utilize disk-based KV cache, or distributed storage options like Mooncake Connector.
For `native`, `cpu.num_bytes` is the aggregate CPU KV pool for the instance (vLLM shards it across workers). For `mooncake`, `cpu.num_bytes` is the DRAM each node contributes to the shared pool (so the total pool ≈ `num_bytes × #inference-nodes`); the store uses RDMA, so it requires an RDMA-capable fabric. Enabling offload automatically enables prefix caching.


### Optimized P/D disaggregation deployment
Expand Down
118 changes: 102 additions & 16 deletions packages/prime-rl-configs/src/prime_rl/configs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,71 @@ class WeightBroadcastConfig(BaseConfig):
"""Weight broadcast transport."""


class KVCacheOffloadConfig(BaseConfig):
cpu_bytes: int = Field(1_000_000_000, gt=0)
"""CPU bytes available for KV cache offloading per worker."""
class CPUOffloadTier(BaseConfig):
num_bytes: int = Field(..., gt=0)
"""CPU/DRAM offload capacity. For the ``native`` backend this is vLLM's aggregate ``cpu_bytes_to_use`` (scaled across workers internally). For the ``mooncake`` backend this is the per-node store client's DRAM segment (``-global_segment_size``)."""


class DiskOffloadTier(BaseConfig):
path: Path
"""Filesystem root for the disk tier. For ``native`` this is the ``fs_python`` secondary tier's ``root_dir``; for ``mooncake`` it is the store client's ``MOONCAKE_OFFLOAD_FILE_STORAGE_PATH``. Capacity is bounded by the filesystem at ``path`` (neither backend enforces a byte quota)."""


class BaseKVCacheOffloadConfig(BaseConfig):
cpu: CPUOffloadTier | None = None
"""CPU/DRAM offload tier. Always required — disk-only offload is not supported."""

disk: DiskOffloadTier | None = None
"""Optional disk tier, layered behind the CPU tier (GPU → DRAM → disk)."""

@model_validator(mode="after")
def valid_tiers(self):
# Both backends support only two shapes: cpu-only or cpu+disk. Native disk
# tiering needs a CPU primary tier; Mooncake standalone-store needs a DRAM
# staging tier. Disk-only is rejected for both.
if self.cpu is None:
raise ValueError("inference.kv_cache_offload requires a cpu tier (disk-only offload is not supported).")
return self


class NativeKVCacheOffloadConfig(BaseKVCacheOffloadConfig):
type: Literal["native"] = "native"
"""vLLM-native offloading. cpu-only uses ``OffloadingConnector`` + ``CPUOffloadingSpec``; cpu+disk uses ``TieringOffloadingSpec`` (CPU primary tier + ``fs_python`` disk secondary). Fully self-contained — no external processes."""

def to_connector_dict(self) -> dict[str, Any]:
assert self.cpu is not None
extra: dict[str, Any] = {"cpu_bytes_to_use": int(self.cpu.num_bytes)}
if self.disk is not None:
extra["spec_name"] = "TieringOffloadingSpec"
extra["secondary_tiers"] = [{"type": "fs_python", "root_dir": str(self.disk.path)}]
return {
"kv_connector": "OffloadingConnector",
"kv_role": "kv_both",
"kv_connector_extra_config": extra,
}


class MooncakeKVCacheOffloadConfig(BaseKVCacheOffloadConfig):
type: Literal["mooncake"] = "mooncake"
"""Mooncake distributed store offloading (SLURM only). One ``mooncake_master`` + metadata server runs on the head inference node; every node runs a ``mooncake_client`` contributing its segment to the single shared pool, so prefixes cached on any node are reusable by all. The cpu tier sizes each node's DRAM segment; the optional disk tier adds an SSD tier."""

device_name: str = ""
"""RDMA device name(s) for the store (empty = auto-detect)."""

def to_connector_dict(self) -> dict[str, Any]:
# Addresses/sizes/tiers are realized by the per-node store launch in the sbatch
# template (MOONCAKE_CONFIG_PATH JSON); blocks are keyed by model + parallel rank +
# content hash (no instance id), so the shared pool is reused across nodes/replicas.
return {
"kv_connector": "MooncakeStoreConnector",
"kv_role": "kv_both",
"kv_connector_extra_config": {},
}


KVCacheOffloadConfig: TypeAlias = Annotated[
NativeKVCacheOffloadConfig | MooncakeKVCacheOffloadConfig, Field(discriminator="type")
]
Comment thread
cursor[bot] marked this conversation as resolved.


# Valid vLLM max_lora_rank values (from vllm/config/lora.py)
Expand Down Expand Up @@ -261,7 +323,10 @@ class InferenceConfig(BaseConfig):
weight_broadcast: WeightBroadcastConfig = WeightBroadcastConfig()

kv_cache_offload: KVCacheOffloadConfig | None = None
"""CPU KV cache offload for inference workers. Standard inference uses vLLM's ``OffloadingConnector``. Disaggregated P/D deployments combine it with NIXL through ``MultiConnector`` in the SLURM launcher."""
"""KV cache offload for inference workers, as composable CPU/disk tiers. Discriminated on ``type``: ``native`` (vLLM ``OffloadingConnector``/``TieringOffloadingSpec``, self-contained) or ``mooncake`` (per-node Mooncake distributed store). Disaggregated P/D combines the chosen connector with NIXL through ``MultiConnector``."""

use_pd_kv_transfer: bool = False
"""Auto-set for disaggregated P/D: emit the NIXL transfer connector. Persisted into the per-node config (which drops ``deployment``) so the connector is still built per worker. Not meant to be set by hand."""

enable_return_routed_experts: bool = False
"""Return routed experts in responses. Forwarded as ``--enable-return-routed-experts``."""
Expand Down Expand Up @@ -307,6 +372,7 @@ def auto_setup_kv_cache_offload(self):
def auto_setup_disaggregated(self):
"""Auto-configure inference for disaggregated P/D: enable EP and compute DP."""
if self.deployment.type == "disaggregated":
self.use_pd_kv_transfer = True
if "enable_expert_parallel" not in self.model_fields_set:
self.enable_expert_parallel = True
if "enable_eplb" not in self.model_fields_set:
Expand Down Expand Up @@ -368,6 +434,35 @@ def auto_setup_api_server_count(self):
self.api_server_count = 1 # LoRA requires only one API server
return self

def build_kv_transfer_config(self) -> dict[str, Any] | None:
"""Build the single vLLM ``kv_transfer_config`` from the transfer + offload connectors.

Disaggregated P/D always uses NIXL for prefill→decode transfer. KV cache offload (if
configured) contributes its own connector. When both are present they are composed via
``MultiConnector``. Returns None when neither applies.
"""
connectors: list[dict[str, Any]] = []
if self.use_pd_kv_transfer:
connectors.append(
{
"kv_connector": "NixlConnector",
"kv_role": "kv_both",
"kv_connector_extra_config": {"num_threads": 1},
}
)
if self.kv_cache_offload is not None:
connectors.append(self.kv_cache_offload.to_connector_dict())

if not connectors:
return None
if len(connectors) == 1:
return connectors[0]
return {
"kv_connector": "MultiConnector",
"kv_role": "kv_both",
"kv_connector_extra_config": {"connectors": connectors},
}

def to_vllm(self) -> Namespace:
"""Convert InferenceConfig to vLLM-compatible Namespace."""
namespace = Namespace()
Expand Down Expand Up @@ -411,18 +506,9 @@ def to_vllm(self) -> Namespace:
# Set `logprobs_mode` to `processed_logprobs` by default
rsetattr(namespace, "logprobs_mode", "processed_logprobs")

if self.kv_cache_offload is not None:
rsetattr(
namespace,
"kv_transfer_config",
{
"kv_connector": "OffloadingConnector",
"kv_role": "kv_both",
"kv_connector_extra_config": {
"cpu_bytes_to_use": int(self.kv_cache_offload.cpu_bytes),
},
},
)
kv_transfer_config = self.build_kv_transfer_config()
if kv_transfer_config is not None:
rsetattr(namespace, "kv_transfer_config", kv_transfer_config)

# Pass prime-rl-specific flags through vLLM's additional_config dict;
# workers read these via get_current_vllm_config().additional_config.
Expand Down
14 changes: 14 additions & 0 deletions packages/prime-rl-configs/src/prime_rl/configs/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,20 @@ def validate_router_replay_without_kv_offload(self):
)
return self

@model_validator(mode="after")
def validate_mooncake_offload_requires_slurm(self):
if (
self.slurm is None
and self.inference is not None
and self.inference.kv_cache_offload is not None
and self.inference.kv_cache_offload.type == "mooncake"
):
raise ValueError(
"Mooncake KV offload requires SLURM — the per-node store is launched by the sbatch "
"template. Use inference.kv_cache_offload.type='native' for local runs."
)
return self

@model_validator(mode="after")
def auto_setup_deployment(self):
if self.deployment.type == "single_node": # single-node
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies = [
"torchdata>=0.11.0",
"transformers",
"vllm>=0.22.0",
"mooncake-transfer-engine>=0.3.10.post2",
Comment thread
S1ro1 marked this conversation as resolved.
"wandb>=0.26.1",
"ring-flash-attn>=0.1.8",
"prime>=0.6.4",
Expand Down
10 changes: 8 additions & 2 deletions src/prime_rl/entrypoints/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def write_slurm_script(config: InferenceConfig, config_path: Path, script_path:
is_disaggregated = config.deployment.type == "disaggregated"
dp_per_node = config.deployment.gpus_per_node // config.parallel.tp

offload = config.kv_cache_offload
is_mooncake = offload is not None and offload.type == "mooncake"

template_vars = dict(
**config.slurm.template_vars,
config_path=config_path,
Expand All @@ -45,7 +48,11 @@ def write_slurm_script(config: InferenceConfig, config_path: Path, script_path:
num_nodes=getattr(config.deployment, "num_nodes", 1),
port=config.server.port,
disaggregated=is_disaggregated,
kv_offload=config.kv_cache_offload is not None,
kv_offload=offload is not None,
kv_offload_mooncake=is_mooncake,
kv_offload_cpu_bytes=int(offload.cpu.num_bytes) if is_mooncake else 0,
kv_offload_disk_path=str(offload.disk.path) if (is_mooncake and offload.disk is not None) else "",
kv_offload_device_name=offload.device_name if is_mooncake else "",
)

is_multi_node = config.deployment.type == "multi_node"
Expand All @@ -64,7 +71,6 @@ def write_slurm_script(config: InferenceConfig, config_path: Path, script_path:
use_deep_gemm=config.use_deep_gemm,
prefill_env_overrides=config.deployment.prefill_env_overrides,
decode_env_overrides=config.deployment.decode_env_overrides,
kv_offload_cpu_bytes=int(config.kv_cache_offload.cpu_bytes) if config.kv_cache_offload else 0,
)
elif is_multi_node:
template_vars.update(
Expand Down
17 changes: 12 additions & 5 deletions src/prime_rl/entrypoints/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,16 @@ def write_slurm_script(config: RLConfig, config_dir: Path, script_path: Path) ->
env = Environment(loader=FileSystemLoader(config.slurm.template_path.parent), keep_trailing_newline=True)
template = env.get_template(config.slurm.template_path.name)

offload = config.inference.kv_cache_offload if config.inference is not None else None
is_mooncake = offload is not None and offload.type == "mooncake"
mooncake_vars = dict(
kv_offload=offload is not None,
kv_offload_mooncake=is_mooncake,
kv_offload_cpu_bytes=int(offload.cpu.num_bytes) if is_mooncake else 0,
kv_offload_disk_path=str(offload.disk.path) if (is_mooncake and offload.disk is not None) else "",
kv_offload_device_name=offload.device_name if is_mooncake else "",
)

if config.deployment.type == "single_node":
script = template.render(
**config.slurm.template_vars,
Expand Down Expand Up @@ -370,10 +380,7 @@ def write_slurm_script(config: RLConfig, config_dir: Path, script_path: Path) ->
prefill_env_overrides=infer_deploy.prefill_env_overrides,
decode_env_overrides=infer_deploy.decode_env_overrides,
dp_per_node=config.deployment.gpus_per_node // config.inference.parallel.tp,
kv_offload=config.inference.kv_cache_offload is not None,
kv_offload_cpu_bytes=int(config.inference.kv_cache_offload.cpu_bytes)
if config.inference.kv_cache_offload
else 0,
**mooncake_vars,
use_nccl_broadcast=config.weight_broadcast is not None and config.weight_broadcast.type == "nccl",
ranks_filter=",".join(map(str, config.trainer.log.ranks_filter)),
)
Expand All @@ -395,7 +402,7 @@ def write_slurm_script(config: RLConfig, config_dir: Path, script_path: Path) ->
inference_enable_expert_parallel=config.inference.enable_expert_parallel if config.inference else False,
inference_data_parallel_rpc_port=config.inference.data_parallel_rpc_port if config.inference else 29600,
dp_per_node=(config.deployment.gpus_per_node // config.inference.parallel.tp) if config.inference else 1,
kv_offload=config.inference is not None and config.inference.kv_cache_offload is not None,
**mooncake_vars,
use_nccl_broadcast=config.weight_broadcast is not None and config.weight_broadcast.type == "nccl",
ranks_filter=",".join(map(str, config.trainer.log.ranks_filter)),
)
Expand Down
Loading
Loading