diff --git a/docs/reference.md b/docs/reference.md index 0586d9af3..5a4da7c88 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -193,6 +193,7 @@ class TrajectoryStepTokens(TypedDict): overlong_prompt: bool is_truncated: bool routed_experts: list[list[list[int]]] | None # [seq_len, layers, topk] to enable router replay + multi_modal_data: NotRequired[Any] # renderers.MultiModalData sidecar (pixel_values, placeholder ranges) — set only on multimodal rollouts ``` Token-level data for training. diff --git a/pyproject.toml b/pyproject.toml index d3bdf19bf..1b6e2e078 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,7 @@ dev = [ "aiohttp>=3.9.0", "python-dotenv>=1.0.0", "nltk", - "renderers>=0.1.6", + "renderers>=0.1.8.dev0", ] [project.optional-dependencies] @@ -94,7 +94,7 @@ browser = [ "python-dotenv>=1.0.0", ] renderers = [ - "renderers>=0.1.6", + "renderers>=0.1.8.dev0", ] rl = [ "torch>=2.8.0,<2.9.0", @@ -125,12 +125,6 @@ prime-tunnel = false prime-sandboxes = false renderers = false -[tool.uv.sources] -# Pinned to renderers main until the next PyPI release lands; drop after. -# fe67f9f = renderers main: PR #4 squash-merge — construction-time -# preserve_*_thinking flags on create_renderer / create_renderer_pool. -renderers = { git = "https://github.com/PrimeIntellect-ai/renderers.git", rev = "fe67f9f" } - [tool.uv.extra-build-dependencies] flash-attn = [{ requirement = "torch", match-runtime = true }] diff --git a/tests/test_renderer_client.py b/tests/test_renderer_client.py index cba7f86df..9608c50a5 100644 --- a/tests/test_renderer_client.py +++ b/tests/test_renderer_client.py @@ -5,7 +5,7 @@ import verifiers as vf from renderers import RendererPool -from renderers.base import ParsedResponse, create_renderer +from renderers.base import ParsedResponse, RenderedTokens, create_renderer from verifiers.clients.renderer_client import ( RendererClient, _attach_tool_call_names, @@ -280,11 +280,13 @@ def bridge_to_next_turn( stop_idx = len(self.bridge_base) - 1 trailing = list(self.bridge_base[stop_idx + 1 :]) extension = list(self.bridge_full[len(self.bridge_base) :]) - return ( - list(previous_prompt_ids) - + list(previous_completion_ids) - + trailing - + extension + return RenderedTokens( + token_ids=( + list(previous_prompt_ids) + + list(previous_completion_ids) + + trailing + + extension + ) ) def parse_response(self, token_ids): @@ -345,7 +347,8 @@ async def test_get_incremental_prompt_ids_matches_tool_tail_without_rerendering_ renderer=renderer, prompt=prompt, state=state, tools=None ) - assert result == [1, 2, 3, 99, 30, 40] + assert result is not None + assert result.token_ids == [1, 2, 3, 99, 30, 40] # The bridge stitches over the completion without re-rendering it — # one bridge call, zero render_ids calls (older diff-based bridges # called render_ids twice). @@ -387,7 +390,8 @@ async def test_get_incremental_prompt_ids_accepts_tool_then_user_tail(): renderer=renderer, prompt=prompt, state=state, tools=None ) - assert result == [1, 2, 3, 99, 40, 50] + assert result is not None + assert result.token_ids == [1, 2, 3, 99, 40, 50] @pytest.mark.asyncio @@ -446,7 +450,8 @@ async def test_get_incremental_prompt_ids_accepts_multimodal_tool_user_tail(): renderer=renderer, prompt=prompt, state=state, tools=None ) - assert result == [1, 2, 3, 99, 40, 50] + assert result is not None + assert result.token_ids == [1, 2, 3, 99, 40, 50] # ── Parity across real renderers: truncated most-recent step ────────── @@ -478,7 +483,7 @@ async def test_get_incremental_prompt_ids_accepts_multimodal_tool_user_tail(): "auto", id="nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", ), - pytest.param("openai/gpt-oss-20b", "gpt_oss", id="openai/gpt-oss-20b"), + pytest.param("openai/gpt-oss-20b", "gpt-oss", id="openai/gpt-oss-20b"), ] @@ -552,11 +557,12 @@ async def test_get_incremental_prompt_ids_bridges_over_truncated_step( prefix = list(prev_prompt_ids) + list(prev_completion_ids) assert result is not None, f"{model_id}: bridge returned None on truncated anchor" - assert result[: len(prefix)] == prefix, ( + result_ids = result.token_ids + assert result_ids[: len(prefix)] == prefix, ( f"{model_id}: bridge result does not prefix-preserve " f"prev_prompt + prev_completion" ) - assert len(result) > len(prefix), ( + assert len(result_ids) > len(prefix), ( f"{model_id}: bridge produced no tail tokens for the new user turn" ) diff --git a/tests/test_save_utils.py b/tests/test_save_utils.py index 6ad7d25f2..92f9cd977 100644 --- a/tests/test_save_utils.py +++ b/tests/test_save_utils.py @@ -27,6 +27,7 @@ ) from verifiers.utils.save_utils import ( GenerateOutputsBuilder, + _delta_intermediate_mm_data, extract_usage_tokens, load_outputs, make_serializable, @@ -897,3 +898,257 @@ def test_correctness_threshold_boundary(self): ) pass_at_k, _ = m.compute() assert pass_at_k["1"] == pytest.approx(0.5) + + +class TestDeltaIntermediateMmData: + """Verify per-step delta encoding of trajectory mm_data sidecars. + + Renderer bridge_to_next_turn emits cumulative mm_data on every + step. The transport-layer delta strips items whose mm_hash already + appeared in the prior step, so the per-window TrainingSample + assembler can recover its window's images by unioning step-deltas. + """ + + @staticmethod + def _mm(*hashes: str): + """Build a renderers.MultiModalData with one image item per hash.""" + from renderers.base import MultiModalData, PlaceholderRange + + return MultiModalData( + mm_hashes={"image": list(hashes)}, + mm_placeholders={ + "image": [ + PlaceholderRange(offset=i * 10, length=4) + for i in range(len(hashes)) + ] + }, + mm_items={"image": [{"pixel_values": f"px-{h}"} for h in hashes]}, + ) + + def _step(self, mm): + return {"tokens": {"multi_modal_data": mm}} + + def test_none_and_single_step_passthrough(self): + assert _delta_intermediate_mm_data(None) is None + assert _delta_intermediate_mm_data([]) == [] + only = [self._step(self._mm("A"))] + assert _delta_intermediate_mm_data(only) is only + + def test_linear_extension_keeps_only_new_items_per_step(self): + traj = [ + self._step(self._mm("A")), + self._step(self._mm("A", "B")), + self._step(self._mm("A", "B", "C")), + ] + out = _delta_intermediate_mm_data(traj) + + assert out[0]["tokens"]["multi_modal_data"].mm_hashes == {"image": ["A"]} + assert out[1]["tokens"]["multi_modal_data"].mm_hashes == {"image": ["B"]} + assert out[2]["tokens"]["multi_modal_data"].mm_hashes == {"image": ["C"]} + # Items and placeholders are reindexed in lockstep with hashes. + assert out[1]["tokens"]["multi_modal_data"].mm_items["image"] == [ + {"pixel_values": "px-B"} + ] + assert ( + out[2]["tokens"]["multi_modal_data"].mm_placeholders["image"][0].offset + == 20 + ) + + def test_compaction_two_training_samples_assemble_correctly(self): + """Rollout with one compaction event → two TrainingSamples. + + Models the prime-rl compaction flow: a single rollout produces + multiple ``TrainingSample`` objects, one per compaction window. + The pre-compaction sample's images are no longer in the + post-compaction step's cumulative ``mm_data`` — the previous + "keep last" strategy would have silently dropped them. With + delta encoding, each per-window assembler recovers exactly the + images its tokens reference: no leakage in either direction. + """ + from renderers.base import MultiModalData, PlaceholderRange + + def step(*hashes: str, offsets: list[int]): + return { + "tokens": { + "multi_modal_data": MultiModalData( + mm_hashes={"image": list(hashes)}, + mm_placeholders={ + "image": [ + PlaceholderRange(offset=o, length=4) for o in offsets + ] + }, + mm_items={ + "image": [{"pixel_values": f"px-{h}"} for h in hashes] + }, + ) + } + } + + # Turn 1: image A. Cumulative {A}. + # Turn 2: image B. Cumulative {A, B}. + # ── compaction event: turns 1+2 summarized in text, images dropped ── + # Turn 3: image C. Cumulative {C} (offsets reset against the + # post-compaction prompt). + # Turn 4: image D. Cumulative {C, D}. + traj = [ + step("A", offsets=[10]), + step("A", "B", offsets=[10, 50]), + step("C", offsets=[8]), + step("C", "D", offsets=[8, 40]), + ] + out = _delta_intermediate_mm_data(traj) + + # Per-step deltas keep only what's new since the immediately prior step. + deltas = [s["tokens"]["multi_modal_data"].mm_hashes for s in out] + assert deltas == [ + {"image": ["A"]}, + {"image": ["B"]}, + {"image": ["C"]}, + {"image": ["D"]}, + ] + + def assemble(steps): + hashes: list[str] = [] + items: list[dict] = [] + placeholders: list[PlaceholderRange] = [] + for s in steps: + mm = s["tokens"]["multi_modal_data"] + hashes += mm.mm_hashes.get("image", []) + items += mm.mm_items.get("image", []) + placeholders += mm.mm_placeholders.get("image", []) + return hashes, items, placeholders + + ts1_hashes, ts1_items, ts1_phs = assemble(out[0:2]) # pre-compaction + ts2_hashes, ts2_items, ts2_phs = assemble(out[2:4]) # post-compaction + + assert ts1_hashes == ["A", "B"] + assert ts2_hashes == ["C", "D"] + # The invariant the previous "keep last" broke: pre-compaction TS + # does not see post-compaction images, and vice versa. + assert set(ts1_hashes).isdisjoint(set(ts2_hashes)) + + # Items / placeholders are reindexed lock-step with hashes (no + # off-by-one or cross-contamination during reindex). + assert ts1_items == [{"pixel_values": "px-A"}, {"pixel_values": "px-B"}] + assert ts2_items == [{"pixel_values": "px-C"}, {"pixel_values": "px-D"}] + + # Placeholder offsets travel verbatim per step; the assembler is + # responsible for shifting them into each window's local frame. + assert [p.offset for p in ts1_phs] == [10, 50] + assert [p.offset for p in ts2_phs] == [8, 40] + + def test_same_image_rendered_in_two_turns_uses_multiset_diff(self): + """Same image hash appearing N times must keep the right N-prior occurrences. + + The renderer doesn't dedupe by hash: ``emit_image`` appends to + the parallel lists every time an image content part is rendered. + So if image A is shown in turn 1 *and* turn 3, the cumulative + ``mm_hashes`` is ``["A", "A"]`` with two distinct placeholder + offsets, and ``mm_items`` is ``[pixA, pixA]`` (literally the + same payload twice). Both placeholder runs need their own item + — set-based diff would drop both as "already seen" and orphan + the second placeholder. Multiset diff drops only the first. + """ + from renderers.base import MultiModalData, PlaceholderRange + + def step(hashes, offsets): + return { + "tokens": { + "multi_modal_data": MultiModalData( + mm_hashes={"image": list(hashes)}, + mm_placeholders={ + "image": [ + PlaceholderRange(offset=o, length=4) for o in offsets + ] + }, + mm_items={ + "image": [{"pixel_values": f"px-{h}"} for h in hashes] + }, + ) + } + } + + # Turn 1: image A at offset 10. Cumulative ["A"]. + # Turn 2: no image. Cumulative unchanged ["A"]. + # Turn 3: image A re-rendered at offset 200. Cumulative ["A", "A"]. + traj = [ + step(["A"], offsets=[10]), + step(["A"], offsets=[10]), + step(["A", "A"], offsets=[10, 200]), + ] + out = _delta_intermediate_mm_data(traj) + + # Step 0 keeps everything (no prior). + assert out[0]["tokens"]["multi_modal_data"].mm_hashes == {"image": ["A"]} + assert [ + p.offset + for p in out[0]["tokens"]["multi_modal_data"].mm_placeholders["image"] + ] == [10] + + # Step 1 introduced no new image (cumulative unchanged). + assert out[1]["tokens"]["multi_modal_data"].mm_hashes == {"image": []} + + # Step 2: prior was ["A"], current is ["A", "A"]. Multiset budget + # consumes the first A; the *second* A (the new one at offset + # 200) survives the diff with its pixel_values intact. Set-based + # diff would have produced []. + step2_mm = out[2]["tokens"]["multi_modal_data"] + assert step2_mm.mm_hashes == {"image": ["A"]} + assert step2_mm.mm_items == {"image": [{"pixel_values": "px-A"}]} + assert [p.offset for p in step2_mm.mm_placeholders["image"]] == [200] + + # End-to-end: assembling the single TrainingSample (no + # compaction) recovers both placeholder runs with matching + # pixel_values, so the trainer can satisfy both image-pad + # token runs in the prompt. + all_hashes: list[str] = [] + all_phs: list[PlaceholderRange] = [] + for s in out: + mm = s["tokens"]["multi_modal_data"] + all_hashes += mm.mm_hashes.get("image", []) + all_phs += mm.mm_placeholders.get("image", []) + assert all_hashes == ["A", "A"] + assert [p.offset for p in all_phs] == [10, 200] + + def test_image_reintroduction_after_compaction(self): + """A hash dropped at compaction and re-rendered later is re-transmitted. + + The delta is computed against the *immediately prior step's* + cumulative, not a global seen-set. If image A appears in turn + 1, is compacted away (step 2's cumulative is empty), and is + re-rendered in turn 3, A shows up in step 0's delta *and* step + 2's delta — necessary so the post-compaction TrainingSample + also receives A's bytes. + """ + traj = [ + self._step(self._mm("A")), + self._step(self._mm()), + self._step(self._mm("A")), + ] + out = _delta_intermediate_mm_data(traj) + + assert out[0]["tokens"]["multi_modal_data"].mm_hashes == {"image": ["A"]} + assert out[1]["tokens"]["multi_modal_data"].mm_hashes == {"image": []} + # A re-emerges in step 2's delta — its absence from step 1's + # cumulative means it counts as "new" again. + assert out[2]["tokens"]["multi_modal_data"].mm_hashes == {"image": ["A"]} + + def test_steps_with_no_new_items_collapse_to_empty_delta(self): + # Step 2's cumulative equals step 1's — no new items. + traj = [ + self._step(self._mm("A", "B")), + self._step(self._mm("A", "B")), + self._step(self._mm("A", "B", "C")), + ] + out = _delta_intermediate_mm_data(traj) + + assert out[1]["tokens"]["multi_modal_data"].mm_hashes == {"image": []} + assert out[1]["tokens"]["multi_modal_data"].mm_items == {"image": []} + assert out[2]["tokens"]["multi_modal_data"].mm_hashes == {"image": ["C"]} + + def test_non_mapping_steps_pass_through(self): + traj = [self._step(self._mm("A")), "not-a-dict", self._step(self._mm("A", "B"))] + out = _delta_intermediate_mm_data(traj) + assert out[1] == "not-a-dict" + # Delta of step 2 still computed against step 0 (last seen cumulative). + assert out[2]["tokens"]["multi_modal_data"].mm_hashes == {"image": ["B"]} diff --git a/uv.lock b/uv.lock index 2fbe0a6ab..6c2e16ec5 100644 --- a/uv.lock +++ b/uv.lock @@ -4845,8 +4845,8 @@ wheels = [ [[package]] name = "renderers" -version = "0.1.8" -source = { git = "https://github.com/PrimeIntellect-ai/renderers.git?rev=fe67f9f#fe67f9f16412074c3207ce69ba1763b97958a9db" } +version = "0.1.8.dev0" +source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jinja2" }, { name = "numpy" }, @@ -4855,6 +4855,10 @@ dependencies = [ { name = "tiktoken" }, { name = "transformers" }, ] +sdist = { url = "https://files.pythonhosted.org/packages/50/de/a445036157af3367c6a962c13333427c83c08926934c541886eb87f9dcdf/renderers-0.1.8.dev0.tar.gz", hash = "sha256:71eef7bfa3d3f5849ba070d38cd89a1f6387ca7710824f2e50d8c05c9b1048b9", size = 210667, upload-time = "2026-05-12T17:48:45.352Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e7/33/936a38c7f20fbe096b751842ffc6ef254c9eb2223153aa860a122ce9a834/renderers-0.1.8.dev0-py3-none-any.whl", hash = "sha256:09bb35233f67599519c0ff6edfad469f0836a55a6b78e039cd8e7b5e527bdcb3", size = 98617, upload-time = "2026-05-12T17:48:44.222Z" }, +] [[package]] name = "requests" @@ -6209,7 +6213,7 @@ requires-dist = [ { name = "pyzmq", specifier = ">=27.1.0" }, { name = "reasoning-gym", marker = "extra == 'rg'" }, { name = "regex", specifier = "<2026.4.4" }, - { name = "renderers", marker = "extra == 'renderers'", git = "https://github.com/PrimeIntellect-ai/renderers.git?rev=fe67f9f" }, + { name = "renderers", marker = "extra == 'renderers'", specifier = ">=0.1.8.dev0" }, { name = "requests" }, { name = "requests", marker = "extra == 'rl'" }, { name = "rich" }, @@ -6242,7 +6246,7 @@ dev = [ { name = "pytest-xdist", specifier = ">=3.8.0" }, { name = "python-dotenv", specifier = ">=1.0.0" }, { name = "reasoning-gym" }, - { name = "renderers", git = "https://github.com/PrimeIntellect-ai/renderers.git?rev=fe67f9f" }, + { name = "renderers", specifier = ">=0.1.8.dev0" }, { name = "ruff" }, { name = "stagehand", specifier = ">=3.0.0" }, { name = "textarena" }, diff --git a/verifiers/clients/renderer_client.py b/verifiers/clients/renderer_client.py index d793a3b67..ad7644357 100644 --- a/verifiers/clients/renderer_client.py +++ b/verifiers/clients/renderer_client.py @@ -20,10 +20,13 @@ from renderers import Message as RendererMessage from renderers import ( + MultimodalRenderer, + RenderedTokens, Renderer, RendererPool, ToolSpec, create_renderer_pool, + is_multimodal, ) from renderers import ToolCall as RendererToolCall from renderers import ToolCallFunction @@ -94,15 +97,15 @@ def _record_bridge(success: bool) -> None: # ── Helpers ───────────────────────────────────────────────────────── -async def _run_with_renderer(renderer: Renderer | RendererPool, fn): - if isinstance(renderer, RendererPool): - - def _work(): - with renderer.checkout() as r: - return fn(r) +async def _maybe_offload(renderer: Renderer | RendererPool, fn): + """Run sync renderer work on a thread iff ``renderer`` is a pool. - return await asyncio.to_thread(_work) - return fn(renderer) + Pool methods can block on the internal queue/lock; we offload to keep + the event loop responsive. A bare ``Renderer`` runs inline. + """ + if isinstance(renderer, RendererPool): + return await asyncio.to_thread(fn) + return fn() def _get_value(obj: Any, key: str, default: Any = None) -> Any: @@ -295,6 +298,28 @@ def _step_token_ids(step: Any) -> tuple[list[int], list[int]] | None: return list(prompt_ids), list(completion_ids) +def _step_multi_modal_data(step: Any): + """Recover the previous turn's ``MultiModalData`` for bridging. + + Mirrors :func:`_step_token_ids`: prefer ``step.tokens.multi_modal_data`` + (post-parse_response_tokens), fall back to ``step.response.message.tokens``. + Returns ``None`` when no multimodal sidecar was emitted (text-only + rollouts) — the bridge handles that branch transparently. + """ + tokens = _get_value(step, "tokens") + if tokens is not None: + mm = _get_value(tokens, "multi_modal_data") + if mm is not None: + return mm + + response = _get_value(step, "response") + message = _get_value(response, "message") + raw_tokens = _get_value(message, "tokens") + if raw_tokens is None: + return None + return _get_value(raw_tokens, "multi_modal_data") + + def _step_rendered_messages(step: Any) -> list[RendererMessage]: prompt = list(_get_value(step, "prompt", []) or []) completion = list(_get_value(step, "completion", []) or []) @@ -309,7 +334,13 @@ async def _get_incremental_prompt_ids( prompt: list[RendererMessage], state: Any, tools: list[ToolSpec] | None, -) -> list[int] | None: +) -> "RenderedTokens | None": + """Return the bridged prompt for the next turn as ``RenderedTokens``. + + Returns ``None`` when no prior trajectory step lines up with the new + prompt's prefix or the renderer's ``bridge_to_next_turn`` can't extend + — both cases fall back to a full re-render in :func:`generate`. + """ if not state: return None @@ -342,15 +373,32 @@ async def _get_incremental_prompt_ids( continue previous_prompt_ids, previous_completion_ids = token_ids - bridged = await _run_with_renderer( - renderer, - lambda r: r.bridge_to_next_turn( + previous_mm_data = _step_multi_modal_data(step) + # Multimodal renderers' bridge accepts ``previous_multi_modal_data`` + # so earlier-turn images carry forward into the new prompt's + # ``mm_placeholders``. Without that carry-forward, vLLM sees + # placeholder counts that don't match the combined token sequence + # and silently falls back to hash-cache lookup (or errors). + # Text-only renderers' bridge signature doesn't include that + # kwarg. ``is_multimodal`` is type-cached so this dispatch is a + # dict lookup, not a runtime_checkable Protocol walk. + if is_multimodal(renderer): + mm_renderer = cast(MultimodalRenderer, renderer) + bridge = lambda: mm_renderer.bridge_to_next_turn( # noqa: E731 previous_prompt_ids, previous_completion_ids, tail, tools=tools, - ), - ) + previous_multi_modal_data=previous_mm_data, + ) + else: + bridge = lambda: renderer.bridge_to_next_turn( # noqa: E731 + previous_prompt_ids, + previous_completion_ids, + tail, + tools=tools, + ) + bridged = await _maybe_offload(renderer, bridge) _record_bridge(success=bridged is not None) return bridged @@ -514,12 +562,21 @@ async def get_native_response( if args.get("prompt_logprobs"): sampling_params["prompt_logprobs"] = 1 - prompt_ids = await _get_incremental_prompt_ids( + bridged = await _get_incremental_prompt_ids( renderer=renderer, prompt=prompt, state=kwargs.get("state"), tools=tools, ) + # ``bridged`` is RenderedTokens | None. Unpack token_ids + mm_data + # so multimodal renderers thread per-image features through to + # /inference/v1/generate without re-rendering the whole turn. + if bridged is not None: + prompt_ids = bridged.token_ids + multi_modal_data = bridged.multi_modal_data + else: + prompt_ids = None + multi_modal_data = None return await generate( client=self.client, @@ -527,6 +584,7 @@ async def get_native_response( messages=prompt, model=model, prompt_ids=prompt_ids, + multi_modal_data=multi_modal_data, tools=tools, sampling_params=sampling_params, cache_salt=args.get("cache_salt") @@ -580,6 +638,7 @@ async def from_native_response(self, response: dict[str, Any]) -> Response: completion_mask=[1] * len(completion_ids), completion_logprobs=completion_logprobs, routed_experts=response.get("routed_experts"), + multi_modal_data=response.get("multi_modal_data"), ) # /inference/v1/generate doesn't return usage; reconstruct from tokens. diff --git a/verifiers/types.py b/verifiers/types.py index f0dc4ac55..25a4b8732 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -179,6 +179,12 @@ class ResponseTokens(CustomBaseModel): completion_mask: list[int] completion_logprobs: list[float] routed_experts: list[list[list[int]]] | None = None # [seq_len, layers, topk] + # Renderer-emitted multimodal sidecar (renderers.base.MultiModalData) + # carrying processed pixel_values / placeholder ranges per modality. + # Populated by the renderer client when the rollout went through a + # multimodal-aware renderer; ``None`` otherwise. Stored as ``Any`` to + # avoid a hard import dependency on ``renderers`` at this layer. + multi_modal_data: Any | None = None FinishReason = Literal["stop", "length", "tool_calls"] | None @@ -216,6 +222,11 @@ class TrajectoryStepTokens(TypedDict): overlong_prompt: bool is_truncated: bool routed_experts: list[list[list[int]]] | None # [seq_len, layers, topk] + # Renderer-emitted multimodal sidecar (renderers.base.MultiModalData) + # carrying processed pixel_values / placeholder ranges per modality. + # ``NotRequired`` because text-only rollouts (and non-renderer client + # types) never populate it. + multi_modal_data: NotRequired[Any] class TokenUsage(TypedDict): diff --git a/verifiers/utils/response_utils.py b/verifiers/utils/response_utils.py index 4e7c8b480..9bbb38ad8 100644 --- a/verifiers/utils/response_utils.py +++ b/verifiers/utils/response_utils.py @@ -35,6 +35,7 @@ async def parse_response_tokens( completion_mask = tokens.completion_mask completion_logprobs = tokens.completion_logprobs routed_experts = tokens.routed_experts + multi_modal_data = tokens.multi_modal_data if max_seq_len is not None: prompt_len = len(prompt_ids) @@ -61,7 +62,7 @@ async def parse_response_tokens( overlong_prompt = False is_truncated = False - return TrajectoryStepTokens( + out = TrajectoryStepTokens( prompt_ids=prompt_ids, prompt_mask=prompt_mask, completion_ids=completion_ids, @@ -71,3 +72,10 @@ async def parse_response_tokens( is_truncated=is_truncated, routed_experts=routed_experts, ) + if multi_modal_data is not None: + out["multi_modal_data"] = multi_modal_data + # Move (not copy) the sidecar to its canonical home on the parsed + # step. Leaving it on ``response.message.tokens`` too means every + # downstream pass (msgpack, save) has to dedupe the duplicate. + tokens.multi_modal_data = None + return out diff --git a/verifiers/utils/save_utils.py b/verifiers/utils/save_utils.py index d9aa889e7..78e4490f9 100644 --- a/verifiers/utils/save_utils.py +++ b/verifiers/utils/save_utils.py @@ -53,6 +53,15 @@ def is_json_serializable(value: object) -> bool: Returns True for JSON primitives, lists/dicts of primitives, Pydantic models, datetime/date, Path, and exceptions. + + Note: renderer multimodal sidecars (``MultiModalData``, + ``PlaceholderRange``, numpy arrays) intentionally return False + here — they are not JSON-native and ``make_serializable`` has no + handler for them (it would stringify to ``"array(...)"`` garbage). + They reach the trainer via msgpack with a custom encoder, and the + JSONL save path excludes the carrying column (``trajectory``) at + the orchestrator boundary, so this gate is bypassed for that + column in ``state_to_output``. """ if value is None: return True @@ -263,6 +272,27 @@ def state_to_output( # add state columns (must be serializable) for col in state_columns or []: value = state.get(col) + if col == "trajectory": + # Renderer multimodal rollouts accumulate mm_data on every step + # (bridge_to_next_turn merges previous_multi_modal_data into the + # new turn). Naively shipping cumulative mm_data on every step + # duplicates every image O(N²) bytes for an N-turn rollout. + # Replace each step's cumulative mm_data with its delta against + # the prior step (items keyed by mm_hash) so any per-window + # TrainingSample assembler — including compaction, where a + # single rollout produces multiple samples and the pre-compaction + # sample's images aren't in the final cumulative set — can + # recover its window's images by unioning step-deltas. + value = _delta_intermediate_mm_data(value) + # Trajectory may carry numpy arrays / renderer dataclasses on + # ``tokens.multi_modal_data`` — these are not JSON-native and + # ``is_json_serializable`` would (correctly) reject them. They + # are transported to the trainer via msgpack with a custom + # encoder, and the JSONL save path excludes ``trajectory`` at + # the orchestrator boundary, so the JSON gate doesn't apply + # here. + output[col] = value + continue if not is_json_serializable(value): raise ValueError( f"state_columns value for '{col}' is not JSON-serializable: " @@ -273,6 +303,166 @@ def state_to_output( return output +def _delta_intermediate_mm_data(trajectory: object) -> object: + """Replace each step's cumulative ``multi_modal_data`` with its delta. + + The renderer's ``bridge_to_next_turn`` merges ``previous_multi_modal_data`` + into the new turn, so each step carries the cumulative set of every + image rendered so far in the trajectory. For each step after the + first, drop items whose ``mm_hash`` already appeared in the immediately + prior step. The first step is left as-is (all items are new). + + ``parse_response_tokens`` moves the sidecar onto ``step["tokens"]`` + and clears the duplicate on ``response.message.tokens``, so only one + location needs rewriting here. + + Each unique image's bytes travel exactly once across the trajectory + (no O(N²) duplication). Per-window ``TrainingSample`` assemblers — + including compaction, where a single rollout produces multiple + samples and the pre-compaction sample's images aren't in the final + cumulative set — recover any window's images by unioning the + step-deltas in that window. Placeholder offsets stay relative to the + step's own cumulative token sequence; the assembler shifts them. + + Returns a new list of step dicts (shallow copies for rewritten + entries) so the input state isn't mutated. Non-list inputs and + empty / single-step trajectories pass through unchanged. + """ + if not isinstance(trajectory, list) or len(trajectory) <= 1: + return trajectory + + out: list = [] + prior_hashes: dict[str, list[str]] = {} + + for idx, raw_step in enumerate(trajectory): + if not isinstance(raw_step, Mapping): + out.append(raw_step) + continue + step = cast(Mapping[str, Any], raw_step) + tokens = step.get("tokens") + step_mm = ( + tokens.get("multi_modal_data") if isinstance(tokens, Mapping) else None + ) + current_hashes = _read_mm_hashes(step_mm) + + if idx == 0: + out.append(step) + prior_hashes = current_hashes + continue + + if isinstance(tokens, Mapping) and step_mm is not None: + delta = _diff_mm_data(step_mm, prior_hashes) + if delta is not step_mm: + new_step: dict[str, Any] = dict(step) + new_step["tokens"] = {**tokens, "multi_modal_data": delta} + out.append(new_step) + prior_hashes = current_hashes + continue + + out.append(step) + prior_hashes = current_hashes + return out + + +def _read_mm_hashes(mm: object) -> dict[str, list[str]]: + """Per-modality list of ``mm_hashes`` from a ``MultiModalData``-like object. + + Returns a list (not a set) so multiplicity is preserved: the same + image rendered N times appears N times in the list, with each + occurrence corresponding to a separate placeholder run in the token + stream. The diff uses multiset semantics so each prior occurrence + "consumes" one matching current occurrence and the *remaining* + current occurrences are kept as new. + """ + if mm is None: + return {} + hashes = getattr(mm, "mm_hashes", None) + if not isinstance(hashes, dict): + return {} + return { + modality: list(hs) for modality, hs in hashes.items() if isinstance(hs, list) + } + + +def _diff_mm_data(mm: object, prior_hashes: dict[str, list[str]]) -> object: + """Return ``mm`` with items the prior step already covered removed. + + Uses **multiset** semantics: each prior-step occurrence of a given + hash consumes one matching current-step occurrence, and only the + *surplus* current occurrences are kept. Necessary because the + renderer doesn't dedupe by hash — if the same image is rendered in + two turns, cumulative ``mm_hashes`` contains the hash twice (each + with its own placeholder offset), and both occurrences need their + ``pixel_values`` to reach the trainer. Set-based diff would drop + both as "already seen" and leave the second placeholder run + orphaned. + + Returns the input unchanged if nothing is dropped (cheap fast-path + for steps that introduced no new items). Returns a new instance of + the same class with the delta items otherwise. Mirrors the + ``MultiModalData`` shape: three parallel per-modality lists + (``mm_hashes``, ``mm_items``, ``mm_placeholders``) reindexed by the + surviving item positions. + """ + hashes = getattr(mm, "mm_hashes", None) + items = getattr(mm, "mm_items", None) + placeholders = getattr(mm, "mm_placeholders", None) + if ( + not isinstance(hashes, dict) + or not isinstance(items, dict) + or not isinstance(placeholders, dict) + ): + return mm + + new_hashes: dict[str, list[str]] = {} + new_items: dict[str, list[Any]] = {} + new_placeholders: dict[str, list[Any]] = {} + any_dropped = False + + for modality, mod_hashes in hashes.items(): + if not isinstance(mod_hashes, list): + new_hashes[modality] = mod_hashes + new_items[modality] = items.get(modality, []) + new_placeholders[modality] = placeholders.get(modality, []) + continue + mod_items = items.get(modality) or [] + mod_placeholders = placeholders.get(modality) or [] + # Multiset budget: each prior occurrence of a hash can consume + # one matching current occurrence. Walk current left-to-right + # and keep an item only after the budget for its hash is gone. + remaining: dict[str, int] = {} + for h in prior_hashes.get(modality, []): + remaining[h] = remaining.get(h, 0) + 1 + keep_idx: list[int] = [] + for i, h in enumerate(mod_hashes): + if remaining.get(h, 0) > 0: + remaining[h] -= 1 + else: + keep_idx.append(i) + if len(keep_idx) != len(mod_hashes): + any_dropped = True + # Trust the renderer's parallel-list invariant + # (``emit_image`` appends to all three together). If it's + # broken on input, indexing fails loudly here rather than + # silently producing mismatched output lists. + new_hashes[modality] = [mod_hashes[i] for i in keep_idx] + new_items[modality] = [mod_items[i] for i in keep_idx] + new_placeholders[modality] = [mod_placeholders[i] for i in keep_idx] + + if not any_dropped: + return mm + + cls = type(mm) + try: + return cls( + mm_hashes=new_hashes, + mm_placeholders=new_placeholders, + mm_items=new_items, + ) + except TypeError: + return mm + + def serialize_timing(timing: object) -> dict[str, Any]: model_dump = getattr(timing, "model_dump", None) if callable(model_dump): diff --git a/verifiers/utils/serve_utils.py b/verifiers/utils/serve_utils.py index 9a1aa26ce..67891e5e3 100644 --- a/verifiers/utils/serve_utils.py +++ b/verifiers/utils/serve_utils.py @@ -1,8 +1,11 @@ +import dataclasses import logging import socket +import sys from datetime import date, datetime from enum import Enum from pathlib import Path +from typing import Any from uuid import UUID import numpy as np @@ -10,6 +13,20 @@ logger = logging.getLogger(__name__) +# Marker key inside the encoded payload so the decoder can recognize a +# tensor round-trip without disturbing arbitrary user dicts. +TENSOR_TAG = "__torch_tensor__" + + +def _encode_array_like(arr: "np.ndarray") -> dict: + return { + TENSOR_TAG: True, + "dtype": str(arr.dtype), + "shape": list(arr.shape), + "data": arr.tobytes(), + } + + def msgpack_encoder(obj): """ Custom encoder for non-standard types. @@ -18,7 +35,11 @@ def msgpack_encoder(obj): is ONLY called for types msgpack doesn't recognize. This avoids the massive performance penalty of recursing through millions of tokens in Python. - Handles: Path, UUID, Enum, datetime, Pydantic models, numpy scalars. + Handles: Path, UUID, Enum, datetime, Pydantic models, numpy scalars, + numpy arrays, torch tensors, and dataclasses (e.g. renderers' + ``MultiModalData`` / ``PlaceholderRange``). Tensors and ndarrays are + encoded as ``{__torch_tensor__: True, dtype, shape, data}`` so the + receiving side can rehydrate them via ``decode_tensor_payload``. Does NOT handle: lists, dicts, basic types (msgpack does this natively in C). """ if isinstance(obj, (Path, UUID)): @@ -29,13 +50,73 @@ def msgpack_encoder(obj): return obj.isoformat() elif isinstance(obj, (np.integer, np.floating)): return obj.item() + elif isinstance(obj, np.ndarray): + return _encode_array_like(obj) + elif (_torch := sys.modules.get("torch")) is not None and isinstance( + obj, _torch.Tensor + ): + # Read torch off ``sys.modules`` instead of importing: text-only + # consumers never load torch, so this branch stays cold for + # them. Any tensor reaching the encoder implies torch is + # already in the process (you can't construct one otherwise). + # ``isinstance`` is precise — the previous string-module check + # also matched non-tensor torch objects (``torch.dtype``, + # ``torch.device``, ``torchvision.*``) and crashed on + # ``.detach()``. + arr = obj.detach().cpu().contiguous().numpy() + return _encode_array_like(arr) elif hasattr(obj, "model_dump"): return obj.model_dump() + elif dataclasses.is_dataclass(obj) and not isinstance(obj, type): + return dataclasses.asdict(obj) else: # raise on unknown types to make issues visible raise TypeError(f"Object of type {type(obj)} is not msgpack serializable") +def decode_tensor_payload(obj: Any, *, to_torch: bool = True): + """Rehydrate a tensor encoded by :func:`msgpack_encoder`. + + Accepts either the encoded dict shape (``{__torch_tensor__: True, + dtype, shape, data}``) or an already-rehydrated tensor/ndarray and + returns a torch tensor (or numpy ndarray if ``to_torch=False``). + """ + if obj is None: + return None + if isinstance(obj, dict) and obj.get(TENSOR_TAG): + arr = np.frombuffer(obj["data"], dtype=np.dtype(obj["dtype"])).reshape( + obj["shape"] + ) + if to_torch: + # importlib (not ``import torch``) so static type checkers in + # downstream consumers without torch installed don't fail on + # unresolved-import. Torch is a soft runtime dep here: callers + # that pass ``to_torch=True`` are expected to have it. + import importlib + + torch = importlib.import_module("torch") + return torch.from_numpy(arr.copy()) + return arr.copy() + # Already a tensor / ndarray — pass through. + return obj + + +def walk_decode_tensors(obj: Any, *, to_torch: bool = True): + """Recursively decode any tensor payloads inside nested dicts/lists. + + Used by the orchestrator after msgpack-decoding a multimodal sidecar + so downstream code sees real tensors without each consumer threading + the decode call manually. + """ + if isinstance(obj, dict): + if obj.get(TENSOR_TAG): + return decode_tensor_payload(obj, to_torch=to_torch) + return {k: walk_decode_tensors(v, to_torch=to_torch) for k, v in obj.items()} + if isinstance(obj, list): + return [walk_decode_tensors(v, to_torch=to_torch) for v in obj] + return obj + + def make_ipc_address(session_id: str, name: str) -> str: """Build an IPC address for inter-process communication.""" return f"ipc:///tmp/vf-{session_id}-{name.replace('/', '--')}"