Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 83 additions & 82 deletions src/lmms_engine/models/qwen3_5/qwen3_5_vit_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,30 @@
----------------------------------
When SP is on, the dataloader still shards by ``dp_rank`` only, so the
``cp_rank`` axis sees the *same* frames duplicated. To actually cut ViT
memory under SP we run frame-balancing on the flat ``dp_cp_group`` (size
= dp_size × cp_size), but only let ``cp_rank == 0`` contribute frames so
the LPT load equals the real (de-duplicated) frame count. After the ViT
forward, features destined for a given dp_rank flow back to its ``cp_rank
== 0`` worker via the reverse all_to_all, then a CP-group all_reduce
(autograd-aware) broadcasts them to ``cp_rank > 0`` so every rank can do
its ``masked_scatter`` *before* the SP layer slices the seq.
memory under SP while keeping the autograd graph symmetric across CP ranks,
each CP rank first takes a deterministic shard of its duplicated local frames
(roughly ``num_frames / cp_size``). We then run the usual LPT balancing over
the flat ``dp_cp_group`` (size = dp_size × cp_size). After the ViT forward,
features flow back to the CP rank that owned each local frame shard; a
CP-group autograd-aware all-gather reconstructs the full local-dp feature set
on every CP rank so each rank can do its ``masked_scatter`` before the SP layer
slices the seq.

Communication:
* Metadata (per-rank token / frame counts) goes through ``all_gather_object``.
* ``hidden_states`` uses ``all_to_all_single_autograd`` so gradients route
back to the originating rank.
* ``grid_thw`` uses plain ``all_to_all_single`` (no grad needed).
* Optional CP broadcast uses autograd-aware ``all_reduce`` (sum), which
back-props as another sum — equivalent to broadcasting forward and
summing gradients on the source rank.
* Optional CP gather uses autograd-aware ``all_gather_tensor_autograd``;
gradients route back to the CP rank that owned each frame shard.
"""

from typing import Any, Dict, Optional, Tuple

import torch
import torch.distributed as dist
from torch.distributed._functional_collectives import (
all_reduce,
all_gather_tensor_autograd,
all_to_all_single,
all_to_all_single_autograd,
)
Expand All @@ -54,6 +54,39 @@ def _patches_per_row(grid_thw: torch.Tensor) -> torch.Tensor:
return grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]


def _cp_frame_range(num_frames: int, cp_rank: int, cp_size: int) -> Tuple[int, int]:
"""Contiguous frame shard for this CP rank.

CP ranks hold duplicated dataloader frames. Sharding those frames before
the dp×cp LPT removes the old source/receiver asymmetry where cp_rank==0
sent real frames and cp_rank>0 sent zeros.
"""
per_rank = (num_frames + cp_size - 1) // cp_size
start = min(cp_rank * per_rank, num_frames)
end = min(start + per_rank, num_frames)
return start, end


def _all_gather_variable_dim0(x: torch.Tensor, group: dist.ProcessGroup) -> torch.Tensor:
"""Autograd-aware all-gather for variable dim-0 tensors.

Pads local tensors to the CP-group max length, gathers with autograd, then
removes padding and concatenates ranks in group order.
"""
world_size = dist.get_world_size(group=group)
local_len = x.shape[0]
lengths = [None for _ in range(world_size)]
dist.all_gather_object(lengths, local_len, group=group)
max_len = max(lengths)
if local_len < max_len:
pad_shape = list(x.shape)
pad_shape[0] = max_len - local_len
x = torch.cat([x, x.new_zeros(pad_shape)], dim=0)
gathered = all_gather_tensor_autograd(x, gather_dim=0, group=group)
chunks = gathered.split(max_len, dim=0)
return torch.cat([chunk[:length] for chunk, length in zip(chunks, lengths)], dim=0)


def input_dispatch(
self,
hidden_states: torch.Tensor,
Expand All @@ -65,30 +98,28 @@ def input_dispatch(
) -> Tuple[Tuple, Dict[str, Any], Dict[str, Any]]:
"""Dispatch frames across ``group`` ahead of the ViT forward.

When ``cp_group`` is provided (SP enabled), only ``cp_rank == 0`` workers
contribute their frames to the LPT pool; ``cp_rank > 0`` workers report
zero frames so they are pure receivers. This matches the dataloader's
dp-only sharding (cp ranks within the same dp_rank hold duplicated input).
When ``cp_group`` is provided (SP enabled), cp ranks hold duplicated
dataloader frames. Each cp rank contributes a deterministic local shard of
those frames before the dp×cp LPT, so all cp ranks have a real source path.

Returns ``(new_args, new_kwargs, ctx)`` for ``wrap_vit_forward``.
"""
world_size = dist.get_world_size(group=group)
my_rank = dist.get_rank(group=group)
device = hidden_states.device

# Determine whether this rank is the "source" (contributes frames) within
# its cp group. cp_rank > 0 holds duplicated input, so it should not push
# frames into the LPT pool.
cp_rank = dist.get_rank(group=cp_group) if cp_group is not None else 0
is_source = cp_rank == 0
cp_size = dist.get_world_size(group=cp_group) if cp_group is not None else 1

frame_start, frame_end = _cp_frame_range(grid_thw.shape[0], cp_rank, cp_size)
local_grid_thw = grid_thw[frame_start:frame_end].contiguous()
patch_start = 0 if frame_start == 0 else int(_patches_per_row(grid_thw[:frame_start]).sum().item())
local_num_patches = _patches_per_row(local_grid_thw).sum().item() if local_grid_thw.numel() > 0 else 0
local_hidden_states = hidden_states[patch_start : patch_start + local_num_patches].contiguous()

# ---- 1) gather per-rank token/frame counts ----
if is_source:
num_tokens = grid_thw.prod(-1).tolist()
num_frames = grid_thw.shape[0]
else:
num_tokens = []
num_frames = 0
num_tokens = local_grid_thw.prod(-1).tolist()
num_frames = local_grid_thw.shape[0]
total_tokens = [None for _ in range(world_size)]
total_frames = [None for _ in range(world_size)]
dist.all_gather_object(total_tokens, num_tokens, group=group)
Expand All @@ -106,10 +137,9 @@ def input_dispatch(

input_splits = [0] * world_size # tokens I send to each dst
input_frames = [0] * world_size # frames I send to each dst
if is_source:
for tokens, dst in zip(num_tokens, my_assignment):
input_splits[dst] += tokens
input_frames[dst] += 1
for tokens, dst in zip(num_tokens, my_assignment):
input_splits[dst] += tokens
input_frames[dst] += 1

# ---- 4) src-view output splits (what I receive from each src) ----
output_splits = [0] * world_size # tokens I receive from each src
Expand All @@ -125,20 +155,19 @@ def input_dispatch(

# ---- 5) permute local tensors so frames are grouped by destination ----
# all_to_all_single splits the input row-wise in tensor order, so we must
# rearrange local frames into [dst=0 block, dst=1 block, ...] first. Only
# source ranks have real input to permute; cp_rank>0 sends an empty tensor.
if is_source and num_frames > 0:
# rearrange local frames into [dst=0 block, dst=1 block, ...] first.
if num_frames > 0:
send_order = torch.argsort(
torch.tensor(my_assignment, dtype=torch.long, device=device),
stable=True,
)
patches_per_local = grid_thw.prod(-1)
patches_per_local = local_grid_thw.prod(-1)
local_starts = torch.cat([torch.zeros(1, dtype=torch.long, device=device), patches_per_local.cumsum(0)])
patch_perm = torch.cat(
[torch.arange(local_starts[i], local_starts[i + 1], device=device) for i in send_order.tolist()]
)
send_hidden = hidden_states[patch_perm].contiguous()
send_grid = grid_thw[send_order].contiguous()
send_hidden = local_hidden_states[patch_perm].contiguous()
send_grid = local_grid_thw[send_order].contiguous()
else:
send_order = torch.empty(0, dtype=torch.long, device=device)
patches_per_local = torch.empty(0, dtype=torch.long, device=device)
Expand Down Expand Up @@ -167,10 +196,9 @@ def input_dispatch(
"input_splits": output_splits,
"output_splits": input_splits,
# Inverse permutation for un-shuffling features back to local-original
# frame order (only meaningful on source ranks).
# frame-shard order.
"send_order": send_order,
"patches_per_local": patches_per_local,
"is_source": is_source,
}
return (self, recv_hidden), {"grid_thw": recv_grid, **kwargs}, ctx

Expand All @@ -183,22 +211,20 @@ def output_dispatch(out, ctx):
``last_hidden_state`` (patch-level) and ``pooler_output`` (merger-reduced)
are shipped — the latter uses splits rescaled by the merger factor.

After the reverse all_to_all, features are laid out on the originating
``cp_rank == 0`` worker only (since cp_rank > 0 was a pure receiver in the
forward dispatch). We then broadcast within the cp group via an
autograd-aware all_reduce so every cp rank can run ``masked_scatter``
before the SP layer slices the seq.
After the reverse all_to_all, each CP rank holds the features for the
local frame shard it contributed. We then broadcast/sum within the cp group
so every cp rank reconstructs the full local-dp feature set and can run
``masked_scatter`` before the SP layer slices the seq.

Source ranks (cp_rank == 0) finally undo the dst-sorted permutation so the
LLM's ``masked_scatter`` sees frames in the original local order.
Each rank first undoes the dst-sorted permutation for its local frame
shard so CP all_reduce(SUM) reconstructs the original per-dp frame order.
"""
in_splits = ctx["input_splits"]
out_splits = ctx["output_splits"]
group = ctx["group"]
cp_group: Optional[dist.ProcessGroup] = ctx["cp_group"]
send_order: torch.Tensor = ctx["send_order"]
patches_per_local: torch.Tensor = ctx["patches_per_local"]
is_source: bool = ctx["is_source"]
device = out.last_hidden_state.device

# last_hidden_state: same scale as patches, use splits as-is.
Expand All @@ -210,10 +236,9 @@ def output_dispatch(out, ctx):
)

# pooler_output: patches // spatial_merge_size**2, infer scale from tensor.
# Source ranks (cp_rank == 0) always have real patches; non-source ranks
# have zero-sized inputs and outputs everywhere, so any ``scale`` works
# (all the all_to_all splits below are 0). The cp broadcast at the bottom
# restores the real shape.
# Ranks with an empty local shard have zero-sized inputs and outputs, so
# any ``scale`` works (all the all_to_all splits below are 0). The cp
# broadcast at the bottom restores the full local-dp shape.
n_tokens = out.pooler_output.shape[0]
n_patches = sum(in_splits)
if n_tokens > 0 and n_patches > 0:
Expand All @@ -231,9 +256,9 @@ def output_dispatch(out, ctx):
group=group,
)

# ---- unpermute back to local-original frame order (source ranks only) ----
# ---- unpermute back to local frame-shard order ----
n_local = send_order.numel()
if is_source and n_local > 0:
if n_local > 0:
# Inverse permutation on frame index.
inv_order = torch.empty_like(send_order)
inv_order[send_order] = torch.arange(n_local, device=device)
Expand All @@ -255,38 +280,14 @@ def output_dispatch(out, ctx):
)
pooler = pooler[pool_perm]

# ---- CP broadcast: source rank holds the real features, others hold
# zero-sized tensors. We need every cp rank to end up with an identical
# full-shape copy so ``masked_scatter`` works before SP slicing.
# ---- CP gather+broadcast: each cp rank holds a disjoint contiguous frame
# shard from the same dp sample. Gather shards in cp-rank order so every cp
# rank reconstructs the same full local-dp feature sequence before SP
# slicing. This keeps all cp ranks as real sources in autograd (no
# cp_rank==0-only source path).
if cp_group is not None and dist.get_world_size(group=cp_group) > 1:
cp_size = dist.get_world_size(group=cp_group)

# Agree on the canonical shape across the cp group (source rank knows
# it; others have shape[0]==0). We pick the max along dim 0.
local_shape = torch.tensor([last_hidden.shape[0], pooler.shape[0]], dtype=torch.long, device=device)
max_shape = local_shape.clone()
dist.all_reduce(max_shape, op=dist.ReduceOp.MAX, group=cp_group)
n_last, n_pool = int(max_shape[0].item()), int(max_shape[1].item())

hidden_dim = last_hidden.shape[1] if last_hidden.shape[0] > 0 else None
pooler_dim = pooler.shape[1] if pooler.shape[0] > 0 else None
# Hidden_dim must agree across cp ranks: gather it via max as well.
dim_tensor = torch.tensor([hidden_dim or 0, pooler_dim or 0], dtype=torch.long, device=device)
dist.all_reduce(dim_tensor, op=dist.ReduceOp.MAX, group=cp_group)
hidden_dim = int(dim_tensor[0].item())
pooler_dim = int(dim_tensor[1].item())

# Pad non-source ranks up to (n_last, hidden_dim) / (n_pool, pooler_dim)
# with zeros so all_reduce(SUM) reproduces the source values.
if last_hidden.shape[0] != n_last:
last_hidden = last_hidden.new_zeros((n_last, hidden_dim))
if pooler.shape[0] != n_pool:
pooler = pooler.new_zeros((n_pool, pooler_dim))

# autograd-aware: backward through all_reduce(SUM) is itself a SUM,
# which on the source rank receives the summed grads from all cp ranks.
last_hidden = all_reduce(last_hidden, reduceOp="sum", group=cp_group)
pooler = all_reduce(pooler, reduceOp="sum", group=cp_group)
last_hidden = _all_gather_variable_dim0(last_hidden, cp_group)
pooler = _all_gather_variable_dim0(pooler, cp_group)

out.last_hidden_state = last_hidden
out.pooler_output = pooler
Expand Down
Loading