Skip to content

tum-ai/number-token-loss

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

Regress, Don't Guess โ€“ A Regression-like Loss on Number Tokens for Language Models

Paper GitHub Pages Demo Integration License

Introducing "Number Token Loss" (NTL) for language models to improve numerical reasoning by using regression-based loss functions on number tokens. Achieves better performance on math tasks without computational overhead ๐Ÿš€


๐Ÿ“– Overview

Number Token Loss (NTL) introduces a novel approach to enhance language models' numerical reasoning capabilities. Unlike traditional cross-entropy loss that treats all incorrect predictions equally, NTL incorporates the numerical proximity of tokens, providing regression-like behavior at the token level.

NTL Concept

๐ŸŽฏ Why do we need the Number Token Loss (NTL)?

Cross Entropy is nominal-scale and thus assigns equal loss to all incorrect predictions. This makes sense for normal tokens but not for number tokens:

With a ground truth token 4, predicting 3 or 9 should not give equal loss ๐Ÿค”๐Ÿ˜ฑ
NTL fixes this! ๐Ÿš€๐Ÿ’ช

For all number tokens, NTL increases with distance from ground truth just like a regression loss. But it doesn't need an extra head, it allows computing a regression-like loss directly on a token head. We propose two schemes:
NTL-WAS โ€“ Wasserstein-1 distance between predicted and one-hot number distributions (see plot above).
NTL-MSE โ€“ Dot-product expectation of numeric value with squared error (most intuitive but has some undesired local minima)

Loss Comparison

๐Ÿ”‘ Key Features

  • ๐ŸŽฏ Model-Agnostic: NTL is just a loss โ†’ applicable to any LM (e.g., Transformer, Mamba) in any architecture (encoder-decoder, decoder-only).
  • ๐Ÿ”Œ Plug-and-Play: NTL requires only a mapping from tokens to numeric values and works with digit-level and multi-digit tokenizations.
  • โšก No computational overhead: NTL adds only ~1% compute time to loss calculation which is negligible over a full training step.
  • ๐Ÿ“ˆ Consistently improves performance: NTL outperforms plain cross entropy across multiple architectures and math benchmarks.
  • ๐Ÿ”ข Performs true regression: On regression tasks a LM head with NTL matches a dedicated regression head.
  • ๐Ÿš€ Scales to large models: Even Granite 3.2 2B and T5-3B benefit heavily from NTL on math tasks like GSM8K.

๐Ÿš€ Quick Links

๐Ÿ› ๏ธ Installation

Prerequisites

  • Python 3.9 or higher
  • CUDA-compatible GPU (recommended)

Setup Instructions

  1. Clone the repository

    git clone https://github.com/tum-ai/number-token-loss.git
    cd number-token-loss
  2. Create and activate environment

    conda create -n ntl python=3.10
    conda activate ntl
  3. Install dependencies

    pip install -r requirements.txt
    pip install -e .
  4. Configure Weights & Biases

    wandb login
    export WANDB_ENTITY='<your_entity>'
    export WANDB_PROJECT='<your_project_name>'

๐Ÿƒโ€โ™‚๏ธ Quick Start

Easy Integration into Your Model

For a minimal working example of how to integrate NTL into your existing Hugging Face model, check out our lightweight integration notebook. It demonstrates:

  • How to add NTL to any decoder-only language model (e.g., LLaMA, GPT)
  • Custom trainer implementation with CE+NTL loss
  • Complete working example with TinyLLaMA

Full Training Pipeline

The main training script uses Hydra for configuration management:

python src/ntl/run_language_modeling.py \
    dataset_args=mathematics_dataset \
    model_args=vanilla_t5_ntl \
    training_args=train

Configuration Options

  • Datasets: gsm8k, mathematics_dataset, arithmetic, rjokes, multirc
  • Models: vanilla_t5, vanilla_t5_ntl, rt, rt_ntl, xval
  • Training: train, eval

Custom Configuration

Override default parameters via command line:

python src/ntl/run_language_modeling.py \
    model_args=vanilla_t5_ntl \
    training_args=train \
    training_args.per_device_train_batch_size=8 \
    model_args.number_token_loss_weight=0.3

๐Ÿ“Š Experimental Results

Download the used datasets

Mathematics Dataset

Model Configuration Command
T5 Baseline Standard Cross-Entropy python src/ntl/run_language_modeling.py run_specific_config@_global_=mathematics_dataset_run model_args=vanilla_t5 dataset_args=mathematics_dataset
T5 + NTL-MSE MSE-based NTL python src/ntl/run_language_modeling.py run_specific_config@_global_=mathematics_dataset_run model_args=vanilla_t5_ntl dataset_args=mathematics_dataset
T5 + NTL-WAS Wasserstein-based NTL python src/ntl/run_language_modeling.py run_specific_config@_global_=mathematics_dataset_run model_args=vanilla_t5_ntl model_args.number_token_loss_with_wasserstein=true dataset_args=mathematics_dataset

Ablation Studies

Comprehensive ablation studies on arithmetic subsets:

View Ablation Commands

NTL-MSE with Different Weights:

# ฮป = 0.3
python src/ntl/run_language_modeling.py dataset_args=arithmetic model_args=vanilla_t5_ntl model_args.number_token_loss_with_wasserstein=false model_args.number_token_loss_weight=0.3 training_args.special_name=NTL-MSE_Lambda0.3

# ฮป = 0.8  
python src/ntl/run_language_modeling.py dataset_args=arithmetic model_args=vanilla_t5_ntl model_args.number_token_loss_with_wasserstein=false model_args.number_token_loss_weight=0.8 training_args.special_name=NTL-MSE_Lambda0.8

# ฮป = 2.0
python src/ntl/run_language_modeling.py dataset_args=arithmetic model_args=vanilla_t5_ntl model_args.number_token_loss_with_wasserstein=false model_args.number_token_loss_weight=2.0 training_args.special_name=NTL-MSE_Lambda2.0

Alternative Loss Functions:

# NTL-MAE
python src/ntl/run_language_modeling.py dataset_args=arithmetic model_args=vanilla_t5_ntl +model_args.number_token_loss_function=mae training_args.special_name=NTL-MAE_Lambda0.3

# NTL-Huber
python src/ntl/run_language_modeling.py dataset_args=arithmetic model_args=vanilla_t5_ntl +model_args.number_token_loss_function=huber training_args.special_name=NTL-Huber_Lambda0.3

Large-Scale Evaluation (GSM8K)

Scale NTL to 3B parameter models:

# T5-3B Baseline
python src/ntl/run_language_modeling.py run_specific_config@_global_=gsm8k_runs model_args=vanilla_t5 dataset_args=gsm8k

# T5-3B + NTL-WAS
python src/ntl/run_language_modeling.py run_specific_config@_global_=gsm8k_runs model_args=vanilla_t5_ntl dataset_args=gsm8k model_args.number_token_loss_weight=0.3

๐Ÿงช Advanced Usage

Debugging Mode

python src/ntl/run_language_modeling.py \
    model_args=vanilla_t5 \
    training_args=train \
    run_specific_config@_global_=debug_config

Model Evaluation

python src/ntl/run_language_modeling.py \
    model_args=vanilla_t5_ntl \
    training_args=eval \
    model_args.model_name_or_path=<path_to_checkpoint>

๐Ÿ“ Citation

If you find this work useful, please cite our paper:

@inproceedings{zausinger2025regress,
  title   = {Regress, Don't Guess โ€“ A Regression-like Loss on Number Tokens for Language Models},
  author  = {Jonas Zausinger and Lars Pennig and Anamarija Kozina and Sean Sdahl
             and Julian Sikora and Adrian Dendorfer and Timofey Kuznetsov
             and Mohamad Hagog and Nina Wiedemann and Kacper Chlodny
             and Vincent Limbach and Anna Ketteler and Thorben Prein
             and Vishwa Mohan Singh and Michael Danziger and Jannis Born},
  booktitle = {Proc. of the 42nd International Conference on Machine Learning (ICML)},
  year    = {2025},
  url     = {https://tum-ai.github.io/number-token-loss/}
}

๐Ÿ“„ License

This project is licensed under the MIT License - see the LICENSE file for details.


About

A regression-alike loss to improve numerical reasoning in language models

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Contributors 14