Skip to content

Conversation

@KLGR123
Copy link

@KLGR123 KLGR123 commented Oct 15, 2025

📝 Description

This PR adds Multi-Token Prediction (MTP) support to the TRL library, based on Meta AI's paper Better & Faster Large Language Models via Multi-token Prediction and other baseline works.

MTP extends the standard next-token prediction by training additional heads to forecast multiple future tokens in parallel, providing richer training signals and improving both training efficiency and inference speed.

🔬 Connection to Reinforcement Learning

MTP has an interesting duality with world models in reinforcement learning. The additional prediction heads essentially learn to forecast future observations, similar to how world models predict future states. This conceptual similarity suggests that MTP could potentially serve as a key component in imagination-based RL algorithms, such as Dreamer, which learns behaviors by imagining trajectories in a learned latent space.

Just as world models enable agents to "dream" about possible futures before taking actions, MTP heads learn to predict multiple future tokens, effectively building an implicit model of the sequential dependencies in language. This parallel opens up exciting possibilities:

  • Implicit World Modeling: MTP heads may capture temporal dynamics similar to explicit world models in RL
  • Planning via Prediction: Multi-step predictions could inform better token generation strategies
  • Latent Imagination: The learned representations from MTP could support trajectory optimization in language space

🎯 Key Features

  • 5 Head Architectures: Support for linear, ffn, mha_ffn, cnn, and identical head types
  • Universal Model Compatibility: Works seamlessly with Llama, Qwen, GPTNeoX, Mistral, and other architectures
  • Flexible Configuration: User-controllable parameters
  • Seamless Integration: Integrates naturally with existing SFTTrainer workflow

📦 Main Components

  • trl/models/modeling_mtp_extension.py - Core MTP implementation with 5 head types
  • trl/trainer/mtp_data_collator.py - Data collator for multi-token targets
  • trl/trainer/sft_config.py - MTP configuration parameters
  • trl/trainer/sft_trainer.py - SFTTrainer integration
  • docs/source/mtp_trainer.md - Complete documentation with examples
  • examples/scripts/sft_with_mtp.py - Comprehensive example script
  • tests/test_mtp_functionality.py - Full test suite

💡 Usage Example

from trl import SFTConfig, SFTTrainer
from datasets import load_dataset

# Simple usage with smart defaults
trainer = SFTTrainer(
    model="Qwen/Qwen2.5-0.5B",
    args=SFTConfig(
        output_dir="model-mtp",
        mtp_enabled=True,
        mtp_num_predictions=2,
        mtp_loss_weight=0.5,
    ),
    train_dataset=dataset,
)
trainer.train()

# Advanced usage with full control
trainer = SFTTrainer(
    model="meta-llama/Llama-3.2-1B",
    args=SFTConfig(
        output_dir="model-mtp-advanced",
        mtp_enabled=True,
        mtp_num_predictions=3,
        mtp_head_type="identical",
        mtp_num_layers=2,
        mtp_loss_weight=0.3
    ),
    train_dataset=dataset,
)
trainer.train()

KLGR123 and others added 12 commits September 2, 2025 19:02
- mtp forward head
- multiple choices with cnn, ffn, mha heads, etc.
- mtp hyper params k, decay, etc.
- mtp loss with ntp loss and backwards
- some test files
- Add support for various LM head attribute names (lm_head, embed_out, head, etc.)
- Fix tests to properly initialize MTP before checking model attributes
- Ensure compatibility with GPTNeoX and other model architectures
- Add helper method to get LM head from different model types
- Add num_attention_heads parameter for user control
- Implement smart default calculation following Transformer best practices
- Use head_dim of 64 or 128 (industry standard)
- Provide clear error messages with suggested values
- Fix issue where num_heads could be 0 for small hidden_size
- Balance flexibility for experts with good defaults for regular users
- Skip 1D parameter initialization for LayerNorm (kaiming/xavier require 2D+)
- Fix dtype detection to support all head types (MHAFFNHead, CNNHead, etc.)
- Add safe fallback for unknown LM head layer types with warning
- Prevent AttributeError when accessing weight tensors
- Ensure all edge cases are handled gracefully
@qgallouedec
Copy link
Member

thanks for the PR. In my opinion, MTP usage is still too low to include it in the main codebase. I therefore suggest either putting it in a separate repository that can be linked to in the TRL documentation, or including it in trl.experimental.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants