Skip to content

Official implementation of "Align-to-Distill: Trainable Attention Alignment for Knowledge Distillation in Neural Machine Translation" (LREC-COLING 2024)

License

Notifications You must be signed in to change notification settings

HeegonJin/Align-to-Distill

 
 

Repository files navigation

Align-to-Distill: Trainable Attention Alignment for Knowledge Distillation in Neural Machine Translation

This is the PyTorch implementation of paper: Align-to-Distill: Trainable Attention Alignment for Knowledge Distillation in Neural Machine Translation (LREC-COLING 2024).

We carry out our experiments on standard Transformer with the fairseq toolkit. If you use any source code included in this repo in your work, please cite the following paper.

@misc{jin2024aligntodistill,
      title={Align-to-Distill: Trainable Attention Alignment for Knowledge Distillation in Neural Machine Translation}, 
      author={Heegon Jin and Seonil Son and Jemin Park and Youngseok Kim and Hyungjong Noh and Yeonsoo Lee},
      year={2024},
      eprint={2403.01479},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

Requirements and Installation

  • PyTorch version >= 1.10.0
  • Python version >= 3.8
  • For training new models, you'll also need an NVIDIA GPU and NCCL
  • To install fairseq and develop locally:
git clone this_repository
cd fairseq
pip install --editable ./

We require a few additional Python dependencies:

pip install sacremoses einops

Prepare dataset

IWSLT'14 German to English

The following instructions can be used to train a Transformer model on the IWSLT'14 German to English dataset.

First download and preprocess the data:

# Download and prepare the data
cd examples/translation/
bash prepare-iwslt14.sh
cd ../..

# Preprocess/binarize the data
TEXT=examples/translation/iwslt14.tokenized.de-en
fairseq-preprocess --source-lang de --target-lang en \
    --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
    --destdir data-bin/iwslt14.tokenized.de-en \
    --workers 20

Training

First, you need train a teacher model, the training script is the same with fairseq. Second, use the trained teacher model to train an A2D student model. The '--teacher-ckpt-path' argument is used to specify the path to the trained teacher model checkpoint from the first step.

Adjustable arguments for experiments:

  • add '--alpha' (default=0.5) : This argument controls the weight between the cross-entropy loss and the response-based distillation loss.
  • add '--beta' (default=1) : This argument controls the weight between the response-based loss and the attention distillation loss.
  • add '--decay' (default=0.9) : This argument sets the decay rate for the attention distillation loss.

Two scripts are provided for running the training processes:

  • train_teacher.sh: This script is used to train the teacher model.
  • train_student.sh: This script is used to train the A2D student model using the trained teacher model.

Train a teacher model

bash train_teacher.sh

Train a student model (with A2D method)

bash train_student.sh

Test a student model (with A2D method)

bash test.sh

Citation

Please cite as:

@misc{jin2024aligntodistill,
      title={Align-to-Distill: Trainable Attention Alignment for Knowledge Distillation in Neural Machine Translation}, 
      author={Heegon Jin and Seonil Son and Jemin Park and Youngseok Kim and Hyungjong Noh and Yeonsoo Lee},
      year={2024},
      eprint={2403.01479},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

About

Official implementation of "Align-to-Distill: Trainable Attention Alignment for Knowledge Distillation in Neural Machine Translation" (LREC-COLING 2024)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 97.8%
  • Other 2.2%