This project adapts Supervised Contrastive Learning to the punctuation restoration task. This is the implementation of the paper Token-Level Supervised Contrastive Learning for Punctuation Restoration accepted by InterSpeech 2021
The data has been converted by corresponding BERT tokenizer with labels, and saved into the pickle files under the dataset/
directory.
The original text files are from International Workshop on Spoken Language Translation, 2012
We fine-tuned a Transformer-based language model with supervised contrastive learning for the punctuation restoration task.
To install the environment for this project, we recommend Anaconda
After having installed Anaconda, environment can be built by:
conda env create -f environment.yml
There are some example scripts under the example_scripts
directory
Firstly, activate the anaconda environment by:
conda activate punc_interspeech
Then, execute the train.py
by (Example):
python train.py --config=config/roberta-large-scl.yml -l 0.1 -t 0.6
Here, we provide several config files and example scripts in
- config/
- example_scripts/
During training, the log of the tensorboard
will be located under runs/
directory, which will be created automatically after program started.
Meanwhile, the models for each epoch will be saved under the saved_model
directory.
Evaluation can be done by running:
python evaluate.py --config=[config path] --checkpoint=[saved model file path]
@inproceedings{huang21g_interspeech,
author={Qiushi Huang and Tom Ko and H. Lilian Tang and Xubo Liu and Bo Wu},
title={{Token-Level Supervised Contrastive Learning for Punctuation Restoration}},
year=2021,
booktitle={Proc. Interspeech 2021},
pages={2012--2016},
doi={10.21437/Interspeech.2021-661}
}