diff --git a/README.md b/README.md index e1f3a630..401393f6 100644 --- a/README.md +++ b/README.md @@ -86,24 +86,25 @@ python -m lmms_engine.launch.cli config_yaml=examples/qwen3_vl/example_config.ya ## 🔥 Featured Examples -| Model | Quick Start | FSDP2 | USP | Muon | Liger | Packing | NSA | EP | Highlights | -|-------|-------------|-------|-----|------|-------|---------|-----|----|------------------| -| **[BAGEL](src/lmms_engine/models/bagel)** | [run.sh](examples/bagel/run.sh) | ✅ | TBD | ✅ | ❌ | ✅ | ✅ | ❌ | Unified visual understanding & generation | -| **[Qwen2.5](src/lmms_engine/models/qwen2)** | [run.sh](examples/qwen2_5_llm/run.sh) | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | Large Language Model | -| **[Qwen2.5-VL](src/lmms_engine/models/qwen2_5_vl/)** | [run.sh](examples/qwen2_5_vl/run.sh) | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | Multimodal Model | -| **[Qwen2.5-Omni](examples/qwen2_5_omni)** | [run.sh](examples/qwen2_5_omni/run.sh) | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | Unified multimodal (image, audio, text) | -| **[Qwen3-VL](examples/qwen3_vl)** | [run.sh](examples/qwen3_vl/run.sh) | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | Native-resolution, long context (10K+ tokens) | -| **[Qwen3-VL MoE](examples/qwen3_vl_moe)** | [run.sh](examples/qwen3_vl_moe/run.sh) | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | Vision-Language MoE with EP (image, video, text) | -| **[Qwen3-MoE](examples/qwen3_moe)** | [run.sh](examples/qwen3_moe/run.sh) | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | Mixture-of-Experts, Expert Parallelism | -| **[Qwen3-Omni MoE](examples/qwen3_omni_moe)** | [config](examples/qwen3_omni_moe_ep2.yaml) | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | Multimodal MoE with EP (image, audio, text) | -| **[WanVideo](examples/wanvideo)** | [run.sh](examples/wanvideo/run.sh) | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | T2V/I2V/V2V generation (1.3B/14B) | -| **[FLA models](examples/dgn)** | [run.sh](examples/dgn/run.sh) | ✅ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | Efficient architecture, FineWeb-Edu pretraining | -| **[dLLM (Qwen3)](examples/diffusion_language_model)** | [run.sh](examples/diffusion_language_model/run.sh) | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | Masked diffusion language model | -| **[RAE-SigLip](examples/representation_autoencoder)** | [run.sh](examples/representation_autoencoder/run.sh) | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | Representation AutoEncoder, LPIPS, EMA | -| **[SiT](examples/scalable_interpolant_transformer)** | [run.sh](examples/scalable_interpolant_transformer/run.sh) | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | Interpolant Transformer, CFG, ImageNet-1K | +| Model | Quick Start | FSDP2 | TP | USP | Muon | Liger | Packing | NSA | EP | Highlights | +|-------|-------------|-------|----|-----|------|-------|---------|-----|----|------------------| +| **[BAGEL](src/lmms_engine/models/bagel)** | [run.sh](examples/bagel/run.sh) | ✅ | ❌ | TBD | ✅ | ❌ | ✅ | ✅ | ❌ | Unified visual understanding & generation | +| **[Qwen2.5](src/lmms_engine/models/qwen2)** | [run.sh](examples/qwen2_5_llm/run.sh) | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | Large Language Model | +| **[Qwen2.5-VL](src/lmms_engine/models/qwen2_5_vl/)** | [run.sh](examples/qwen2_5_vl/run.sh) | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | Multimodal Model | +| **[Qwen2.5-Omni](examples/qwen2_5_omni)** | [run.sh](examples/qwen2_5_omni/run.sh) | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | Unified multimodal (image, audio, text) | +| **[Qwen3-VL](examples/qwen3_vl)** | [run.sh](examples/qwen3_vl/run.sh) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | Native-resolution, long context (10K+ tokens) | +| **[Qwen3-VL MoE](examples/qwen3_vl_moe)** | [run.sh](examples/qwen3_vl_moe/run.sh) | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | Vision-Language MoE with EP (image, video, text) | +| **[Qwen3-MoE](examples/qwen3_moe)** | [run.sh](examples/qwen3_moe/run.sh) | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | Mixture-of-Experts, Expert Parallelism | +| **[Qwen3-Omni MoE](examples/qwen3_omni_moe)** | [config](examples/qwen3_omni_moe_ep2.yaml) | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | Multimodal MoE with EP (image, audio, text) | +| **[WanVideo](examples/wanvideo)** | [run.sh](examples/wanvideo/run.sh) | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | T2V/I2V/V2V generation (1.3B/14B) | +| **[FLA models](examples/dgn)** | [run.sh](examples/dgn/run.sh) | ✅ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | Efficient architecture, FineWeb-Edu pretraining | +| **[dLLM (Qwen3)](examples/diffusion_language_model)** | [run.sh](examples/diffusion_language_model/run.sh) | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | Masked diffusion language model | +| **[RAE-SigLip](examples/representation_autoencoder)** | [run.sh](examples/representation_autoencoder/run.sh) | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | Representation AutoEncoder, LPIPS, EMA | +| **[SiT](examples/scalable_interpolant_transformer)** | [run.sh](examples/scalable_interpolant_transformer/run.sh) | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | Interpolant Transformer, CFG, ImageNet-1K | **Optimization Legend:** - **FSDP2**: Fully Sharded Data Parallel v2 for distributed training +- **TP**: Tensor Parallelism for sharding model compute across GPUs - **USP**: Ulysses Sequence Parallel for long contexts - **Muon**: Advanced optimizer with Newton-Schulz orthogonalization - **Liger**: Triton fused kernels (CrossEntropy, RMSNorm, RoPE, SwiGLU) for 30% memory reduction @@ -148,7 +149,7 @@ Production-grade efficiency from distributed training to kernel fusion. - **Ulysses Sequence Parallel** - Splits sequence dimension across GPUs for ultra-long contexts. Critical for vision-language models like Qwen3-VL with 10K+ visual tokens. -- **Multi-dimensional Parallelism** - Compose TP x PP × DP meshes for cluster-scale training. +- **Multi-dimensional Parallelism** - Compose TP × Ulysses SP/CP × DP meshes for cluster-scale training. ### Memory & Compute Optimizations diff --git a/docs/models/qwenvl.md b/docs/models/qwenvl.md index 240fcb66..b81bf558 100644 --- a/docs/models/qwenvl.md +++ b/docs/models/qwenvl.md @@ -17,7 +17,7 @@ Qwen-VL models are state-of-the-art multimodal models that support image and vid - **Unique Feature**: DeepStack - multi-layer visual embeddings fused into early language model layers - **Modalities**: Image and Video understanding (optimized for long videos) - **Context Length**: 256K tokens (native), extendable to 1M tokens -- **Key Features**: DeepStack fusion, Interleaved 3D M-RoPE, Long video support (>1 hour), Flash Attention 2, Sequence Parallelism +- **Key Features**: DeepStack fusion, Interleaved 3D M-RoPE, Long video support (>1 hour), Flash Attention 2, Sequence Parallelism, text-decoder Tensor Parallelism ## Prerequisites @@ -217,6 +217,7 @@ Create a YAML configuration file for your model. # Optional: Sequence parallelism sp_ulysses_degree: 1 + tp_degree: 1 # Set to 2, 4, 8 for Qwen3-VL text decoder TP ``` ## Key Configuration Parameters @@ -274,6 +275,40 @@ trainer_args: - `use_rmpad: true` recommended - Number of attention heads must be divisible by `sp_ulysses_degree` +### Tensor Parallelism + +Plain Qwen3-VL supports Tensor Parallelism (TP) for the text decoder. TP shards +the attention projections (`q_proj`, `k_proj`, `v_proj`, `o_proj`) and MLP +projections (`gate_proj`, `up_proj`, `down_proj`) across the TP mesh. The vision +tower remains FSDP2-sharded and can still use ViT frame parallelism when enabled. + +```yaml +trainer_args: + tp_degree: 2 # Text decoder tensor parallel degree + sp_ulysses_degree: 1 # Can be >1 when head divisibility requirements are met +``` + +Launch with a Hydra override: + +```bash +torchrun --nproc_per_node=8 -m lmms_engine.launch.cli \ + --config-path examples/qwen3_vl \ + --config-name example_config \ + trainer_args.tp_degree=2 +``` + +Requirements and notes: + +- `world_size` must be divisible by `tp_degree * sp_ulysses_degree`. +- `hidden_size`, `intermediate_size`, `num_attention_heads`, and + `num_key_value_heads` must be divisible by `tp_degree`. +- When `sp_ulysses_degree > 1`, `num_attention_heads / tp_degree` must also be + divisible by `sp_ulysses_degree`. +- `ep_degree > 1` is not used for plain Qwen3-VL; use Qwen3-VL MoE configs for + expert parallelism. +- `lm_head`, vocabulary loss parallelism, and PyTorch DTensor SequenceParallel + are not enabled by this TP path. + ### Liger Kernel [Liger Kernel](https://github.com/linkedin/Liger-Kernel) provides fused kernels for efficient training: @@ -602,4 +637,3 @@ trainer_args: ### Community Resources - [LMMS Engine GitHub](https://github.com/EvolvingLMMs-Lab/lmms-engine) - [Qwen GitHub](https://github.com/QwenLM/Qwen2-VL) - diff --git a/examples/qwen3_vl/example_config.yaml b/examples/qwen3_vl/example_config.yaml index 7c31dc3d..5a49269f 100644 --- a/examples/qwen3_vl/example_config.yaml +++ b/examples/qwen3_vl/example_config.yaml @@ -174,6 +174,7 @@ trainer_args: use_rmpad: true fsdp2: true sp_ulysses_degree: 1 + tp_degree: 1 # Set to 2, 4, 8 for Qwen3-VL text decoder Tensor Parallelism reduce_dtype: bfloat16 output_dtype: bfloat16 print_batch_input_steps: 5 diff --git a/examples/qwen3_vl/vit_frame_parallel_sp.yaml b/examples/qwen3_vl/vit_frame_parallel_sp.yaml index 86f506dd..088b2986 100644 --- a/examples/qwen3_vl/vit_frame_parallel_sp.yaml +++ b/examples/qwen3_vl/vit_frame_parallel_sp.yaml @@ -63,6 +63,7 @@ trainer_args: reshard_after_forward: false min_num_params: 0 sp_ulysses_degree: 2 + tp_degree: 1 # Can be combined with SP; world_size must be divisible by sp_ulysses_degree * tp_degree reduce_dtype: bfloat16 output_dtype: bfloat16 optim: adamw_torch_fused diff --git a/src/lmms_engine/launch/cli.py b/src/lmms_engine/launch/cli.py index f4427610..b195c30d 100644 --- a/src/lmms_engine/launch/cli.py +++ b/src/lmms_engine/launch/cli.py @@ -55,11 +55,19 @@ def create_train_task(config): trainer_args = config.get("trainer_args") sp_degree = trainer_args.get("sp_ulysses_degree", 1) + tp_degree = trainer_args.get("tp_degree", 1) ep_degree = trainer_args.get("ep_degree", 1) + if tp_degree < 1: + raise ValueError(f"tp_degree must be >= 1, got {tp_degree}") + if world_size % (sp_degree * tp_degree) != 0: + raise ValueError( + f"World size ({world_size}) must be divisible by " + f"sp_ulysses_degree ({sp_degree}) * tp_degree ({tp_degree})" + ) # DP size actually will not be affected by ep_degree, but kept for initialization here - dp_size = world_size // sp_degree + dp_size = world_size // (sp_degree * tp_degree) - # For now, we haven't implement the tp and pp + # For now, we haven't implemented pp. use_cpu = trainer_args.get("use_cpu", False) backend = "gloo" if use_cpu else "nccl" # If the process group is already initialized, don't initialize it again @@ -80,7 +88,13 @@ def create_train_task(config): init_method=f"env://", timeout=datetime.timedelta(seconds=ddp_timeout), ) - setup_process_group_manager(tp_size=1, cp_size=sp_degree, pp_size=1, dp_size=dp_size, ep_size=ep_degree) + setup_process_group_manager( + tp_size=tp_degree, + cp_size=sp_degree, + pp_size=1, + dp_size=dp_size, + ep_size=ep_degree, + ) trainer_args = config.pop("trainer_args") diff --git a/src/lmms_engine/launch/config/default_config.yaml b/src/lmms_engine/launch/config/default_config.yaml index 2f88306e..58e210ae 100644 --- a/src/lmms_engine/launch/config/default_config.yaml +++ b/src/lmms_engine/launch/config/default_config.yaml @@ -184,6 +184,7 @@ trainer_args: profiler_config: null ep_degree: 1 sp_ulysses_degree: 1 + tp_degree: 1 ema_enabled: false ema_decay: 0.9999 ema_update_every: 1 diff --git a/src/lmms_engine/parallel/parallelize.py b/src/lmms_engine/parallel/parallelize.py index a3136bb1..2affb444 100644 --- a/src/lmms_engine/parallel/parallelize.py +++ b/src/lmms_engine/parallel/parallelize.py @@ -6,6 +6,7 @@ from .qwen3_5_moe.parallelize import apply_qwen3_5_moe_parallelize_fn from .qwen3_moe.parallelize import apply_qwen3_moe_parallelize_fn from .qwen3_omni_moe.parallelize import apply_qwen3_omni_moe_parallelize_fn +from .qwen3_vl.parallelize import apply_qwen3_vl_parallelize_fn from .qwen3_vl_moe.parallelize import apply_qwen3_vl_moe_parallelize_fn MODEL_TO_PARALLEL_METHOD = { @@ -13,6 +14,7 @@ "qwen3_5_moe": apply_qwen3_5_moe_parallelize_fn, "qwen3_omni_moe": apply_qwen3_omni_moe_parallelize_fn, "qwen3_omni_moe_thinker": apply_qwen3_omni_moe_parallelize_fn, + "qwen3_vl": apply_qwen3_vl_parallelize_fn, "qwen3_vl_moe": apply_qwen3_vl_moe_parallelize_fn, } diff --git a/src/lmms_engine/parallel/qwen3_vl/__init__.py b/src/lmms_engine/parallel/qwen3_vl/__init__.py new file mode 100644 index 00000000..52e7d1e3 --- /dev/null +++ b/src/lmms_engine/parallel/qwen3_vl/__init__.py @@ -0,0 +1,11 @@ +from .parallelize import ( + apply_qwen3_vl_fsdp2, + apply_qwen3_vl_parallel, + apply_qwen3_vl_parallelize_fn, +) + +__all__ = [ + "apply_qwen3_vl_fsdp2", + "apply_qwen3_vl_parallel", + "apply_qwen3_vl_parallelize_fn", +] diff --git a/src/lmms_engine/parallel/qwen3_vl/parallelize.py b/src/lmms_engine/parallel/qwen3_vl/parallelize.py new file mode 100644 index 00000000..aeaa685a --- /dev/null +++ b/src/lmms_engine/parallel/qwen3_vl/parallelize.py @@ -0,0 +1,125 @@ +from typing import TYPE_CHECKING + +import torch +from loguru import logger +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + RowwiseParallel, + parallelize_module, +) +from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLForConditionalGeneration, +) + +import lmms_engine.parallel.process_group_manager as pgm +from lmms_engine.utils.fsdp2_utils import fsdp2_load_full_state_dict + +if TYPE_CHECKING: + from lmms_engine.train.config import TrainingArguments + + +def _check_divisible(name: str, value: int, degree: int) -> None: + if value % degree != 0: + raise ValueError(f"{name} ({value}) must be divisible by tp_degree ({degree})") + + +def _validate_qwen3_vl_tp_config(model: Qwen3VLForConditionalGeneration, train_args: "TrainingArguments") -> None: + tp_degree = pgm.process_group_manager.tp_world_size + sp_degree = pgm.process_group_manager.cp_world_size + + if tp_degree < 1: + raise ValueError(f"tp_degree must be >= 1, got {tp_degree}") + if train_args.ep_degree > 1: + raise ValueError("ep_degree > 1 is not supported for plain qwen3_vl") + if tp_degree == 1: + return + + text_config = model.config.text_config + _check_divisible("hidden_size", text_config.hidden_size, tp_degree) + _check_divisible("intermediate_size", text_config.intermediate_size, tp_degree) + _check_divisible("num_attention_heads", text_config.num_attention_heads, tp_degree) + _check_divisible("num_key_value_heads", text_config.num_key_value_heads, tp_degree) + + local_attention_heads = text_config.num_attention_heads // tp_degree + if sp_degree > 1 and local_attention_heads % sp_degree != 0: + raise ValueError( + f"num_attention_heads / tp_degree ({local_attention_heads}) must be divisible by " + f"sp_ulysses_degree ({sp_degree})" + ) + + +def apply_qwen3_vl_parallel( + model: Qwen3VLForConditionalGeneration, + tp_mesh: DeviceMesh, + **kwargs, +) -> None: + tp_plan = { + "self_attn.q_proj": ColwiseParallel(use_local_output=True), + "self_attn.k_proj": ColwiseParallel(use_local_output=True), + "self_attn.v_proj": ColwiseParallel(use_local_output=True), + "self_attn.o_proj": RowwiseParallel(use_local_output=True), + "mlp.gate_proj": ColwiseParallel(use_local_output=True), + "mlp.up_proj": ColwiseParallel(use_local_output=True), + "mlp.down_proj": RowwiseParallel(use_local_output=True), + } + + for decoder_layer in model.model.language_model.layers: + parallelize_module(decoder_layer, device_mesh=tp_mesh, parallelize_plan=tp_plan) + + logger.info(f"Applied Qwen3-VL text TP to {len(model.model.language_model.layers)} decoder layers") + + +def apply_qwen3_vl_fsdp2( + model: Qwen3VLForConditionalGeneration, + train_args: "TrainingArguments", + **kwargs, +) -> None: + if not train_args.fsdp_config.get("transformer_layer_cls_to_wrap", None): + logger.warning("transformer_layer_cls_to_wrap ignored; qwen3_vl wraps modules explicitly.") + + param_dtype = torch.bfloat16 if train_args.bf16 else torch.float16 + + if train_args.gradient_checkpointing: + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + + reduce_dtype = getattr(torch, train_args.reduce_dtype) + output_dtype = getattr(torch, train_args.output_dtype) + mp_policy = MixedPrecisionPolicy( + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + output_dtype=output_dtype, + ) + + fsdp_kwargs = { + "reshard_after_forward": getattr(train_args, "fsdp_config", {}).get("reshard_after_forward", True), + "mp_policy": mp_policy, + "mesh": pgm.process_group_manager.device_mesh["fsdp"], + } + + if hasattr(model.model, "visual") and model.model.visual is not None: + fully_shard(model.model.visual, **fsdp_kwargs) + + for decoder_layer in model.model.language_model.layers: + fully_shard(decoder_layer.self_attn, **fsdp_kwargs) + fully_shard(decoder_layer.mlp, **fsdp_kwargs) + + fully_shard(model.model.language_model.embed_tokens, **fsdp_kwargs) + fully_shard(model, **fsdp_kwargs) + + +def apply_qwen3_vl_parallelize_fn( + model: Qwen3VLForConditionalGeneration, + train_args: "TrainingArguments", + **kwargs, +) -> None: + _validate_qwen3_vl_tp_config(model, train_args) + + full_state_dict = model.state_dict() + if pgm.process_group_manager.tp_world_size > 1: + tp_mesh = pgm.process_group_manager.device_mesh["tp"] + apply_qwen3_vl_parallel(model, tp_mesh=tp_mesh, **kwargs) + + apply_qwen3_vl_fsdp2(model, train_args, **kwargs) + fsdp2_load_full_state_dict(model, full_state_dict) diff --git a/src/lmms_engine/train/config.py b/src/lmms_engine/train/config.py index d9245aec..c4448fed 100644 --- a/src/lmms_engine/train/config.py +++ b/src/lmms_engine/train/config.py @@ -27,6 +27,7 @@ class TrainingArguments(transformers.TrainingArguments): # Parallelism ep_degree: Optional[int] = 1 sp_ulysses_degree: Optional[int] = 1 + tp_degree: Optional[int] = 1 # --- EMA (Exponential Moving Average) --- ema_enabled: Optional[bool] = False diff --git a/src/lmms_engine/train/fsdp2/fsdp2_trainer.py b/src/lmms_engine/train/fsdp2/fsdp2_trainer.py index be709fc6..31d7b19e 100644 --- a/src/lmms_engine/train/fsdp2/fsdp2_trainer.py +++ b/src/lmms_engine/train/fsdp2/fsdp2_trainer.py @@ -392,11 +392,13 @@ def train(self, resume_from_checkpoint: bool = False): device = self.fsdp2_model.device flops_tensor = torch.tensor(flops, device=device) sp_size = pgm.process_group_manager.cp_world_size + tp_size = pgm.process_group_manager.tp_world_size + parallel_size = sp_size * tp_size # Calculate training metrics (MFU, token stats, throughput) perf_metrics, self.total_tokens = self.calculate_training_metrics( flops_tensor=flops_tensor, - sp_size=sp_size, + parallel_size=parallel_size, promised_flops=promised_flops, device=device, seq_len=seq_len, @@ -633,7 +635,7 @@ def print_batch_input(self, batch): @staticmethod def calculate_training_metrics( flops_tensor: torch.Tensor, - sp_size: int, + parallel_size: int, promised_flops: float, device: torch.device, seq_len: list, @@ -646,7 +648,7 @@ def calculate_training_metrics( Args: flops_tensor: Tensor containing FLOPs count - sp_size: Sequence parallel size + parallel_size: Product of sequence and tensor parallel sizes promised_flops: Promised FLOPs capacity device: Device to perform computations on seq_len: List of sequence lengths per batch @@ -660,16 +662,16 @@ def calculate_training_metrics( metrics = {} # Calculate mfu per rank - # Divide by sp size because attention mask we use to calculate are unsplitted - mfu = flops_tensor.item() / sp_size / promised_flops + # Divide by parallel size because seq_len/flops are estimated before SP/TP sharding. + mfu = flops_tensor.item() / parallel_size / promised_flops mfu = torch.tensor(mfu, device=device) torch.distributed.all_reduce(mfu, op=torch.distributed.ReduceOp.AVG) mfu = mfu.item() # Calculating token stats seq_len = torch.tensor(seq_len, device=device, dtype=torch.float32) - # Divide total seq len by sp size if sp is enabled since we split the seq len - total_seq_len = seq_len.sum() / sp_size + # Divide by parallel size to avoid counting replicated SP/TP batches multiple times. + total_seq_len = seq_len.sum() / parallel_size torch.distributed.all_reduce(total_seq_len, op=torch.distributed.ReduceOp.SUM) # Avg seq len won't be effected by sp since we perform all reduce # across world size