diff --git a/examples/qwen3_vl/run_vit_frame_parallel_sp.sh b/examples/qwen3_vl/run_vit_frame_parallel_sp.sh new file mode 100755 index 00000000..4e45f455 --- /dev/null +++ b/examples/qwen3_vl/run_vit_frame_parallel_sp.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +NGPUS=${NGPUS:-8} +MASTER_ADDR=${MASTER_ADDR:-127.0.0.1} +MASTER_PORT=${MASTER_PORT:-12355} + +torchrun --nproc_per_node=${NGPUS} \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=${MASTER_ADDR} \ + --master_port=${MASTER_PORT} \ + -m lmms_engine.launch.cli \ + --config-path examples/qwen3_vl \ + --config-name vit_frame_parallel_sp diff --git a/examples/qwen3_vl/vit_frame_parallel_sp.yaml b/examples/qwen3_vl/vit_frame_parallel_sp.yaml new file mode 100644 index 00000000..86f506dd --- /dev/null +++ b/examples/qwen3_vl/vit_frame_parallel_sp.yaml @@ -0,0 +1,77 @@ +# Qwen3-VL training example with ViT frame parallelism + Ulysses SP. +# +# ``vit_frame_parallel`` balances visual frames across the flattened DPxCP +# group. With ``sp_ulysses_degree > 1``, each CP rank owns a shard of the local +# duplicated frames for ViT, then gathers visual features back before LM SP +# slicing. + +trainer_type: fsdp2_trainer + +dataset_config: + dataset_type: qwen3_vl_iterable + dataset_format: yaml + dataset_path: data/video/debug.yaml + processor_config: + processor_name: Qwen/Qwen3-VL-8B-Instruct + processor_type: qwen3_vl + shuffle: true + object_storage: none + packing: true + packing_strategy: balanced + packing_length: 51200 + filter_overlong: true + filter_overlong_workers: 8 + video_sampling_strategy: fps + video_max_pixels: 50176 + video_max_frames: 512 + frame_num: 64 + fps: 1 + video_backend: qwen_vl_utils + extra_kwargs: + packing_kwargs: + num_buckets: 2 + +trainer_args: + output_dir: ./output/qwen3_vl_vit_fp_sp + do_train: true + do_eval: false + per_device_train_batch_size: 1 + gradient_accumulation_steps: 1 + learning_rate: 2.0e-04 + weight_decay: 0.0 + num_train_epochs: 1 + max_steps: 1000 + lr_scheduler_type: cosine + warmup_ratio: 0.1 + logging_steps: 1 + save_strategy: steps + save_steps: 1000 + save_total_limit: 1 + bf16: true + tf32: true + dataloader_drop_last: true + dataloader_num_workers: 4 + dataloader_prefetch_factor: 2 + remove_unused_columns: false + gradient_checkpointing: true + use_liger_kernel: true + use_rmpad: true + fsdp2: true + fsdp_config: + transformer_layer_cls_to_wrap: + - Qwen3VLTextDecoderLayer + reshard_after_forward: false + min_num_params: 0 + sp_ulysses_degree: 2 + reduce_dtype: bfloat16 + output_dtype: bfloat16 + optim: adamw_torch_fused + seed: 42 + run_name: qwen3_vl_vit_fp_sp + +model_config: + load_from_pretrained_path: Qwen/Qwen3-VL-8B-Instruct + attn_implementation: flash_attention_2 + torch_dtype: bfloat16 + monkey_patch_kwargs: + patch_type: ["liger", "vit_frame_parallel"] diff --git a/src/lmms_engine/models/qwen3_5/qwen3_5_vit_ops.py b/src/lmms_engine/models/qwen3_5/qwen3_5_vit_ops.py index 5820e9c2..f44fe873 100644 --- a/src/lmms_engine/models/qwen3_5/qwen3_5_vit_ops.py +++ b/src/lmms_engine/models/qwen3_5/qwen3_5_vit_ops.py @@ -127,7 +127,7 @@ def input_dispatch( loads = [token for tokens in total_tokens for token in tokens] # ---- 2) LPT ---- - assignment_list, _ = lpt_balance(loads, num_ranks=world_size) + assignment_list, _ = lpt_balance(loads, num_ranks=world_size, frames_per_rank=total_frames) # ---- 3) src-view input splits (what I send to each dst) ---- # Slice out the segment of `assignment_list` corresponding to my local frames. diff --git a/src/lmms_engine/models/qwen3_vl/monkey_patch.py b/src/lmms_engine/models/qwen3_vl/monkey_patch.py index 780d1d3e..1e9a10a4 100644 --- a/src/lmms_engine/models/qwen3_vl/monkey_patch.py +++ b/src/lmms_engine/models/qwen3_vl/monkey_patch.py @@ -2,13 +2,16 @@ from functools import partial, wraps from typing import Callable +from loguru import logger from packaging import version from transformers import PreTrainedModel, Qwen3VLTextModel +import lmms_engine.parallel.process_group_manager as pgm from lmms_engine.parallel.sequence_parallel.ulysses import ( get_ulysses_sequence_parallel_world_size, patch_vlm_for_ulysses_input_slicing, ) +from lmms_engine.parallel.vit_parallel.frame_parallel import wrap_vit_forward try: from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss @@ -145,3 +148,35 @@ def wrapper(*args, **kwargs): for vision_block in vision_model.blocks: _patch_layer_norm_module(vision_block.norm1) _patch_layer_norm_module(vision_block.norm2) + + +@MONKEY_PATCHER.register("qwen3_vl", "vit_frame_parallel") +def apply_vit_frame_parallel_to_qwen3_vl(model: PreTrainedModel = None, **kwargs) -> None: + """Wrap ``Qwen3VLVisionModel.forward`` with DPxCP frame-parallel dispatch.""" + from transformers.models.qwen3_vl import modeling_qwen3_vl + + from .qwen3_vl_vit_ops import input_dispatch, output_dispatch + + if pgm.process_group_manager is None: + logger.info("vit_frame_parallel: process_group_manager not initialized, skipping ViT wrap") + return + + dp_cp_world_size = pgm.process_group_manager.dp_cp_world_size + if dp_cp_world_size <= 1: + logger.info("vit_frame_parallel: dp_cp_world_size <= 1, skipping ViT wrap") + return + + dp_cp_group = pgm.process_group_manager.dp_cp_group + cp_group = pgm.process_group_manager.cp_group if pgm.process_group_manager.cp_world_size > 1 else None + orig_forward = modeling_qwen3_vl.Qwen3VLVisionModel.forward + + wrapped = wrap_vit_forward( + input_dispatch=partial(input_dispatch, group=dp_cp_group, cp_group=cp_group), + orig_forward=orig_forward, + output_dispatch=output_dispatch, + ) + modeling_qwen3_vl.Qwen3VLVisionModel.forward = wrapped + logger.info( + f"vit_frame_parallel: wrapped Qwen3VLVisionModel.forward " + f"(dp_cp_size={dp_cp_world_size}, cp_size={pgm.process_group_manager.cp_world_size})" + ) diff --git a/src/lmms_engine/models/qwen3_vl/qwen3_vl_vit_ops.py b/src/lmms_engine/models/qwen3_vl/qwen3_vl_vit_ops.py new file mode 100644 index 00000000..0aa791e1 --- /dev/null +++ b/src/lmms_engine/models/qwen3_vl/qwen3_vl_vit_ops.py @@ -0,0 +1,242 @@ +"""Frame-parallel dispatch for Qwen3-VL ``Qwen3VLVisionModel.forward``. + +Qwen3-VL's vision output includes patch-level ``last_hidden_state``, merged +``pooler_output``, and merged ``deepstack_features``. The dispatch mirrors the +Qwen3.5 DPxCP frame-parallel path and additionally routes deepstack tensors +with the same split scale as ``pooler_output``. +""" + +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.distributed as dist +from torch.distributed._functional_collectives import ( + all_gather_tensor_autograd, + all_to_all_single, + all_to_all_single_autograd, +) + +from lmms_engine.parallel.vit_parallel.balance import lpt_balance + + +def _patches_per_row(grid_thw: torch.Tensor) -> torch.Tensor: + return grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2] + + +def _cp_frame_range(num_frames: int, cp_rank: int, cp_size: int) -> Tuple[int, int]: + per_rank = (num_frames + cp_size - 1) // cp_size + start = min(cp_rank * per_rank, num_frames) + end = min(start + per_rank, num_frames) + return start, end + + +def _all_gather_variable_dim0(x: torch.Tensor, group: dist.ProcessGroup) -> torch.Tensor: + world_size = dist.get_world_size(group=group) + local_len = x.shape[0] + lengths = [None for _ in range(world_size)] + dist.all_gather_object(lengths, local_len, group=group) + max_len = max(lengths) + if local_len < max_len: + pad_shape = list(x.shape) + pad_shape[0] = max_len - local_len + x = torch.cat([x, x.new_zeros(pad_shape)], dim=0) + gathered = all_gather_tensor_autograd(x, gather_dim=0, group=group) + chunks = gathered.split(max_len, dim=0) + return torch.cat([chunk[:length] for chunk, length in zip(chunks, lengths)], dim=0) + + +def input_dispatch( + self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + *, + group: dist.ProcessGroup, + cp_group: Optional[dist.ProcessGroup] = None, + **kwargs, +) -> Tuple[Tuple, Dict[str, Any], Dict[str, Any]]: + world_size = dist.get_world_size(group=group) + my_rank = dist.get_rank(group=group) + device = hidden_states.device + + cp_rank = dist.get_rank(group=cp_group) if cp_group is not None else 0 + cp_size = dist.get_world_size(group=cp_group) if cp_group is not None else 1 + + frame_start, frame_end = _cp_frame_range(grid_thw.shape[0], cp_rank, cp_size) + local_grid_thw = grid_thw[frame_start:frame_end].contiguous() + patch_start = 0 if frame_start == 0 else int(_patches_per_row(grid_thw[:frame_start]).sum().item()) + local_num_patches = int(_patches_per_row(local_grid_thw).sum().item()) if local_grid_thw.numel() > 0 else 0 + local_hidden_states = hidden_states[patch_start : patch_start + local_num_patches].contiguous() + + num_tokens = local_grid_thw.prod(-1).tolist() + num_frames = local_grid_thw.shape[0] + total_tokens = [None for _ in range(world_size)] + total_frames = [None for _ in range(world_size)] + dist.all_gather_object(total_tokens, num_tokens, group=group) + dist.all_gather_object(total_frames, num_frames, group=group) + loads = [token for tokens in total_tokens for token in tokens] + + assignment_list, _ = lpt_balance(loads, num_ranks=world_size, frames_per_rank=total_frames) + + my_start = sum(total_frames[:my_rank]) + my_end = my_start + num_frames + my_assignment = assignment_list[my_start:my_end] + + input_splits = [0] * world_size + input_frames = [0] * world_size + for tokens, dst in zip(num_tokens, my_assignment): + input_splits[dst] += tokens + input_frames[dst] += 1 + + output_splits = [0] * world_size + output_frames = [0] * world_size + cursor = 0 + for src in range(world_size): + n = total_frames[src] + for k in range(cursor, cursor + n): + if assignment_list[k] == my_rank: + output_splits[src] += loads[k] + output_frames[src] += 1 + cursor += n + + if num_frames > 0: + send_order = torch.argsort( + torch.tensor(my_assignment, dtype=torch.long, device=device), + stable=True, + ) + patches_per_local = local_grid_thw.prod(-1) + local_starts = torch.cat([torch.zeros(1, dtype=torch.long, device=device), patches_per_local.cumsum(0)]) + patch_perm = torch.cat( + [torch.arange(local_starts[i], local_starts[i + 1], device=device) for i in send_order.tolist()] + ) + send_hidden = local_hidden_states[patch_perm].contiguous() + send_grid = local_grid_thw[send_order].contiguous() + else: + send_order = torch.empty(0, dtype=torch.long, device=device) + patches_per_local = torch.empty(0, dtype=torch.long, device=device) + send_hidden = hidden_states.new_zeros((0, hidden_states.shape[1])) + send_grid = grid_thw.new_zeros((0, grid_thw.shape[1])) + + recv_hidden = all_to_all_single_autograd( + send_hidden, + output_split_sizes=output_splits, + input_split_sizes=input_splits, + group=group, + ) + recv_grid = all_to_all_single( + send_grid, + output_split_sizes=output_frames, + input_split_sizes=input_frames, + group=group, + ) + + ctx = { + "group": group, + "cp_group": cp_group, + "input_splits": output_splits, + "output_splits": input_splits, + "send_order": send_order, + "patches_per_local": patches_per_local, + } + return (self, recv_hidden), {"grid_thw": recv_grid, **kwargs}, ctx + + +def _dispatch_merged_tensor( + x: torch.Tensor, + *, + in_splits: list[int], + out_splits: list[int], + scale: int, + send_order: torch.Tensor, + patches_per_local: torch.Tensor, + group: dist.ProcessGroup, + cp_group: Optional[dist.ProcessGroup], +) -> torch.Tensor: + merged_in = [s // scale for s in in_splits] + merged_out = [s // scale for s in out_splits] + x = all_to_all_single_autograd( + x, + output_split_sizes=merged_out, + input_split_sizes=merged_in, + group=group, + ) + + n_local = send_order.numel() + if n_local > 0: + device = x.device + inv_order = torch.empty_like(send_order) + inv_order[send_order] = torch.arange(n_local, device=device) + per_frame = patches_per_local[send_order] // scale + starts = torch.cat([torch.zeros(1, dtype=torch.long, device=device), per_frame.cumsum(0)]) + perm = torch.cat([torch.arange(starts[i], starts[i + 1], device=device) for i in inv_order.tolist()]) + x = x[perm] + + if cp_group is not None and dist.get_world_size(group=cp_group) > 1: + x = _all_gather_variable_dim0(x, cp_group) + return x + + +def output_dispatch(out, ctx): + in_splits = ctx["input_splits"] + out_splits = ctx["output_splits"] + group = ctx["group"] + cp_group: Optional[dist.ProcessGroup] = ctx["cp_group"] + send_order: torch.Tensor = ctx["send_order"] + patches_per_local: torch.Tensor = ctx["patches_per_local"] + device = out.last_hidden_state.device + + last_hidden = all_to_all_single_autograd( + out.last_hidden_state, + output_split_sizes=out_splits, + input_split_sizes=in_splits, + group=group, + ) + + n_tokens = out.pooler_output.shape[0] + n_patches = sum(in_splits) + if n_tokens > 0 and n_patches > 0: + assert n_patches % n_tokens == 0, f"pooler_output tokens ({n_tokens}) doesn't divide patch total ({n_patches})" + scale = n_patches // n_tokens + else: + scale = 1 + + n_local = send_order.numel() + if n_local > 0: + inv_order = torch.empty_like(send_order) + inv_order[send_order] = torch.arange(n_local, device=device) + starts_full = torch.cat( + [torch.zeros(1, dtype=torch.long, device=device), patches_per_local[send_order].cumsum(0)] + ) + full_perm = torch.cat( + [torch.arange(starts_full[i], starts_full[i + 1], device=device) for i in inv_order.tolist()] + ) + last_hidden = last_hidden[full_perm] + + if cp_group is not None and dist.get_world_size(group=cp_group) > 1: + last_hidden = _all_gather_variable_dim0(last_hidden, cp_group) + + out.last_hidden_state = last_hidden + out.pooler_output = _dispatch_merged_tensor( + out.pooler_output, + in_splits=in_splits, + out_splits=out_splits, + scale=scale, + send_order=send_order, + patches_per_local=patches_per_local, + group=group, + cp_group=cp_group, + ) + if out.deepstack_features is not None: + out.deepstack_features = [ + _dispatch_merged_tensor( + feature, + in_splits=in_splits, + out_splits=out_splits, + scale=scale, + send_order=send_order, + patches_per_local=patches_per_local, + group=group, + cp_group=cp_group, + ) + for feature in out.deepstack_features + ] + return out diff --git a/src/lmms_engine/parallel/vit_parallel/balance.py b/src/lmms_engine/parallel/vit_parallel/balance.py index 51326f52..6779493d 100644 --- a/src/lmms_engine/parallel/vit_parallel/balance.py +++ b/src/lmms_engine/parallel/vit_parallel/balance.py @@ -1,4 +1,4 @@ -"""LPT balancing for ViT frame/image parallelism. +"""Balancing for ViT frame/image parallelism. Distributes a global set of frames (or images) across DP ranks so that the total ViT compute load (e.g. patch count) per rank is as balanced as possible. @@ -8,10 +8,14 @@ the same assignment, so no communication is needed to agree on the plan once loads have been gathered. -Algorithm: Longest Processing Time (LPT) greedy. Worst-case ratio 4/3, in -practice usually within ~1% of optimal on typical multimodal batches. +When each frame's source rank is known, the planner is locality-aware: frames +start on their source rank and only overloaded ranks spill frames to +underloaded ranks. This avoids the communication-heavy behavior of global LPT, +where a rank that is already near average can still have all local frames +replaced by frames from other ranks. """ +import math from typing import List, Optional, Sequence, Tuple @@ -19,8 +23,9 @@ def lpt_balance( loads: Sequence[int], num_ranks: Optional[int] = None, rank_idx: Optional[int] = None, + frames_per_rank: Optional[Sequence[int]] = None, ) -> Tuple[List[int], List[int]]: - """Compute an LPT assignment of frames to ranks. + """Compute a deterministic assignment of frames to ranks. Args: loads: per-frame load (e.g. ``T * H * W`` patches). Frames are @@ -32,6 +37,9 @@ def lpt_balance( rank_idx: if provided, also return the list of frame indices assigned to this rank as a convenience. If ``None``, the caller can derive it from the returned ``assignment``. + frames_per_rank: source frame counts for each rank in the flattened + ``loads`` list. When provided, enables locality-aware balancing: + frames stay on their source rank unless that rank is overloaded. Returns: ``(assignment, load_per_rank)`` where @@ -57,6 +65,9 @@ def lpt_balance( if num_ranks <= 0: raise ValueError(f"num_ranks must be positive, got {num_ranks}") + if frames_per_rank is not None: + return _locality_aware_balance(loads, frames_per_rank, num_ranks) + n_frames = len(loads) assignment: List[int] = [-1] * n_frames load_per_rank: List[int] = [0] * num_ranks @@ -78,6 +89,104 @@ def lpt_balance( return assignment, load_per_rank +def _source_ranks(frames_per_rank: Sequence[int], num_frames: int, num_ranks: int) -> List[int]: + if len(frames_per_rank) != num_ranks: + raise ValueError(f"frames_per_rank must have {num_ranks} entries, got {len(frames_per_rank)}") + if sum(frames_per_rank) != num_frames: + raise ValueError(f"sum(frames_per_rank) must equal {num_frames}, got {sum(frames_per_rank)}") + + source_ranks: List[int] = [] + for rank, num_rank_frames in enumerate(frames_per_rank): + if num_rank_frames < 0: + raise ValueError(f"frames_per_rank entries must be non-negative, got {num_rank_frames}") + source_ranks.extend([rank] * num_rank_frames) + return source_ranks + + +def _pick_spill_frame( + frame_indices: Sequence[int], + loads: Sequence[int], + donor_load: int, + receiver_load: int, + target_load: int, +) -> Optional[int]: + receiver_deficit = target_load - receiver_load + + fitting = [idx for idx in frame_indices if loads[idx] <= receiver_deficit] + if fitting: + return max(fitting, key=lambda idx: (loads[idx], -idx)) + + improving = [idx for idx in frame_indices if max(donor_load - loads[idx], receiver_load + loads[idx]) < donor_load] + if improving: + return min(improving, key=lambda idx: (loads[idx], idx)) + + return None + + +def _locality_aware_balance( + loads: Sequence[int], + frames_per_rank: Sequence[int], + num_ranks: int, +) -> Tuple[List[int], List[int]]: + """Balance by spilling only overloaded ranks' local frames.""" + n_frames = len(loads) + source_ranks = _source_ranks(frames_per_rank, n_frames, num_ranks) + assignment = source_ranks[:] + load_per_rank = [0] * num_ranks + frames_by_rank: List[List[int]] = [[] for _ in range(num_ranks)] + + for idx, (load, rank) in enumerate(zip(loads, source_ranks)): + load_per_rank[rank] += load + frames_by_rank[rank].append(idx) + + total_load = sum(loads) + if total_load == 0: + return assignment, load_per_rank + + target_load = math.ceil(total_load / num_ranks) + + while True: + donors = [rank for rank, load in enumerate(load_per_rank) if load > target_load] + receivers = [rank for rank, load in enumerate(load_per_rank) if load < target_load] + if not donors or not receivers: + break + + donors.sort(key=lambda rank: (-(load_per_rank[rank] - target_load), rank)) + receivers.sort(key=lambda rank: (-(target_load - load_per_rank[rank]), rank)) + + moved = False + for donor in donors: + donor_frames = [idx for idx in frames_by_rank[donor] if assignment[idx] == donor] + donor_frames.sort(key=lambda idx: (-loads[idx], idx)) + if not donor_frames: + continue + + for receiver in receivers: + frame_idx = _pick_spill_frame( + donor_frames, + loads, + load_per_rank[donor], + load_per_rank[receiver], + target_load, + ) + if frame_idx is None: + continue + + assignment[frame_idx] = receiver + load_per_rank[donor] -= loads[frame_idx] + load_per_rank[receiver] += loads[frame_idx] + moved = True + break + + if moved: + break + + if not moved: + break + + return assignment, load_per_rank + + def frames_for_rank(assignment: Sequence[int], rank_idx: int) -> List[int]: """Return the global frame indices assigned to ``rank_idx``.