fix: add defensive fallback for Qwen VL mrope get_rope_index shape mismatch#10589
fix: add defensive fallback for Qwen VL mrope get_rope_index shape mismatch#10589Ottohere-Mourn wants to merge 2 commits into
Conversation
…smatch When video samples have inconsistent grid metadata between the tokenizer and the forward pass (e.g. due to odd-frame padding logic in _get_qwen_video_grid_metadata vs. the native video_processor), the placeholder token count in input_ids disagrees with the feature count, causing a RuntimeError in get_rope_index: "value tensor of shape [3, A] cannot be broadcast to indexing result of shape [3, B]" This affects a tiny fraction of samples (~0.001% in a 9.16M sample mix) and is normally invisible in smoke tests with small max_samples. The fix catches this specific shape mismatch at the collator level, degrades to a flat 3-axis positional encoding for the offending batch (allowing training to continue without interruption), and prints a one-line diagnostic so the user can later locate and clean those samples at the data level. Verified on a 9.16M-sample multimodal training run. No false positives detected.
There was a problem hiding this comment.
Code Review
This pull request introduces a defensive fallback mechanism in the data collator for Qwen VL models to handle mrope shape mismatch errors, preventing training runs from crashing. The review feedback highlights a performance bottleneck in the fallback method due to GPU-CPU synchronization and loop overhead, and provides a vectorized PyTorch implementation to resolve it.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def _fallback_rope_position_ids(input_ids: "torch.Tensor", attention_mask: "torch.Tensor"): | ||
| r"""Flat positional fallback: every non-pad token gets a monotonic position; all 3 mrope axes identical.""" | ||
| bsz, seq_len = input_ids.shape | ||
| position_ids = torch.zeros(3, bsz, seq_len, dtype=input_ids.dtype, device=input_ids.device) | ||
| rope_deltas = torch.zeros(bsz, dtype=input_ids.dtype, device=input_ids.device) | ||
| for batch_idx in range(bsz): | ||
| valid_len = int(attention_mask[batch_idx].bool().sum().item()) | ||
| positions = torch.arange(seq_len, device=input_ids.device) | ||
| position_ids[:, batch_idx] = positions.view(1, -1).expand(3, -1) | ||
| rope_deltas[batch_idx] = positions.max() + 1 - valid_len if valid_len > 0 else 0 | ||
| return position_ids, rope_deltas |
There was a problem hiding this comment.
Performance Bottleneck: GPU-CPU Synchronization & Loop Overhead
In the current implementation of _fallback_rope_position_ids, there are two major performance issues that can slow down training throughput:
- GPU-CPU Synchronization: Calling
.item()on a GPU tensor (insideattention_mask[batch_idx].bool().sum().item()) forces the CPU to block and wait for the GPU to finish its computation. Since this is executed inside a loop for every batch in the data collator, it can severely bottleneck the training pipeline and lower GPU utilization. - Redundant Tensor Creation & Loop Overhead: Re-creating the
positionstensorbsztimes inside a Python loop is inefficient.
Solution
We can completely vectorize this method to run loop-free and avoid any GPU-CPU synchronization by leveraging PyTorch's native tensor operations (expand, torch.where, and sum along dimensions).
@staticmethod
def _fallback_rope_position_ids(input_ids: "torch.Tensor", attention_mask: "torch.Tensor"):
r"""Flat positional fallback: every non-pad token gets a monotonic position; all 3 mrope axes identical."""
bsz, seq_len = input_ids.shape
positions = torch.arange(seq_len, device=input_ids.device, dtype=input_ids.dtype)
position_ids = positions.view(1, 1, seq_len).expand(3, bsz, seq_len).contiguous()
valid_len = attention_mask.bool().sum(dim=-1)
rope_deltas = torch.where(valid_len > 0, seq_len - valid_len, torch.zeros_like(valid_len))
return position_ids, rope_deltas.to(dtype=input_ids.dtype)There was a problem hiding this comment.
Thanks, applied the vectorized version in 2e67cbb. The Python loop and .item() call have been removed.
Description
When training Qwen VL models with mixed multimodal data,
get_rope_indexcan raise:This is not a frame-count issue — it is caused by missing media files triggering inconsistent fallback paths.
Root cause chain
video_grid_thwwith specific placeholder token count (e.g. 462 or 504)video_grid_thwvia transformers' nativevideo_processor— a different code path → token count divergesinput_ids≠ actual visual feature count → shape mismatch inget_rope_indexIn a 9.16M-sample run, only 19 unique IDs triggered this (105 rows, 0.0011%), all referencing missing files in M3D_CT/MedMNIST.
Fix
get_rope_func(qwen vl branch), catching only the known shape mismatchValidation