Skip to content

fix: add defensive fallback for Qwen VL mrope get_rope_index shape mismatch#10589

Open
Ottohere-Mourn wants to merge 2 commits into
hiyouga:mainfrom
Ottohere-Mourn:fix/qwen3-video-mrope-fallback
Open

fix: add defensive fallback for Qwen VL mrope get_rope_index shape mismatch#10589
Ottohere-Mourn wants to merge 2 commits into
hiyouga:mainfrom
Ottohere-Mourn:fix/qwen3-video-mrope-fallback

Conversation

@Ottohere-Mourn

@Ottohere-Mourn Ottohere-Mourn commented Jun 17, 2026

Copy link
Copy Markdown

Description

When training Qwen VL models with mixed multimodal data, get_rope_index can raise:

RuntimeError: "value tensor of shape [3, A] cannot be broadcast to indexing result of shape [3, B]"

This is not a frame-count issue — it is caused by missing media files triggering inconsistent fallback paths.

Root cause chain

  1. Some samples reference npy/video files that are missing from the dataset directory
  2. The custom video loader catches the load error and produces a fallback volume (e.g. a fixed-shape all-black array)
  3. Tokenize stage processes this fallback through custom grid metadata logic, producing a video_grid_thw with specific placeholder token count (e.g. 462 or 504)
  4. Forward stage re-loads the same missing file, also fallback, but computes video_grid_thw via transformers' native video_processor — a different code path → token count diverges
  5. Placeholder count in input_ids ≠ actual visual feature count → shape mismatch in get_rope_index

In a 9.16M-sample run, only 19 unique IDs triggered this (105 rows, 0.0011%), all referencing missing files in M3D_CT/MedMNIST.

Fix

  1. try/except around get_rope_func (qwen vl branch), catching only the known shape mismatch
  2. _fallback_rope_position_ids: flat 3-axis positional encoding, fully vectorized
  3. _log_rope_mismatch: one-line diagnostic for downstream data cleaning

Validation

  • 9.16M-sample Qwen3.5-27B run: 0 false positives, step 279+ stable
  • 105 affected samples identified via diagnostic log and cleaned at data level

…smatch

When video samples have inconsistent grid metadata between the tokenizer
and the forward pass (e.g. due to odd-frame padding logic in
_get_qwen_video_grid_metadata vs. the native video_processor), the
placeholder token count in input_ids disagrees with the feature count,
causing a RuntimeError in get_rope_index:

  "value tensor of shape [3, A] cannot be broadcast to indexing result
   of shape [3, B]"

This affects a tiny fraction of samples (~0.001% in a 9.16M sample
mix) and is normally invisible in smoke tests with small max_samples.

The fix catches this specific shape mismatch at the collator level,
degrades to a flat 3-axis positional encoding for the offending batch
(allowing training to continue without interruption), and prints a
one-line diagnostic so the user can later locate and clean those
samples at the data level.

Verified on a 9.16M-sample multimodal training run. No false positives
detected.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a defensive fallback mechanism in the data collator for Qwen VL models to handle mrope shape mismatch errors, preventing training runs from crashing. The review feedback highlights a performance bottleneck in the fallback method due to GPU-CPU synchronization and loop overhead, and provides a vectorized PyTorch implementation to resolve it.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread src/llamafactory/data/collator.py Outdated
Comment on lines +221 to +231
def _fallback_rope_position_ids(input_ids: "torch.Tensor", attention_mask: "torch.Tensor"):
r"""Flat positional fallback: every non-pad token gets a monotonic position; all 3 mrope axes identical."""
bsz, seq_len = input_ids.shape
position_ids = torch.zeros(3, bsz, seq_len, dtype=input_ids.dtype, device=input_ids.device)
rope_deltas = torch.zeros(bsz, dtype=input_ids.dtype, device=input_ids.device)
for batch_idx in range(bsz):
valid_len = int(attention_mask[batch_idx].bool().sum().item())
positions = torch.arange(seq_len, device=input_ids.device)
position_ids[:, batch_idx] = positions.view(1, -1).expand(3, -1)
rope_deltas[batch_idx] = positions.max() + 1 - valid_len if valid_len > 0 else 0
return position_ids, rope_deltas

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Performance Bottleneck: GPU-CPU Synchronization & Loop Overhead

In the current implementation of _fallback_rope_position_ids, there are two major performance issues that can slow down training throughput:

  1. GPU-CPU Synchronization: Calling .item() on a GPU tensor (inside attention_mask[batch_idx].bool().sum().item()) forces the CPU to block and wait for the GPU to finish its computation. Since this is executed inside a loop for every batch in the data collator, it can severely bottleneck the training pipeline and lower GPU utilization.
  2. Redundant Tensor Creation & Loop Overhead: Re-creating the positions tensor bsz times inside a Python loop is inefficient.

Solution

We can completely vectorize this method to run loop-free and avoid any GPU-CPU synchronization by leveraging PyTorch's native tensor operations (expand, torch.where, and sum along dimensions).

    @staticmethod
    def _fallback_rope_position_ids(input_ids: "torch.Tensor", attention_mask: "torch.Tensor"):
        r"""Flat positional fallback: every non-pad token gets a monotonic position; all 3 mrope axes identical."""
        bsz, seq_len = input_ids.shape
        positions = torch.arange(seq_len, device=input_ids.device, dtype=input_ids.dtype)
        position_ids = positions.view(1, 1, seq_len).expand(3, bsz, seq_len).contiguous()
        valid_len = attention_mask.bool().sum(dim=-1)
        rope_deltas = torch.where(valid_len > 0, seq_len - valid_len, torch.zeros_like(valid_len))
        return position_ids, rope_deltas.to(dtype=input_ids.dtype)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, applied the vectorized version in 2e67cbb. The Python loop and .item() call have been removed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant