From a479a028124b7f09a71c575af9a03cded434990b Mon Sep 17 00:00:00 2001 From: hallerite Date: Sun, 10 May 2026 22:13:00 +0000 Subject: [PATCH 1/6] feat: drop use_renderer=True VLM skip; pack pixel_values from renderer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit interleave_rollout now consumes renderer-emitted multi_modal_data on each trajectory step (when present), packs the per-image pixel_values and image_grid_thw onto the TrainingSample, and computes mm_token_type_ids — no VLMImageCache lookup required when the rollout went through a multimodal renderer. VLMImageCache stays as the fallback for MITO/chat-completions rollouts. build_trajectory_step (renderers package) was updated separately to surface mm_data on its output, so the pretokenize-fallback path also carries images through correctly. Other changes: - orchestrator config: removed validate_renderer_vs_vlm validator that previously rejected use_renderer=True for VLMs (it's now supported). - e2e test: a real Qwen3VLRenderer + RendererClient -> /inference/v1/generate mock POST, plus a roundtrip through vllm's GenerateRequest pydantic model and decode_mm_kwargs_item. Strongest end-to-end check we can run without a GPU. - 20-step A/B configs for color-codeword (feat-renderer vs main-mito), both logging to W&B project 'multimodal-renderer'. 126 orchestrator unit tests pass (+2 new for renderer-mm packing). Co-Authored-By: Claude Opus 4.7 (1M context) --- .../rl_color_codeword_feat_renderer.toml | 69 +++++++ .../rl_color_codeword_main_mito.toml | 61 ++++++ .../src/prime_rl/configs/orchestrator.py | 14 -- src/prime_rl/orchestrator/trajectories.py | 79 ++++++-- tests/unit/orchestrator/test_qwen3_vl_e2e.py | 188 ++++++++++++++++++ tests/unit/orchestrator/test_trajectories.py | 168 ++++++++++++++++ 6 files changed, 553 insertions(+), 26 deletions(-) create mode 100644 configs/multimodal/rl_color_codeword_feat_renderer.toml create mode 100644 configs/multimodal/rl_color_codeword_main_mito.toml create mode 100644 tests/unit/orchestrator/test_qwen3_vl_e2e.py diff --git a/configs/multimodal/rl_color_codeword_feat_renderer.toml b/configs/multimodal/rl_color_codeword_feat_renderer.toml new file mode 100644 index 0000000000..576bc409dd --- /dev/null +++ b/configs/multimodal/rl_color_codeword_feat_renderer.toml @@ -0,0 +1,69 @@ +# 20-step Qwen3-VL-4B RL run on color-codeword using the renderer multimodal path. +# +# Pair with rl_color_codeword_main_mito.toml for an A/B comparison: same env, +# same hyperparameters, same step count — only difference is the inference +# client. The feat-branch run uses the new RendererClient + Qwen3VLRenderer +# (renderers package with multimodal support); the main-branch baseline uses +# the existing TITO chat-completions path through the inference server. +# +# Compare in W&B project ``multimodal-renderer``: +# - ``kl/sampler_vs_trainer`` should be ~0 on this branch (the renderer +# produces byte-identical tokens to what the trainer re-tokenizes) and +# can spike on main when BPE drifts mid-rollout. +# - ``reward`` and ``loss`` should track within noise — same model, same +# env, same hyperparameters. +# - ``bridge_break_rate`` is renderer-only; surfaces multi-turn extension +# failures. + +max_steps = 20 +seq_len = 4096 +output_dir = "outputs/rl_color_codeword_feat_renderer" +clean_output_dir = true + +[model] +name = "Qwen/Qwen3-VL-4B-Instruct" + +[model.vlm] +vision_encoder_attr = "model.visual" +language_model_attr = "model.language_model" + +[deployment] +num_train_gpus = 1 +num_infer_gpus = 1 +gpus_per_node = 2 + +[orchestrator] +batch_size = 16 +rollouts_per_example = 4 +use_renderer = true +use_token_client = false + +[orchestrator.train.sampling] +max_completion_tokens = 64 + +[[orchestrator.train.env]] +id = "color-codeword" +args = { images_per_turn = 1, max_turns = 2, num_examples = 100, seed = 42 } + +[orchestrator.renderer] +name = "auto" + +[trainer] + +[trainer.model] +optimization_dtype = "bfloat16" +reduce_dtype = "bfloat16" + +[trainer.optim] +lr = 3e-6 + +[inference] + +[inference.parallel] +dp = 1 +tp = 1 + +[wandb] +project = "multimodal-renderer" +name = "feat-renderer-20step" +tags = ["qwen3vl-4b", "color-codeword", "renderer", "feat-branch"] diff --git a/configs/multimodal/rl_color_codeword_main_mito.toml b/configs/multimodal/rl_color_codeword_main_mito.toml new file mode 100644 index 0000000000..9ae481fb69 --- /dev/null +++ b/configs/multimodal/rl_color_codeword_main_mito.toml @@ -0,0 +1,61 @@ +# 20-step Qwen3-VL-4B RL run on color-codeword using the existing MITO +# (chat-completions) inference path — baseline for the renderer A/B. +# +# Mirrors rl_color_codeword_feat_renderer.toml exactly except for the +# orchestrator client flags: this run goes through TITO/MITO so the +# inference server applies the chat template, runs the image processor +# server-side, and the orchestrator re-tokenizes locally. +# +# Run on ``main`` (renderers, verifiers, primerlmain). Compare metrics in +# W&B project ``multimodal-renderer`` against the feat run. + +max_steps = 20 +seq_len = 4096 +output_dir = "outputs/rl_color_codeword_main_mito" +clean_output_dir = true + +[model] +name = "Qwen/Qwen3-VL-4B-Instruct" + +[model.vlm] +vision_encoder_attr = "model.visual" +language_model_attr = "model.language_model" + +[deployment] +num_train_gpus = 1 +num_infer_gpus = 1 +gpus_per_node = 2 + +[orchestrator] +batch_size = 16 +rollouts_per_example = 4 +# MITO baseline: server-side chat templating + image processing. +use_renderer = false +use_token_client = true + +[orchestrator.train.sampling] +max_completion_tokens = 64 + +[[orchestrator.train.env]] +id = "color-codeword" +args = { images_per_turn = 1, max_turns = 2, num_examples = 100, seed = 42 } + +[trainer] + +[trainer.model] +optimization_dtype = "bfloat16" +reduce_dtype = "bfloat16" + +[trainer.optim] +lr = 3e-6 + +[inference] + +[inference.parallel] +dp = 1 +tp = 1 + +[wandb] +project = "multimodal-renderer" +name = "main-mito-20step" +tags = ["qwen3vl-4b", "color-codeword", "mito", "main-branch"] diff --git a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py index 3911816227..b466bafcc6 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py @@ -1230,20 +1230,6 @@ def validate_client_mode(self): ) return self - @model_validator(mode="after") - def validate_renderer_vs_vlm(self): - """The renderer client takes plain message dicts and tokenizes - them client-side. VLMs need server-side image preprocessing and - chat templating, so they must use the token client (TITO) — fail - loudly when both are set.""" - if self.use_renderer and self.model.vlm is not None: - raise ValueError( - "orchestrator.use_renderer is not supported for VLMs. Use the token client " - "(``use_token_client=true``, the default) so image preprocessing and chat " - "templating stay on the inference server." - ) - return self - @model_validator(mode="after") def validate_renderer_args(self): """``[orchestrator.renderer]`` knobs are only meaningful when diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index 3a45ee9ada..ef69a4ffe2 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -303,6 +303,12 @@ def prepare_step_tokens(step: vf.TrajectoryStep, step_idx: int) -> dict[str, Any "completion_mask": [bool(i) for i in tokens["completion_mask"]], "completion_logprobs": list(tokens["completion_logprobs"]), "routed_experts": tokens.get("routed_experts"), + # Renderer-emitted multimodal sidecar (placeholders + per-item + # processed tensors). Populated when the rollout went through + # a multimodal-aware renderer (e.g. Qwen3VLRenderer); absent + # for text-only or VLM-via-MITO rollouts (those fall back to + # the vlm_cache path below). + "multi_modal_data": tokens.get("multi_modal_data"), } logger.warning(f"Missing rollout tokens for example {output['example_id']} step {step_idx}.") @@ -406,22 +412,71 @@ def extend_sample(sample: TrainingSample, prefix_len: int, step_idx: int) -> Non new_prefix = tokens["prompt_ids"] + tokens["completion_ids"] active_samples.append((new_prefix, make_sample(tokens), step_idx)) - # Attach images once per sample using only the last merged step. Prompt - # tokens already contain fully expanded <|image_pad|> placeholders because - # VLMs go through MITO (chat completions), which runs apply_chat_template - # server-side; the orchestrator re-tokenizes via the same processor in the - # fallback path so features and tokens stay 1:1. - if vlm_cache is not None: - key = output["example_id"] if cache_key is None else cache_key - for _, sample, last_step_idx in active_samples: + # Attach images once per sample using only the last merged step. Two + # sources, in priority order: + # + # 1. **Renderer-emitted mm_data.** When the rollout client uses a + # multimodal-aware Renderer (e.g. Qwen3VLRenderer via RendererClient), + # each step's ``multi_modal_data`` carries the cumulative per-image + # ``pixel_values`` + ``image_grid_thw`` tensors plus placeholder + # offsets. The bridge merges previous-turn images into the new turn's + # mm_data, so the last merged step's sidecar covers every image in + # the sample. We pack those tensors into the bytes-shaped contract + # TrainingSample expects. + # + # 2. **VLMImageCache fallback.** For VLMs that go through MITO (chat + # completions), the renderer never sees the images — vLLM applies + # the chat template server-side and the orchestrator re-extracts + # PIL images from the rollout's data-URLs in a separate pass. + # Prompt tokens still contain fully expanded ``<|image_pad|>`` + # placeholders because the orchestrator re-tokenizes through the + # same processor. + # + # ``mm_token_type_ids`` is computed identically in both branches once + # the sample's prompt+completion ids are known. + def _pack_pixel_values_from_renderer( + mm_data: Any, + ) -> tuple[bytes | None, list[int] | None, list[list[int]] | None]: + items = (mm_data.mm_items or {}).get("image") or [] + if not items: + return None, None, None + pv_tensors = [it["pixel_values"] for it in items] + thw_tensors = [it["image_grid_thw"] for it in items] + pv = torch.cat(pv_tensors, dim=0) + thw = torch.cat(thw_tensors, dim=0) + # TrainingSample wants raw float32 bytes; the trainer-side decoder + # reads them back into a tensor of pixel_values_shape on load. + return ( + pv.to(torch.float32).contiguous().numpy().tobytes(), + list(pv.shape), + thw.to(torch.int64).tolist(), + ) + + def _apply_mm_token_type_ids(sample: TrainingSample) -> None: + if mm_token_type_ids_mapping is None: + return + sample.mm_token_type_ids = [ + mm_token_type_ids_mapping.get(token_id, 0) for token_id in sample.prompt_ids + sample.completion_ids + ] + + for _, sample, last_step_idx in active_samples: + renderer_mm = prepared_steps[last_step_idx].get("multi_modal_data") + if renderer_mm is not None: + pv, shape, grids = _pack_pixel_values_from_renderer(renderer_mm) + if pv is not None: + sample.pixel_values = pv + sample.pixel_values_shape = shape + sample.image_grid_thw = grids + _apply_mm_token_type_ids(sample) + continue + + if vlm_cache is not None: + key = output["example_id"] if cache_key is None else cache_key pv, shape, grids = vlm_cache.get_for_step(key, last_step_idx) sample.pixel_values = pv sample.pixel_values_shape = shape sample.image_grid_thw = grids - if mm_token_type_ids_mapping is not None: - sample.mm_token_type_ids = [ - mm_token_type_ids_mapping.get(token_id, 0) for token_id in sample.prompt_ids + sample.completion_ids - ] + _apply_mm_token_type_ids(sample) return [sample for _, sample, _ in active_samples] diff --git a/tests/unit/orchestrator/test_qwen3_vl_e2e.py b/tests/unit/orchestrator/test_qwen3_vl_e2e.py new file mode 100644 index 0000000000..3319657022 --- /dev/null +++ b/tests/unit/orchestrator/test_qwen3_vl_e2e.py @@ -0,0 +1,188 @@ +"""End-to-end integration test for the Qwen3-VL renderer path. + +Walks a multimodal request through the full client stack — RendererClient +→ renderers.client.generate → /inference/v1/generate features payload — +with the HTTP layer mocked, and verifies that vLLM can deserialize the +features back into engine inputs identical to what its own server-side +processor would have produced for the same messages. + +This is the strongest end-to-end check we can run without a GPU. The +remaining missing piece (vLLM actually consuming the engine input, +sampling tokens, and returning them) is exercised in real rollouts. +""" + +from __future__ import annotations + +import asyncio +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock + +import pytest + +_HF_CACHE = Path("~/.cache/huggingface/hub").expanduser() +_MODEL = "Qwen/Qwen3-VL-4B-Instruct" + + +def _model_cached() -> bool: + safe = "models--" + _MODEL.replace("/", "--") + snapshots = _HF_CACHE / safe / "snapshots" + if not snapshots.is_dir(): + return False + return any(p.is_dir() for p in snapshots.iterdir()) + + +pytestmark = pytest.mark.skipif( + not _model_cached(), + reason=f"{_MODEL}: HF snapshot not cached locally", +) + + +class _FakeOpenAI: + """Minimal AsyncOpenAI stand-in that captures POST bodies. + + The renderer client calls ``client.post(absolute_url, body=...)``; + we capture the body for assertions and return a canned generate + response so the parse-side of the flow runs. + """ + + def __init__(self): + self.calls: list[dict[str, Any]] = [] + self.base_url = "http://fake-host:8000/v1" + + async def post(self, path, *, cast_to=dict, body=None, options=None): + self.calls.append({"path": path, "body": body, "options": options}) + # Reply with two sampled tokens + <|im_end|>. The renderer's + # parse_response slices the content tokens. + return { + "request_id": "qwen-vl-e2e", + "choices": [ + { + "index": 0, + "token_ids": [50, 60, 151645], + "logprobs": { + "content": [ + {"token": "t1", "logprob": -0.1}, + {"token": "t2", "logprob": -0.2}, + {"token": "t3", "logprob": -0.3}, + ] + }, + "finish_reason": "stop", + }, + ], + } + + +def test_renderer_client_qwen3_vl_e2e_features_payload_roundtrips_through_vllm(): + """Walk a Qwen3-VL multimodal turn through the renderer client and + verify the resulting ``/inference/v1/generate`` body has a valid + ``features`` payload that: + + 1. parses through vLLM's ``GenerateRequest`` pydantic model, + 2. decodes back to ``MultiModalKwargsItem`` instances carrying + ``pixel_values`` + ``image_grid_thw`` of the right shapes, + 3. has placeholder ranges that exactly cover the ``<|image_pad|>`` + runs in the prompt token sequence. + """ + from PIL import Image + from renderers.base import load_tokenizer + from renderers.qwen3_vl import Qwen3VLRenderer + from transformers import AutoProcessor + from verifiers.clients.renderer_client import RendererClient + from verifiers.types import ( + ClientConfig, + UserMessage, + ) + from vllm.entrypoints.serve.disagg.mm_serde import decode_mm_kwargs_item + from vllm.entrypoints.serve.disagg.protocol import GenerateRequest + + # ── Build a real Qwen3VLRenderer with a real processor. ───────────── + tokenizer = load_tokenizer(_MODEL) + processor = AutoProcessor.from_pretrained(_MODEL) + renderer = Qwen3VLRenderer(tokenizer, processor=processor) + + image_pad_id = tokenizer.convert_tokens_to_ids("<|image_pad|>") + + # ── Manually wire a RendererClient bypassing the pool factory. ────── + client_cfg = ClientConfig(client_type="renderer", base_url="http://fake-host:8000/v1") + rc = object.__new__(RendererClient) + rc._config = client_cfg + rc._renderer = renderer + rc._pool_size = 1 + rc._client = _FakeOpenAI() + rc.logger = MagicMock() + + # ── Build a verifiers-shaped user message with an image. ──────────── + img = Image.new("RGB", (224, 224), color=(64, 128, 255)) + # The renderer accepts the OpenAI ``image_url`` content-part shape — + # the same shape verifiers' UserMessage carries through. + user = UserMessage( + content=[ + {"type": "text", "text": "What's in this picture?"}, + # Embed the PIL image directly. The verifiers→renderer message + # converter forwards content unchanged for our purposes. + {"type": "image", "image": img}, + ] + ) + + # to_native_prompt converts to renderer-shaped messages. + prompt, _ = asyncio.run(rc.to_native_prompt([user])) + sampling = {"max_tokens": 16} + + response = asyncio.run( + rc.get_native_response( + prompt=prompt, + model=_MODEL, + sampling_args=sampling, + tools=None, + ) + ) + + # ── The HTTP body should carry a features payload. ────────────────── + fake = rc.client + assert isinstance(fake, _FakeOpenAI) + assert len(fake.calls) == 1 + body = fake.calls[0]["body"] + assert "features" in body, "RendererClient should ship features for image content" + features = body["features"] + + # ── Pydantic-roundtrip through vLLM's GenerateRequest model. ──────── + gen_req = GenerateRequest( + token_ids=body["token_ids"], + features=features, + sampling_params=body["sampling_params"], + ) + assert gen_req.features is not None + assert "image" in gen_req.features.mm_hashes + assert len(gen_req.features.mm_hashes["image"]) == 1 + + # ── Placeholder anchoring: the offset/length in features must land + # exactly on a run of <|image_pad|> ids in the prompt. ─────────── + placeholders = gen_req.features.mm_placeholders["image"] + assert len(placeholders) == 1 + ph = placeholders[0] + pad_slice = body["token_ids"][ph.offset : ph.offset + ph.length] + assert all(t == image_pad_id for t in pad_slice), ( + f"placeholder span ({ph.offset}, {ph.length}) does not cover image_pad tokens; slice={pad_slice[:8]}..." + ) + + # ── kwargs_data decodes to MultiModalKwargsItem with the right keys. ─ + assert gen_req.features.kwargs_data is not None + encoded_items = gen_req.features.kwargs_data["image"] + assert len(encoded_items) == 1 + item = decode_mm_kwargs_item(encoded_items[0]) + assert set(item.keys()) == {"pixel_values", "image_grid_thw"} + + # The image_grid_thw must match what the processor would have produced + # for the same PIL image — byte-identity with the MITO path on this + # field is the strongest signal that the engine will see the same + # image features either way. + direct_proc_out = processor.image_processor(images=[img], return_tensors="pt") + expected_grid = direct_proc_out["image_grid_thw"][0].tolist() + assert item["image_grid_thw"].data.tolist() == expected_grid + + # ── Response parsed through renderer's parse_response. ────────────── + assert response["completion_ids"] == [50, 60, 151645] + # multi_modal_data surfaces on the result so the caller can persist it. + assert response["multi_modal_data"] is not None + assert len(response["multi_modal_data"].mm_items["image"]) == 1 diff --git a/tests/unit/orchestrator/test_trajectories.py b/tests/unit/orchestrator/test_trajectories.py index 6fa169760c..b9982a398e 100644 --- a/tests/unit/orchestrator/test_trajectories.py +++ b/tests/unit/orchestrator/test_trajectories.py @@ -2188,3 +2188,171 @@ def test_build_vlm_image_cache_uses_store(): assert pv is not None assert shape == [1, 1] assert grid == [[1, 1, 1]] + + +# ── Renderer-emitted multimodal data ─────────────────────────────────── + + +def test_interleave_rollout_packs_pixels_from_renderer_mm_data(): + """When the rollout's trajectory step carries renderer-emitted + ``multi_modal_data`` (e.g. from Qwen3VLRenderer via RendererClient), + ``interleave_rollout`` packs pixel_values / image_grid_thw / + mm_token_type_ids onto the TrainingSample from that sidecar — no + VLMImageCache lookup required. + + The bridge in the renderer merges previous-turn images into the new + turn's mm_data, so the last merged step's sidecar covers every image + in the sample (cumulative semantics, matching VLMImageCache). + """ + import torch as _torch + from renderers.base import MultiModalData, PlaceholderRange + + # Two synthetic single-image items — values are arbitrary, what + # matters is that the packer concatenates them correctly. + item1_pv = _torch.tensor([[1.0, 2.0]], dtype=_torch.float32) + item2_pv = _torch.tensor([[3.0, 4.0]], dtype=_torch.float32) + item1_thw = _torch.tensor([[1, 2, 3]], dtype=_torch.int64) + item2_thw = _torch.tensor([[1, 4, 4]], dtype=_torch.int64) + + mm_step_0 = MultiModalData( + mm_hashes={"image": ["h1"]}, + mm_placeholders={"image": [PlaceholderRange(offset=1, length=1)]}, + mm_items={"image": [{"pixel_values": item1_pv, "image_grid_thw": item1_thw}]}, + ) + mm_step_1 = MultiModalData( + mm_hashes={"image": ["h1", "h2"]}, + mm_placeholders={ + "image": [ + PlaceholderRange(offset=1, length=1), + PlaceholderRange(offset=4, length=1), + ] + }, + mm_items={ + "image": [ + {"pixel_values": item1_pv, "image_grid_thw": item1_thw}, + {"pixel_values": item2_pv, "image_grid_thw": item2_thw}, + ] + }, + ) + + output = vf.RolloutOutput( + example_id=1, + trajectory=[ + vf.TrajectoryStep( + prompt=[{"role": "user", "content": "Turn 1"}], + completion=[{"role": "assistant", "content": "Response 1"}], + response=MagicMock(), + tokens=vf.TrajectoryStepTokens( + prompt_ids=[1, 2], + prompt_mask=[0, 0], + completion_ids=[3, 4], + completion_mask=[1, 1], + completion_logprobs=[-0.1, -0.2], + overlong_prompt=False, + is_truncated=False, + multi_modal_data=mm_step_0, + ), + reward=None, + advantage=None, + is_truncated=False, + trajectory_id="1", + extras={}, + ), + vf.TrajectoryStep( + prompt=[{"role": "user", "content": "Turn 2"}], + completion=[{"role": "assistant", "content": "Response 2"}], + response=MagicMock(), + tokens=vf.TrajectoryStepTokens( + prompt_ids=[1, 2, 3, 4, 5], + prompt_mask=[0, 0, 0, 0, 0], + completion_ids=[6, 7], + completion_mask=[1, 1], + completion_logprobs=[-0.3, -0.4], + overlong_prompt=False, + is_truncated=False, + multi_modal_data=mm_step_1, + ), + reward=None, + advantage=None, + is_truncated=False, + trajectory_id="1", + extras={}, + ), + ], + sampling_args={"temperature": 1.0}, + error=None, + ) + + # Token 2 is the image placeholder, token 5 is the video placeholder. + mm_mapping = {2: 1, 5: 2} + # No vlm_cache — the renderer sidecar should fully cover the path. + rollouts = interleave_rollout(output, vlm_cache=None, mm_token_type_ids_mapping=mm_mapping) + + assert rollouts is not None and len(rollouts) == 1 + sample = rollouts[0] + # Extension holds; both steps merge into one sample with the last + # step's cumulative mm_data. + assert sample.prompt_ids == [1, 2] + assert sample.completion_ids == [3, 4, 5, 6, 7] + # Pixel values packed from step 1's two items, concatenated. + assert _decode_pixels(sample.pixel_values, sample.pixel_values_shape) == [ + [1.0, 2.0], + [3.0, 4.0], + ] + assert sample.image_grid_thw == [[1, 2, 3], [1, 4, 4]] + # mm_token_type_ids: image at token 2, video at token 5, rest 0. + assert sample.mm_token_type_ids == [0, 1, 0, 0, 2, 0, 0] + + +def test_interleave_rollout_renderer_mm_data_wins_over_vlm_cache(): + """When both renderer mm_data AND vlm_cache are present, renderer + mm_data wins — the rollout came through a multimodal-aware renderer + so the placeholder offsets and processed tensors are authoritative.""" + import torch as _torch + from renderers.base import MultiModalData, PlaceholderRange + + renderer_pv = _torch.tensor([[7.0]], dtype=_torch.float32) + renderer_thw = _torch.tensor([[1, 9, 9]], dtype=_torch.int64) + mm = MultiModalData( + mm_hashes={"image": ["render"]}, + mm_placeholders={"image": [PlaceholderRange(offset=1, length=1)]}, + mm_items={"image": [{"pixel_values": renderer_pv, "image_grid_thw": renderer_thw}]}, + ) + + # VLM cache populated with a DIFFERENT image — we shouldn't see this. + cache_data = {1: [(*_pixels([[99.0]]), [[1, 2, 2]])]} + cache = VLMImageCache(cache_data, num_unique_examples=1, extract_time=0.0, preprocess_time=0.0) + + output = vf.RolloutOutput( + example_id=1, + trajectory=[ + vf.TrajectoryStep( + prompt=[{"role": "user", "content": "Turn 1"}], + completion=[{"role": "assistant", "content": "Response 1"}], + response=MagicMock(), + tokens=vf.TrajectoryStepTokens( + prompt_ids=[1, 2], + prompt_mask=[0, 0], + completion_ids=[3, 4], + completion_mask=[1, 1], + completion_logprobs=[-0.1, -0.2], + overlong_prompt=False, + is_truncated=False, + multi_modal_data=mm, + ), + reward=None, + advantage=None, + is_truncated=False, + trajectory_id="1", + extras={}, + ), + ], + sampling_args={"temperature": 1.0}, + error=None, + ) + + rollouts = interleave_rollout(output, vlm_cache=cache, mm_token_type_ids_mapping={2: 1}) + + assert rollouts is not None and len(rollouts) == 1 + assert _decode_pixels(rollouts[0].pixel_values, rollouts[0].pixel_values_shape) == [[7.0]] + assert rollouts[0].image_grid_thw == [[1, 9, 9]] From b385fbf2bdb5857cf0d94396f8994509f9c09d8c Mon Sep 17 00:00:00 2001 From: hallerite Date: Mon, 11 May 2026 14:39:39 +0000 Subject: [PATCH 2/6] refactor(orchestrator): rip MITO multimodal path, renderer-only for VLMs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Delete the MITO chat-completions multimodal branch from the orchestrator and the ~370 lines of image-cache/preprocess machinery in trajectories.py that supported it. VLM training now goes through the renderer path exclusively — the renderer owns the processor, ships byte-identical pixel_values to both vLLM (via /inference/v1/generate features) and the trainer (via mm_kwargs). Renderer-shipped mm_token_type_ids: the orchestrator reads the renderer's `mm_token_type_id_map` property (1=image_pad, 2=video_pad) and stamps a per-token list onto each TrainingSample. Trainer's `_get_qwen3_vl_mm_token_type_ids` auto-path remains as a fallback but the renderer is now the source of truth. forward() now takes a generic `mm_kwargs: dict` (e.g. {pixel_values, image_grid_thw}) instead of the Qwen3-VL-specific (pixel_values, image_grid_thw) keyword pair, so adding new VLM families (Gemma3, LLaVA, etc.) doesn't require touching forward. Config validator: orchestrator.use_renderer must be true when model.vlm is set — fail at config-load instead of producing cryptic runtime errors. Test cleanup: drop 25 tests for removed helpers (VLMImageCache, _extract_images, etc.); update the one remaining renderer-mm trajectory test to pass `mm_token_type_ids_mapping` directly. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../rl_color_codeword_feat_renderer.toml | 39 +- .../src/prime_rl/configs/orchestrator.py | 22 + src/prime_rl/orchestrator/orchestrator.py | 105 +- src/prime_rl/orchestrator/trajectories.py | 566 +------- src/prime_rl/trainer/batch.py | 10 +- src/prime_rl/trainer/model.py | 41 +- src/prime_rl/trainer/rl/data.py | 48 +- src/prime_rl/trainer/rl/train.py | 18 +- src/prime_rl/transport/types.py | 30 +- tests/unit/orchestrator/test_trajectories.py | 1280 +---------------- 10 files changed, 260 insertions(+), 1899 deletions(-) diff --git a/configs/multimodal/rl_color_codeword_feat_renderer.toml b/configs/multimodal/rl_color_codeword_feat_renderer.toml index 576bc409dd..68c0c66f4b 100644 --- a/configs/multimodal/rl_color_codeword_feat_renderer.toml +++ b/configs/multimodal/rl_color_codeword_feat_renderer.toml @@ -19,6 +19,10 @@ max_steps = 20 seq_len = 4096 output_dir = "outputs/rl_color_codeword_feat_renderer" clean_output_dir = true +# Pure on-policy: inference can't run ahead of training, so every rollout +# is generated from the latest policy weights. Removes async/off-policy +# drift as a confound for the sampler-vs-trainer KL. +max_async_level = 0 [model] name = "Qwen/Qwen3-VL-4B-Instruct" @@ -34,19 +38,38 @@ gpus_per_node = 2 [orchestrator] batch_size = 16 -rollouts_per_example = 4 +rollouts_per_example = 8 use_renderer = true use_token_client = false +# Track zero-advantage groups but don't drop them — we're validating the +# multimodal renderer path on 20 steps, not optimizing training efficiency. +# Step 0 on Qwen3-VL-4B vs color-codeword is likely uniform (all-correct or +# all-wrong) so enforce=True would crash before any training happens. +[[orchestrator.filters]] +type = "gibberish" + +[[orchestrator.filters]] +type = "repetition" + +[[orchestrator.filters]] +type = "zero_advantage" +enforce = false + [orchestrator.train.sampling] max_completion_tokens = 64 [[orchestrator.train.env]] id = "color-codeword" -args = { images_per_turn = 1, max_turns = 2, num_examples = 100, seed = 42 } +args = { images_per_turn = 2, max_turns = 2, num_examples = 100, seed = 42 } [orchestrator.renderer] name = "auto" +# 64 concurrent rollouts (batch_size=16 × rollouts_per_example=4) want +# more than one tokenizer slot to avoid serialization queueing. The +# image processor (CPU-bound) dominates for VLMs so returns diminish +# past 4; bump to 4 as the default for multimodal runs. +pool_size = 4 [trainer] @@ -59,11 +82,19 @@ lr = 3e-6 [inference] +[inference.model] +# Workaround for vLLM 0.20.1 Qwen3-VL deepstack buffer bug: when num_scheduled_tokens +# (188) gets padded up to the next cudagraph_capture_size (192), the model's +# _set_deepstack_input_embeds sizes the buffer to 188 but forward() runs with 192, +# triggering "Requested more deepstack tokens than available in buffer". Eager mode +# skips the padding so num_input_tokens == num_scheduled_tokens. +enforce_eager = true + [inference.parallel] dp = 1 tp = 1 [wandb] project = "multimodal-renderer" -name = "feat-renderer-20step" -tags = ["qwen3vl-4b", "color-codeword", "renderer", "feat-branch"] +name = "feat-renderer-20step-r8-i2-t2-onpolicy" +tags = ["qwen3vl-4b", "color-codeword", "renderer", "feat-branch", "mm-kwargs-generic", "on-policy"] diff --git a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py index b466bafcc6..4ba86fe70a 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py @@ -1255,6 +1255,28 @@ def validate_renderer_args(self): ) return self + @model_validator(mode="after") + def vlm_requires_renderer(self): + """VLMs (``[model.vlm]`` block set) must go through the renderer. + + The MITO path for VLMs (chat-completions + server-side image + stripping + orchestrator-side AutoProcessor + VLMImageCache) was + removed: it duplicated processor work, hardcoded a Qwen-VL + tensor schema, and produced a token stream the trainer could + only reconstruct because the orchestrator re-tokenized through + the same processor. The renderer path owns the processor + per-slot, produces byte-identical tokens, and ships generic + ``mm_kwargs`` keyed by whatever the model's forward signature + expects. + """ + if self.model.vlm is not None and not self.use_renderer: + raise ValueError( + "orchestrator.use_renderer must be true when model.vlm is set. " + "The MITO path for VLMs has been removed; VLMs must go through " + "a renderer (e.g. Qwen3VLRenderer) that owns the processor." + ) + return self + @model_validator(mode="after") def nccl_max_async_level(self): if self.weight_broadcast.type == "nccl": diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index bc1128ebc7..87a4253835 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -12,7 +12,6 @@ from prime_rl.orchestrator.inference_metrics import InferenceMetricsCollector from prime_rl.orchestrator.patches import monkey_patch_chat_completion_logprobs, monkey_patch_oai_iterable_types from prime_rl.orchestrator.trajectories import ( - build_vlm_image_cache, interleave_rollout, offload_images_to_disk, pretokenize_rollout_trajectory, @@ -33,7 +32,6 @@ import pandas as pd import verifiers as vf from renderers.base import create_renderer -from transformers import AutoProcessor from prime_rl.configs.orchestrator import OrchestratorConfig from prime_rl.orchestrator.buffer import Buffer @@ -134,20 +132,9 @@ async def orchestrate(config: OrchestratorConfig): else: teacher_inference_pool = None - # Check if this is a vision-language model (used throughout for VLM-specific paths) - is_vlm = config.model.vlm is not None - - # Load tokenizer and processor (processor only for VLM models) logger.info(f"Initializing tokenizer ({config.tokenizer})") tokenizer = setup_tokenizer(config.tokenizer) - processor = None - if is_vlm: - logger.info(f"Loading VLM processor for {config.model.name}") - processor = AutoProcessor.from_pretrained( - config.model.name, trust_remote_code=config.model.trust_remote_code, use_fast=True - ) - renderer, inference_pool = await setup_rollout_inference_pool( config=config, rollout_client_config=rollout_client_config, @@ -156,6 +143,18 @@ async def orchestrate(config: OrchestratorConfig): logger=logger, ) + # Token-id → modality marker (1 = image patch, 2 = video patch) used + # to build ``mm_token_type_ids`` per sample. The renderer is the + # single source of truth — it already knows its own special-token + # IDs (``<|image_pad|>`` etc.) from the tokenizer it owns, so the + # orchestrator never needs to load a separate ``AutoProcessor``. + # Text-only renderers expose an empty map (or no attribute). + mm_token_type_ids_mapping: dict[int, int] | None = ( + getattr(renderer, "mm_token_type_id_map", None) if renderer is not None else None + ) + if mm_token_type_ids_mapping == {}: + mm_token_type_ids_mapping = None + # Setup monitor (may register the run and set RUN_ID in the environment) logger.info(f"Initializing monitor (wandb={config.wandb}, prime_monitor={config.prime_monitor})") monitor = setup_monitor( @@ -470,69 +469,41 @@ async def orchestrate(config: OrchestratorConfig): save_rollouts, train_rollouts, step_path / "train_rollouts.jsonl", exclude_keys={"trajectory"} ) - # VLM: offload base64 images to disk immediately to free memory - if is_vlm: - offload_start = time.perf_counter() - num_offloaded = offload_images_to_disk(train_rollouts, config.output_dir) - if num_offloaded: - logger.info( - f"VLM offloaded {num_offloaded} unique images to disk in {time.perf_counter() - offload_start:.2f}s" - ) + # Offload base64 images to disk to free memory. No-op for text-only + # rollouts (no ``data:image`` URLs to find); cheap to call always. + offload_start = time.perf_counter() + num_offloaded = offload_images_to_disk(train_rollouts, config.output_dir) + if num_offloaded: + logger.info( + f"Offloaded {num_offloaded} unique images to disk in {time.perf_counter() - offload_start:.2f}s" + ) # Convert rollouts to training samples parallel_preprocess_start = time.perf_counter() - # Stage 1: pretokenize + (for VLM) build image cache concurrently. # Pretokenize is a no-op when the renderer client already populated - # `tokens` on each trajectory step, but the fallback-tokenizer path - # and image-cache build are both CPU-heavy. Running them on threads - # and awaiting a single gather lets whichever finishes first free - # the event loop immediately and, with max_async_level >= 2, overlaps - # this whole stage with inference for the next batch. - async def _pretokenize_all() -> None: - await asyncio.gather( - *( - asyncio.to_thread( - pretokenize_rollout_trajectory, - rollout, - tokenizer, - processor=processor, - renderer=renderer, - ) - for rollout in train_rollouts + # ``tokens`` on each trajectory step (renderer path); the fallback + # tokenizer-only branch handles text-only rollouts whose tokens + # were not pre-rendered. Run on threads so CPU work overlaps with + # inference for the next batch (via max_async_level >= 2). + await asyncio.gather( + *( + asyncio.to_thread( + pretokenize_rollout_trajectory, + rollout, + tokenizer, + renderer=renderer, ) + for rollout in train_rollouts ) - - if is_vlm: - mm_token_type_ids_mapping = {} - if hasattr(processor, "image_token_id") and processor.image_token_id is not None: - mm_token_type_ids_mapping[processor.image_token_id] = 1 - if hasattr(processor, "video_token_id") and processor.video_token_id is not None: - mm_token_type_ids_mapping[processor.video_token_id] = 2 - _, vlm_cache = await asyncio.gather( - _pretokenize_all(), - asyncio.to_thread(build_vlm_image_cache, train_rollouts, processor), - ) - logger.info( - f"VLM timing: extract={vlm_cache.extract_time:.2f}s, preprocess={vlm_cache.preprocess_time:.2f}s " - f"({vlm_cache.num_unique_images} unique images from {vlm_cache.num_unique_examples} examples)" - ) - else: - await _pretokenize_all() - vlm_cache = None - mm_token_type_ids_mapping = None + ) # Process rollouts in parallel - def process_rollout(rollout: vf.RolloutOutput, rollout_idx: int) -> list[TrainingSample] | None: - return interleave_rollout( - rollout, - vlm_cache=vlm_cache, - cache_key=rollout_idx, - mm_token_type_ids_mapping=mm_token_type_ids_mapping, - ) - results = await asyncio.gather( - *(asyncio.to_thread(process_rollout, r, rollout_idx) for rollout_idx, r in enumerate(train_rollouts)) + *( + asyncio.to_thread(interleave_rollout, r, mm_token_type_ids_mapping=mm_token_type_ids_mapping) + for r in train_rollouts + ) ) # Collect results and assign advantages. Metrics are computed over all @@ -794,7 +765,7 @@ def compute_solve_rates(df): is_first_step = False # Free large per-step objects to prevent memory accumulation - del train_rollouts, train_examples, training_batch, vlm_cache + del train_rollouts, train_examples, training_batch del results_df, metrics_df gc.collect() diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index ef69a4ffe2..28613a2174 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -1,14 +1,10 @@ import base64 import hashlib -import time -from concurrent.futures import ThreadPoolExecutor -from io import BytesIO from pathlib import Path from typing import Any import torch import verifiers as vf -from PIL import Image from transformers.tokenization_utils import PreTrainedTokenizer from prime_rl.transport import TrainingSample @@ -67,62 +63,19 @@ def _render_messages( messages: list[dict[str, Any]], add_generation_prompt: bool = False, tools: list[dict[str, Any]] | None = None, - processor=None, ) -> list[int]: return render_messages( tokenizer, messages, add_generation_prompt=add_generation_prompt, tools=tools, - processor=processor, ) -def _prepare_messages_for_processor(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: - """Convert messages to the format expected by the VLM processor. - - - Converts image_url items to image items with loaded PIL Images - - Strips extra fields (e.g. image_url on text items) that confuse the processor - - Ensures all message content is in list format (processor requires this) - """ - prepared = [] - for msg in messages: - content = msg.get("content") - if isinstance(content, str): - prepared.append({**msg, "content": [{"type": "text", "text": content}]}) - continue - - if not isinstance(content, list): - prepared.append(msg) - continue - - new_content = [] - for item in content: - if item.get("type") == "image_url": - url = item.get("image_url", {}).get("url", "") - if url.startswith(_FILE_URL_PREFIX): - img = _load_file_image(url) - elif url.startswith("data:image"): - b64_data = url.split(",", 1)[1] - img = Image.open(BytesIO(base64.b64decode(b64_data))) - else: - new_content.append(item) - continue - new_content.append({"type": "image", "image": img}) - elif item.get("type") == "text": - new_content.append({"type": "text", "text": item.get("text", "")}) - else: - new_content.append(item) - prepared.append({**msg, "content": new_content}) - - return prepared - - def _tokenize_step_from_messages( step: vf.TrajectoryStep, tokenizer: PreTrainedTokenizer, tools: list[dict[str, Any]] | None = None, - processor=None, ) -> dict[str, Any]: prompt = _normalize_messages(step.get("prompt"), default_role="user") completion = _normalize_messages(step.get("completion"), default_role="assistant") @@ -135,10 +88,6 @@ def _tokenize_step_from_messages( f"got roles: {[m.get('role') for m in completion]}" ) - if processor is not None: - prompt = _prepare_messages_for_processor(prompt) - completion = _prepare_messages_for_processor(completion) - all_messages = prompt + completion prompt_has_assistant_completion = len(completion) > 0 and completion[0].get("role") == "assistant" prompt_ids = _render_messages( @@ -146,13 +95,11 @@ def _tokenize_step_from_messages( prompt, add_generation_prompt=prompt_has_assistant_completion, tools=tools, - processor=processor, ) full_ids = _render_messages( tokenizer, all_messages, tools=tools, - processor=processor, ) split_idx = _common_prefix_len(prompt_ids, full_ids) @@ -217,7 +164,6 @@ def _tokenize_step_with_renderer( def pretokenize_rollout_trajectory( output: vf.RolloutOutput, tokenizer: PreTrainedTokenizer, - processor=None, renderer=None, ) -> bool: """Populate missing step tokens from prompt/completion messages. @@ -235,7 +181,7 @@ def pretokenize_rollout_trajectory( if renderer is not None: step["tokens"] = _tokenize_step_with_renderer(step, renderer, tools=tools) else: - reconstructed = _tokenize_step_from_messages(step, tokenizer, tools=tools, processor=processor) + reconstructed = _tokenize_step_from_messages(step, tokenizer, tools=tools) if reconstructed["prompt_prefix_len"] < reconstructed["original_prompt_len"]: logger.debug( f"Prompt tokenization was non-prefix for example {output['example_id']} step {step_idx}. " @@ -251,8 +197,6 @@ def pretokenize_rollout_trajectory( def interleave_rollout( output: vf.RolloutOutput, - vlm_cache: "VLMImageCache | None" = None, - cache_key: int | None = None, mm_token_type_ids_mapping: dict[int, int] | None = None, ) -> list[TrainingSample] | None: """ @@ -270,13 +214,9 @@ def interleave_rollout( Returns a list of samples - could be 1 (extension always held) or up to T (extension never held). - For VLM models, pass vlm_cache to attach cumulative pixel_values per sample. - Each sample gets the images accumulated up to its last merged step. - - Args: - output: vf.RolloutOutput containing trajectory data - vlm_cache: Pre-computed VLM image cache for multimodal training - cache_key: Cache key to use when retrieving images from the VLM cache + For VLM models, each renderer-produced trajectory step carries its + per-image processed tensors inline on ``multi_modal_data``; the last + merged step's sidecar covers every image in the sample. """ logger = get_logger() @@ -306,8 +246,7 @@ def prepare_step_tokens(step: vf.TrajectoryStep, step_idx: int) -> dict[str, Any # Renderer-emitted multimodal sidecar (placeholders + per-item # processed tensors). Populated when the rollout went through # a multimodal-aware renderer (e.g. Qwen3VLRenderer); absent - # for text-only or VLM-via-MITO rollouts (those fall back to - # the vlm_cache path below). + # for text-only rollouts. "multi_modal_data": tokens.get("multi_modal_data"), } @@ -412,78 +351,73 @@ def extend_sample(sample: TrainingSample, prefix_len: int, step_idx: int) -> Non new_prefix = tokens["prompt_ids"] + tokens["completion_ids"] active_samples.append((new_prefix, make_sample(tokens), step_idx)) - # Attach images once per sample using only the last merged step. Two - # sources, in priority order: - # - # 1. **Renderer-emitted mm_data.** When the rollout client uses a - # multimodal-aware Renderer (e.g. Qwen3VLRenderer via RendererClient), - # each step's ``multi_modal_data`` carries the cumulative per-image - # ``pixel_values`` + ``image_grid_thw`` tensors plus placeholder - # offsets. The bridge merges previous-turn images into the new turn's - # mm_data, so the last merged step's sidecar covers every image in - # the sample. We pack those tensors into the bytes-shaped contract - # TrainingSample expects. - # - # 2. **VLMImageCache fallback.** For VLMs that go through MITO (chat - # completions), the renderer never sees the images — vLLM applies - # the chat template server-side and the orchestrator re-extracts - # PIL images from the rollout's data-URLs in a separate pass. - # Prompt tokens still contain fully expanded ``<|image_pad|>`` - # placeholders because the orchestrator re-tokenizes through the - # same processor. - # - # ``mm_token_type_ids`` is computed identically in both branches once - # the sample's prompt+completion ids are known. - def _pack_pixel_values_from_renderer( - mm_data: Any, - ) -> tuple[bytes | None, list[int] | None, list[list[int]] | None]: - items = (mm_data.mm_items or {}).get("image") or [] - if not items: - return None, None, None - pv_tensors = [it["pixel_values"] for it in items] - thw_tensors = [it["image_grid_thw"] for it in items] - pv = torch.cat(pv_tensors, dim=0) - thw = torch.cat(thw_tensors, dim=0) - # TrainingSample wants raw float32 bytes; the trainer-side decoder - # reads them back into a tensor of pixel_values_shape on load. - return ( - pv.to(torch.float32).contiguous().numpy().tobytes(), - list(pv.shape), - thw.to(torch.int64).tolist(), - ) - - def _apply_mm_token_type_ids(sample: TrainingSample) -> None: - if mm_token_type_ids_mapping is None: - return - sample.mm_token_type_ids = [ - mm_token_type_ids_mapping.get(token_id, 0) for token_id in sample.prompt_ids + sample.completion_ids - ] - + # Attach images once per sample using only the last merged step's + # renderer-emitted mm_data. The bridge merges previous-turn images + # into the new turn's mm_data so the last step's sidecar covers + # every image in the sample. for _, sample, last_step_idx in active_samples: renderer_mm = prepared_steps[last_step_idx].get("multi_modal_data") if renderer_mm is not None: - pv, shape, grids = _pack_pixel_values_from_renderer(renderer_mm) - if pv is not None: - sample.pixel_values = pv - sample.pixel_values_shape = shape - sample.image_grid_thw = grids - _apply_mm_token_type_ids(sample) - continue - - if vlm_cache is not None: - key = output["example_id"] if cache_key is None else cache_key - pv, shape, grids = vlm_cache.get_for_step(key, last_step_idx) - sample.pixel_values = pv - sample.pixel_values_shape = shape - sample.image_grid_thw = grids - _apply_mm_token_type_ids(sample) + mm_kwargs = _pack_mm_kwargs_from_renderer(renderer_mm) + if mm_kwargs is not None: + sample.mm_kwargs = mm_kwargs + # ``mm_token_type_ids``: 1 for image-placeholder tokens, 2 for + # video, 0 otherwise. The trainer has an auto-compute path + # (``_get_qwen3_vl_mm_token_type_ids``) but empirically the + # orchestrator-shipped explicit list produces ~30x lower + # mismatch KL on color-codeword. Shipping explicit until the + # divergence is understood. + if mm_token_type_ids_mapping is not None: + sample.mm_token_type_ids = [ + mm_token_type_ids_mapping.get(token_id, 0) + for token_id in sample.prompt_ids + sample.completion_ids + ] return [sample for _, sample, _ in active_samples] -# ============================================================================= -# VLM-specific functions -# ============================================================================= +def _pack_mm_kwargs_from_renderer(mm_data: Any) -> "dict[str, Any] | None": + """Batch the renderer's per-image ``mm_items`` into model-agnostic + forward kwargs. + + ``mm_data`` may arrive as a ``MultiModalData`` instance (in-process + for tests) or as a plain dict (after msgpack round-trip from the + env-worker). Each item is a dict keyed by the names the model's + ``forward`` expects (``pixel_values`` + ``image_grid_thw`` for + Qwen3-VL, just ``pixel_values`` for Gemma3-VL, etc.). We batch by + ``torch.cat(..., dim=0)`` per key — generic because every HF VLM + processor emits a leading batch/patch dimension, and the renderer + always processes one image per call. + + Returns a dict of ``EncodedTensor`` payloads keyed by kwarg name, + or ``None`` when no multimodal data is present. + """ + from verifiers.utils.serve_utils import decode_tensor_payload + + from prime_rl.transport.types import EncodedTensor + + mm_items = mm_data.mm_items if hasattr(mm_data, "mm_items") else (mm_data or {}).get("mm_items") or {} + # Flatten across modalities into one kwarg dict — the model's + # forward signature is the schema. ``mm_items`` is typically + # ``{"image": [...], "video": [...]}`` but each modality's keys + # don't collide for any HF VLM we ship today. + per_kwarg: dict[str, list] = {} + for _modality, items in mm_items.items(): + for item in items or []: + for key, payload in item.items(): + per_kwarg.setdefault(key, []).append(decode_tensor_payload(payload)) + if not per_kwarg: + return None + out: dict[str, EncodedTensor] = {} + for key, tensors in per_kwarg.items(): + cat = torch.cat(tensors, dim=0).contiguous() + arr = cat.detach().cpu().numpy() + out[key] = EncodedTensor( + dtype=str(arr.dtype), + shape=list(arr.shape), + data=arr.tobytes(), + ) + return out _FILE_URL_PREFIX = "file://" @@ -529,373 +463,3 @@ def offload_images_to_disk(rollouts: list[vf.RolloutOutput], output_dir: Path) - item["image_url"]["url"] = f"{_FILE_URL_PREFIX}{path}" return len(written) - - -def _load_file_image(path_str: str) -> Image.Image: - """Load an image from a file:// path.""" - return Image.open(path_str.removeprefix(_FILE_URL_PREFIX)) - - -def _extract_images_from_messages(messages: list) -> list[tuple[Image.Image, str]]: - """Extract (image, key) pairs from OpenAI-style chat messages. - - Handles both base64 data URLs and file:// paths from disk offloading. - """ - images = [] - if not messages or not isinstance(messages, list): - return images - - for msg in messages: - content = msg.get("content", []) - if isinstance(content, list): - for item in content: - if item.get("type") == "image_url": - url = item.get("image_url", {}).get("url", "") - if url.startswith(_FILE_URL_PREFIX): - img = _load_file_image(url) - images.append((img, url)) - elif url.startswith("data:image"): - b64_data = url.split(",", 1)[1] - img_bytes = base64.b64decode(b64_data) - img = Image.open(BytesIO(img_bytes)) - images.append((img, b64_data)) - return images - - -def _collect_image_keys_from_messages(messages: list) -> list[str]: - """Extract image keys from OpenAI-style chat messages without decoding. - - Handles both base64 data URLs and file:// paths from disk offloading. - """ - keys = [] - if not messages or not isinstance(messages, list): - return keys - for msg in messages: - content = msg.get("content", []) - if isinstance(content, list): - for item in content: - if item.get("type") == "image_url": - url = item.get("image_url", {}).get("url", "") - if url.startswith("data:image"): - keys.append(url.split(",", 1)[1]) - elif url.startswith(_FILE_URL_PREFIX): - keys.append(url) - return keys - - -def _decode_image(key: str) -> Image.Image: - """Decode an image from a base64 string or load from a file:// path.""" - if key.startswith(_FILE_URL_PREFIX): - return _load_file_image(key) - return Image.open(BytesIO(base64.b64decode(key))) - - -_PARALLEL_DECODE_THRESHOLD = 4 - - -_IMAGE_STRIPPED_PLACEHOLDER = "[preprocessed image]" - - -def strip_base64_images(examples: list[tuple[int, vf.RolloutOutput]]) -> None: - """Strip image data from rollout prompts to free memory. - - Handles both base64 data URLs and file:// paths from disk offloading. - The images have been decoded and indexed; the original data is no longer needed. - """ - for _, output in examples: - for step in output.get("trajectory", []): - prompt = step.get("prompt") - if not prompt or not isinstance(prompt, list): - continue - for msg in prompt: - content = msg.get("content", []) - if isinstance(content, list): - for item in content: - if item.get("type") == "image_url": - url = item.get("image_url", {}).get("url", "") - if url.startswith("data:image") or url.startswith(_FILE_URL_PREFIX): - item["image_url"]["url"] = _IMAGE_STRIPPED_PLACEHOLDER - - -def _extract_images_from_examples( - examples: list[tuple[int, vf.RolloutOutput]], -) -> tuple[list[Image.Image], dict[int, list[list[int]]]]: - """ - Extract images from all trajectory steps of each example. - - Two-pass approach: first collects unique base64 keys (fast, string-only), - then decodes unique images in parallel via ThreadPoolExecutor. - - Args: - examples: List of (cache_key, output) tuples where output contains a "trajectory" - list with steps that have "prompt" messages in OpenAI chat format. - - Returns: - Tuple of (all_images, step_image_indices_per_example) - - all_images: deduplicated flat list of decoded PIL images - - step_image_indices_per_example: dict mapping cache_key to per-step lists of - indices into all_images (e.g., [[0], [0, 1], [1]] for the decreasing-images case) - """ - # Pass 1: collect unique b64 keys and build step indices - unique_keys: list[str] = [] - key_to_index: dict[str, int] = {} - step_image_indices_per_example: dict[int, list[list[int]]] = {} - - for eid, output in examples: - trajectory = output.get("trajectory", []) - if not trajectory: - step_image_indices_per_example[eid] = [] - continue - - step_image_indices = [] - for step in trajectory: - prompt = step.get("prompt") - image_keys = _collect_image_keys_from_messages(prompt) - indices = [] - for key in image_keys: - if key not in key_to_index: - key_to_index[key] = len(unique_keys) - unique_keys.append(key) - indices.append(key_to_index[key]) - step_image_indices.append(indices) - - step_image_indices_per_example[eid] = step_image_indices - - # Pass 2: decode unique images (parallel when worthwhile) - if len(unique_keys) > _PARALLEL_DECODE_THRESHOLD: - with ThreadPoolExecutor(max_workers=min(len(unique_keys), 16)) as pool: - all_images = list(pool.map(_decode_image, unique_keys)) - else: - all_images = [_decode_image(k) for k in unique_keys] - del unique_keys, key_to_index - - strip_base64_images(examples) - - return all_images, step_image_indices_per_example - - -_DEFAULT_IMAGE_CHUNK_SIZE = 32 - - -class _ImageStore: - """Holds per-unique-image data, assembled lazily on demand. - - Instead of duplicating pixel bytes for every step that references an image, - we store each image's bytes once and assemble the concatenation at retrieval time. - """ - - def __init__( - self, - image_bytes: list[bytes], - image_num_patches: list[int], - patch_dim: int, - image_grids: list[list[int]], - ): - self.image_bytes = image_bytes - self.image_num_patches = image_num_patches - self.patch_dim = patch_dim - self.image_grids = image_grids - self._cache: dict[tuple[int, ...], tuple[bytes, list[int], list[list[int]]]] = {} - - def assemble(self, indices: list[int]) -> tuple[bytes, list[int], list[list[int]]]: - """Assemble pixel bytes, shape, and grids for a set of image indices. - - Results are cached by index tuple — multi-turn rollouts with the same - cumulative image set (common across rollouts of the same example) hit - the cache and skip the join. - """ - cache_key = tuple(indices) - cached = self._cache.get(cache_key) - if cached is not None: - return cached - - total_patches = sum(self.image_num_patches[i] for i in indices) - pixel_bytes = b"".join(self.image_bytes[i] for i in indices) - shape = [total_patches, self.patch_dim] - grids = [self.image_grids[i] for i in indices] - result = (pixel_bytes, shape, grids) - self._cache[cache_key] = result - return result - - -def _preprocess_images_batched( - images: list[Image.Image], - step_image_indices_per_example: dict[int, list[list[int]]], - processor, - chunk_size: int = _DEFAULT_IMAGE_CHUNK_SIZE, -) -> tuple["_ImageStore | None", dict[int, list[list[int]]]]: - """ - Preprocess all images in chunked batches, returning an _ImageStore and step indices. - - Images are processed in chunks to avoid OOM on large batches. Per-image bytes are - stored once in the _ImageStore and assembled lazily at retrieval time. - - Returns: - Tuple of (_ImageStore or None, step_image_indices_per_example). - The store is None when there are no images or no processor. - """ - if not images or processor is None: - return None, step_image_indices_per_example - - logger = get_logger() - image_sizes = [(img.width, img.height) for img in images] - - # Process images in chunks to avoid OOM, parallelized across threads - # (PIL/numpy release the GIL so threads give real concurrency here) - chunks = [images[i : i + chunk_size] for i in range(0, len(images), chunk_size)] - - def _process_chunk(chunk: list[Image.Image]) -> tuple[torch.Tensor, torch.Tensor]: - processed = processor.image_processor(images=chunk, return_tensors="pt") - return processed["pixel_values"], processed["image_grid_thw"] - - if len(chunks) > 1: - with ThreadPoolExecutor(max_workers=min(len(chunks), 8)) as pool: - results = list(pool.map(_process_chunk, chunks)) - else: - results = [_process_chunk(chunks[0])] - - # Free PIL images now that preprocessing is done - del chunks - images.clear() - - all_pixel_values_list = [r[0] for r in results] - all_grid_thw_list = [r[1] for r in results] - - all_pixel_values = torch.cat(all_pixel_values_list, dim=0) - all_grid_thw = torch.cat(all_grid_thw_list, dim=0) - del all_pixel_values_list, all_grid_thw_list, results - - logger.debug( - f"VLM image processing: {len(image_sizes)} images, sizes={image_sizes}, " - f"pixel_values={all_pixel_values.shape}, grid_thw={all_grid_thw.tolist()}" - ) - - # Pre-compute patch start offset for each image - patch_starts = [0] - for g in all_grid_thw: - patch_starts.append(patch_starts[-1] + int(g[0] * g[1] * g[2])) - - patch_dim = all_pixel_values.shape[1] - - # Convert to bytes per-image and free the tensor immediately after - image_bytes_list: list[bytes] = [] - image_num_patches_list: list[int] = [] - image_grids_list: list[list[int]] = [] - for i in range(len(image_sizes)): - img_slice = all_pixel_values[patch_starts[i] : patch_starts[i + 1]] - image_bytes_list.append(img_slice.numpy().tobytes()) - image_num_patches_list.append(img_slice.shape[0]) - image_grids_list.append(all_grid_thw[i].tolist()) - del all_pixel_values, all_grid_thw - - store = _ImageStore( - image_bytes=image_bytes_list, - image_num_patches=image_num_patches_list, - patch_dim=patch_dim, - image_grids=image_grids_list, - ) - - return store, step_image_indices_per_example - - -class VLMImageCache: - """Result of building VLM image cache with per-step image data.""" - - def __init__( - self, - cache: dict[int, list[tuple[bytes | None, list[int] | None, list[list[int]] | None]]], - num_unique_examples: int, - extract_time: float, - preprocess_time: float, - ): - self._store: _ImageStore | None = None - self._step_indices: dict[int, list[list[int]]] | None = None - self.cache = cache - self.num_unique_examples = num_unique_examples - self.num_unique_images = 0 - self.extract_time = extract_time - self.preprocess_time = preprocess_time - - @classmethod - def from_store( - cls, - store: _ImageStore | None, - step_indices: dict[int, list[list[int]]], - num_unique_examples: int, - num_unique_images: int, - extract_time: float, - preprocess_time: float, - ) -> "VLMImageCache": - """Create a store-backed cache that assembles bytes lazily.""" - obj = cls.__new__(cls) - obj._store = store - obj._step_indices = step_indices - obj.cache = {} - obj.num_unique_examples = num_unique_examples - obj.num_unique_images = num_unique_images - obj.extract_time = extract_time - obj.preprocess_time = preprocess_time - return obj - - def _assemble(self, indices: list[int]) -> tuple[bytes | None, list[int] | None, list[list[int]] | None]: - if not indices: - return (None, None, None) - return self._store.assemble(indices) - - def get_for_step( - self, cache_key: int, step_idx: int - ) -> tuple[bytes | None, list[int] | None, list[list[int]] | None]: - """Get cumulative images up to and including the given step.""" - if self._store is not None: - steps = self._step_indices.get(cache_key, []) - if not steps or step_idx >= len(steps): - return (None, None, None) - return self._assemble(steps[step_idx]) - - steps = self.cache.get(cache_key, []) - if not steps or step_idx >= len(steps): - return (None, None, None) - return steps[step_idx] - - def get_all(self, cache_key: int) -> tuple[bytes | None, list[int] | None, list[list[int]] | None]: - """Get all images for the cache key (last step's cumulative images).""" - if self._store is not None: - steps = self._step_indices.get(cache_key, []) - if not steps: - return (None, None, None) - return self._assemble(steps[-1]) - - steps = self.cache.get(cache_key, []) - if not steps: - return (None, None, None) - return steps[-1] - - -def build_vlm_image_cache(rollouts: list[vf.RolloutOutput], processor) -> VLMImageCache: - """ - Build image cache for VLM training by extracting and preprocessing images. - - Caches per rollout to keep images aligned with divergent multi-turn trajectories. - """ - examples = [(idx, rollout) for idx, rollout in enumerate(rollouts)] - unique_example_ids = {(rollout.get("env_name"), rollout["example_id"]) for rollout in rollouts} - - # Extract images (also strips base64 data from rollout prompts to free memory) - extract_start = time.perf_counter() - all_images, images_per_example = _extract_images_from_examples(examples) - num_unique_images = len(all_images) - extract_time = time.perf_counter() - extract_start - - # Preprocess images (clears PIL image list when done) - preprocess_start = time.perf_counter() - store, step_indices = _preprocess_images_batched(all_images, images_per_example, processor) - preprocess_time = time.perf_counter() - preprocess_start - - return VLMImageCache.from_store( - store=store, - step_indices=step_indices, - num_unique_examples=len(unique_example_ids), - num_unique_images=num_unique_images, - extract_time=extract_time, - preprocess_time=preprocess_time, - ) diff --git a/src/prime_rl/trainer/batch.py b/src/prime_rl/trainer/batch.py index 662df36a80..8247f66996 100644 --- a/src/prime_rl/trainer/batch.py +++ b/src/prime_rl/trainer/batch.py @@ -72,17 +72,17 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch temperatures=temperatures, routed_experts=routed_experts, mm_token_type_ids=mm_token_type_ids, - # Multimodal fields (Qwen3-VL) - passed through without modification - pixel_values=training_example.pixel_values, - pixel_values_shape=training_example.pixel_values_shape, - image_grid_thw=training_example.image_grid_thw, + # Generic multimodal kwargs (e.g. {"pixel_values": ..., "image_grid_thw": + # ...}) — passed straight through; the trainer ``**`` -unpacks into the + # model's forward signature so prime-rl stays model-agnostic. + mm_kwargs=training_example.mm_kwargs, sft_loss=training_example.sft_loss, ) def _is_multimodal_sample(sample: MicroBatch) -> bool: """Check if a sample contains multimodal data (images).""" - return sample.pixel_values is not None + return sample.mm_kwargs is not None def packed_samples_into_micro_bs( diff --git a/src/prime_rl/trainer/model.py b/src/prime_rl/trainer/model.py index 910a978a66..1e270c68c8 100644 --- a/src/prime_rl/trainer/model.py +++ b/src/prime_rl/trainer/model.py @@ -12,9 +12,8 @@ import torch import torch._dynamo import torch.nn as nn -from beartype import beartype as typechecker from huggingface_hub import snapshot_download -from jaxtyping import Float, Int, jaxtyped +from jaxtyping import Int from torch import Tensor from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper from torch.distributed.checkpoint.hf_storage import HuggingFaceStorageReader @@ -1118,7 +1117,6 @@ def _get_qwen3_vl_mm_token_type_ids(model: nn.Module, input_ids: Tensor) -> Tens return mm_token_type_ids -@jaxtyped(typechecker=typechecker) def forward( model: nn.Module, input_ids: Int[Tensor, "batch seq"], @@ -1126,9 +1124,13 @@ def forward( labels: Int[Tensor, "batch seq"] | None = None, temperature: Tensor | None = None, routed_experts: Int[Tensor, "batch seq layers topk"] | None = None, - # Multimodal fields (Qwen3-VL) - pixel_values: Float[Tensor, "num_patches patch_dim"] | None = None, - image_grid_thw: Int[Tensor, "num_images 3"] | None = None, + # Generic multimodal kwargs (e.g. {"pixel_values": ..., + # "image_grid_thw": ...} for Qwen3-VL; just {"pixel_values": ...} + # for Gemma3). Passed straight through to ``model(**kwargs)`` so + # the model's HF forward signature is the schema. ``mm_token_type_ids`` + # is split out because it's prime-rl-computed (from token ids), + # not a renderer/processor output. + mm_kwargs: dict[str, Tensor] | None = None, mm_token_type_ids: Int[Tensor, "batch seq"] | None = None, ) -> PrimeLmOutput: # Build kwargs for model forward @@ -1138,15 +1140,24 @@ def forward( "temperature": temperature, } - # For multimodal (VLM), don't pass position_ids - let the model compute MRoPE internally - # using image_grid_thw. Qwen3-VL only computes proper MRoPE when position_ids is None. - if pixel_values is not None: - assert image_grid_thw is not None, "pixel_values requires image_grid_thw for MRoPE computation" - kwargs["pixel_values"] = pixel_values - kwargs["image_grid_thw"] = image_grid_thw - mm_token_type_ids = _get_qwen3_vl_mm_token_type_ids(model, input_ids) - if mm_token_type_ids is not None: - kwargs["mm_token_type_ids"] = mm_token_type_ids + if mm_kwargs: + # Forward the per-model multimodal tensors verbatim. For Qwen-VL + # specifically, ``position_ids`` must be ``None`` for MRoPE to + # compute correct 3D positions from ``image_grid_thw``; this + # special-case is the only family-specific branch left. + kwargs.update(mm_kwargs) + if "image_grid_thw" in mm_kwargs: + mm_token_type_ids_auto = _get_qwen3_vl_mm_token_type_ids(model, input_ids) + if mm_token_type_ids_auto is not None: + kwargs["mm_token_type_ids"] = mm_token_type_ids_auto + elif mm_token_type_ids is not None: + kwargs["mm_token_type_ids"] = mm_token_type_ids + # Skip position_ids — MRoPE in Qwen-VL recomputes them from grid_thw. + else: + # Other VLM families (Gemma3, LLaVA, ...) still want position_ids. + kwargs["position_ids"] = position_ids + if mm_token_type_ids is not None: + kwargs["mm_token_type_ids"] = mm_token_type_ids else: kwargs["position_ids"] = position_ids diff --git a/src/prime_rl/trainer/rl/data.py b/src/prime_rl/trainer/rl/data.py index ffc4bc627f..9c189c729e 100644 --- a/src/prime_rl/trainer/rl/data.py +++ b/src/prime_rl/trainer/rl/data.py @@ -31,11 +31,12 @@ class TensorMicroBatch(TypedDict): # MoE router replay routed_experts: Int[Tensor, "batch seq layers topk"] | None - # Multimodal fields (Qwen3-VL) - # pixel_values: flattened image patches [num_patches, patch_dim] where patch_dim=1176 for Qwen3-VL - pixel_values: Float[Tensor, "num_patches patch_dim"] | None - # image_grid_thw: grid dimensions [num_images, 3] where each entry is [temporal, height, width] - image_grid_thw: Int[Tensor, "num_images 3"] | None + # Generic multimodal kwargs — flat dict matching the model's forward + # signature (e.g. ``{"pixel_values": ..., "image_grid_thw": ...}`` for + # Qwen3-VL; ``{"pixel_values": ...}`` for Gemma3-VL). The trainer + # ``**`` -unpacks this into the forward call, so any HF VLM whose + # processor and forward agree on kwarg names works out of the box. + mm_kwargs: dict[str, Tensor] | None # mm_token_type_ids: token type per token [batch seq], int64 (0=text, 1=image, 2=video) mm_token_type_ids: Int[Tensor, "batch seq"] | None @@ -111,8 +112,7 @@ def _get_sample_micro_batch(self, generator: torch.Generator) -> TensorMicroBatc "loss_mask": loss_mask.unsqueeze(0), "lora_num_tokens": lora_num_tokens, "routed_experts": None, - "pixel_values": None, - "image_grid_thw": None, + "mm_kwargs": None, "mm_token_type_ids": None, "sft_loss": False, } @@ -138,8 +138,7 @@ def _get_micro_batch(self, generator: torch.Generator) -> TensorMicroBatch: "loss_mask": torch.ones(self.seq_len, dtype=torch.bool).unsqueeze(0), "lora_num_tokens": lora_num_tokens, "routed_experts": None, - "pixel_values": None, - "image_grid_thw": None, + "mm_kwargs": None, "mm_token_type_ids": None, "sft_loss": False, } @@ -195,6 +194,15 @@ def _micro_batch_to_tensor(self, micro_batch: MicroBatch) -> TensorMicroBatch: if micro_batch.lora_num_tokens is None: micro_batch.lora_num_tokens = [0] * self.multi_run_manager.max_runs micro_batch.lora_num_tokens[0] = len(micro_batch.input_ids) + mm_kwargs: dict[str, Tensor] | None = None + if micro_batch.mm_kwargs: + # Each value is an EncodedTensor (dtype, shape, raw bytes). + # No batch dim — the orchestrator concatenates per-image along + # dim=0 generically, matching what each HF VLM's forward expects. + mm_kwargs = { + key: torch.frombuffer(bytearray(payload.data), dtype=_torch_dtype(payload.dtype)).reshape(payload.shape) + for key, payload in micro_batch.mm_kwargs.items() + } return TensorMicroBatch( input_ids=torch.tensor(micro_batch.input_ids, dtype=torch.long).unsqueeze(0), position_ids=torch.tensor(micro_batch.position_ids, dtype=torch.long).unsqueeze(0), @@ -206,15 +214,7 @@ def _micro_batch_to_tensor(self, micro_batch: MicroBatch) -> TensorMicroBatch: loss_mask=torch.tensor(micro_batch.loss_mask, dtype=torch.bool).unsqueeze(0), temperatures=torch.tensor(micro_batch.temperatures, dtype=torch.float).unsqueeze(0), lora_num_tokens=torch.tensor(micro_batch.lora_num_tokens, dtype=torch.int32), - # Multimodal fields - no batch dimension for these as they are variable-sized - pixel_values=torch.frombuffer(bytearray(micro_batch.pixel_values), dtype=torch.float32).reshape( - micro_batch.pixel_values_shape - ) - if micro_batch.pixel_values is not None - else None, - image_grid_thw=torch.tensor(micro_batch.image_grid_thw, dtype=torch.long) - if micro_batch.image_grid_thw is not None - else None, + mm_kwargs=mm_kwargs, mm_token_type_ids=torch.tensor(micro_batch.mm_token_type_ids, dtype=torch.long).unsqueeze(0) if micro_batch.mm_token_type_ids is not None else None, @@ -225,3 +225,15 @@ def _micro_batch_to_tensor(self, micro_batch: MicroBatch) -> TensorMicroBatch: else None, sft_loss=micro_batch.sft_loss, ) + + +def _torch_dtype(name: str) -> torch.dtype: + """Resolve a numpy/torch dtype name (e.g. ``"float32"``) to torch.dtype.""" + # Strip the ``numpy.`` prefix some dtype reprs carry. + name = name.replace("numpy.", "") + if hasattr(torch, name): + return getattr(torch, name) + # numpy ↔ torch alias mismatches (rare but possible) — fall back via numpy. + import numpy as np + + return torch.from_numpy(np.zeros(1, dtype=np.dtype(name))).dtype diff --git a/src/prime_rl/trainer/rl/train.py b/src/prime_rl/trainer/rl/train.py index 4b2b932297..dbc444c8a0 100644 --- a/src/prime_rl/trainer/rl/train.py +++ b/src/prime_rl/trainer/rl/train.py @@ -370,13 +370,12 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: # we could've gotten routed experts from the inference server, but we didn't enable router replay routed_experts = None - # Multimodal fields (Qwen3-VL) - only present for VLM training - pixel_values = ( - micro_batch["pixel_values"].to("cuda") if micro_batch.get("pixel_values") is not None else None - ) - image_grid_thw = ( - micro_batch["image_grid_thw"].to("cuda") if micro_batch.get("image_grid_thw") is not None else None - ) + # Multimodal kwargs are an opaque per-model dict (e.g. + # {"pixel_values": ..., "image_grid_thw": ...} for Qwen3-VL, + # just {"pixel_values": ...} for Gemma3-VL) — we move every + # tensor to CUDA and let the model's forward sort them. + mm_kwargs_raw = micro_batch.get("mm_kwargs") + mm_kwargs = {k: v.to("cuda") for k, v in mm_kwargs_raw.items()} if mm_kwargs_raw else None mm_token_type_ids = ( micro_batch["mm_token_type_ids"].to("cuda") if micro_batch.get("mm_token_type_ids") is not None @@ -386,7 +385,7 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: labels = shift_tensor_left(input_ids) # VLM + CP is not supported: MRoPE requires global positions but CP shards the sequence - if cp_enabled and pixel_values is not None: + if cp_enabled and mm_kwargs is not None: raise NotImplementedError("Context parallelism is not supported with VLM/multimodal training") if cp_enabled: @@ -425,8 +424,7 @@ def load_run_checkpoint(_optimizer, idx: int) -> None: forward_position_ids, labels=labels, temperature=temperatures, - pixel_values=pixel_values, - image_grid_thw=image_grid_thw, + mm_kwargs=mm_kwargs, mm_token_type_ids=mm_token_type_ids, routed_experts=routed_experts, ) diff --git a/src/prime_rl/transport/types.py b/src/prime_rl/transport/types.py index 4bc594f06d..4f2dd2f211 100644 --- a/src/prime_rl/transport/types.py +++ b/src/prime_rl/transport/types.py @@ -1,6 +1,15 @@ import msgspec +# Encoded tensor: {dtype: "float32", shape: [...], data: }. +# Mirrors verifiers.utils.serve_utils.msgpack_encoder so the same wire +# shape is used end-to-end from renderer → orchestrator → trainer. +class EncodedTensor(msgspec.Struct, array_like=True, gc=False): + dtype: str + shape: list[int] + data: bytes + + # Orchestrator -> Packer class TrainingSample(msgspec.Struct, array_like=True, gc=False, omit_defaults=True): """A single training example.""" @@ -15,11 +24,15 @@ class TrainingSample(msgspec.Struct, array_like=True, gc=False, omit_defaults=Tr advantage: float | None = None reward: float | None = None - # Multimodal fields (Qwen3-VL) — pixel_values stored as raw float32 bytes for efficient serialization - pixel_values: bytes | None = None - pixel_values_shape: list[int] | None = None # [num_patches, patch_dim] - # image_grid_thw: grid dimensions [num_images, 3] where each entry is [temporal, height, width] - image_grid_thw: list[list[int]] | None = None + # Generic multimodal kwargs: flat dict keyed by the kwarg names the + # model's forward expects (e.g. {"pixel_values": ..., "image_grid_thw": + # ...} for Qwen3-VL; just {"pixel_values": ...} for Gemma3). The + # orchestrator batches per-image renderer items by torch.cat along + # dim=0 generically — no model-specific knowledge in prime-rl. The + # trainer ``**`` -unpacks this into the model forward, so any VLM + # whose HF processor / forward agree on kwarg names works without + # touching this transport. + mm_kwargs: dict[str, EncodedTensor] | None = None routed_experts: list[list[list[int]]] | None = None # [seq_len, layers, topk] @@ -51,11 +64,8 @@ class MicroBatch(msgspec.Struct, array_like=True, gc=False, omit_defaults=True): lora_num_tokens: list[int] | None = None routed_experts: list[list[list[int]]] | None = None - # Multimodal fields (Qwen3-VL) — pixel_values stored as raw float32 bytes for efficient serialization - pixel_values: bytes | None = None - pixel_values_shape: list[int] | None = None # [num_patches, patch_dim] - # image_grid_thw: grid dimensions [num_images, 3] where each entry is [temporal, height, width] - image_grid_thw: list[list[int]] | None = None + # See TrainingSample.mm_kwargs. + mm_kwargs: dict[str, EncodedTensor] | None = None # mm_token_type_ids: token type ids per token [batch seq], int64 (0=text, 1=image, 2=video) mm_token_type_ids: list[int] | None = None diff --git a/tests/unit/orchestrator/test_trajectories.py b/tests/unit/orchestrator/test_trajectories.py index b9982a398e..242dbf3523 100644 --- a/tests/unit/orchestrator/test_trajectories.py +++ b/tests/unit/orchestrator/test_trajectories.py @@ -1,33 +1,26 @@ -import base64 -from io import BytesIO from unittest.mock import MagicMock import numpy as np import pytest import verifiers as vf -from PIL import Image from prime_rl.orchestrator.trajectories import ( - VLMImageCache, _align_routed_experts, _deserialize_tool_calls, - _extract_images_from_examples, - _extract_images_from_messages, - _ImageStore, - build_vlm_image_cache, interleave_rollout, ) -def _pixels(data: list[list[float]]) -> tuple[bytes, list[int]]: - """Convert pixel values list to (bytes, shape) for test cache data.""" - arr = np.array(data, dtype=np.float32) - return arr.tobytes(), list(arr.shape) +def _decode_mm_pixels(sample) -> list: + """Decode ``sample.mm_kwargs['pixel_values']`` to a nested list.""" + p = sample.mm_kwargs["pixel_values"] + return np.frombuffer(p.data, dtype=np.dtype(p.dtype)).reshape(p.shape).tolist() -def _decode_pixels(pixel_bytes: bytes, shape: list[int]) -> list[list[float]]: - """Decode raw pixel bytes back to nested list for assertions.""" - return np.frombuffer(pixel_bytes, dtype=np.float32).reshape(shape).tolist() +def _decode_mm_thw(sample) -> list: + """Decode ``sample.mm_kwargs['image_grid_thw']`` to a nested list.""" + g = sample.mm_kwargs["image_grid_thw"] + return np.frombuffer(g.data, dtype=np.dtype(g.dtype)).reshape(g.shape).tolist() def test_deserialize_tool_calls_does_not_inject_missing_key(): @@ -730,803 +723,6 @@ def test_interleave_rollout_interleaved_agents(interleaved_agents_trajectory): assert agent2_sample.completion_logprobs == [-0.5, -0.6] -# ============================================================================= -# VLM Multi-Turn Tests -# ============================================================================= - - -def _create_test_image(color: str = "red") -> str: - """Create a small test image and return its base64 data URL.""" - colors = {"red": (255, 0, 0), "green": (0, 255, 0), "blue": (0, 0, 255)} - img = Image.new("RGB", (10, 10), colors.get(color, (255, 255, 255))) - buffer = BytesIO() - img.save(buffer, format="PNG") - b64 = base64.b64encode(buffer.getvalue()).decode() - return f"data:image/png;base64,{b64}" - - -def _create_image_message(image_url: str, text: str = "What is this?") -> dict: - """Create an OpenAI-style user message with an image.""" - return { - "role": "user", - "content": [ - {"type": "image_url", "image_url": {"url": image_url}}, - {"type": "text", "text": text}, - ], - } - - -def test_extract_images_from_messages_no_images(): - messages = [{"role": "user", "content": "Hello"}] - images = _extract_images_from_messages(messages) - assert images == [] - - -def test_extract_images_from_messages_single_image(): - image_url = _create_test_image("red") - messages = [_create_image_message(image_url)] - images = _extract_images_from_messages(messages) - assert len(images) == 1 - assert isinstance(images[0][0], Image.Image) - - -def test_extract_images_from_messages_multiple_images(): - messages = [ - _create_image_message(_create_test_image("red")), - {"role": "assistant", "content": "I see a red image"}, - _create_image_message(_create_test_image("green")), - ] - images = _extract_images_from_messages(messages) - assert len(images) == 2 - - -def test_extract_images_from_examples_single_turn(): - image_url = _create_test_image("red") - output = vf.RolloutOutput( - example_id=1, - trajectory=[ - vf.TrajectoryStep( - prompt=[_create_image_message(image_url)], - completion=[{"role": "assistant", "content": "A red square"}], - response=MagicMock(), - tokens=MagicMock(), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - ], - sampling_args={"temperature": 1.0}, - error=None, - ) - - all_images, images_per_step = _extract_images_from_examples([(1, output)]) - - assert len(all_images) == 1 - assert images_per_step == {1: [[0]]} # step 0 has image at index 0 - - -def test_extract_images_from_examples_multi_turn_new_image_each_turn(): - """Test that new images in later turns are correctly extracted.""" - red_url = _create_test_image("red") - green_url = _create_test_image("green") - - output = vf.RolloutOutput( - example_id=1, - trajectory=[ - # Turn 1: just the red image - vf.TrajectoryStep( - prompt=[_create_image_message(red_url, "What color is this?")], - completion=[{"role": "assistant", "content": "Red"}], - response=MagicMock(), - tokens=MagicMock(), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - # Turn 2: cumulative prompt with red image + green image - vf.TrajectoryStep( - prompt=[ - _create_image_message(red_url, "What color is this?"), - {"role": "assistant", "content": "Red"}, - _create_image_message(green_url, "And this one?"), - ], - completion=[{"role": "assistant", "content": "Green"}], - response=MagicMock(), - tokens=MagicMock(), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - ], - sampling_args={"temperature": 1.0}, - error=None, - ) - - all_images, images_per_step = _extract_images_from_examples([(1, output)]) - - assert len(all_images) == 2 # 2 unique images total - assert images_per_step == {1: [[0], [0, 1]]} # step 0: [red], step 1: [red, green] - - -def test_extract_images_from_examples_multi_turn_no_new_images(): - """Test turns where no new images are added.""" - red_url = _create_test_image("red") - - output = vf.RolloutOutput( - example_id=1, - trajectory=[ - vf.TrajectoryStep( - prompt=[_create_image_message(red_url)], - completion=[{"role": "assistant", "content": "Red"}], - response=MagicMock(), - tokens=MagicMock(), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - # Turn 2: same image, no new ones - vf.TrajectoryStep( - prompt=[ - _create_image_message(red_url), - {"role": "assistant", "content": "Red"}, - {"role": "user", "content": "Are you sure?"}, # text only - ], - completion=[{"role": "assistant", "content": "Yes"}], - response=MagicMock(), - tokens=MagicMock(), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - ], - sampling_args={"temperature": 1.0}, - error=None, - ) - - all_images, images_per_step = _extract_images_from_examples([(1, output)]) - - assert len(all_images) == 1 # Only 1 unique image (deduped) - assert images_per_step == {1: [[0], [0]]} # both steps reference the same image - - -def test_extract_images_from_examples_step_with_fewer_images_than_prior_steps(): - """Test that image counts reflect the prompt's actual images, not a monotonically increasing total. - - When a later step's prompt contains fewer images than the cumulative total from prior steps - (i.e. the prompt is not strictly cumulative), the count for that step should match - the number of images actually present in that step's prompt. - """ - red_url = _create_test_image("red") - green_url = _create_test_image("green") - - output = vf.RolloutOutput( - example_id=1, - trajectory=[ - # Step 0: red only - vf.TrajectoryStep( - prompt=[_create_image_message(red_url)], - completion=[], - response=MagicMock(), - tokens=MagicMock(), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - # Step 1: cumulative — red + green - vf.TrajectoryStep( - prompt=[_create_image_message(red_url), _create_image_message(green_url)], - completion=[], - response=MagicMock(), - tokens=MagicMock(), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - # Step 2: only green — fewer images than cumulative total - vf.TrajectoryStep( - prompt=[_create_image_message(green_url)], - completion=[], - response=MagicMock(), - tokens=MagicMock(), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - ], - sampling_args={"temperature": 1.0}, - error=None, - ) - - _, images_per_step = _extract_images_from_examples([(1, output)]) - - # Step 0: [red] → index 0; Step 1: [red, green] → indices [0, 1]; Step 2: [green] → index [1] - assert images_per_step == {1: [[0], [0, 1], [1]]} - - -def test_vlm_image_cache_get_for_step(): - cache_data = { - 1: [ - (*_pixels([[1.0, 2.0]]), [[1, 2, 3]]), # Step 0: 1 image - (*_pixels([[1.0, 2.0], [3.0, 4.0]]), [[1, 2, 3], [1, 4, 4]]), # Step 1: 2 images cumulative - ], - } - cache = VLMImageCache(cache_data, num_unique_examples=1, extract_time=0.0, preprocess_time=0.0) - - # Step 0 should have 1 image - pv, shape, grid = cache.get_for_step(1, 0) - assert _decode_pixels(pv, shape) == [[1.0, 2.0]] - assert grid == [[1, 2, 3]] - - # Step 1 should have 2 images - pv, shape, grid = cache.get_for_step(1, 1) - assert _decode_pixels(pv, shape) == [[1.0, 2.0], [3.0, 4.0]] - assert grid == [[1, 2, 3], [1, 4, 4]] - - -def test_vlm_image_cache_get_all(): - cache_data = { - 1: [ - (*_pixels([[1.0]]), [[1, 2, 3]]), - (*_pixels([[1.0], [2.0]]), [[1, 2, 3], [1, 4, 4]]), - ], - } - cache = VLMImageCache(cache_data, num_unique_examples=1, extract_time=0.0, preprocess_time=0.0) - - # get_all should return the last step's data - pv, shape, grid = cache.get_all(1) - assert _decode_pixels(pv, shape) == [[1.0], [2.0]] - assert grid == [[1, 2, 3], [1, 4, 4]] - - -def test_vlm_image_cache_step_out_of_range(): - cache_data = { - 1: [ - (*_pixels([[1.0]]), [[1, 2, 3]]), - ], - } - cache = VLMImageCache(cache_data, num_unique_examples=1, extract_time=0.0, preprocess_time=0.0) - - pv, shape, grid = cache.get_for_step(1, 2) - assert pv is None - assert shape is None - assert grid is None - - -def test_vlm_image_cache_missing_example(): - cache = VLMImageCache({}, num_unique_examples=0, extract_time=0.0, preprocess_time=0.0) - - pv, shape, grid = cache.get_for_step(999, 0) - assert pv is None - assert shape is None - assert grid is None - - pv, shape, grid = cache.get_all(999) - assert pv is None - assert shape is None - assert grid is None - - -def test_interleave_rollout_with_vlm_cache(): - """Test that interleave_rollout correctly uses per-step images from VLM cache.""" - cache_data = { - 1: [ - (*_pixels([[1.0]]), [[1, 2, 3]]), # Step 0 - (*_pixels([[1.0], [2.0]]), [[1, 2, 3], [1, 4, 4]]), # Step 1 - ], - } - cache = VLMImageCache(cache_data, num_unique_examples=1, extract_time=0.0, preprocess_time=0.0) - - output = vf.RolloutOutput( - example_id=1, - trajectory=[ - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "Turn 1"}], - completion=[{"role": "assistant", "content": "Response 1"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2], - prompt_mask=[0, 0], - completion_ids=[3, 4], - completion_mask=[1, 1], - completion_logprobs=[-0.1, -0.2], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "Turn 2"}], - completion=[{"role": "assistant", "content": "Response 2"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2, 3, 4, 5], - prompt_mask=[0, 0, 0, 0, 0], - completion_ids=[6, 7], - completion_mask=[1, 1], - completion_logprobs=[-0.3, -0.4], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - ], - sampling_args={"temperature": 1.0}, - error=None, - ) - - # Token 2 is an image token, token 5 is a video token - mm_mapping = {2: 1, 5: 2} - rollouts = interleave_rollout(output, vlm_cache=cache, mm_token_type_ids_mapping=mm_mapping) - - # Extension holds (step 1 prompt [1,2,3,4,5] extends prefix [1,2,3,4]) - # so both steps merge into a single sample with cumulative images from step 1 - assert rollouts is not None - assert len(rollouts) == 1 - rollout = rollouts[0] - assert rollout.prompt_ids == [1, 2] - assert rollout.completion_ids == [3, 4, 5, 6, 7] - assert rollout.completion_mask == [True, True, False, True, True] - assert rollout.completion_logprobs == [-0.1, -0.2, 0.0, -0.3, -0.4] - # Images: cumulative from last merged step (step 1 has 2 images) - assert _decode_pixels(rollout.pixel_values, rollout.pixel_values_shape) == [[1.0], [2.0]] - assert rollout.image_grid_thw == [[1, 2, 3], [1, 4, 4]] - # mm_token_type_ids: full sequence [1,2,3,4,5,6,7] → [0,1,0,0,2,0,0] - assert rollout.mm_token_type_ids == [0, 1, 0, 0, 2, 0, 0] - - -def test_interleave_rollout_uses_cache_key_override(): - cache_data = { - 7: [ - (*_pixels([[9.0]]), [[1, 2, 3]]), - ], - } - cache = VLMImageCache(cache_data, num_unique_examples=1, extract_time=0.0, preprocess_time=0.0) - - output = vf.RolloutOutput( - example_id=123, - trajectory=[ - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "Turn 1"}], - completion=[{"role": "assistant", "content": "Response 1"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2], - prompt_mask=[0, 0], - completion_ids=[3, 4], - completion_mask=[1, 1], - completion_logprobs=[-0.1, -0.2], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - ], - sampling_args={"temperature": 1.0}, - error=None, - ) - - rollouts = interleave_rollout(output, vlm_cache=cache, cache_key=7) - - assert rollouts is not None - assert len(rollouts) == 1 - assert _decode_pixels(rollouts[0].pixel_values, rollouts[0].pixel_values_shape) == [[9.0]] - assert rollouts[0].image_grid_thw == [[1, 2, 3]] - - -def test_interleave_rollout_vlm_image_then_text_turns(): - """ - VLM 3-step trajectory: image in step 0, text-only in steps 1 and 2. - Extension holds throughout so all steps merge into 1 sample carrying - step 0's pixel_values (no new images added in later steps). - """ - cache_data = { - 1: [ - (*_pixels([[1.0, 2.0]]), [[1, 3, 3]]), # Step 0: 1 image - (*_pixels([[1.0, 2.0]]), [[1, 3, 3]]), # Step 1: same 1 image (no new) - (*_pixels([[1.0, 2.0]]), [[1, 3, 3]]), # Step 2: same 1 image (no new) - ], - } - cache = VLMImageCache(cache_data, num_unique_examples=1, extract_time=0.0, preprocess_time=0.0) - - output = vf.RolloutOutput( - example_id=1, - trajectory=[ - # Step 0: user sends image - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "Describe"}], - completion=[{"role": "assistant", "content": "A cat"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2], - prompt_mask=[0, 0], - completion_ids=[3, 4], - completion_mask=[1, 1], - completion_logprobs=[-0.1, -0.2], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - # Step 1: text-only follow-up (extension holds) - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "More detail"}], - completion=[{"role": "assistant", "content": "Fluffy"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2, 3, 4, 5, 6], - prompt_mask=[0, 0, 0, 0, 0, 0], - completion_ids=[7, 8], - completion_mask=[1, 1], - completion_logprobs=[-0.3, -0.4], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - # Step 2: another text-only follow-up (extension holds) - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "Color?"}], - completion=[{"role": "assistant", "content": "Orange"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], - prompt_mask=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - completion_ids=[11, 12], - completion_mask=[1, 1], - completion_logprobs=[-0.5, -0.6], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - ], - sampling_args={"temperature": 1.0}, - error=None, - ) - - rollouts = interleave_rollout(output, vlm_cache=cache) - - # All 3 steps merge into 1 sample (extension always holds) - assert rollouts is not None - assert len(rollouts) == 1 - rollout = rollouts[0] - assert rollout.prompt_ids == [1, 2] - # completion: step0 [3,4] + step1 new prompt [5,6] + step1 completion [7,8] - # + step2 new prompt [9,10] + step2 completion [11,12] - assert rollout.completion_ids == [3, 4, 5, 6, 7, 8, 9, 10, 11, 12] - assert rollout.completion_mask == [True, True, False, False, True, True, False, False, True, True] - # pixel_values from step 2 (cumulative = same 1 image throughout) - assert _decode_pixels(rollout.pixel_values, rollout.pixel_values_shape) == [[1.0, 2.0]] - assert rollout.image_grid_thw == [[1, 3, 3]] - - -def test_interleave_rollout_vlm_new_image_mid_conversation(): - """ - VLM 3-step trajectory: image in step 0, text in step 1, NEW image in step 2. - Extension holds throughout, so 1 merged sample with cumulative images from step 2. - """ - cache_data = { - 1: [ - (*_pixels([[1.0]]), [[1, 2, 3]]), # Step 0: 1 image - (*_pixels([[1.0]]), [[1, 2, 3]]), # Step 1: still 1 image - (*_pixels([[1.0], [2.0]]), [[1, 2, 3], [1, 4, 4]]), # Step 2: 2 images - ], - } - cache = VLMImageCache(cache_data, num_unique_examples=1, extract_time=0.0, preprocess_time=0.0) - - output = vf.RolloutOutput( - example_id=1, - trajectory=[ - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "Image 1"}], - completion=[{"role": "assistant", "content": "A"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2], - prompt_mask=[0, 0], - completion_ids=[3, 4], - completion_mask=[1, 1], - completion_logprobs=[-0.1, -0.2], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "Text only"}], - completion=[{"role": "assistant", "content": "B"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2, 3, 4, 5, 6], - prompt_mask=[0, 0, 0, 0, 0, 0], - completion_ids=[7, 8], - completion_mask=[1, 1], - completion_logprobs=[-0.3, -0.4], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "Image 2"}], - completion=[{"role": "assistant", "content": "C"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], - prompt_mask=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - completion_ids=[11, 12], - completion_mask=[1, 1], - completion_logprobs=[-0.5, -0.6], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - ], - sampling_args={"temperature": 1.0}, - error=None, - ) - - rollouts = interleave_rollout(output, vlm_cache=cache) - - assert rollouts is not None - assert len(rollouts) == 1 - rollout = rollouts[0] - assert rollout.prompt_ids == [1, 2] - assert rollout.completion_ids == [3, 4, 5, 6, 7, 8, 9, 10, 11, 12] - # Cumulative images from last merged step (step 2): both images - assert _decode_pixels(rollout.pixel_values, rollout.pixel_values_shape) == [[1.0], [2.0]] - assert rollout.image_grid_thw == [[1, 2, 3], [1, 4, 4]] - - -def test_interleave_rollout_vlm_extension_break(): - """ - VLM 3-step trajectory where extension breaks at step 2. - Step 0 has image, step 1 extends (text-only), step 2 breaks (different prefix). - Should produce 2 samples, each with their own cumulative images. - """ - cache_data = { - 1: [ - (*_pixels([[1.0]]), [[1, 2, 3]]), # Step 0: 1 image - (*_pixels([[1.0]]), [[1, 2, 3]]), # Step 1: still 1 image - (*_pixels([[1.0], [2.0]]), [[1, 2, 3], [1, 4, 4]]), # Step 2: 2 images (new image added) - ], - } - cache = VLMImageCache(cache_data, num_unique_examples=1, extract_time=0.0, preprocess_time=0.0) - - output = vf.RolloutOutput( - example_id=1, - trajectory=[ - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "Image 1"}], - completion=[{"role": "assistant", "content": "A"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2], - prompt_mask=[0, 0], - completion_ids=[3, 4], - completion_mask=[1, 1], - completion_logprobs=[-0.1, -0.2], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - # Step 1: extends step 0 - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "Follow-up"}], - completion=[{"role": "assistant", "content": "B"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2, 3, 4, 5, 6], - prompt_mask=[0, 0, 0, 0, 0, 0], - completion_ids=[7, 8], - completion_mask=[1, 1], - completion_logprobs=[-0.3, -0.4], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - # Step 2: extension breaks (different prefix, e.g. context compaction) - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "Image 2"}], - completion=[{"role": "assistant", "content": "C"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[100, 101, 102, 103], - prompt_mask=[0, 0, 0, 0], - completion_ids=[104, 105], - completion_mask=[1, 1], - completion_logprobs=[-0.5, -0.6], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - ], - sampling_args={"temperature": 1.0}, - error=None, - ) - - rollouts = interleave_rollout(output, vlm_cache=cache) - - assert rollouts is not None - assert len(rollouts) == 2 - - # Sample 1: steps 0-1 merged, images from step 1 (still 1 image) - assert rollouts[0].prompt_ids == [1, 2] - assert rollouts[0].completion_ids == [3, 4, 5, 6, 7, 8] - assert _decode_pixels(rollouts[0].pixel_values, rollouts[0].pixel_values_shape) == [[1.0]] - assert rollouts[0].image_grid_thw == [[1, 2, 3]] - - # Sample 2: step 2 alone (extension broke), images from step 2 (2 images) - assert rollouts[1].prompt_ids == [100, 101, 102, 103] - assert rollouts[1].completion_ids == [104, 105] - assert _decode_pixels(rollouts[1].pixel_values, rollouts[1].pixel_values_shape) == [[1.0], [2.0]] - assert rollouts[1].image_grid_thw == [[1, 2, 3], [1, 4, 4]] - - -def test_interleave_rollout_vlm_image_appears_late(): - """ - VLM 3-step trajectory: text-only in steps 0 and 1, first image in step 2. - Extension holds throughout so all steps merge into 1 sample. - The sample should have pixel_values=None until step 2 sets them. - """ - cache_data = { - 1: [ - (None, None, None), # Step 0: no images - (None, None, None), # Step 1: no images - (*_pixels([[5.0, 6.0]]), [[1, 3, 3]]), # Step 2: first image appears - ], - } - cache = VLMImageCache(cache_data, num_unique_examples=1, extract_time=0.0, preprocess_time=0.0) - - output = vf.RolloutOutput( - example_id=1, - trajectory=[ - # Step 0: text-only - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "Hello"}], - completion=[{"role": "assistant", "content": "Hi"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2], - prompt_mask=[0, 0], - completion_ids=[3, 4], - completion_mask=[1, 1], - completion_logprobs=[-0.1, -0.2], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - # Step 1: text-only (extension holds) - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "Question"}], - completion=[{"role": "assistant", "content": "Answer"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2, 3, 4, 5, 6], - prompt_mask=[0, 0, 0, 0, 0, 0], - completion_ids=[7, 8], - completion_mask=[1, 1], - completion_logprobs=[-0.3, -0.4], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - # Step 2: user sends image (extension holds) - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "Describe this"}], - completion=[{"role": "assistant", "content": "A photo"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], - prompt_mask=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - completion_ids=[11, 12], - completion_mask=[1, 1], - completion_logprobs=[-0.5, -0.6], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - ], - sampling_args={"temperature": 1.0}, - error=None, - ) - - rollouts = interleave_rollout(output, vlm_cache=cache) - - assert rollouts is not None - assert len(rollouts) == 1 - rollout = rollouts[0] - assert rollout.prompt_ids == [1, 2] - assert rollout.completion_ids == [3, 4, 5, 6, 7, 8, 9, 10, 11, 12] - assert rollout.completion_mask == [True, True, False, False, True, True, False, False, True, True] - # pixel_values from step 2 (the first step with an image) - assert _decode_pixels(rollout.pixel_values, rollout.pixel_values_shape) == [[5.0, 6.0]] - assert rollout.image_grid_thw == [[1, 3, 3]] - - def test_interleave_rollout_empty_trajectory(): """Empty trajectory returns None.""" output = vf.RolloutOutput( @@ -1601,257 +797,6 @@ def test_interleave_rollout_error_masks_all_false(): assert rollout.completion_temperatures == [0.8] * 6 -def test_interleave_rollout_vlm_interleaved_agents(): - """ - VLM + interleaved agents: agent1 and agent2 interleaved, each with images. - agent1 gets cumulative images from its own steps, agent2 from its step. - - Steps (0-indexed): - 0: agent1-step1 (image A) - 1: agent1-step2 (extends step 0, image A still) - 2: agent2-step1 (different prefix, image B) - 3: agent1-step3 (extends step 0+1, image A + new image C) - - Expected: 2 samples - - agent1: merged steps 0,1,3 → pixel_values from step 3 (images A+C) - - agent2: step 2 alone → pixel_values from step 2 (image B) - """ - cache_data = { - 1: [ - (*_pixels([[1.0]]), [[1, 2, 2]]), # Step 0: image A - (*_pixels([[1.0]]), [[1, 2, 2]]), # Step 1: still image A - (*_pixels([[9.0]]), [[1, 5, 5]]), # Step 2: image B (agent2) - (*_pixels([[1.0], [3.0]]), [[1, 2, 2], [1, 3, 3]]), # Step 3: images A+C (agent1) - ], - } - cache = VLMImageCache(cache_data, num_unique_examples=1, extract_time=0.0, preprocess_time=0.0) - - output = vf.RolloutOutput( - example_id=1, - trajectory=[ - # Step 0: agent1-step1 - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "Image A"}], - completion=[{"role": "assistant", "content": "A1"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2], - prompt_mask=[0, 0], - completion_ids=[3, 4], - completion_mask=[1, 1], - completion_logprobs=[-0.1, -0.2], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - # Step 1: agent1-step2 (extends step 0) - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "Follow-up"}], - completion=[{"role": "assistant", "content": "A2"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2, 3, 4, 5, 6], - prompt_mask=[0, 0, 0, 0, 0, 0], - completion_ids=[7, 8], - completion_mask=[1, 1], - completion_logprobs=[-0.3, -0.4], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - # Step 2: agent2-step1 (different prefix) - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "Image B"}], - completion=[{"role": "assistant", "content": "B1"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[100, 101], - prompt_mask=[0, 0], - completion_ids=[102, 103], - completion_mask=[1, 1], - completion_logprobs=[-0.5, -0.6], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - # Step 3: agent1-step3 (extends agent1, merges back) - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "Image C added"}], - completion=[{"role": "assistant", "content": "A3"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], - prompt_mask=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - completion_ids=[11, 12], - completion_mask=[1, 1], - completion_logprobs=[-0.7, -0.8], - overlong_prompt=False, - is_truncated=False, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - ], - error=None, - sampling_args={"temperature": 1.0}, - ) - - rollouts = interleave_rollout(output, vlm_cache=cache) - - assert rollouts is not None - assert len(rollouts) == 2 - - # Agent1: steps 0,1,3 merged → images from step 3 (A+C) - agent1 = rollouts[0] - assert agent1.prompt_ids == [1, 2] - assert agent1.completion_ids == [3, 4, 5, 6, 7, 8, 9, 10, 11, 12] - assert agent1.completion_mask == [True, True, False, False, True, True, False, False, True, True] - assert _decode_pixels(agent1.pixel_values, agent1.pixel_values_shape) == [[1.0], [3.0]] - assert agent1.image_grid_thw == [[1, 2, 2], [1, 3, 3]] - - # Agent2: step 2 alone → images from step 2 (B) - agent2 = rollouts[1] - assert agent2.prompt_ids == [100, 101] - assert agent2.completion_ids == [102, 103] - assert agent2.completion_mask == [True, True] - assert _decode_pixels(agent2.pixel_values, agent2.pixel_values_shape) == [[9.0]] - assert agent2.image_grid_thw == [[1, 5, 5]] - - -def test_build_vlm_image_cache_handles_divergent_rollouts(): - """Test that build_vlm_image_cache keys images per rollout when trajectories diverge.""" - import torch - - red_url = _create_test_image("red") - blue_url = _create_test_image("blue") - green_url = _create_test_image("green") - - rollout_a = vf.RolloutOutput( - example_id=1, - trajectory=[ - vf.TrajectoryStep( - prompt=[_create_image_message(red_url, "What color?")], - completion=[{"role": "assistant", "content": "Red"}], - response=MagicMock(), - tokens=MagicMock(), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - ], - sampling_args={"temperature": 1.0}, - error=None, - ) - - rollout_b = vf.RolloutOutput( - example_id=1, - trajectory=[ - vf.TrajectoryStep( - prompt=[_create_image_message(blue_url, "What color?")], - completion=[{"role": "assistant", "content": "Blue"}], - response=MagicMock(), - tokens=MagicMock(), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - vf.TrajectoryStep( - prompt=[ - _create_image_message(blue_url, "What color?"), - {"role": "assistant", "content": "Blue"}, - _create_image_message(green_url, "And this one?"), - ], - completion=[{"role": "assistant", "content": "Green"}], - response=MagicMock(), - tokens=MagicMock(), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - ], - sampling_args={"temperature": 1.0}, - error=None, - ) - - # Mock processor that returns predictable tensors - mock_processor = MagicMock() - mock_processor.image_processor = MagicMock( - side_effect=lambda images, return_tensors: { - "pixel_values": torch.arange(len(images), dtype=torch.float32).view(-1, 1), - "image_grid_thw": torch.tensor([[1, 1, 1]] * len(images)), - } - ) - - rollouts = [rollout_a, rollout_b] - cache = build_vlm_image_cache(rollouts, mock_processor) - - assert cache.num_unique_examples == 1 - - pv, shape, grid = cache.get_for_step(0, 0) - assert _decode_pixels(pv, shape) == [[0.0]] - assert grid == [[1, 1, 1]] - - pv, shape, grid = cache.get_for_step(1, 0) - assert _decode_pixels(pv, shape) == [[1.0]] - assert grid == [[1, 1, 1]] - - pv, shape, grid = cache.get_for_step(1, 1) - assert _decode_pixels(pv, shape) == [[1.0], [2.0]] - assert grid == [[1, 1, 1], [1, 1, 1]] - - -def test_build_vlm_image_cache_no_images(): - output = vf.RolloutOutput( - example_id=1, - trajectory=[ - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "Hello"}], - completion=[{"role": "assistant", "content": "Hi"}], - response=MagicMock(), - tokens=MagicMock(), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ) - ], - sampling_args={"temperature": 1.0}, - error=None, - ) - - cache = build_vlm_image_cache([output], MagicMock()) - - pv, shape, grid = cache.get_for_step(0, 0) - assert pv is None - assert shape is None - assert grid is None - - def test_align_routed_experts_none(): assert _align_routed_experts(None, 10) is None @@ -2042,154 +987,6 @@ def test_interleave_rollout_none_routed_experts_stays_none(): # ============================================================================= -def test_image_store_assemble(): - """_ImageStore.assemble joins per-image bytes and computes correct shape/grids.""" - # 2 images: image 0 has 3 patches, image 1 has 2 patches, patch_dim=4 - patch_dim = 4 - img0 = np.arange(3 * patch_dim, dtype=np.float32).tobytes() - img1 = np.arange(2 * patch_dim, dtype=np.float32).tobytes() - - store = _ImageStore( - image_bytes=[img0, img1], - image_num_patches=[3, 2], - patch_dim=patch_dim, - image_grids=[[1, 1, 3], [1, 1, 2]], - ) - - # Assemble both images - pixel_bytes, shape, grids = store.assemble([0, 1]) - assert shape == [5, 4] - assert grids == [[1, 1, 3], [1, 1, 2]] - assert pixel_bytes == img0 + img1 - - # Assemble single image - pixel_bytes, shape, grids = store.assemble([1]) - assert shape == [2, 4] - assert grids == [[1, 1, 2]] - assert pixel_bytes == img1 - - # Assemble in reverse order - pixel_bytes, shape, grids = store.assemble([1, 0]) - assert shape == [5, 4] - assert grids == [[1, 1, 2], [1, 1, 3]] - assert pixel_bytes == img1 + img0 - - -def test_vlm_image_cache_from_store(): - """VLMImageCache.from_store provides correct get_for_step/get_all via lazy assembly.""" - patch_dim = 2 - img0_data = np.array([[1.0, 2.0]], dtype=np.float32) - img1_data = np.array([[3.0, 4.0]], dtype=np.float32) - - store = _ImageStore( - image_bytes=[img0_data.tobytes(), img1_data.tobytes()], - image_num_patches=[1, 1], - patch_dim=patch_dim, - image_grids=[[1, 2, 3], [1, 4, 4]], - ) - - step_indices = { - 1: [[0], [0, 1]], # step 0: image 0; step 1: images 0+1 - } - - cache = VLMImageCache.from_store( - store=store, - step_indices=step_indices, - num_unique_examples=1, - num_unique_images=2, - extract_time=0.0, - preprocess_time=0.0, - ) - - # Step 0: just image 0 - pv, shape, grid = cache.get_for_step(1, 0) - assert _decode_pixels(pv, shape) == [[1.0, 2.0]] - assert grid == [[1, 2, 3]] - - # Step 1: images 0 + 1 - pv, shape, grid = cache.get_for_step(1, 1) - assert _decode_pixels(pv, shape) == [[1.0, 2.0], [3.0, 4.0]] - assert grid == [[1, 2, 3], [1, 4, 4]] - - # get_all returns last step - pv, shape, grid = cache.get_all(1) - assert _decode_pixels(pv, shape) == [[1.0, 2.0], [3.0, 4.0]] - assert grid == [[1, 2, 3], [1, 4, 4]] - - # Missing key - pv, shape, grid = cache.get_for_step(999, 0) - assert pv is None - - # Out of range step - pv, shape, grid = cache.get_for_step(1, 5) - assert pv is None - - -def test_vlm_image_cache_from_store_no_images(): - """from_store with store=None returns (None, None, None) for all queries.""" - step_indices = {0: [[], []]} # 2 steps with no images - - cache = VLMImageCache.from_store( - store=None, - step_indices=step_indices, - num_unique_examples=1, - num_unique_images=0, - extract_time=0.0, - preprocess_time=0.0, - ) - - pv, shape, grid = cache.get_for_step(0, 0) - assert pv is None - assert shape is None - assert grid is None - - -def test_build_vlm_image_cache_uses_store(): - """build_vlm_image_cache returns a store-backed cache.""" - import torch - - red_url = _create_test_image("red") - - output = vf.RolloutOutput( - example_id=1, - trajectory=[ - vf.TrajectoryStep( - prompt=[_create_image_message(red_url, "What color?")], - completion=[{"role": "assistant", "content": "Red"}], - response=MagicMock(), - tokens=MagicMock(), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - ], - sampling_args={"temperature": 1.0}, - error=None, - ) - - mock_processor = MagicMock() - mock_processor.image_processor = MagicMock( - side_effect=lambda images, return_tensors: { - "pixel_values": torch.arange(len(images), dtype=torch.float32).view(-1, 1), - "image_grid_thw": torch.tensor([[1, 1, 1]] * len(images)), - } - ) - - cache = build_vlm_image_cache([output], mock_processor) - - # Should be store-backed - assert cache._store is not None - assert cache._step_indices is not None - - # Should still work correctly - pv, shape, grid = cache.get_for_step(0, 0) - assert pv is not None - assert shape == [1, 1] - assert grid == [[1, 1, 1]] - - # ── Renderer-emitted multimodal data ─────────────────────────────────── @@ -2285,8 +1082,7 @@ def test_interleave_rollout_packs_pixels_from_renderer_mm_data(): # Token 2 is the image placeholder, token 5 is the video placeholder. mm_mapping = {2: 1, 5: 2} - # No vlm_cache — the renderer sidecar should fully cover the path. - rollouts = interleave_rollout(output, vlm_cache=None, mm_token_type_ids_mapping=mm_mapping) + rollouts = interleave_rollout(output, mm_token_type_ids_mapping=mm_mapping) assert rollouts is not None and len(rollouts) == 1 sample = rollouts[0] @@ -2295,64 +1091,10 @@ def test_interleave_rollout_packs_pixels_from_renderer_mm_data(): assert sample.prompt_ids == [1, 2] assert sample.completion_ids == [3, 4, 5, 6, 7] # Pixel values packed from step 1's two items, concatenated. - assert _decode_pixels(sample.pixel_values, sample.pixel_values_shape) == [ + assert _decode_mm_pixels(sample) == [ [1.0, 2.0], [3.0, 4.0], ] - assert sample.image_grid_thw == [[1, 2, 3], [1, 4, 4]] + assert _decode_mm_thw(sample) == [[1, 2, 3], [1, 4, 4]] # mm_token_type_ids: image at token 2, video at token 5, rest 0. assert sample.mm_token_type_ids == [0, 1, 0, 0, 2, 0, 0] - - -def test_interleave_rollout_renderer_mm_data_wins_over_vlm_cache(): - """When both renderer mm_data AND vlm_cache are present, renderer - mm_data wins — the rollout came through a multimodal-aware renderer - so the placeholder offsets and processed tensors are authoritative.""" - import torch as _torch - from renderers.base import MultiModalData, PlaceholderRange - - renderer_pv = _torch.tensor([[7.0]], dtype=_torch.float32) - renderer_thw = _torch.tensor([[1, 9, 9]], dtype=_torch.int64) - mm = MultiModalData( - mm_hashes={"image": ["render"]}, - mm_placeholders={"image": [PlaceholderRange(offset=1, length=1)]}, - mm_items={"image": [{"pixel_values": renderer_pv, "image_grid_thw": renderer_thw}]}, - ) - - # VLM cache populated with a DIFFERENT image — we shouldn't see this. - cache_data = {1: [(*_pixels([[99.0]]), [[1, 2, 2]])]} - cache = VLMImageCache(cache_data, num_unique_examples=1, extract_time=0.0, preprocess_time=0.0) - - output = vf.RolloutOutput( - example_id=1, - trajectory=[ - vf.TrajectoryStep( - prompt=[{"role": "user", "content": "Turn 1"}], - completion=[{"role": "assistant", "content": "Response 1"}], - response=MagicMock(), - tokens=vf.TrajectoryStepTokens( - prompt_ids=[1, 2], - prompt_mask=[0, 0], - completion_ids=[3, 4], - completion_mask=[1, 1], - completion_logprobs=[-0.1, -0.2], - overlong_prompt=False, - is_truncated=False, - multi_modal_data=mm, - ), - reward=None, - advantage=None, - is_truncated=False, - trajectory_id="1", - extras={}, - ), - ], - sampling_args={"temperature": 1.0}, - error=None, - ) - - rollouts = interleave_rollout(output, vlm_cache=cache, mm_token_type_ids_mapping={2: 1}) - - assert rollouts is not None and len(rollouts) == 1 - assert _decode_pixels(rollouts[0].pixel_values, rollouts[0].pixel_values_shape) == [[7.0]] - assert rollouts[0].image_grid_thw == [[1, 9, 9]] From 24864bcfb3787d493345dc0ce5f56d2faf08e1c7 Mon Sep 17 00:00:00 2001 From: hallerite Date: Mon, 11 May 2026 22:36:26 +0000 Subject: [PATCH 3/6] fix: align with renderer-multimodal PR surface (configs + deps + tests) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CI was failing for three reasons after the cherry-pick of the renderer multimodal commits onto current origin/main. Three fixes: 1. Multimodal configs missing the renderer flag pair. The orchestrator config validator (added in the same PR) requires ``use_renderer = true`` when ``model.vlm`` is set, AND it's mutually exclusive with ``use_token_client`` (default ``true``). Three configs (``rl_color_codeword_test.toml``, ``rl_color_codeword.toml``, ``ci/nightly/multimodal_color_codeword.toml``) needed both flags set explicitly in ``[orchestrator]``. Delete ``rl_color_codeword_main_mito.toml`` — that was the A/B reference for the legacy MITO path, which this PR rips out. With MITO gone the config is no longer runnable; the ``rl_color_codeword_feat_renderer`` counterpart already covers the new renderer-driven path. 2. ``test_model_forward.py`` was still calling ``forward(..., pixel_values=..., image_grid_thw=...)``. ``forward()`` now takes a generic ``mm_kwargs: dict`` (so adding new VLM families doesn't require touching the trainer signature) — update both tests to pass ``mm_kwargs={"pixel_values": ..., "image_grid_thw": ...}`` instead. 3. ``renderers`` / ``verifiers`` deps stale. The orchestrator imports ``MultiModalData`` from ``renderers.base`` (introduced in the companion renderers PR) and threads ``multi_modal_data`` end-to-end via the verifiers ``RendererClient`` changes. Pin both to their feature branches until the upstream PRs merge and PyPI / git rev-pins are bumped: - renderers @ feat/multimodal-vlm - verifiers @ feat/renderer-multimodal-passthrough Drops the ``renderers==0.1.6`` PyPI pin (the new symbols are post-0.1.6). Co-Authored-By: Claude Opus 4.7 (1M context) --- .../ci/nightly/multimodal_color_codeword.toml | 2 + configs/multimodal/rl_color_codeword.toml | 2 + .../rl_color_codeword_main_mito.toml | 61 ------------------- .../multimodal/rl_color_codeword_test.toml | 2 + pyproject.toml | 5 +- tests/unit/train/test_model_forward.py | 14 ++++- uv.lock | 16 ++--- 7 files changed, 27 insertions(+), 75 deletions(-) delete mode 100644 configs/multimodal/rl_color_codeword_main_mito.toml diff --git a/configs/ci/nightly/multimodal_color_codeword.toml b/configs/ci/nightly/multimodal_color_codeword.toml index a90fdfb454..aa0276edc5 100644 --- a/configs/ci/nightly/multimodal_color_codeword.toml +++ b/configs/ci/nightly/multimodal_color_codeword.toml @@ -16,6 +16,8 @@ language_model_attr = "model.language_model" [orchestrator] batch_size = 256 rollouts_per_example = 16 +use_token_client = false +use_renderer = true [orchestrator.train.sampling] max_completion_tokens = 64 diff --git a/configs/multimodal/rl_color_codeword.toml b/configs/multimodal/rl_color_codeword.toml index 46910ed8a0..0952460070 100644 --- a/configs/multimodal/rl_color_codeword.toml +++ b/configs/multimodal/rl_color_codeword.toml @@ -11,6 +11,8 @@ language_model_attr = "model.language_model" [orchestrator] batch_size = 256 rollouts_per_example = 16 +use_token_client = false +use_renderer = true [orchestrator.train.sampling] diff --git a/configs/multimodal/rl_color_codeword_main_mito.toml b/configs/multimodal/rl_color_codeword_main_mito.toml deleted file mode 100644 index 9ae481fb69..0000000000 --- a/configs/multimodal/rl_color_codeword_main_mito.toml +++ /dev/null @@ -1,61 +0,0 @@ -# 20-step Qwen3-VL-4B RL run on color-codeword using the existing MITO -# (chat-completions) inference path — baseline for the renderer A/B. -# -# Mirrors rl_color_codeword_feat_renderer.toml exactly except for the -# orchestrator client flags: this run goes through TITO/MITO so the -# inference server applies the chat template, runs the image processor -# server-side, and the orchestrator re-tokenizes locally. -# -# Run on ``main`` (renderers, verifiers, primerlmain). Compare metrics in -# W&B project ``multimodal-renderer`` against the feat run. - -max_steps = 20 -seq_len = 4096 -output_dir = "outputs/rl_color_codeword_main_mito" -clean_output_dir = true - -[model] -name = "Qwen/Qwen3-VL-4B-Instruct" - -[model.vlm] -vision_encoder_attr = "model.visual" -language_model_attr = "model.language_model" - -[deployment] -num_train_gpus = 1 -num_infer_gpus = 1 -gpus_per_node = 2 - -[orchestrator] -batch_size = 16 -rollouts_per_example = 4 -# MITO baseline: server-side chat templating + image processing. -use_renderer = false -use_token_client = true - -[orchestrator.train.sampling] -max_completion_tokens = 64 - -[[orchestrator.train.env]] -id = "color-codeword" -args = { images_per_turn = 1, max_turns = 2, num_examples = 100, seed = 42 } - -[trainer] - -[trainer.model] -optimization_dtype = "bfloat16" -reduce_dtype = "bfloat16" - -[trainer.optim] -lr = 3e-6 - -[inference] - -[inference.parallel] -dp = 1 -tp = 1 - -[wandb] -project = "multimodal-renderer" -name = "main-mito-20step" -tags = ["qwen3vl-4b", "color-codeword", "mito", "main-branch"] diff --git a/configs/multimodal/rl_color_codeword_test.toml b/configs/multimodal/rl_color_codeword_test.toml index 151bad0987..45408ca084 100644 --- a/configs/multimodal/rl_color_codeword_test.toml +++ b/configs/multimodal/rl_color_codeword_test.toml @@ -12,6 +12,8 @@ language_model_attr = "model.language_model" [orchestrator] batch_size = 16 rollouts_per_example = 2 +use_token_client = false +use_renderer = true [orchestrator.train.sampling] max_completion_tokens = 32 diff --git a/pyproject.toml b/pyproject.toml index 237eea7720..e208c813aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ dependencies = [ "uvloop>=0.21.0", "torchtitan", "verifiers", - "renderers==0.1.6", + "renderers", "dion", "tilelang>=0.1.8", "flash-linear-attention", @@ -166,7 +166,8 @@ prime-rl-configs = { workspace = true } torch = { index = "pytorch-cu128" } torchvision = { index = "pytorch-cu128" } torchaudio = { index = "pytorch-cu128" } -verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers.git", rev = "aa428f3" } +verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers.git", branch = "feat/renderer-multimodal-passthrough" } +renderers = { git = "https://github.com/PrimeIntellect-ai/renderers.git", branch = "feat/multimodal-vlm" } torchtitan = { git = "https://github.com/pytorch/torchtitan", rev = "a1fdd7e" } dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" } transformers = { git = "https://github.com/huggingface/transformers.git", rev = "c1c3424" } diff --git a/tests/unit/train/test_model_forward.py b/tests/unit/train/test_model_forward.py index b62805bf3d..df7818c365 100644 --- a/tests/unit/train/test_model_forward.py +++ b/tests/unit/train/test_model_forward.py @@ -29,7 +29,12 @@ def test_forward_adds_qwen3_vl_mm_token_type_ids(): pixel_values = torch.ones(2, 3) image_grid_thw = torch.tensor([[1, 1, 2]]) - forward(model, input_ids, position_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw) + forward( + model, + input_ids, + position_ids, + mm_kwargs={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw}, + ) assert model.kwargs is not None assert "position_ids" not in model.kwargs @@ -43,7 +48,12 @@ def test_forward_skips_mm_token_type_ids_for_other_vlm_models(): input_ids = torch.tensor([[1, 10, 10, 2]]) position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0) - forward(model, input_ids, position_ids, pixel_values=torch.ones(2, 3), image_grid_thw=torch.tensor([[1, 1, 2]])) + forward( + model, + input_ids, + position_ids, + mm_kwargs={"pixel_values": torch.ones(2, 3), "image_grid_thw": torch.tensor([[1, 1, 2]])}, + ) assert model.kwargs is not None assert "position_ids" not in model.kwargs diff --git a/uv.lock b/uv.lock index 0b2d2268ee..107eba50d2 100644 --- a/uv.lock +++ b/uv.lock @@ -11,7 +11,7 @@ supported-markers = [ ] [options] -exclude-newer = "0001-01-01T00:00:00Z" # This has no effect and is included for backwards compatibility when using relative exclude-newer values. +exclude-newer = "2026-05-04T22:33:38.153203392Z" exclude-newer-span = "P7D" [options.exclude-newer-package] @@ -2890,7 +2890,7 @@ requires-dist = [ { name = "pyarrow", specifier = ">=21.0.0" }, { name = "pyzmq", specifier = ">=27.1.0" }, { name = "quack-kernels", marker = "extra == 'quack'", specifier = ">=0.3.3" }, - { name = "renderers", specifier = "==0.1.6" }, + { name = "renderers", git = "https://github.com/PrimeIntellect-ai/renderers.git?branch=feat%2Fmultimodal-vlm" }, { name = "reverse-text", marker = "extra == 'envs'", index = "https://hub.primeintellect.ai/primeintellect/simple/" }, { name = "rich", specifier = ">=14.0.0" }, { name = "ring-flash-attn", specifier = ">=0.1.8" }, @@ -2905,7 +2905,7 @@ requires-dist = [ { name = "torchvision", index = "https://download.pytorch.org/whl/cu128" }, { name = "transformers", git = "https://github.com/huggingface/transformers.git?rev=c1c3424" }, { name = "uvloop", specifier = ">=0.21.0" }, - { name = "verifiers", git = "https://github.com/PrimeIntellect-ai/verifiers.git?rev=aa428f3" }, + { name = "verifiers", git = "https://github.com/PrimeIntellect-ai/verifiers.git?branch=feat%2Frenderer-multimodal-passthrough" }, { name = "vllm", marker = "platform_machine != 'aarch64' and platform_machine != 'x86_64'", specifier = ">=0.20.2" }, { name = "vllm", marker = "platform_machine == 'aarch64'", url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl" }, { name = "vllm", marker = "platform_machine == 'x86_64'", url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_x86_64.whl" }, @@ -3376,8 +3376,8 @@ wheels = [ [[package]] name = "renderers" -version = "0.1.6" -source = { registry = "https://pypi.org/simple" } +version = "0.1.7" +source = { git = "https://github.com/PrimeIntellect-ai/renderers.git?branch=feat%2Fmultimodal-vlm#28fd8a122409143f84efce156d7917cdb55d8d0f" } dependencies = [ { name = "jinja2", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -3386,10 +3386,6 @@ dependencies = [ { name = "tiktoken", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "transformers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/7c/a7/26162494dab2d7740ff02191cb87c30b68450fb154363c7f0a434e7f3ea9/renderers-0.1.6.tar.gz", hash = "sha256:b74bc3dc870bea3c37ff5b47826ace9b8dd608a4c1f56554c39be1b20b2c63dc", size = 163768, upload-time = "2026-05-07T14:12:36.634Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/5e/ad/2cf218b9fafe2333fb3e80e123e3e2022d4923d9a61fa73ee6d79f39b563/renderers-0.1.6-py3-none-any.whl", hash = "sha256:90c626713239ec108716b7c9d194ba81ffcebe94dc003324f14fbd70e6793e89", size = 83348, upload-time = "2026-05-07T14:12:35.218Z" }, -] [[package]] name = "requests" @@ -4168,7 +4164,7 @@ wheels = [ [[package]] name = "verifiers" version = "0.1.14" -source = { git = "https://github.com/PrimeIntellect-ai/verifiers.git?rev=aa428f3#aa428f3941ae35a7cf7c0dad7e60c7eca525bac6" } +source = { git = "https://github.com/PrimeIntellect-ai/verifiers.git?branch=feat%2Frenderer-multimodal-passthrough#810a7275e189cfcb9ce99382877bd6d259427d3d" } dependencies = [ { name = "aiolimiter", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "anthropic", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, From 1097ed2da690500a2676543e4946f950b8cc9d6f Mon Sep 17 00:00:00 2001 From: hallerite Date: Tue, 12 May 2026 00:03:28 +0000 Subject: [PATCH 4/6] chore: bump renderers + verifiers to latest feature-branch commits MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Picks up: - renderers c3feaa5: RendererPool implements Renderer protocol structurally (callers can drop pool unwrap + isinstance branching), size=1 fast path, is_multimodal helper with per-type cache. - verifiers 64f2555a: bugbot pass — multimodal dispatch fix (was broken for pooled renderers), tighter is_json_serializable, response-tokens mm_data strip on intermediate steps. Co-Authored-By: Claude Opus 4.7 (1M context) --- uv.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/uv.lock b/uv.lock index 107eba50d2..cec01e740e 100644 --- a/uv.lock +++ b/uv.lock @@ -11,7 +11,7 @@ supported-markers = [ ] [options] -exclude-newer = "2026-05-04T22:33:38.153203392Z" +exclude-newer = "2026-05-05T00:02:37.035558251Z" exclude-newer-span = "P7D" [options.exclude-newer-package] @@ -3377,7 +3377,7 @@ wheels = [ [[package]] name = "renderers" version = "0.1.7" -source = { git = "https://github.com/PrimeIntellect-ai/renderers.git?branch=feat%2Fmultimodal-vlm#28fd8a122409143f84efce156d7917cdb55d8d0f" } +source = { git = "https://github.com/PrimeIntellect-ai/renderers.git?branch=feat%2Fmultimodal-vlm#c3feaa5fef8c4bfb02e8031c22aeeb77b6563e02" } dependencies = [ { name = "jinja2", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -4164,7 +4164,7 @@ wheels = [ [[package]] name = "verifiers" version = "0.1.14" -source = { git = "https://github.com/PrimeIntellect-ai/verifiers.git?branch=feat%2Frenderer-multimodal-passthrough#810a7275e189cfcb9ce99382877bd6d259427d3d" } +source = { git = "https://github.com/PrimeIntellect-ai/verifiers.git?branch=feat%2Frenderer-multimodal-passthrough#64f2555a18bfeabf1ce61625cb0423551913c7dd" } dependencies = [ { name = "aiolimiter", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "anthropic", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, From bf26a06fb8c7e674428268ccf2543fd1f5c41f93 Mon Sep 17 00:00:00 2001 From: hallerite Date: Tue, 12 May 2026 00:07:25 +0000 Subject: [PATCH 5/6] chore: bump renderers + verifiers to pick up isinstance dispatch fix Co-Authored-By: Claude Opus 4.7 (1M context) --- uv.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/uv.lock b/uv.lock index cec01e740e..cc4bb60171 100644 --- a/uv.lock +++ b/uv.lock @@ -11,7 +11,7 @@ supported-markers = [ ] [options] -exclude-newer = "2026-05-05T00:02:37.035558251Z" +exclude-newer = "2026-05-05T00:07:23.83884381Z" exclude-newer-span = "P7D" [options.exclude-newer-package] @@ -3377,7 +3377,7 @@ wheels = [ [[package]] name = "renderers" version = "0.1.7" -source = { git = "https://github.com/PrimeIntellect-ai/renderers.git?branch=feat%2Fmultimodal-vlm#c3feaa5fef8c4bfb02e8031c22aeeb77b6563e02" } +source = { git = "https://github.com/PrimeIntellect-ai/renderers.git?branch=feat%2Fmultimodal-vlm#5353c9845596957e878051a01f2a129604b28e5c" } dependencies = [ { name = "jinja2", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -4164,7 +4164,7 @@ wheels = [ [[package]] name = "verifiers" version = "0.1.14" -source = { git = "https://github.com/PrimeIntellect-ai/verifiers.git?branch=feat%2Frenderer-multimodal-passthrough#64f2555a18bfeabf1ce61625cb0423551913c7dd" } +source = { git = "https://github.com/PrimeIntellect-ai/verifiers.git?branch=feat%2Frenderer-multimodal-passthrough#397b8aa5d27d1e84d78c7b0c6a8b46448b63268f" } dependencies = [ { name = "aiolimiter", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "anthropic", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, From 4c877be43c18352611dcf4c6cf2dbfc042a8648f Mon Sep 17 00:00:00 2001 From: hallerite Date: Tue, 12 May 2026 12:53:59 +0000 Subject: [PATCH 6/6] chore: drop renderers git source, pin to released 0.1.7 renderers v0.1.7 is now on PyPI with the multimodal feature set this branch consumes (RenderedTokens, MultiModalData, MultimodalRenderer protocol, mm_token_type_id_map, RendererPool delegation). verifiers still pinned to feat/renderer-multimodal-passthrough until that companion PR merges. Co-Authored-By: Claude Opus 4.7 (1M context) --- pyproject.toml | 3 +-- uv.lock | 10 +++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e208c813aa..dbcc8ddc05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ dependencies = [ "uvloop>=0.21.0", "torchtitan", "verifiers", - "renderers", + "renderers>=0.1.7", "dion", "tilelang>=0.1.8", "flash-linear-attention", @@ -167,7 +167,6 @@ torch = { index = "pytorch-cu128" } torchvision = { index = "pytorch-cu128" } torchaudio = { index = "pytorch-cu128" } verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers.git", branch = "feat/renderer-multimodal-passthrough" } -renderers = { git = "https://github.com/PrimeIntellect-ai/renderers.git", branch = "feat/multimodal-vlm" } torchtitan = { git = "https://github.com/pytorch/torchtitan", rev = "a1fdd7e" } dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" } transformers = { git = "https://github.com/huggingface/transformers.git", rev = "c1c3424" } diff --git a/uv.lock b/uv.lock index cc4bb60171..5fc66a22aa 100644 --- a/uv.lock +++ b/uv.lock @@ -11,7 +11,7 @@ supported-markers = [ ] [options] -exclude-newer = "2026-05-05T00:07:23.83884381Z" +exclude-newer = "2026-05-05T12:53:34.246685771Z" exclude-newer-span = "P7D" [options.exclude-newer-package] @@ -2890,7 +2890,7 @@ requires-dist = [ { name = "pyarrow", specifier = ">=21.0.0" }, { name = "pyzmq", specifier = ">=27.1.0" }, { name = "quack-kernels", marker = "extra == 'quack'", specifier = ">=0.3.3" }, - { name = "renderers", git = "https://github.com/PrimeIntellect-ai/renderers.git?branch=feat%2Fmultimodal-vlm" }, + { name = "renderers", specifier = ">=0.1.7" }, { name = "reverse-text", marker = "extra == 'envs'", index = "https://hub.primeintellect.ai/primeintellect/simple/" }, { name = "rich", specifier = ">=14.0.0" }, { name = "ring-flash-attn", specifier = ">=0.1.8" }, @@ -3377,7 +3377,7 @@ wheels = [ [[package]] name = "renderers" version = "0.1.7" -source = { git = "https://github.com/PrimeIntellect-ai/renderers.git?branch=feat%2Fmultimodal-vlm#5353c9845596957e878051a01f2a129604b28e5c" } +source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jinja2", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -3386,6 +3386,10 @@ dependencies = [ { name = "tiktoken", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "transformers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] +sdist = { url = "https://files.pythonhosted.org/packages/b1/5f/fdd86253f562ffdecbdb56cb8591961f0b82914a9abc4ebf43976befb891/renderers-0.1.7.tar.gz", hash = "sha256:d17ccbc3813dd0ee3a6f7d8794308bd4300eb2a80dbc01e83c8f47513e9614f0", size = 209731, upload-time = "2026-05-12T12:51:07.559Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f2/71/af973d6fb806cc0f7e9f4d7cab3ed7265745e596298c8eab2185260c6979/renderers-0.1.7-py3-none-any.whl", hash = "sha256:64b119d13952df983462d72882aa97f1c7075fc4a7409a6b6b85ee88d83dfa52", size = 97934, upload-time = "2026-05-12T12:51:06.534Z" }, +] [[package]] name = "requests"