diff --git a/src/lmms_engine/models/qwen3_5/qwen3_5_vit_ops.py b/src/lmms_engine/models/qwen3_5/qwen3_5_vit_ops.py index a17ac56b..5820e9c2 100644 --- a/src/lmms_engine/models/qwen3_5/qwen3_5_vit_ops.py +++ b/src/lmms_engine/models/qwen3_5/qwen3_5_vit_ops.py @@ -18,22 +18,22 @@ ---------------------------------- When SP is on, the dataloader still shards by ``dp_rank`` only, so the ``cp_rank`` axis sees the *same* frames duplicated. To actually cut ViT -memory under SP we run frame-balancing on the flat ``dp_cp_group`` (size -= dp_size × cp_size), but only let ``cp_rank == 0`` contribute frames so -the LPT load equals the real (de-duplicated) frame count. After the ViT -forward, features destined for a given dp_rank flow back to its ``cp_rank -== 0`` worker via the reverse all_to_all, then a CP-group all_reduce -(autograd-aware) broadcasts them to ``cp_rank > 0`` so every rank can do -its ``masked_scatter`` *before* the SP layer slices the seq. +memory under SP while keeping the autograd graph symmetric across CP ranks, +each CP rank first takes a deterministic shard of its duplicated local frames +(roughly ``num_frames / cp_size``). We then run the usual LPT balancing over +the flat ``dp_cp_group`` (size = dp_size × cp_size). After the ViT forward, +features flow back to the CP rank that owned each local frame shard; a +CP-group autograd-aware all-gather reconstructs the full local-dp feature set +on every CP rank so each rank can do its ``masked_scatter`` before the SP layer +slices the seq. Communication: * Metadata (per-rank token / frame counts) goes through ``all_gather_object``. * ``hidden_states`` uses ``all_to_all_single_autograd`` so gradients route back to the originating rank. * ``grid_thw`` uses plain ``all_to_all_single`` (no grad needed). - * Optional CP broadcast uses autograd-aware ``all_reduce`` (sum), which - back-props as another sum — equivalent to broadcasting forward and - summing gradients on the source rank. + * Optional CP gather uses autograd-aware ``all_gather_tensor_autograd``; + gradients route back to the CP rank that owned each frame shard. """ from typing import Any, Dict, Optional, Tuple @@ -41,7 +41,7 @@ import torch import torch.distributed as dist from torch.distributed._functional_collectives import ( - all_reduce, + all_gather_tensor_autograd, all_to_all_single, all_to_all_single_autograd, ) @@ -54,6 +54,39 @@ def _patches_per_row(grid_thw: torch.Tensor) -> torch.Tensor: return grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2] +def _cp_frame_range(num_frames: int, cp_rank: int, cp_size: int) -> Tuple[int, int]: + """Contiguous frame shard for this CP rank. + + CP ranks hold duplicated dataloader frames. Sharding those frames before + the dp×cp LPT removes the old source/receiver asymmetry where cp_rank==0 + sent real frames and cp_rank>0 sent zeros. + """ + per_rank = (num_frames + cp_size - 1) // cp_size + start = min(cp_rank * per_rank, num_frames) + end = min(start + per_rank, num_frames) + return start, end + + +def _all_gather_variable_dim0(x: torch.Tensor, group: dist.ProcessGroup) -> torch.Tensor: + """Autograd-aware all-gather for variable dim-0 tensors. + + Pads local tensors to the CP-group max length, gathers with autograd, then + removes padding and concatenates ranks in group order. + """ + world_size = dist.get_world_size(group=group) + local_len = x.shape[0] + lengths = [None for _ in range(world_size)] + dist.all_gather_object(lengths, local_len, group=group) + max_len = max(lengths) + if local_len < max_len: + pad_shape = list(x.shape) + pad_shape[0] = max_len - local_len + x = torch.cat([x, x.new_zeros(pad_shape)], dim=0) + gathered = all_gather_tensor_autograd(x, gather_dim=0, group=group) + chunks = gathered.split(max_len, dim=0) + return torch.cat([chunk[:length] for chunk, length in zip(chunks, lengths)], dim=0) + + def input_dispatch( self, hidden_states: torch.Tensor, @@ -65,10 +98,9 @@ def input_dispatch( ) -> Tuple[Tuple, Dict[str, Any], Dict[str, Any]]: """Dispatch frames across ``group`` ahead of the ViT forward. - When ``cp_group`` is provided (SP enabled), only ``cp_rank == 0`` workers - contribute their frames to the LPT pool; ``cp_rank > 0`` workers report - zero frames so they are pure receivers. This matches the dataloader's - dp-only sharding (cp ranks within the same dp_rank hold duplicated input). + When ``cp_group`` is provided (SP enabled), cp ranks hold duplicated + dataloader frames. Each cp rank contributes a deterministic local shard of + those frames before the dp×cp LPT, so all cp ranks have a real source path. Returns ``(new_args, new_kwargs, ctx)`` for ``wrap_vit_forward``. """ @@ -76,19 +108,18 @@ def input_dispatch( my_rank = dist.get_rank(group=group) device = hidden_states.device - # Determine whether this rank is the "source" (contributes frames) within - # its cp group. cp_rank > 0 holds duplicated input, so it should not push - # frames into the LPT pool. cp_rank = dist.get_rank(group=cp_group) if cp_group is not None else 0 - is_source = cp_rank == 0 + cp_size = dist.get_world_size(group=cp_group) if cp_group is not None else 1 + + frame_start, frame_end = _cp_frame_range(grid_thw.shape[0], cp_rank, cp_size) + local_grid_thw = grid_thw[frame_start:frame_end].contiguous() + patch_start = 0 if frame_start == 0 else int(_patches_per_row(grid_thw[:frame_start]).sum().item()) + local_num_patches = _patches_per_row(local_grid_thw).sum().item() if local_grid_thw.numel() > 0 else 0 + local_hidden_states = hidden_states[patch_start : patch_start + local_num_patches].contiguous() # ---- 1) gather per-rank token/frame counts ---- - if is_source: - num_tokens = grid_thw.prod(-1).tolist() - num_frames = grid_thw.shape[0] - else: - num_tokens = [] - num_frames = 0 + num_tokens = local_grid_thw.prod(-1).tolist() + num_frames = local_grid_thw.shape[0] total_tokens = [None for _ in range(world_size)] total_frames = [None for _ in range(world_size)] dist.all_gather_object(total_tokens, num_tokens, group=group) @@ -106,10 +137,9 @@ def input_dispatch( input_splits = [0] * world_size # tokens I send to each dst input_frames = [0] * world_size # frames I send to each dst - if is_source: - for tokens, dst in zip(num_tokens, my_assignment): - input_splits[dst] += tokens - input_frames[dst] += 1 + for tokens, dst in zip(num_tokens, my_assignment): + input_splits[dst] += tokens + input_frames[dst] += 1 # ---- 4) src-view output splits (what I receive from each src) ---- output_splits = [0] * world_size # tokens I receive from each src @@ -125,20 +155,19 @@ def input_dispatch( # ---- 5) permute local tensors so frames are grouped by destination ---- # all_to_all_single splits the input row-wise in tensor order, so we must - # rearrange local frames into [dst=0 block, dst=1 block, ...] first. Only - # source ranks have real input to permute; cp_rank>0 sends an empty tensor. - if is_source and num_frames > 0: + # rearrange local frames into [dst=0 block, dst=1 block, ...] first. + if num_frames > 0: send_order = torch.argsort( torch.tensor(my_assignment, dtype=torch.long, device=device), stable=True, ) - patches_per_local = grid_thw.prod(-1) + patches_per_local = local_grid_thw.prod(-1) local_starts = torch.cat([torch.zeros(1, dtype=torch.long, device=device), patches_per_local.cumsum(0)]) patch_perm = torch.cat( [torch.arange(local_starts[i], local_starts[i + 1], device=device) for i in send_order.tolist()] ) - send_hidden = hidden_states[patch_perm].contiguous() - send_grid = grid_thw[send_order].contiguous() + send_hidden = local_hidden_states[patch_perm].contiguous() + send_grid = local_grid_thw[send_order].contiguous() else: send_order = torch.empty(0, dtype=torch.long, device=device) patches_per_local = torch.empty(0, dtype=torch.long, device=device) @@ -167,10 +196,9 @@ def input_dispatch( "input_splits": output_splits, "output_splits": input_splits, # Inverse permutation for un-shuffling features back to local-original - # frame order (only meaningful on source ranks). + # frame-shard order. "send_order": send_order, "patches_per_local": patches_per_local, - "is_source": is_source, } return (self, recv_hidden), {"grid_thw": recv_grid, **kwargs}, ctx @@ -183,14 +211,13 @@ def output_dispatch(out, ctx): ``last_hidden_state`` (patch-level) and ``pooler_output`` (merger-reduced) are shipped — the latter uses splits rescaled by the merger factor. - After the reverse all_to_all, features are laid out on the originating - ``cp_rank == 0`` worker only (since cp_rank > 0 was a pure receiver in the - forward dispatch). We then broadcast within the cp group via an - autograd-aware all_reduce so every cp rank can run ``masked_scatter`` - before the SP layer slices the seq. + After the reverse all_to_all, each CP rank holds the features for the + local frame shard it contributed. We then broadcast/sum within the cp group + so every cp rank reconstructs the full local-dp feature set and can run + ``masked_scatter`` before the SP layer slices the seq. - Source ranks (cp_rank == 0) finally undo the dst-sorted permutation so the - LLM's ``masked_scatter`` sees frames in the original local order. + Each rank first undoes the dst-sorted permutation for its local frame + shard so CP all_reduce(SUM) reconstructs the original per-dp frame order. """ in_splits = ctx["input_splits"] out_splits = ctx["output_splits"] @@ -198,7 +225,6 @@ def output_dispatch(out, ctx): cp_group: Optional[dist.ProcessGroup] = ctx["cp_group"] send_order: torch.Tensor = ctx["send_order"] patches_per_local: torch.Tensor = ctx["patches_per_local"] - is_source: bool = ctx["is_source"] device = out.last_hidden_state.device # last_hidden_state: same scale as patches, use splits as-is. @@ -210,10 +236,9 @@ def output_dispatch(out, ctx): ) # pooler_output: patches // spatial_merge_size**2, infer scale from tensor. - # Source ranks (cp_rank == 0) always have real patches; non-source ranks - # have zero-sized inputs and outputs everywhere, so any ``scale`` works - # (all the all_to_all splits below are 0). The cp broadcast at the bottom - # restores the real shape. + # Ranks with an empty local shard have zero-sized inputs and outputs, so + # any ``scale`` works (all the all_to_all splits below are 0). The cp + # broadcast at the bottom restores the full local-dp shape. n_tokens = out.pooler_output.shape[0] n_patches = sum(in_splits) if n_tokens > 0 and n_patches > 0: @@ -231,9 +256,9 @@ def output_dispatch(out, ctx): group=group, ) - # ---- unpermute back to local-original frame order (source ranks only) ---- + # ---- unpermute back to local frame-shard order ---- n_local = send_order.numel() - if is_source and n_local > 0: + if n_local > 0: # Inverse permutation on frame index. inv_order = torch.empty_like(send_order) inv_order[send_order] = torch.arange(n_local, device=device) @@ -255,38 +280,14 @@ def output_dispatch(out, ctx): ) pooler = pooler[pool_perm] - # ---- CP broadcast: source rank holds the real features, others hold - # zero-sized tensors. We need every cp rank to end up with an identical - # full-shape copy so ``masked_scatter`` works before SP slicing. + # ---- CP gather+broadcast: each cp rank holds a disjoint contiguous frame + # shard from the same dp sample. Gather shards in cp-rank order so every cp + # rank reconstructs the same full local-dp feature sequence before SP + # slicing. This keeps all cp ranks as real sources in autograd (no + # cp_rank==0-only source path). if cp_group is not None and dist.get_world_size(group=cp_group) > 1: - cp_size = dist.get_world_size(group=cp_group) - - # Agree on the canonical shape across the cp group (source rank knows - # it; others have shape[0]==0). We pick the max along dim 0. - local_shape = torch.tensor([last_hidden.shape[0], pooler.shape[0]], dtype=torch.long, device=device) - max_shape = local_shape.clone() - dist.all_reduce(max_shape, op=dist.ReduceOp.MAX, group=cp_group) - n_last, n_pool = int(max_shape[0].item()), int(max_shape[1].item()) - - hidden_dim = last_hidden.shape[1] if last_hidden.shape[0] > 0 else None - pooler_dim = pooler.shape[1] if pooler.shape[0] > 0 else None - # Hidden_dim must agree across cp ranks: gather it via max as well. - dim_tensor = torch.tensor([hidden_dim or 0, pooler_dim or 0], dtype=torch.long, device=device) - dist.all_reduce(dim_tensor, op=dist.ReduceOp.MAX, group=cp_group) - hidden_dim = int(dim_tensor[0].item()) - pooler_dim = int(dim_tensor[1].item()) - - # Pad non-source ranks up to (n_last, hidden_dim) / (n_pool, pooler_dim) - # with zeros so all_reduce(SUM) reproduces the source values. - if last_hidden.shape[0] != n_last: - last_hidden = last_hidden.new_zeros((n_last, hidden_dim)) - if pooler.shape[0] != n_pool: - pooler = pooler.new_zeros((n_pool, pooler_dim)) - - # autograd-aware: backward through all_reduce(SUM) is itself a SUM, - # which on the source rank receives the summed grads from all cp ranks. - last_hidden = all_reduce(last_hidden, reduceOp="sum", group=cp_group) - pooler = all_reduce(pooler, reduceOp="sum", group=cp_group) + last_hidden = _all_gather_variable_dim0(last_hidden, cp_group) + pooler = _all_gather_variable_dim0(pooler, cp_group) out.last_hidden_state = last_hidden out.pooler_output = pooler