Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,24 +86,25 @@ python -m lmms_engine.launch.cli config_yaml=examples/qwen3_vl/example_config.ya

## 🔥 Featured Examples

| Model | Quick Start | FSDP2 | USP | Muon | Liger | Packing | NSA | EP | Highlights |
|-------|-------------|-------|-----|------|-------|---------|-----|----|------------------|
| **[BAGEL](src/lmms_engine/models/bagel)** | [run.sh](examples/bagel/run.sh) | ✅ | TBD | ✅ | ❌ | ✅ | ✅ | ❌ | Unified visual understanding & generation |
| **[Qwen2.5](src/lmms_engine/models/qwen2)** | [run.sh](examples/qwen2_5_llm/run.sh) | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | Large Language Model |
| **[Qwen2.5-VL](src/lmms_engine/models/qwen2_5_vl/)** | [run.sh](examples/qwen2_5_vl/run.sh) | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | Multimodal Model |
| **[Qwen2.5-Omni](examples/qwen2_5_omni)** | [run.sh](examples/qwen2_5_omni/run.sh) | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | Unified multimodal (image, audio, text) |
| **[Qwen3-VL](examples/qwen3_vl)** | [run.sh](examples/qwen3_vl/run.sh) | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | Native-resolution, long context (10K+ tokens) |
| **[Qwen3-VL MoE](examples/qwen3_vl_moe)** | [run.sh](examples/qwen3_vl_moe/run.sh) | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | Vision-Language MoE with EP (image, video, text) |
| **[Qwen3-MoE](examples/qwen3_moe)** | [run.sh](examples/qwen3_moe/run.sh) | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | Mixture-of-Experts, Expert Parallelism |
| **[Qwen3-Omni MoE](examples/qwen3_omni_moe)** | [config](examples/qwen3_omni_moe_ep2.yaml) | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | Multimodal MoE with EP (image, audio, text) |
| **[WanVideo](examples/wanvideo)** | [run.sh](examples/wanvideo/run.sh) | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | T2V/I2V/V2V generation (1.3B/14B) |
| **[FLA models](examples/dgn)** | [run.sh](examples/dgn/run.sh) | ✅ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | Efficient architecture, FineWeb-Edu pretraining |
| **[dLLM (Qwen3)](examples/diffusion_language_model)** | [run.sh](examples/diffusion_language_model/run.sh) | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | Masked diffusion language model |
| **[RAE-SigLip](examples/representation_autoencoder)** | [run.sh](examples/representation_autoencoder/run.sh) | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | Representation AutoEncoder, LPIPS, EMA |
| **[SiT](examples/scalable_interpolant_transformer)** | [run.sh](examples/scalable_interpolant_transformer/run.sh) | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | Interpolant Transformer, CFG, ImageNet-1K |
| Model | Quick Start | FSDP2 | TP | USP | Muon | Liger | Packing | NSA | EP | Highlights |
|-------|-------------|-------|----|-----|------|-------|---------|-----|----|------------------|
| **[BAGEL](src/lmms_engine/models/bagel)** | [run.sh](examples/bagel/run.sh) | ✅ | ❌ | TBD | ✅ | ❌ | ✅ | ✅ | ❌ | Unified visual understanding & generation |
| **[Qwen2.5](src/lmms_engine/models/qwen2)** | [run.sh](examples/qwen2_5_llm/run.sh) | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | Large Language Model |
| **[Qwen2.5-VL](src/lmms_engine/models/qwen2_5_vl/)** | [run.sh](examples/qwen2_5_vl/run.sh) | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | Multimodal Model |
| **[Qwen2.5-Omni](examples/qwen2_5_omni)** | [run.sh](examples/qwen2_5_omni/run.sh) | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | Unified multimodal (image, audio, text) |
| **[Qwen3-VL](examples/qwen3_vl)** | [run.sh](examples/qwen3_vl/run.sh) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | Native-resolution, long context (10K+ tokens) |
| **[Qwen3-VL MoE](examples/qwen3_vl_moe)** | [run.sh](examples/qwen3_vl_moe/run.sh) | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | Vision-Language MoE with EP (image, video, text) |
| **[Qwen3-MoE](examples/qwen3_moe)** | [run.sh](examples/qwen3_moe/run.sh) | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | Mixture-of-Experts, Expert Parallelism |
| **[Qwen3-Omni MoE](examples/qwen3_omni_moe)** | [config](examples/qwen3_omni_moe_ep2.yaml) | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | Multimodal MoE with EP (image, audio, text) |
| **[WanVideo](examples/wanvideo)** | [run.sh](examples/wanvideo/run.sh) | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | T2V/I2V/V2V generation (1.3B/14B) |
| **[FLA models](examples/dgn)** | [run.sh](examples/dgn/run.sh) | ✅ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | Efficient architecture, FineWeb-Edu pretraining |
| **[dLLM (Qwen3)](examples/diffusion_language_model)** | [run.sh](examples/diffusion_language_model/run.sh) | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | Masked diffusion language model |
| **[RAE-SigLip](examples/representation_autoencoder)** | [run.sh](examples/representation_autoencoder/run.sh) | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | Representation AutoEncoder, LPIPS, EMA |
| **[SiT](examples/scalable_interpolant_transformer)** | [run.sh](examples/scalable_interpolant_transformer/run.sh) | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | Interpolant Transformer, CFG, ImageNet-1K |

**Optimization Legend:**
- **FSDP2**: Fully Sharded Data Parallel v2 for distributed training
- **TP**: Tensor Parallelism for sharding model compute across GPUs
- **USP**: Ulysses Sequence Parallel for long contexts
- **Muon**: Advanced optimizer with Newton-Schulz orthogonalization
- **Liger**: Triton fused kernels (CrossEntropy, RMSNorm, RoPE, SwiGLU) for 30% memory reduction
Expand Down Expand Up @@ -148,7 +149,7 @@ Production-grade efficiency from distributed training to kernel fusion.

- **Ulysses Sequence Parallel** - Splits sequence dimension across GPUs for ultra-long contexts. Critical for vision-language models like Qwen3-VL with 10K+ visual tokens.

- **Multi-dimensional Parallelism** - Compose TP x PP × DP meshes for cluster-scale training.
- **Multi-dimensional Parallelism** - Compose TP × Ulysses SP/CP × DP meshes for cluster-scale training.

### Memory & Compute Optimizations

Expand Down
38 changes: 36 additions & 2 deletions docs/models/qwenvl.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Qwen-VL models are state-of-the-art multimodal models that support image and vid
- **Unique Feature**: DeepStack - multi-layer visual embeddings fused into early language model layers
- **Modalities**: Image and Video understanding (optimized for long videos)
- **Context Length**: 256K tokens (native), extendable to 1M tokens
- **Key Features**: DeepStack fusion, Interleaved 3D M-RoPE, Long video support (>1 hour), Flash Attention 2, Sequence Parallelism
- **Key Features**: DeepStack fusion, Interleaved 3D M-RoPE, Long video support (>1 hour), Flash Attention 2, Sequence Parallelism, text-decoder Tensor Parallelism

## Prerequisites

Expand Down Expand Up @@ -217,6 +217,7 @@ Create a YAML configuration file for your model.

# Optional: Sequence parallelism
sp_ulysses_degree: 1
tp_degree: 1 # Set to 2, 4, 8 for Qwen3-VL text decoder TP
```

## Key Configuration Parameters
Expand Down Expand Up @@ -274,6 +275,40 @@ trainer_args:
- `use_rmpad: true` recommended
- Number of attention heads must be divisible by `sp_ulysses_degree`

### Tensor Parallelism

Plain Qwen3-VL supports Tensor Parallelism (TP) for the text decoder. TP shards
the attention projections (`q_proj`, `k_proj`, `v_proj`, `o_proj`) and MLP
projections (`gate_proj`, `up_proj`, `down_proj`) across the TP mesh. The vision
tower remains FSDP2-sharded and can still use ViT frame parallelism when enabled.

```yaml
trainer_args:
tp_degree: 2 # Text decoder tensor parallel degree
sp_ulysses_degree: 1 # Can be >1 when head divisibility requirements are met
```

Launch with a Hydra override:

```bash
torchrun --nproc_per_node=8 -m lmms_engine.launch.cli \
--config-path examples/qwen3_vl \
--config-name example_config \
trainer_args.tp_degree=2
```

Requirements and notes:

- `world_size` must be divisible by `tp_degree * sp_ulysses_degree`.
- `hidden_size`, `intermediate_size`, `num_attention_heads`, and
`num_key_value_heads` must be divisible by `tp_degree`.
- When `sp_ulysses_degree > 1`, `num_attention_heads / tp_degree` must also be
divisible by `sp_ulysses_degree`.
- `ep_degree > 1` is not used for plain Qwen3-VL; use Qwen3-VL MoE configs for
expert parallelism.
- `lm_head`, vocabulary loss parallelism, and PyTorch DTensor SequenceParallel
are not enabled by this TP path.

### Liger Kernel

[Liger Kernel](https://github.com/linkedin/Liger-Kernel) provides fused kernels for efficient training:
Expand Down Expand Up @@ -602,4 +637,3 @@ trainer_args:
### Community Resources
- [LMMS Engine GitHub](https://github.com/EvolvingLMMs-Lab/lmms-engine)
- [Qwen GitHub](https://github.com/QwenLM/Qwen2-VL)

1 change: 1 addition & 0 deletions examples/qwen3_vl/example_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ trainer_args:
use_rmpad: true
fsdp2: true
sp_ulysses_degree: 1
tp_degree: 1 # Set to 2, 4, 8 for Qwen3-VL text decoder Tensor Parallelism
reduce_dtype: bfloat16
output_dtype: bfloat16
print_batch_input_steps: 5
Expand Down
1 change: 1 addition & 0 deletions examples/qwen3_vl/vit_frame_parallel_sp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ trainer_args:
reshard_after_forward: false
min_num_params: 0
sp_ulysses_degree: 2
tp_degree: 1 # Can be combined with SP; world_size must be divisible by sp_ulysses_degree * tp_degree
reduce_dtype: bfloat16
output_dtype: bfloat16
optim: adamw_torch_fused
Expand Down
20 changes: 17 additions & 3 deletions src/lmms_engine/launch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,19 @@ def create_train_task(config):

trainer_args = config.get("trainer_args")
sp_degree = trainer_args.get("sp_ulysses_degree", 1)
tp_degree = trainer_args.get("tp_degree", 1)
ep_degree = trainer_args.get("ep_degree", 1)
if tp_degree < 1:
raise ValueError(f"tp_degree must be >= 1, got {tp_degree}")
if world_size % (sp_degree * tp_degree) != 0:
raise ValueError(
f"World size ({world_size}) must be divisible by "
f"sp_ulysses_degree ({sp_degree}) * tp_degree ({tp_degree})"
)
# DP size actually will not be affected by ep_degree, but kept for initialization here
dp_size = world_size // sp_degree
dp_size = world_size // (sp_degree * tp_degree)

# For now, we haven't implement the tp and pp
# For now, we haven't implemented pp.
use_cpu = trainer_args.get("use_cpu", False)
backend = "gloo" if use_cpu else "nccl"
# If the process group is already initialized, don't initialize it again
Expand All @@ -80,7 +88,13 @@ def create_train_task(config):
init_method=f"env://",
timeout=datetime.timedelta(seconds=ddp_timeout),
)
setup_process_group_manager(tp_size=1, cp_size=sp_degree, pp_size=1, dp_size=dp_size, ep_size=ep_degree)
setup_process_group_manager(
tp_size=tp_degree,
cp_size=sp_degree,
pp_size=1,
dp_size=dp_size,
ep_size=ep_degree,
)

trainer_args = config.pop("trainer_args")

Expand Down
1 change: 1 addition & 0 deletions src/lmms_engine/launch/config/default_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ trainer_args:
profiler_config: null
ep_degree: 1
sp_ulysses_degree: 1
tp_degree: 1
ema_enabled: false
ema_decay: 0.9999
ema_update_every: 1
Expand Down
2 changes: 2 additions & 0 deletions src/lmms_engine/parallel/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
from .qwen3_5_moe.parallelize import apply_qwen3_5_moe_parallelize_fn
from .qwen3_moe.parallelize import apply_qwen3_moe_parallelize_fn
from .qwen3_omni_moe.parallelize import apply_qwen3_omni_moe_parallelize_fn
from .qwen3_vl.parallelize import apply_qwen3_vl_parallelize_fn
from .qwen3_vl_moe.parallelize import apply_qwen3_vl_moe_parallelize_fn

MODEL_TO_PARALLEL_METHOD = {
"qwen3_moe": apply_qwen3_moe_parallelize_fn,
"qwen3_5_moe": apply_qwen3_5_moe_parallelize_fn,
"qwen3_omni_moe": apply_qwen3_omni_moe_parallelize_fn,
"qwen3_omni_moe_thinker": apply_qwen3_omni_moe_parallelize_fn,
"qwen3_vl": apply_qwen3_vl_parallelize_fn,
"qwen3_vl_moe": apply_qwen3_vl_moe_parallelize_fn,
}

Expand Down
11 changes: 11 additions & 0 deletions src/lmms_engine/parallel/qwen3_vl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .parallelize import (
apply_qwen3_vl_fsdp2,
apply_qwen3_vl_parallel,
apply_qwen3_vl_parallelize_fn,
)

__all__ = [
"apply_qwen3_vl_fsdp2",
"apply_qwen3_vl_parallel",
"apply_qwen3_vl_parallelize_fn",
]
125 changes: 125 additions & 0 deletions src/lmms_engine/parallel/qwen3_vl/parallelize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from typing import TYPE_CHECKING

import torch
from loguru import logger
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)
from transformers.models.qwen3_vl.modeling_qwen3_vl import (
Qwen3VLForConditionalGeneration,
)

import lmms_engine.parallel.process_group_manager as pgm
from lmms_engine.utils.fsdp2_utils import fsdp2_load_full_state_dict

if TYPE_CHECKING:
from lmms_engine.train.config import TrainingArguments


def _check_divisible(name: str, value: int, degree: int) -> None:
if value % degree != 0:
raise ValueError(f"{name} ({value}) must be divisible by tp_degree ({degree})")


def _validate_qwen3_vl_tp_config(model: Qwen3VLForConditionalGeneration, train_args: "TrainingArguments") -> None:
tp_degree = pgm.process_group_manager.tp_world_size
sp_degree = pgm.process_group_manager.cp_world_size

if tp_degree < 1:
raise ValueError(f"tp_degree must be >= 1, got {tp_degree}")
if train_args.ep_degree > 1:
raise ValueError("ep_degree > 1 is not supported for plain qwen3_vl")
if tp_degree == 1:
return

text_config = model.config.text_config
_check_divisible("hidden_size", text_config.hidden_size, tp_degree)
_check_divisible("intermediate_size", text_config.intermediate_size, tp_degree)
_check_divisible("num_attention_heads", text_config.num_attention_heads, tp_degree)
_check_divisible("num_key_value_heads", text_config.num_key_value_heads, tp_degree)

local_attention_heads = text_config.num_attention_heads // tp_degree
if sp_degree > 1 and local_attention_heads % sp_degree != 0:
raise ValueError(
f"num_attention_heads / tp_degree ({local_attention_heads}) must be divisible by "
f"sp_ulysses_degree ({sp_degree})"
)


def apply_qwen3_vl_parallel(
model: Qwen3VLForConditionalGeneration,
tp_mesh: DeviceMesh,
**kwargs,
) -> None:
tp_plan = {
"self_attn.q_proj": ColwiseParallel(use_local_output=True),
"self_attn.k_proj": ColwiseParallel(use_local_output=True),
"self_attn.v_proj": ColwiseParallel(use_local_output=True),
"self_attn.o_proj": RowwiseParallel(use_local_output=True),
"mlp.gate_proj": ColwiseParallel(use_local_output=True),
"mlp.up_proj": ColwiseParallel(use_local_output=True),
"mlp.down_proj": RowwiseParallel(use_local_output=True),
}

for decoder_layer in model.model.language_model.layers:
parallelize_module(decoder_layer, device_mesh=tp_mesh, parallelize_plan=tp_plan)

logger.info(f"Applied Qwen3-VL text TP to {len(model.model.language_model.layers)} decoder layers")


def apply_qwen3_vl_fsdp2(
model: Qwen3VLForConditionalGeneration,
train_args: "TrainingArguments",
**kwargs,
) -> None:
if not train_args.fsdp_config.get("transformer_layer_cls_to_wrap", None):
logger.warning("transformer_layer_cls_to_wrap ignored; qwen3_vl wraps modules explicitly.")

param_dtype = torch.bfloat16 if train_args.bf16 else torch.float16

if train_args.gradient_checkpointing:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})

reduce_dtype = getattr(torch, train_args.reduce_dtype)
output_dtype = getattr(torch, train_args.output_dtype)
mp_policy = MixedPrecisionPolicy(
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
output_dtype=output_dtype,
)

fsdp_kwargs = {
"reshard_after_forward": getattr(train_args, "fsdp_config", {}).get("reshard_after_forward", True),
"mp_policy": mp_policy,
"mesh": pgm.process_group_manager.device_mesh["fsdp"],
}

if hasattr(model.model, "visual") and model.model.visual is not None:
fully_shard(model.model.visual, **fsdp_kwargs)

for decoder_layer in model.model.language_model.layers:
fully_shard(decoder_layer.self_attn, **fsdp_kwargs)
fully_shard(decoder_layer.mlp, **fsdp_kwargs)

fully_shard(model.model.language_model.embed_tokens, **fsdp_kwargs)
fully_shard(model, **fsdp_kwargs)


def apply_qwen3_vl_parallelize_fn(
model: Qwen3VLForConditionalGeneration,
train_args: "TrainingArguments",
**kwargs,
) -> None:
_validate_qwen3_vl_tp_config(model, train_args)

full_state_dict = model.state_dict()
if pgm.process_group_manager.tp_world_size > 1:
tp_mesh = pgm.process_group_manager.device_mesh["tp"]
apply_qwen3_vl_parallel(model, tp_mesh=tp_mesh, **kwargs)

apply_qwen3_vl_fsdp2(model, train_args, **kwargs)
fsdp2_load_full_state_dict(model, full_state_dict)
1 change: 1 addition & 0 deletions src/lmms_engine/train/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class TrainingArguments(transformers.TrainingArguments):
# Parallelism
ep_degree: Optional[int] = 1
sp_ulysses_degree: Optional[int] = 1
tp_degree: Optional[int] = 1

# --- EMA (Exponential Moving Average) ---
ema_enabled: Optional[bool] = False
Expand Down
16 changes: 9 additions & 7 deletions src/lmms_engine/train/fsdp2/fsdp2_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,11 +392,13 @@ def train(self, resume_from_checkpoint: bool = False):
device = self.fsdp2_model.device
flops_tensor = torch.tensor(flops, device=device)
sp_size = pgm.process_group_manager.cp_world_size
tp_size = pgm.process_group_manager.tp_world_size
parallel_size = sp_size * tp_size

# Calculate training metrics (MFU, token stats, throughput)
perf_metrics, self.total_tokens = self.calculate_training_metrics(
flops_tensor=flops_tensor,
sp_size=sp_size,
parallel_size=parallel_size,
promised_flops=promised_flops,
device=device,
seq_len=seq_len,
Expand Down Expand Up @@ -633,7 +635,7 @@ def print_batch_input(self, batch):
@staticmethod
def calculate_training_metrics(
flops_tensor: torch.Tensor,
sp_size: int,
parallel_size: int,
promised_flops: float,
device: torch.device,
seq_len: list,
Expand All @@ -646,7 +648,7 @@ def calculate_training_metrics(

Args:
flops_tensor: Tensor containing FLOPs count
sp_size: Sequence parallel size
parallel_size: Product of sequence and tensor parallel sizes
promised_flops: Promised FLOPs capacity
device: Device to perform computations on
seq_len: List of sequence lengths per batch
Expand All @@ -660,16 +662,16 @@ def calculate_training_metrics(
metrics = {}

# Calculate mfu per rank
# Divide by sp size because attention mask we use to calculate are unsplitted
mfu = flops_tensor.item() / sp_size / promised_flops
# Divide by parallel size because seq_len/flops are estimated before SP/TP sharding.
mfu = flops_tensor.item() / parallel_size / promised_flops
mfu = torch.tensor(mfu, device=device)
torch.distributed.all_reduce(mfu, op=torch.distributed.ReduceOp.AVG)
mfu = mfu.item()

# Calculating token stats
seq_len = torch.tensor(seq_len, device=device, dtype=torch.float32)
# Divide total seq len by sp size if sp is enabled since we split the seq len
total_seq_len = seq_len.sum() / sp_size
# Divide by parallel size to avoid counting replicated SP/TP batches multiple times.
total_seq_len = seq_len.sum() / parallel_size
torch.distributed.all_reduce(total_seq_len, op=torch.distributed.ReduceOp.SUM)
# Avg seq len won't be effected by sp since we perform all reduce
# across world size
Expand Down
Loading