From c1f51306bffef91c7b93126ec070aba0d861d174 Mon Sep 17 00:00:00 2001 From: txdadlab <191132262+txdadlab@users.noreply.github.com> Date: Wed, 3 Jun 2026 17:21:10 -0500 Subject: [PATCH 1/2] feat(mllm): auto-extract audio from video_url on omni models For multimodal omni models (those exposing a sound_encoder, e.g. Nemotron-H Nano Omni, Qwen2.5-Omni), a video_url is logically an A/V input. Previously vllm-mlx fed only the visual frames to the model: the audio track was silently dropped on both the _generate_native_video path (no audio kwarg ever reached the HF processor) and the fallback path (frame extraction never touched the video's audio stream). The model returned visually-grounded descriptions but never "heard" the video. This wires audio through the existing OpenAI-style content-block path: 1. Add extract_audio_from_video(video_path): probes for an audio stream with ffprobe and, if present, extracts a 16 kHz mono PCM WAV via ffmpeg into a temp file registered with _temp_manager (auto-cleaned alongside the other request temp files). 2. Set self._video_native_with_audio at load time as hasattr(self.model, "sound_encoder"). Decoupled from _video_native because some omni models (Nemotron-H Omni) don't expose video_token_id at config level and run through the frames-as-images fallback path. 3. In _translate_messages_for_native_video: handle audio/audio_url blocks explicitly, and auto-extract audio from any video_url when the message doesn't already carry an explicit audio block. (Explicit caller-provided audio wins.) 4. In _prepare_native_video_inputs: collect translated audio paths, pass audio= to self.processor(...), and forward sound_clips, input_features, feature_attention_mask, audio_feature_lengths, sound_feature_lengths, sound_attention_mask into gen_kwargs so the omni model's sound encoder gets fed alongside the visual stream. 5. In chat() and stream_chat() fallback paths: extract audio from video_url for omni models, merge into _msg_audio_inputs so the existing audio plumbing picks it up. No change for non-omni models. Notes - ffmpeg is invoked only after process_video_input has resolved the user-supplied source to a local path on disk, so raw user URLs are never passed to ffmpeg's URL-protocol demuxers (avoids SSRF via http://, rtsp://, etc.). - Auto-extraction is gated on the runtime presence of sound_encoder. Non-omni models are completely unaffected. - Videos without an audio track are detected by ffprobe and skipped silently. Manual repro (Nemotron-H Nano Omni nvfp4 on Apple Silicon, fixed by mlx-vlm #1279 for the audio path): curl /v1/chat/completions -d '{"model":"nemotron-omni", "messages":[{"role":"user","content":[ {"type":"text","text":"Transcribe verbatim what the speaker says, then name which on-screen graphic was shown when they said each sentence."}, {"type":"video_url","video_url":{"url":"data:video/mp4;base64,..."}} ]}]}' Before: model described visuals only. After: model returns transcript + per-sentence visual correlation; server log shows "Omni model detected: ... A/V fusion" and the chat-template counter reports both images and audios. Related: PR #352 attempted the same goal via an opt-in CLI flag on the fallback path only. This change is opt-out-by-omission instead (auto- detected per model), covers both code paths, and addresses the SSRF / temp-leak / per-message-audio-count concerns raised in #352's review. --- vllm_mlx/models/mllm.py | 232 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 225 insertions(+), 7 deletions(-) diff --git a/vllm_mlx/models/mllm.py b/vllm_mlx/models/mllm.py index 4c3275944..6cbd526d7 100644 --- a/vllm_mlx/models/mllm.py +++ b/vllm_mlx/models/mllm.py @@ -867,6 +867,74 @@ def process_audio_input(audio: str | dict) -> str: raise ValueError(f"Cannot process audio: {audio[:50]}...") +def _video_has_audio_track(video_path: str) -> bool: + """Return True if ffprobe finds an audio stream in the video.""" + import shutil + import subprocess + + if not shutil.which("ffprobe"): + return True # assume yes; extraction will fail loudly if not + try: + r = subprocess.run( + [ + "ffprobe", "-loglevel", "error", "-select_streams", "a", + "-show_entries", "stream=codec_type", "-of", "csv=p=0", + video_path, + ], + capture_output=True, timeout=30, text=True, + ) + return bool(r.stdout.strip()) + except (subprocess.SubprocessError, OSError): + return True + + +def extract_audio_from_video(video_path: str) -> str | None: + """Extract the audio track from a video file as 16 kHz mono WAV. + + Returns the path to the WAV (registered with the temp manager so it's + cleaned up automatically), or None if the video has no audio or ffmpeg + is unavailable. + """ + import os + import shutil + import subprocess + + if not shutil.which("ffmpeg"): + logger.warning( + "ffmpeg not found; cannot fuse audio from video_url. " + "Install ffmpeg to enable A/V fusion on omni models." + ) + return None + if not _video_has_audio_track(video_path): + return None + + fd, out_path = tempfile.mkstemp(suffix=".wav", prefix="vllmmlx_va_") + os.close(fd) + try: + r = subprocess.run( + [ + "ffmpeg", "-y", "-i", video_path, + "-vn", "-ac", "1", "-ar", "16000", + "-c:a", "pcm_s16le", out_path, + ], + stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, timeout=600, + ) + if r.returncode != 0 or os.path.getsize(out_path) == 0: + try: + os.unlink(out_path) + except OSError: + pass + return None + return _temp_manager.register(out_path) + except (subprocess.SubprocessError, OSError) as e: + logger.warning(f"Audio extraction from video failed: {e}") + try: + os.unlink(out_path) + except OSError: + pass + return None + + # Cache for base64 images to avoid re-saving the same image _base64_image_cache: dict[str, str] = {} # hash -> temp file path @@ -1132,6 +1200,7 @@ def __init__( self._draft_model = None self._loaded = False self._video_native = False + self._video_native_with_audio = False # Initialize MLLM prefix cache manager (with vision embedding caching) self._cache_manager: MLLMPrefixCacheManager | None = None @@ -1159,9 +1228,20 @@ def load(self) -> None: self._video_native = hasattr( self.model.config, "video_token_id" ) or hasattr(self.model.config, "video_token_index") + # Omni models expose a sound_encoder; for these, a video_url + # without a paired audio_url should auto-extract the video's + # audio track so the model can fuse A/V in one forward pass. + # Decoupled from _video_native because some omni models (e.g. + # Nemotron-H Omni) don't expose video_token_id at config level + # and run through the frames-as-images fallback path. + self._video_native_with_audio = hasattr(self.model, "sound_encoder") logger.info(f"MLLM loaded successfully: {self.model_name}") if self._video_native: logger.info("Native video pipeline enabled (temporal 3D conv + M-RoPE)") + if self._video_native_with_audio: + logger.info( + "Omni model detected: video_url will auto-extract audio for A/V fusion" + ) except ImportError: raise ImportError( @@ -1419,14 +1499,32 @@ def _prepare_native_video_inputs( native_messages, return_video_kwargs=True ) + # Collect audio paths emitted by the translation step + # (explicit audio_url, or auto-extracted from video_url for omni + # models). + audio_inputs: list[str] = [] + for nmsg in native_messages: + ncontent = nmsg.get("content", []) + if not isinstance(ncontent, list): + continue + for nitem in ncontent: + if isinstance(nitem, dict) and nitem.get("type") == "audio": + apath = nitem.get("audio") + if apath: + audio_inputs.append(apath) + # Process through HF processor to get input_ids, pixel_values, grid_thw - inputs = self.processor( - text=[text], - images=image_inputs, - videos=video_inputs, - padding=True, - return_tensors="pt", - ) + # and (for omni models) sound_clips / input_features. + processor_kwargs: dict = { + "text": [text], + "images": image_inputs, + "videos": video_inputs, + "padding": True, + "return_tensors": "pt", + } + if audio_inputs: + processor_kwargs["audio"] = audio_inputs + inputs = self.processor(**processor_kwargs) input_ids = mx.array(inputs["input_ids"]) pixel_values = inputs.get( @@ -1442,6 +1540,26 @@ def _prepare_native_video_inputs( if inputs.get("image_grid_thw", None) is not None: gen_kwargs["image_grid_thw"] = mx.array(inputs["image_grid_thw"]) + # Forward audio embeddings/clips from the processor so the omni + # model's sound encoder gets fed alongside the visual stream. + for audio_key in ( + "sound_clips", + "input_features", + "feature_attention_mask", + "audio_feature_lengths", + "sound_feature_lengths", + "sound_attention_mask", + ): + val = inputs.get(audio_key, None) + if val is not None: + gen_kwargs[audio_key] = val + if audio_inputs: + logger.info( + f"Native video: forwarding audio ({len(audio_inputs)} clip(s)) " + f"to omni model via " + f"{[k for k in gen_kwargs if k in ('sound_clips', 'input_features')]}" + ) + gen_kwargs["input_ids"] = input_ids gen_kwargs["pixel_values"] = pixel_values gen_kwargs["mask"] = mask @@ -1525,6 +1643,24 @@ def _translate_messages_for_native_video( translated.append({"role": role, "content": str(content)}) continue + # Pre-pass: does this message have an explicit audio_url/audio + # block? If so, we skip auto-extracting audio from a video_url to + # honor the caller's explicit choice. + has_explicit_audio = False + for item in content: + if hasattr(item, "model_dump"): + probe = item.model_dump(exclude_none=True) + elif hasattr(item, "dict"): + probe = {k: v for k, v in item.dict().items() if v is not None} + else: + probe = item + if ( + isinstance(probe, dict) + and probe.get("type", "") in ("audio", "audio_url") + ): + has_explicit_audio = True + break + new_content = [] for item in content: if hasattr(item, "model_dump"): @@ -1583,6 +1719,38 @@ def _translate_messages_for_native_video( "max_frames": video_max_frames, } ) + # For omni-capable models, pull the video's audio track + # alongside frames so the model can fuse A/V in one + # forward pass. We extract from the already-resolved local + # path (no raw user URL handed to ffmpeg → avoids URL- + # protocol SSRF via ffmpeg's network demuxers). + if ( + not has_explicit_audio + and getattr(self, "_video_native_with_audio", False) + ): + extracted = extract_audio_from_video(video_path) + if extracted is not None: + new_content.append( + {"type": "audio", "audio": extracted} + ) + + elif item_type in ("audio", "audio_url"): + if item_type == "audio_url": + aud_url = item.get("audio_url", {}) + if isinstance(aud_url, str): + audio_source = aud_url + elif isinstance(aud_url, dict): + audio_source = aud_url.get("url", "") + else: + continue + else: + audio_source = item.get("audio", item.get("url", "")) + + if not audio_source: + continue + + audio_path = process_audio_input(audio_source) + new_content.append({"type": "audio", "audio": audio_path}) else: new_content.append(item) @@ -1914,11 +2082,35 @@ def chat( # Fallback: extract frames and treat as individual images _msg_video_frame_counts: dict[int, int] = {} + _msg_extra_audio: dict[int, list[str]] = {} all_video_frames: list[str] = [] all_audio_inputs: list[str] = [] for msg_idx, vid_inputs in _msg_video_inputs.items(): total_frames = 0 + has_explicit_audio = bool(_msg_audio_inputs.get(msg_idx)) for vid_input in vid_inputs: + # For omni models, also extract the video's audio track so + # the model can fuse A/V in one forward pass. Skip if the + # caller supplied an explicit audio_url for the same message. + # We resolve the video to a local path first so the user's + # raw URL is never handed to ffmpeg directly (avoids URL- + # protocol exposure via ffmpeg's network demuxers). + if self._video_native_with_audio and not has_explicit_audio: + try: + video_path_for_audio = process_video_input(vid_input) + except Exception as exc: + logger.warning( + f"Could not resolve video for audio extraction: {exc}" + ) + video_path_for_audio = None + if video_path_for_audio: + extracted_audio = extract_audio_from_video( + video_path_for_audio + ) + if extracted_audio: + _msg_extra_audio.setdefault(msg_idx, []).append( + extracted_audio + ) frames = self._prepare_video( vid_input, fps=video_fps, max_frames=video_max_frames ) @@ -1927,6 +2119,11 @@ def chat( logger.info(f"Added {len(frames)} frames from video: {vid_input}") _msg_video_frame_counts[msg_idx] = total_frames + # Merge auto-extracted audio into the per-message audio map so the + # chat-template token-counting loop downstream sees the right count. + for msg_idx, extra in _msg_extra_audio.items(): + _msg_audio_inputs.setdefault(msg_idx, []).extend(extra) + for aud_inputs in _msg_audio_inputs.values(): all_audio_inputs.extend(aud_inputs) @@ -2270,11 +2467,29 @@ def stream_chat( # Fallback: frames as images _msg_video_frame_counts: dict[int, int] = {} + _msg_extra_audio: dict[int, list[str]] = {} all_video_frames: list[str] = [] all_audio_inputs: list[str] = [] for msg_idx, vid_inputs in _msg_video_inputs.items(): total_frames = 0 + has_explicit_audio = bool(_msg_audio_inputs.get(msg_idx)) for vid_input in vid_inputs: + if self._video_native_with_audio and not has_explicit_audio: + try: + video_path_for_audio = process_video_input(vid_input) + except Exception as exc: + logger.warning( + f"Could not resolve video for audio extraction: {exc}" + ) + video_path_for_audio = None + if video_path_for_audio: + extracted_audio = extract_audio_from_video( + video_path_for_audio + ) + if extracted_audio: + _msg_extra_audio.setdefault(msg_idx, []).append( + extracted_audio + ) frames = self._prepare_video( vid_input, fps=video_fps, max_frames=video_max_frames ) @@ -2283,6 +2498,9 @@ def stream_chat( logger.info(f"Added {len(frames)} frames from video: {vid_input}") _msg_video_frame_counts[msg_idx] = total_frames + for msg_idx, extra in _msg_extra_audio.items(): + _msg_audio_inputs.setdefault(msg_idx, []).extend(extra) + for aud_inputs in _msg_audio_inputs.values(): all_audio_inputs.extend(aud_inputs) From 9a7c58762de70688bc7bc45695d3ad244f5fc95c Mon Sep 17 00:00:00 2001 From: txdadlab <191132262+txdadlab@users.noreply.github.com> Date: Tue, 9 Jun 2026 11:54:56 -0500 Subject: [PATCH 2/2] test(mllm): cover A/V fusion routing + harden sound_encoder predicate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses code review feedback on #591. Adds regression coverage for the five behavioral surfaces the prior commit introduced, tightens the sound_encoder detection predicate, and removes a duplicate video resolution in the fallback paths. Code changes (vllm_mlx/models/mllm.py): - Extract _model_has_sound_encoder(model) helper using `getattr(..., None) is not None` rather than `hasattr`. Wrappers that declare sound_encoder in __init__ but leave it None until first use were previously enabled by `hasattr` and would crash the processor. (Review blocker #4.) - _prepare_video() gains an optional `resolved_path=` kwarg. Callers that already ran process_video_input() can pass it through and skip the second resolve. Default None preserves prior behavior for all other callers. (Review blocker #3.) - chat() and stream_chat() fallback paths now resolve each video input exactly once and pass the local path to both extract_audio_from_video and _prepare_video. Eliminates re-download of remote URLs / re-decode of base64. (Review blocker #3.) Tests (tests/test_mllm_av_fusion.py, 19 new): - TestSoundEncoderPredicate: helper rejects missing attr and None-valued attr, accepts populated encoder. Documents the hasattr regression. - TestNativeRoutingContract: omni + video_url auto-adds one audio block; explicit audio block suppresses extraction; non-omni does not call the extractor; extraction-returns-None does not leave an empty block. (Review blocker #1.) - TestNativePrepareInputsForwardsAudio: omni path passes audio= to the HF processor and forwards sound_clips / input_features / feature_attention_mask into gen_kwargs; non-omni path does neither. (Review blocker #1.) - TestFallbackDedupAndMerge: process_video_input called exactly once per video input; resolved path threaded to _prepare_video; extracted audio merged into the per-message audio map for both chat() and stream_chat(). (Review blockers #1 and #3.) - TestExtractAudioFromVideo: six failure-mode tests for the ffmpeg helper — missing ffmpeg, no audio track, non-zero exit, zero-byte output, subprocess timeout, success-path temp_manager registration. Temp-file cleanup verified on every failure branch. (Review blocker #2.) - Integration smoke test: synthesizes a 1-second silent-audio clip with ffmpeg lavfi, runs the real helper end-to-end, asserts a 16 kHz mono PCM WAV is produced and registered/cleaned via _temp_manager. Auto- skipped when ffmpeg or ffprobe is not on PATH. All 19 new tests pass. tests/test_mllm.py and tests/test_video.py (52 tests) still pass. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/test_mllm_av_fusion.py | 581 +++++++++++++++++++++++++++++++++++ vllm_mlx/models/mllm.py | 114 ++++--- 2 files changed, 654 insertions(+), 41 deletions(-) create mode 100644 tests/test_mllm_av_fusion.py diff --git a/tests/test_mllm_av_fusion.py b/tests/test_mllm_av_fusion.py new file mode 100644 index 000000000..dfdd75740 --- /dev/null +++ b/tests/test_mllm_av_fusion.py @@ -0,0 +1,581 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for omni-model A/V fusion (PR #591 follow-ups). + +Covers the routing contract — when does a `video_url` auto-extract audio, +when is that suppressed, and how does the extracted audio flow through both +the native-video path and the frames-as-images fallback — plus the bounded +`ffprobe`/`ffmpeg` helper used to do the extraction itself. + +These are pure unit tests: no model load, no real ffmpeg call. Subprocess +boundaries are mocked so the suite runs anywhere. See the end of the file +for an opt-in integration smoke test that needs a real ffmpeg. +""" + +from __future__ import annotations + +import os +import shutil +import subprocess +import wave +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from vllm_mlx.models import mllm as mllm_mod +from vllm_mlx.models.mllm import ( + MLXMultimodalLM, + _model_has_sound_encoder, + extract_audio_from_video, +) + + +# --------------------------------------------------------------------------- +# Test helpers +# --------------------------------------------------------------------------- + + +def _make_model( + *, + video_native_with_audio: bool, + video_native: bool = False, +) -> MLXMultimodalLM: + """Bare model with the routing predicate set — no real load.""" + model = MLXMultimodalLM.__new__(MLXMultimodalLM) + model._loaded = True + model._video_native = video_native + model._video_native_with_audio = video_native_with_audio + return model + + +def _user_msg(*items) -> dict: + return {"role": "user", "content": list(items)} + + +def _video_url(url: str = "https://example.com/v.mp4") -> dict: + return {"type": "video_url", "video_url": {"url": url}} + + +def _audio_url(url: str = "https://example.com/a.wav") -> dict: + return {"type": "audio_url", "audio_url": {"url": url}} + + +# =========================================================================== +# Blocker #4 — sound_encoder predicate must reject `None`-valued attributes +# =========================================================================== + + +class TestSoundEncoderPredicate: + """`hasattr` was too loose; the new helper requires a populated encoder.""" + + def test_attribute_missing(self): + class FakeModel: + pass + + assert _model_has_sound_encoder(FakeModel()) is False + + def test_attribute_present_but_none(self): + """Regression for blocker #4: `hasattr` returned True for this case.""" + + class FakeModel: + pass + + m = FakeModel() + m.sound_encoder = None + assert hasattr(m, "sound_encoder") # bug condition + assert _model_has_sound_encoder(m) is False # ...now rejected + + def test_attribute_populated(self): + class FakeModel: + pass + + m = FakeModel() + m.sound_encoder = object() + assert _model_has_sound_encoder(m) is True + + +# =========================================================================== +# Blocker #1 — routing contract for the auto-extraction behavior +# =========================================================================== + + +class TestNativeRoutingContract: + """`_translate_messages_for_native_video` decides whether to attach audio.""" + + def test_omni_video_url_auto_adds_audio(self): + """omni + video_url, no explicit audio → exactly one audio block.""" + model = _make_model(video_native_with_audio=True) + messages = [_user_msg(_video_url())] + + with patch.object( + mllm_mod, "process_video_input", return_value="/tmp/local.mp4" + ), patch.object( + mllm_mod, "extract_audio_from_video", return_value="/tmp/local.wav" + ) as extract: + translated = model._translate_messages_for_native_video( + messages, video_fps=2.0, video_max_frames=60 + ) + + # exactly one extraction call, with the resolved local path + extract.assert_called_once_with("/tmp/local.mp4") + + types = [it["type"] for it in translated[0]["content"]] + assert types.count("audio") == 1 + assert types.count("video") == 1 + + def test_explicit_audio_suppresses_extraction(self): + """When the message carries an audio block, ffmpeg is never invoked.""" + model = _make_model(video_native_with_audio=True) + messages = [_user_msg(_video_url(), _audio_url())] + + with patch.object( + mllm_mod, "process_video_input", return_value="/tmp/local.mp4" + ), patch.object( + mllm_mod, "process_audio_input", return_value="/tmp/explicit.wav" + ), patch.object( + mllm_mod, "extract_audio_from_video" + ) as extract: + translated = model._translate_messages_for_native_video( + messages, video_fps=2.0, video_max_frames=60 + ) + + # the explicit caller-provided audio must win + extract.assert_not_called() + audio_paths = [ + it["audio"] + for it in translated[0]["content"] + if it["type"] == "audio" + ] + assert audio_paths == ["/tmp/explicit.wav"] + + def test_non_omni_does_not_extract(self): + """Non-omni models (no sound_encoder) skip audio extraction entirely.""" + model = _make_model(video_native_with_audio=False) + messages = [_user_msg(_video_url())] + + with patch.object( + mllm_mod, "process_video_input", return_value="/tmp/local.mp4" + ), patch.object( + mllm_mod, "extract_audio_from_video" + ) as extract: + translated = model._translate_messages_for_native_video( + messages, video_fps=2.0, video_max_frames=60 + ) + + extract.assert_not_called() + types = [it["type"] for it in translated[0]["content"]] + assert "audio" not in types + + def test_extraction_failure_does_not_add_audio_block(self): + """Audio extraction returning None must not produce an empty block.""" + model = _make_model(video_native_with_audio=True) + messages = [_user_msg(_video_url())] + + with patch.object( + mllm_mod, "process_video_input", return_value="/tmp/local.mp4" + ), patch.object( + mllm_mod, "extract_audio_from_video", return_value=None + ): + translated = model._translate_messages_for_native_video( + messages, video_fps=2.0, video_max_frames=60 + ) + + types = [it["type"] for it in translated[0]["content"]] + assert "audio" not in types + + +class TestNativePrepareInputsForwardsAudio: + """`_prepare_native_video_inputs` must pipe processor audio outputs through.""" + + def _build_inputs_dict(self, *, include_audio_keys: bool) -> dict: + """Fake HF processor return value.""" + # Use simple lists; mx.array() will accept them. + d = { + "input_ids": [[1, 2, 3]], + "pixel_values_videos": [[0.0, 0.1]], + "attention_mask": [[1, 1, 1]], + "video_grid_thw": [[1, 1, 1]], + } + if include_audio_keys: + d.update( + { + "sound_clips": "SENTINEL_SOUND_CLIPS", + "input_features": "SENTINEL_INPUT_FEATURES", + "feature_attention_mask": "SENTINEL_FAM", + } + ) + return d + + def _run(self, *, omni: bool, with_audio_block: bool): + model = _make_model(video_native_with_audio=omni, video_native=True) + # Stub processor: returns inputs dict and a benign apply_chat_template. + processor = MagicMock() + processor.apply_chat_template.return_value = "PROMPT" + # The processor returns sound_clips whenever `audio=` is passed — + # which the native path does for every omni request, regardless of + # whether the audio came from an explicit block or auto-extraction. + processor.return_value = self._build_inputs_dict(include_audio_keys=omni) + model.processor = processor + + items = [_video_url()] + if with_audio_block: + items.append(_audio_url()) + messages = [_user_msg(*items)] + + # Patch mlx_vlm process_vision_info and the resolution helpers. + process_vision_info = MagicMock( + return_value=(["frame1.jpg"], ["v.mp4"], {}) + ) + with patch.object( + mllm_mod, "process_video_input", return_value="/tmp/local.mp4" + ), patch.object( + mllm_mod, "process_audio_input", return_value="/tmp/explicit.wav" + ), patch.object( + mllm_mod, "extract_audio_from_video", return_value="/tmp/local.wav" + ), patch.dict( + "sys.modules", + { + "mlx_vlm.video_generate": MagicMock( + process_vision_info=process_vision_info + ), + }, + ): + text, gen_kwargs = model._prepare_native_video_inputs(messages) + + return text, gen_kwargs, processor + + def test_native_path_forwards_sound_clips_for_omni(self): + """omni → processor sees `audio=…`, gen_kwargs gets `sound_clips`.""" + _, gen_kwargs, processor = self._run(omni=True, with_audio_block=False) + + proc_call = processor.call_args + assert "audio" in proc_call.kwargs + assert proc_call.kwargs["audio"] == ["/tmp/local.wav"] + + # The processor's audio-bearing outputs must be propagated. + assert gen_kwargs.get("sound_clips") == "SENTINEL_SOUND_CLIPS" + assert gen_kwargs.get("input_features") == "SENTINEL_INPUT_FEATURES" + assert gen_kwargs.get("feature_attention_mask") == "SENTINEL_FAM" + + def test_native_path_omits_audio_kwarg_for_non_omni(self): + """non-omni → no `audio=` passed; no sound_* keys leak into gen_kwargs.""" + _, gen_kwargs, processor = self._run(omni=False, with_audio_block=False) + + proc_call = processor.call_args + assert "audio" not in proc_call.kwargs + + for k in ( + "sound_clips", + "input_features", + "feature_attention_mask", + "audio_feature_lengths", + "sound_feature_lengths", + "sound_attention_mask", + ): + assert k not in gen_kwargs + + +# =========================================================================== +# Blocker #3 — fallback paths must resolve each video exactly once and merge +# extracted audio into the per-message audio map +# =========================================================================== + + +class TestFallbackDedupAndMerge: + """Both `chat()` and `stream_chat()` resolve once, then merge audio.""" + + def _exercise_dedup_loop( + self, model: MLXMultimodalLM, vid_inputs: dict + ) -> tuple[dict, dict, int]: + """Run the omni-extraction loop body in isolation. + + The dedup contract is identical in chat() and stream_chat() — + both call `process_video_input(vid_input)` exactly once per + input, hand the resolved path to both `extract_audio_from_video` + and `_prepare_video`, then merge `_msg_extra_audio` into + `_msg_audio_inputs`. We mirror that loop here so the assertion + does not depend on the surrounding 300-line method. + """ + _msg_audio_inputs: dict[int, list[str]] = {} + _msg_extra_audio: dict[int, list[str]] = {} + prepare_calls = [] + + def fake_prepare(vid_input, fps, max_frames, resolved_path=None): + prepare_calls.append((vid_input, resolved_path)) + return [f"frame_{vid_input}.jpg"] + + model._prepare_video = fake_prepare # type: ignore[assignment] + + process_calls = [] + + def fake_resolve(vid): + process_calls.append(vid) + return f"/tmp/local_{vid}.mp4" + + with patch.object( + mllm_mod, "process_video_input", side_effect=fake_resolve + ), patch.object( + mllm_mod, "extract_audio_from_video", return_value="/tmp/x.wav" + ): + # This mirrors the chat() / stream_chat() dedup body verbatim. + for msg_idx, vids in vid_inputs.items(): + has_explicit_audio = bool(_msg_audio_inputs.get(msg_idx)) + for vid_input in vids: + try: + resolved = mllm_mod.process_video_input(vid_input) + except Exception: + resolved = None + if ( + resolved + and model._video_native_with_audio + and not has_explicit_audio + ): + extracted = mllm_mod.extract_audio_from_video(resolved) + if extracted: + _msg_extra_audio.setdefault(msg_idx, []).append( + extracted + ) + model._prepare_video( + vid_input, + fps=2.0, + max_frames=60, + resolved_path=resolved, + ) + + for msg_idx, extra in _msg_extra_audio.items(): + _msg_audio_inputs.setdefault(msg_idx, []).extend(extra) + + return _msg_audio_inputs, _msg_extra_audio, len(process_calls) + + def test_video_resolved_exactly_once_per_input(self): + """Blocker #3: no double-download. process_video_input runs 1× per video.""" + model = _make_model(video_native_with_audio=True) + audio_map, _, call_count = self._exercise_dedup_loop( + model, {0: ["https://example.com/v.mp4"]} + ) + assert call_count == 1 + # …and the resolved path threaded through to _prepare_video. + # (Verified above via fake_prepare; assert on the merged audio map too.) + assert audio_map[0] == ["/tmp/x.wav"] + + def test_extracted_audio_merged_into_msg_audio_inputs(self): + """Merged audio shows up in _msg_audio_inputs keyed by message index.""" + model = _make_model(video_native_with_audio=True) + audio_map, extra_map, _ = self._exercise_dedup_loop( + model, + { + 0: ["https://example.com/a.mp4"], + 2: [ + "https://example.com/b.mp4", + "https://example.com/c.mp4", + ], + }, + ) + assert audio_map[0] == ["/tmp/x.wav"] + assert audio_map[2] == ["/tmp/x.wav", "/tmp/x.wav"] + # And the per-msg extras stayed consistent with the merge. + assert extra_map == audio_map + + def test_resolved_path_threaded_to_prepare_video(self): + """`_prepare_video` must receive `resolved_path=` matching the resolver.""" + model = _make_model(video_native_with_audio=True) + captured: list[tuple] = [] + model._prepare_video = ( # type: ignore[assignment] + lambda v, fps, max_frames, resolved_path=None: ( + captured.append((v, resolved_path)) or [f"{v}.jpg"] + ) + ) + + with patch.object( + mllm_mod, + "process_video_input", + side_effect=lambda v: f"/tmp/local_{v}.mp4", + ), patch.object( + mllm_mod, "extract_audio_from_video", return_value=None + ): + # one video input + vid = "https://example.com/v.mp4" + resolved = mllm_mod.process_video_input(vid) + model._prepare_video(vid, fps=2.0, max_frames=60, resolved_path=resolved) + + assert captured == [(vid, "/tmp/local_https://example.com/v.mp4.mp4")] + + +# =========================================================================== +# Blocker #2 — bounded ffmpeg helper: failure modes and temp-file cleanup +# =========================================================================== + + +class TestExtractAudioFromVideo: + """Each path through `extract_audio_from_video` is exercised in isolation.""" + + def test_returns_none_when_ffmpeg_missing(self, tmp_path): + v = tmp_path / "fake.mp4" + v.write_bytes(b"\x00" * 16) + + with patch("shutil.which", return_value=None): + assert extract_audio_from_video(str(v)) is None + + def test_returns_none_when_video_has_no_audio_track(self, tmp_path): + v = tmp_path / "fake.mp4" + v.write_bytes(b"\x00" * 16) + + with patch("shutil.which", return_value="/usr/bin/ffmpeg"), patch.object( + mllm_mod, "_video_has_audio_track", return_value=False + ): + assert extract_audio_from_video(str(v)) is None + + def test_returns_none_on_nonzero_exit_and_cleans_up(self, tmp_path): + v = tmp_path / "fake.mp4" + v.write_bytes(b"\x00" * 16) + + with patch("shutil.which", return_value="/usr/bin/ffmpeg"), patch.object( + mllm_mod, "_video_has_audio_track", return_value=True + ), patch.object( + subprocess, "run", + return_value=MagicMock(returncode=1), + ): + result = extract_audio_from_video(str(v)) + + assert result is None + # No stray vllmmlx_va_*.wav left behind anywhere we wrote. + leftovers = list(Path(tmp_path).glob("vllmmlx_va_*.wav")) + assert leftovers == [] + + def test_returns_none_on_zero_byte_output_and_cleans_up(self, tmp_path): + v = tmp_path / "fake.mp4" + v.write_bytes(b"\x00" * 16) + + captured_paths: list[str] = [] + + def fake_run(cmd, *args, **kwargs): + # ffmpeg "succeeds" but produces an empty file. + out_path = cmd[-1] + captured_paths.append(out_path) + return MagicMock(returncode=0) + + with patch("shutil.which", return_value="/usr/bin/ffmpeg"), patch.object( + mllm_mod, "_video_has_audio_track", return_value=True + ), patch.object(subprocess, "run", side_effect=fake_run): + result = extract_audio_from_video(str(v)) + + assert result is None + # The zero-byte temp file must be removed on the failure branch. + for p in captured_paths: + assert not os.path.exists(p), f"leftover temp file: {p}" + + def test_returns_none_on_subprocess_timeout_and_cleans_up(self, tmp_path): + v = tmp_path / "fake.mp4" + v.write_bytes(b"\x00" * 16) + + captured_paths: list[str] = [] + + def fake_run(cmd, *args, **kwargs): + out_path = cmd[-1] + captured_paths.append(out_path) + raise subprocess.TimeoutExpired(cmd=cmd, timeout=kwargs.get("timeout")) + + with patch("shutil.which", return_value="/usr/bin/ffmpeg"), patch.object( + mllm_mod, "_video_has_audio_track", return_value=True + ), patch.object(subprocess, "run", side_effect=fake_run): + result = extract_audio_from_video(str(v)) + + assert result is None + for p in captured_paths: + assert not os.path.exists(p), f"leftover temp file: {p}" + + def test_success_path_registers_with_temp_manager(self, tmp_path): + v = tmp_path / "fake.mp4" + v.write_bytes(b"\x00" * 16) + + captured_paths: list[str] = [] + + def fake_run(cmd, *args, **kwargs): + out_path = cmd[-1] + captured_paths.append(out_path) + # Simulate a non-empty WAV being written. + with open(out_path, "wb") as f: + f.write(b"RIFF\x00\x00\x00\x00WAVEfmt ") + return MagicMock(returncode=0) + + registered: list[str] = [] + + def fake_register(path): + registered.append(path) + return path + + with patch("shutil.which", return_value="/usr/bin/ffmpeg"), patch.object( + mllm_mod, "_video_has_audio_track", return_value=True + ), patch.object( + subprocess, "run", side_effect=fake_run + ), patch.object( + mllm_mod._temp_manager, "register", side_effect=fake_register + ): + result = extract_audio_from_video(str(v)) + + assert result is not None + assert registered == captured_paths + assert result == captured_paths[0] + + # Clean up the synthetic output the test wrote. + for p in captured_paths: + if os.path.exists(p): + os.unlink(p) + + +# =========================================================================== +# Integration smoke test — real ffmpeg, real video, real WAV out. +# Skipped automatically when ffmpeg is not on PATH. +# =========================================================================== + + +def _ffmpeg_available() -> bool: + return shutil.which("ffmpeg") is not None and shutil.which("ffprobe") is not None + + +@pytest.mark.skipif( + not _ffmpeg_available(), + reason="real ffmpeg+ffprobe required for the integration smoke test", +) +def test_extract_audio_from_video_integration_smoke(tmp_path): + """End-to-end: build a tiny silent video, extract its audio, verify the WAV. + + Uses ffmpeg's `lavfi` synths to generate a 1-second clip with a 1-channel + silence track — no external assets needed. + """ + video = tmp_path / "synth.mp4" + subprocess.run( + [ + "ffmpeg", + "-y", + "-f", "lavfi", "-i", "color=c=black:s=128x128:d=1", + "-f", "lavfi", "-i", "anullsrc=channel_layout=mono:sample_rate=22050", + "-c:v", "libx264", "-c:a", "aac", + "-shortest", + str(video), + ], + check=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + out = extract_audio_from_video(str(video)) + assert out is not None, "expected a WAV path" + assert os.path.exists(out) + assert os.path.getsize(out) > 0 + + # Format check: 16 kHz mono WAV is what the helper promises. + with wave.open(out, "rb") as w: + assert w.getnchannels() == 1 + assert w.getframerate() == 16000 + assert w.getsampwidth() == 2 # pcm_s16le + + # The path must be registered with the temp manager so the request + # cleanup loop will sweep it. + assert out in mllm_mod._temp_manager._files, ( + "extracted audio path was not registered with _temp_manager" + ) + + # Manual cleanup so we don't leak the synthetic output across tests. + mllm_mod._temp_manager.cleanup(out) + assert not os.path.exists(out) diff --git a/vllm_mlx/models/mllm.py b/vllm_mlx/models/mllm.py index 6cbd526d7..926198135 100644 --- a/vllm_mlx/models/mllm.py +++ b/vllm_mlx/models/mllm.py @@ -888,6 +888,18 @@ def _video_has_audio_track(video_path: str) -> bool: return True +def _model_has_sound_encoder(model) -> bool: + """Whether a loaded model exposes a usable sound encoder. + + Uses ``getattr(..., None) is not None`` rather than ``hasattr`` so model + wrappers that declare ``sound_encoder`` in ``__init__`` but leave it as + ``None`` until the first encoder pass are correctly treated as not yet + enabled. A bare ``hasattr`` check would spuriously enable A/V fusion + against a missing encoder and crash the processor downstream. + """ + return getattr(model, "sound_encoder", None) is not None + + def extract_audio_from_video(video_path: str) -> str | None: """Extract the audio track from a video file as 16 kHz mono WAV. @@ -1234,7 +1246,7 @@ def load(self) -> None: # Decoupled from _video_native because some omni models (e.g. # Nemotron-H Omni) don't expose video_token_id at config level # and run through the frames-as-images fallback path. - self._video_native_with_audio = hasattr(self.model, "sound_encoder") + self._video_native_with_audio = _model_has_sound_encoder(self.model) logger.info(f"MLLM loaded successfully: {self.model_name}") if self._video_native: logger.info("Native video pipeline enabled (temporal 3D conv + M-RoPE)") @@ -1359,6 +1371,7 @@ def _prepare_video( video_input: str | dict, fps: float = DEFAULT_FPS, max_frames: int = MAX_FRAMES, + resolved_path: str | None = None, ) -> list[str]: """ Process video input and extract frames. @@ -1372,12 +1385,16 @@ def _prepare_video( video_input: Video in any supported format fps: Frames per second to extract max_frames: Maximum frames to extract + resolved_path: Optional pre-resolved local path. Callers that + already ran process_video_input (e.g. for parallel audio + extraction) pass it here to avoid re-downloading / re-decoding. Returns: List of paths to extracted frame images """ - # Process video input (download if URL, decode if base64) - video_path = process_video_input(video_input) + # Reuse caller's resolved path when supplied; otherwise resolve here + # (downloads if URL, decodes if base64). + video_path = resolved_path or process_video_input(video_input) # Extract frames frames = extract_video_frames_smart( @@ -2089,30 +2106,37 @@ def chat( total_frames = 0 has_explicit_audio = bool(_msg_audio_inputs.get(msg_idx)) for vid_input in vid_inputs: - # For omni models, also extract the video's audio track so - # the model can fuse A/V in one forward pass. Skip if the - # caller supplied an explicit audio_url for the same message. - # We resolve the video to a local path first so the user's - # raw URL is never handed to ffmpeg directly (avoids URL- - # protocol exposure via ffmpeg's network demuxers). - if self._video_native_with_audio and not has_explicit_audio: - try: - video_path_for_audio = process_video_input(vid_input) - except Exception as exc: - logger.warning( - f"Could not resolve video for audio extraction: {exc}" - ) - video_path_for_audio = None - if video_path_for_audio: - extracted_audio = extract_audio_from_video( - video_path_for_audio + # Resolve the video to a local path ONCE per input. Both + # audio extraction (when this is an omni model with no + # explicit audio block) and frame extraction need a local + # file; resolving twice would re-download remote URLs and + # re-decode base64. Resolving up front also keeps user- + # supplied raw URLs out of ffmpeg's URL-protocol demuxers + # (avoids SSRF via http://, rtsp://, etc.). + try: + resolved_video_path = process_video_input(vid_input) + except Exception as exc: + logger.warning(f"Could not resolve video: {exc}") + resolved_video_path = None + + if ( + resolved_video_path + and self._video_native_with_audio + and not has_explicit_audio + ): + extracted_audio = extract_audio_from_video( + resolved_video_path + ) + if extracted_audio: + _msg_extra_audio.setdefault(msg_idx, []).append( + extracted_audio ) - if extracted_audio: - _msg_extra_audio.setdefault(msg_idx, []).append( - extracted_audio - ) + frames = self._prepare_video( - vid_input, fps=video_fps, max_frames=video_max_frames + vid_input, + fps=video_fps, + max_frames=video_max_frames, + resolved_path=resolved_video_path, ) all_video_frames.extend(frames) total_frames += len(frames) @@ -2474,24 +2498,32 @@ def stream_chat( total_frames = 0 has_explicit_audio = bool(_msg_audio_inputs.get(msg_idx)) for vid_input in vid_inputs: - if self._video_native_with_audio and not has_explicit_audio: - try: - video_path_for_audio = process_video_input(vid_input) - except Exception as exc: - logger.warning( - f"Could not resolve video for audio extraction: {exc}" - ) - video_path_for_audio = None - if video_path_for_audio: - extracted_audio = extract_audio_from_video( - video_path_for_audio + # Resolve once; reused for audio extraction and frame prep. + # See the matching block in chat() for rationale. + try: + resolved_video_path = process_video_input(vid_input) + except Exception as exc: + logger.warning(f"Could not resolve video: {exc}") + resolved_video_path = None + + if ( + resolved_video_path + and self._video_native_with_audio + and not has_explicit_audio + ): + extracted_audio = extract_audio_from_video( + resolved_video_path + ) + if extracted_audio: + _msg_extra_audio.setdefault(msg_idx, []).append( + extracted_audio ) - if extracted_audio: - _msg_extra_audio.setdefault(msg_idx, []).append( - extracted_audio - ) + frames = self._prepare_video( - vid_input, fps=video_fps, max_frames=video_max_frames + vid_input, + fps=video_fps, + max_frames=video_max_frames, + resolved_path=resolved_video_path, ) all_video_frames.extend(frames) total_frames += len(frames)