|
| 1 | +# Prepending or Cross-Attention for Speech-to-Text? An Empirical Comparison (NAACL 2025) |
| 2 | + |
| 3 | +This README contains the instructions to replicate the training and evaluation of the models in the paper |
| 4 | +[Prepending or Cross-Attention for Speech-to-Text? An Empirical Comparison](https://arxiv.org/abs/2501.02370) |
| 5 | +published at NAACL 2025. |
| 6 | + |
| 7 | +## Training |
| 8 | + |
| 9 | +Below we list the scripts used in our experiments. The scripts were executed using |
| 10 | +2 A100 GPUs with 64GB of VRAM. In case of a different environment (e.g., a GPU with less VRAM) |
| 11 | +you need to adapt `--max-tokens` (which controls the mini-batch size on a single GPU) |
| 12 | +and `--update-freq`, so that `number of GPUs * max_tokens * update_freq = 320,000`. |
| 13 | + |
| 14 | +### Cross-attention encoder-decoder |
| 15 | + |
| 16 | +The Transformer encoder-decoder with cross-attention (line 1 of Table 1 in the paper) |
| 17 | +has been trained using: |
| 18 | + |
| 19 | +```shell |
| 20 | +python fbk-fairseq/train.py $data_root \ |
| 21 | + --train-subset $train_tsv --valid-subset $dev_tsv --config-yaml $config \ |
| 22 | + --save-dir $save_dir --user-dir fbk-fairseq/examples/speech_to_text \ |
| 23 | + --task speech_to_text_ctc --criterion label_smoothed_cross_entropy \ |
| 24 | + --label-smoothing 0.1 \ |
| 25 | + --arch s2t_transformer_fbk \ |
| 26 | + --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \ |
| 27 | + --warmup-updates 25000 \ |
| 28 | + --clip-norm 10.0 --adam-betas '(0.9, 0.98)' \ |
| 29 | + --seed 1 --skip-invalid-size-inputs-valid-test \ |
| 30 | + --update-freq 4 --max-tokens 40000 --num-workers 4 \ |
| 31 | + --max-update 100000 --patience 10 --keep-last-epochs 12 \ |
| 32 | + --log-format simple >> $save_dir/train.log 2> $save_dir/train.err |
| 33 | +``` |
| 34 | + |
| 35 | +Similarly, the Conformer version with CTC auxiliary loss (line 4 of Table 1) |
| 36 | +was trained with: |
| 37 | + |
| 38 | +```shell |
| 39 | +python fbk-fairseq/train.py $data_root \ |
| 40 | + --train-subset $train_tsv --valid-subset $dev_tsv --config-yaml $config \ |
| 41 | + --save-dir $save_dir --user-dir fbk-fairseq/examples/speech_to_text \ |
| 42 | + --task speech_to_text_ctc --criterion ctc_multi_loss --underlying-criterion label_smoothed_cross_entropy \ |
| 43 | + --label-smoothing 0.1 --ctc-encoder-layer 8 --ctc-weight 0.5 \ |
| 44 | + --arch conformer \ |
| 45 | + --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \ |
| 46 | + --warmup-updates 25000 \ |
| 47 | + --clip-norm 10.0 --adam-betas '(0.9, 0.98)' \ |
| 48 | + --seed 1 --skip-invalid-size-inputs-valid-test \ |
| 49 | + --update-freq 4 --max-tokens 40000 --num-workers 4 \ |
| 50 | + --max-update 100000 --patience 10 --keep-last-epochs 12 \ |
| 51 | + --log-format simple >> $save_dir/train.log 2> $save_dir/train.err |
| 52 | +``` |
| 53 | + |
| 54 | +And to enable CTC compression (line 4.1 of Table 1), add to this command `--ctc-compress-strategy avg`. |
| 55 | + |
| 56 | +### Decoder-prepending |
| 57 | + |
| 58 | +The decoder-prepending models (line 2 of Table 1) have been trained with: |
| 59 | + |
| 60 | +```shell |
| 61 | +python fbk-fairseq-dev/train.py $data_root \ |
| 62 | + --train-subset $train_tsv --valid-subset $dev_tsv --config-yaml $config \ |
| 63 | + --save-dir $save_dir --user-dir fbk-fairseq/examples/speech_to_text \ |
| 64 | + --task speech_to_text_ctc --criterion label_smoothed_cross_entropy \ |
| 65 | + --label-smoothing 0.1 \ |
| 66 | + --arch s2tlm_transformer_fbk --encoder-layers 12 --decoder-layers 6 --causal-prompt-mask \ |
| 67 | + --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \ |
| 68 | + --warmup-updates 25000 \ |
| 69 | + --clip-norm 10.0 --adam-betas '(0.9, 0.98)' \ |
| 70 | + --seed 1 --skip-invalid-size-inputs-valid-test \ |
| 71 | + --update-freq 4 --max-tokens 40000 --num-workers 4 \ |
| 72 | + --max-update 100000 --patience 10 --keep-last-epochs 12 \ |
| 73 | + --log-format simple >> $save_dir/train.log 2> $save_dir/train.err |
| 74 | +``` |
| 75 | + |
| 76 | +To train the version without causal masking in the speech features, remove `--causal-prompt-mask`. |
| 77 | + |
| 78 | +The Conformer version with CTC auxiliary loss (line 5 of Table 1) was trained with: |
| 79 | + |
| 80 | +```shell |
| 81 | +python fbk-fairseq-dev/train.py $data_root \ |
| 82 | + --train-subset $train_tsv --valid-subset $dev_tsv --config-yaml $config \ |
| 83 | + --save-dir $save_dir --user-dir fbk-fairseq/examples/speech_to_text \ |
| 84 | + --task speech_to_text_ctc --criterion ctc_multi_loss --underlying-criterion label_smoothed_cross_entropy \ |
| 85 | + --label-smoothing 0.1 --ctc-encoder-layer 8 --ctc-weight 0.5 \ |
| 86 | + --arch s2tlm_conformer --encoder-layers 12 --decoder-layers 6 --causal-prompt-mask \ |
| 87 | + --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \ |
| 88 | + --warmup-updates 25000 \ |
| 89 | + --clip-norm 10.0 --adam-betas '(0.9, 0.98)' \ |
| 90 | + --seed 1 --skip-invalid-size-inputs-valid-test \ |
| 91 | + --update-freq 4 --max-tokens 40000 --num-workers 4 \ |
| 92 | + --max-update 100000 --patience 10 --keep-last-epochs 12 \ |
| 93 | + --log-format simple >> $save_dir/train.log 2> $save_dir/train.err |
| 94 | +``` |
| 95 | + |
| 96 | +And, as in the previous case, CTC compression (line 5.1) is obtained by adding `--ctc-compress-strategy avg`. |
| 97 | + |
| 98 | +### Decoder-only |
| 99 | + |
| 100 | +The decoder-only models were obtained with the same script fo the decoder-prepending ones, |
| 101 | +but setting the number of encoder layers to 0 and increasing the number of decoder layers. |
| 102 | +This means that line 3 of Table 1 was obtained with: |
| 103 | + |
| 104 | +```shell |
| 105 | +python fbk-fairseq-dev/train.py $data_root \ |
| 106 | + --train-subset $train_tsv --valid-subset $dev_tsv --config-yaml $config \ |
| 107 | + --save-dir $save_dir --user-dir fbk-fairseq/examples/speech_to_text \ |
| 108 | + --task speech_to_text_ctc --criterion label_smoothed_cross_entropy \ |
| 109 | + --label-smoothing 0.1 \ |
| 110 | + --arch s2tlm_transformer_fbk --encoder-layers 0 --decoder-layers 18 \ |
| 111 | + --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt \ |
| 112 | + --warmup-updates 25000 \ |
| 113 | + --clip-norm 10.0 --adam-betas '(0.9, 0.98)' \ |
| 114 | + --seed 1 --skip-invalid-size-inputs-valid-test \ |
| 115 | + --update-freq 4 --max-tokens 40000 --num-workers 4 \ |
| 116 | + --max-update 100000 --patience 10 --keep-last-epochs 12 \ |
| 117 | + --log-format simple >> $save_dir/train.log 2> $save_dir/train.err |
| 118 | +``` |
| 119 | + |
| 120 | +And causal masking can be enforced adding `--causal-prompt-mask`. |
| 121 | + |
| 122 | +## Evaluation |
| 123 | + |
| 124 | +We generate the hypothesis for our models with the command: |
| 125 | + |
| 126 | +```shell |
| 127 | +python fbk-fairseq/fairseq_cli/generate.py $DATA_ROOT \ |
| 128 | + --user-dir fbk-fairseq/examples/speech_to_text/ --config-yaml $CONFIG_YAML \ |
| 129 | + --gen-subset $SPLIT \ |
| 130 | + --max-tokens 80000 --unkpen 10000 --beam 5 \ |
| 131 | + --max-source-positions 12000 --max-target-positions 4000 \ |
| 132 | + --model-overrides "{'max_source_positions':12000,'max_target_positions':4000}" \ |
| 133 | + --task speech_to_text_ctc --criterion label_smoothed_cross_entropy --no-repeat-ngram-size 5 \ |
| 134 | + --path $MODEL |
| 135 | +``` |
| 136 | + |
| 137 | +For models trained with the auxiliary CTC loss, change the `--criterion` |
| 138 | +to `ctc_multi_loss` and add `--underlying-criterion label_smoothed_cross_entropy`. |
| 139 | + |
| 140 | +### WER |
| 141 | + |
| 142 | +WER was computed using jiWER after removing punctuation. This was done with the following script: |
| 143 | + |
| 144 | +```shell |
| 145 | +ref=$1 |
| 146 | +out=$2 |
| 147 | +tmp_dir=$(mktemp -d -t ci-XXXXXXXXXX) |
| 148 | +tr -d '[:punct:]' < $ref | sed 's/ / /g' > $tmp_dir/ref.txt |
| 149 | +tr -d '[:punct:]' < $out | sed 's/ / /g' > $tmp_dir/out.txt |
| 150 | + |
| 151 | +jiwer -h $tmp_dir/out.txt -r $tmp_dir/ref.txt |
| 152 | +rm -rf $tmp_dir |
| 153 | +``` |
| 154 | + |
| 155 | +The statistical significance was computed using the script |
| 156 | +[WER bootstrap resampling](../examples/speech_to_text/scripts/wer_bootstrap_resampling.py). |
| 157 | + |
| 158 | +### BLEU |
| 159 | + |
| 160 | +All the scores and statistical significance were computed with the `sacreBLEU` command. |
| 161 | + |
| 162 | +## Citation |
| 163 | + |
| 164 | +```bibtex |
| 165 | +@inproceedings{lam-et-al-2025-prepending, |
| 166 | + title={{Prepending or Cross-Attention for Speech-to-Text? An Empirical Comparison}}, |
| 167 | + author={Tsz Kin Lam and Marco Gaido and Sara Papi and Luisa Bentivogli and Barry Haddow}, |
| 168 | + booktitle = "Proceedings of the 2025 Annual Conference of the Nations of the Americas Chapter of the Association for Computational Linguistics", |
| 169 | + address = "Albuquerque, New Mexico", |
| 170 | + year={2025} |
| 171 | +} |
| 172 | +``` |
| 173 | + |
0 commit comments