Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions src/prime_rl/inference/vllm/worker/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from torch.nn import Module
from vllm.model_executor.model_loader import DefaultModelLoader, get_model_loader
from vllm.model_executor.model_loader.utils import process_weights_after_loading

from prime_rl.inference.vllm.worker.weight_transfer import load_weights_checkpoint_layerwise

# This is to get type hints for the Worker class but not actually extend it at runtime as this is required by vLLM worker extension
if TYPE_CHECKING:
Expand Down Expand Up @@ -46,8 +47,9 @@ def update_weights_from_path(self, weight_path: str) -> None:
allow_patterns_overrides=getattr(model, "allow_patterns_overrides", None),
)
weights_iterator = model_loader._get_weights_iterator(local_source)
model.load_weights(weights_iterator) # type: ignore

# Process weights after loading (important for some models)
device = next(model.parameters()).device
process_weights_after_loading(model, self.model_runner.model_config, device)
load_weights_checkpoint_layerwise(
model,
weights_iterator,
self.model_runner.model_config,
self.vllm_config,
)
28 changes: 13 additions & 15 deletions src/prime_rl/inference/vllm/worker/nccl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pickle
from typing import TYPE_CHECKING, Callable, Generator, cast
from typing import TYPE_CHECKING, Generator, cast

import torch
from torch.nn import Module
Expand All @@ -8,10 +8,9 @@
from vllm.logger import init_logger

from prime_rl.inference.vllm.worker.weight_transfer import (
load_weights_checkpoint,
load_weights_checkpoint_layerwise,
load_weights_kernel,
postprocess_weights_checkpoint,
postprocess_weights_kernel,
update_mla_absorbed_weights,
)
from prime_rl.utils.nccl import disable_nccl_p2p_if_unavailable

Expand Down Expand Up @@ -144,15 +143,14 @@ def update_weights_from_path(self, weight_dir: str) -> None:
assert isinstance(model, Module)

state_iter = self.nccl_broadcast_receiver.receive_state_dict()
device = next(model.parameters()).device
loader_fn: Callable[[Module, Generator[tuple[str, torch.Tensor], None, None]], None]
postprocess_fn: Callable[[Module, object, torch.device], None]
if self.quantize_in_weight_transfer:
loader_fn = load_weights_kernel
postprocess_fn = postprocess_weights_kernel
else:
loader_fn = load_weights_checkpoint
postprocess_fn = postprocess_weights_checkpoint

loader_fn(model, state_iter)
postprocess_fn(model, self.model_runner.model_config, device)
load_weights_kernel(model, state_iter)
update_mla_absorbed_weights(model)
return

load_weights_checkpoint_layerwise(
model,
state_iter,
self.model_runner.model_config,
self.vllm_config,
)
27 changes: 15 additions & 12 deletions src/prime_rl/inference/vllm/worker/weight_transfer.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
from typing import Generator
from typing import Generator, Iterable

import torch
from torch.nn import Module
from vllm.config import set_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.model_loader.utils import process_weights_after_loading
from vllm.model_executor.model_loader.reload import finalize_layerwise_reload, initialize_layerwise_reload

logger = init_logger("vllm.inference.vllm.worker_weight_transfer")


def load_weights_checkpoint(model: Module, state_iter: Generator[tuple[str, torch.Tensor], None, None]) -> None:
model.load_weights(state_iter) # type: ignore


def postprocess_weights_checkpoint(model: Module, model_config, device: torch.device) -> None:
process_weights_after_loading(model, model_config, device)
def load_weights_checkpoint_layerwise(
Comment thread
cursor[bot] marked this conversation as resolved.
model: Module,
state_iter: Iterable[tuple[str, torch.Tensor]],
model_config,
vllm_config,
) -> None:
logger.info("Reloading checkpoint-format weights with vLLM layerwise processing")
device = next(model.parameters()).device
with torch.device(device), set_current_vllm_config(vllm_config):
initialize_layerwise_reload(model)
model.load_weights(state_iter) # type: ignore
finalize_layerwise_reload(model, model_config)


def _invert_logical_to_physical_map(logical_to_physical_map: torch.Tensor, num_physical_experts: int) -> torch.Tensor:
Expand Down Expand Up @@ -143,7 +150,3 @@ def update_mla_absorbed_weights(model: Module) -> None:
module.W_UK_T.copy_(w_uk.permute(1, 2, 0))

logger.debug(f"Updated MLA absorbed weights for module {name}")


def postprocess_weights_kernel(model: Module, _model_config, _device: torch.device) -> None:
update_mla_absorbed_weights(model)
Loading