Reproducing Kotoba-whisper models requires following five stages to be completed in successive order:
- Setup
- Download Dataset
- Generate Labels
- Filter Dataset
- Initialize Distil-Whisper
- Train Model
- Evaluate Model
To reproduce kotoba-whisper models, please refer the following scripts:
- kotoba-tech/kotoba-whisper-v2.1: kotoba_whisper_v2.1.sh
- kotoba-tech/kotoba-whisper-v2.0: kotoba_whisper_v2.0.sh
- kotoba-tech/kotoba-whisper-v1.1: kotoba_whisper_v1.1.sh
- kotoba-tech/kotoba-whisper-v1.0: kotoba_whisper_v1.0.sh
Clone the repo and configure your huggingface environment.
- pip install
git clone [email protected]:kotoba-tech/kotoba-whisper.git
cd kotoba-whisper
pip install -r requirements.txt
pip install flash-attn --no-build-isolation
- huggingface configuration
accelerate config
huggingface-cli login
Although ReazonSpeech is available on huggingface,
the repository has stability issues (raising TimeOutError for the larger subsets such as large
or all
), so we instead
download the source files locally first, and use our manual data loader to load
the dataset in the next label generation step.
The python script run_pseudo_labelling.py
downloads the source files of ReazonSpeech locally.
python reazonspeech_manual_downloader.py [--target TARGET] [-p POOL] [-s START_QUE] [-e END_QUE]
The argument --target
should be either of tiny
/small
/medium
/large
/all
. The full dataset all
is very large, so you
may want to download by small chunks, which can be done by specifying the index of raw files by --start
and --end
such as below.
python reazonspeech_manual_downloader.py --target all -p 100 -s 0 -e 50
ReazonSpeech all
has 4096 files in total, and we use the last file to create our held-out test set,
so we ran the above command until -e 4095
with reasonable chunk size (we set 50).
The python script run_pseudo_labelling.py
is a flexible inference script that can be used
to generate pseudo-labels under a range of settings, including using both greedy and beam-search.
To generate labels from the teacher model on the locally downloaded ReazonSpeech dataset, run the following command, which
generates labels on all the audio and upload to the huggingface hub in the audio dataset format to where specified by
--hub_model_id
.
accelerate launch run_pseudo_labelling.py \
--model_name_or_path "openai/whisper-large-v3" \
--dataset_name "${PWD}/reazonspeech_manual_dataloader.py" \
--dataset_config_name "tiny" \
--dataset_split_name "train" \
--text_column_name "transcription" \
--id_column_name "name" \
--per_device_eval_batch_size 4 \
--dataloader_num_workers 32 \
--preprocessing_num_workers 32 \
--logging_steps 100 \
--max_label_length 128 \
--language "ja" \
--return_timestamps \
--attn_type "flash_attn" \
--generation_num_beams 1 \
--decode_token_ids False \
--overwrite_output_dir \
--output_dir "output" \
--wandb_project "wandb" \
--hub_model_id "{your-hf-org}/{your-dataset-name}"
Note that we use our custom data loader, but any huggingface audio datasets can be used in the above script.
The original distil-whisper paper proposed to filter the dataset based on the word error rate (WER) between the reference and the predicted transcription to retain the quality of the dataset for the distillation. We also follow the filtering procedure and drop the dataset if the WER is more than 10%. The following script will take the dataset with the whisper label generated by the previous step, drop those with WER more than 10%, transform the wave signal to Mel spectrogram, and upload the dataset to the huggingface hub in the audio dataset format.
python run_data_filtering.py \
-d "your-hf-org/dataset_name" \
--dataset_config_name "tiny" \
--wer_threshold 10 \
--text_column_name "transcription" \
--preprocessing_num_workers 64 \
--max_label_length 128
The script create_student_model.py
can be used to initialise a small student model
from a large teacher model. When initialising a student model with fewer layers than the teacher model, the student is
initialised by copying maximally spaced layers from the teacher, as per the DistilBart
recommendations. First, we need to create a model repository on the Hugging Face Hub. This repository will contain all the required files
to reproduce the training run, alongside model weights, training logs and a README.md card. You can either create a model
repository directly on the Hugging Face Hub using the link: https://huggingface.co/new. Or, via the CLI, as we'll show here.
huggingface-cli repo create {your-hf-org}/{your-model-name}
Let's clone the repository so that we can place our training script and model weights inside:
git lfs install
git clone "https://huggingface.co/{your-hf-org}/{your-model-name}"
We can now copy the relevant training scrips to the repository:
cp create_student_model.py {your-hf-org}/{your-model-name}
cp run_distillation.py {your-hf-org}/{your-model-name}
cd {your-hf-org}/{your-model-name} || exit
The following command demonstrates how to initialise a student model from the Whisper checkpoint, with all 32 encoder layer and 2 decoder layers. The 2 student decoder layers are copied from teacher layers 1 and 32 respectively, as the maximally spaced layers:
python create_student_model.py \
--teacher_checkpoint "openai/whisper-large-v3" \
--encoder_layers 32 \
--decoder_layers 2 \
--save_dir "{your-hf-org}/{your-model-name}-init"
The initialised model will be saved to the sub-directory in our model repository.
The script run_distillation.py
is an end-to-end script for loading multiple
datasets, a student model, a teacher model, and performing teacher-student distillation. It uses the loss formulation
from the Distil-Whisper paper, which is a weighted sum of the cross-entropy and
KL-divergence loss terms. The following command takes the ReazonSpeech dataset that was pseudo-labelled in the first stage and trains the
2-layer decoder model initialized in the previous step.
accelerate launch run_distillation.py \
--model_name_or_path "{your-hf-org}/{your-model-name}-init" \
--teacher_model_name_or_path "openai/whisper-large-v3" \
--train_dataset_name "{your-hf-org}/{your-dataset-name}.wer_10.0.vectorized" \
--train_dataset_config_name "tiny" \
--language "ja" \
--max_label_length 128 \
--train_split_name "train" \
--save_steps 2500 \
--warmup_steps "50" \
--learning_rate 0.0001 \
--lr_scheduler_type "constant_with_warmup" \
--logging_steps 50 \
--save_total_limit 1 \
--per_device_train_batch_size 16 \
--gradient_accumulation_steps 2 \
--preprocessing_num_workers 64 \
--dataloader_num_workers 1 \
--dtype "bfloat16" \
--output_dir "./" \
--wandb_project "wandb" \
--gradient_checkpointing \
--freeze_encoder \
--push_to_hub \
--do_train \
--overwrite_output_dir \
--num_train_epochs 8
We evaluate our models for the short form evaluation on audio samples less than 30s in duration. The script
run_short_form_eval.py
can be used to run the evaluation for an audio-transcription paired
dataset. Following example runs evaluation on japanese-asr/ja_asr.reazonspeech_test,
the held-out test split from ReazonSpeech.
python run_eval_pipeline.py -m "{your-hf-org}/{your-model-name}" -d "japanese-asr/ja_asr.reazonspeech_test"
While developing kotoba-whisper models, we have experimented with different split of ReazonSpeech for distillation, and
all the models and datasets for such ablation study can be found at https://huggingface.co/japanese-asr.
Following tables are summaries of WER and CER for the distil-whisper model on different size of ReazonSpeech against
OpenAI whisper models (the model names follow distil-whisper-large-v3-ja-reazonspeech-{size of reazonspeech}
).
Note that kotoba-tech/kotoba-whisper-v1.0
is an alias of japanese-asr/distil-whisper-large-v3-ja-reazonspeech-large
, and
kotoba-tech/kotoba-whisper-v2.0
is an alias of japanese-asr/distil-whisper-large-v3-ja-reazonspeech-all
.
- CER
model | japanese-asr/ja_asr.common_voice_8_0 | japanese-asr/ja_asr.jsut_basic5000 | japanese-asr/ja_asr.reazonspeech_test |
---|---|---|---|
kotoba-tech/kotoba-whisper-v2.1 (punctuator + stable-ts) | 9.3 | 8.4 | 11.3 |
kotoba-tech/kotoba-whisper-v2.1 (punctuator) | 9.3 | 8.4 | 11.3 |
kotoba-tech/kotoba-whisper-v2.1 (stable-ts) | 9.3 | 8.4 | 11.3 |
kotoba-tech/kotoba-whisper-v2.0 | 9.2 | 8.4 | 11.6 |
kotoba-tech/kotoba-whisper-v1.1 (punctuator + stable-ts) | 9.5 | 8.5 | 12.2 |
kotoba-tech/kotoba-whisper-v1.1 (punctuator) | 9.5 | 8.5 | 12.2 |
kotoba-tech/kotoba-whisper-v1.1 (stable-ts) | 9.5 | 8.5 | 12.2 |
kotoba-tech/kotoba-whisper-v1.0 | 9.4 | 8.5 | 12.2 |
japanese-asr/distil-whisper-large-v3-ja-reazonspeech-medium | 10.9 | 11.3 | 14.7 |
japanese-asr/distil-whisper-large-v3-ja-reazonspeech-small | 30.3 | 39.1 | 40.8 |
japanese-asr/distil-whisper-large-v3-ja-reazonspeech-tiny | 94.3 | 96.2 | 96.7 |
openai/whisper-large-v3 | 8.5 | 7.1 | 15.1 |
openai/whisper-large-v2 | 9.7 | 8.2 | 28.5 |
openai/whisper-large | 10 | 8.9 | 34.4 |
openai/whisper-medium | 11.4 | 10 | 33.3 |
openai/whisper-base | 28.2 | 25 | 69.4 |
openai/whisper-small | 15.7 | 14.2 | 40.8 |
openai/whisper-tiny | 58 | 37.6 | 142.2 |
reazon-research/reazonspeech-nemo-v2 | 9.1 | 7.4 | 11.2 |
- WER
model | japanese-asr/ja_asr.common_voice_8_0 | japanese-asr/ja_asr.jsut_basic5000 | japanese-asr/ja_asr.reazonspeech_test |
---|---|---|---|
kotoba-tech/kotoba-whisper-v2.1 (punctuator + stable-ts) | 59.3 | 63.7 | 54.7 |
kotoba-tech/kotoba-whisper-v2.1 (punctuator) | 59.3 | 63.7 | 54.7 |
kotoba-tech/kotoba-whisper-v2.1 (stable-ts) | 59.3 | 63.7 | 54.7 |
kotoba-tech/kotoba-whisper-v2.0 | 58.8 | 63.7 | 55.6 |
kotoba-tech/kotoba-whisper-v1.1 (punctuator + stable-ts) | 59.6 | 64.3 | 55.6 |
kotoba-tech/kotoba-whisper-v1.1 (punctuator) | 59.6 | 64.3 | 55.6 |
kotoba-tech/kotoba-whisper-v1.1 (stable-ts) | 59.6 | 64.3 | 55.6 |
kotoba-tech/kotoba-whisper-v1.0 | 59.3 | 64.4 | 56.5 |
japanese-asr/distil-whisper-large-v3-ja-reazonspeech-medium | 64.6 | 72.1 | 62.9 |
japanese-asr/distil-whisper-large-v3-ja-reazonspeech-small | 85 | 94.2 | 82.1 |
japanese-asr/distil-whisper-large-v3-ja-reazonspeech-tiny | 100 | 100 | 99 |
openai/whisper-large-v3 | 55.3 | 59.2 | 60.3 |
openai/whisper-large-v2 | 59.5 | 63.2 | 74.2 |
openai/whisper-large | 60.9 | 66.5 | 75 |
openai/whisper-medium | 63.4 | 69.6 | 76 |
openai/whisper-base | 87.1 | 92.9 | 91.7 |
openai/whisper-small | 74.2 | 82 | 83.1 |
openai/whisper-tiny | 93.8 | 97.6 | 94.9 |
reazon-research/reazonspeech-nemo-v2 | 57.5 | 60.6 | 47.5 |
- Raw CER
model | japanese-asr/ja_asr.common_voice_8_0 | japanese-asr/ja_asr.jsut_basic5000 | japanese-asr/ja_asr.reazonspeech_test |
---|---|---|---|
kotoba-tech/kotoba-whisper-v2.1 (punctuator + stable-ts) | 13.7 | 11.4 | 17 |
kotoba-tech/kotoba-whisper-v2.1 (punctuator) | 13.8 | 11.6 | 17.3 |
kotoba-tech/kotoba-whisper-v2.1 (stable-ts) | 15.5 | 15.4 | 17 |
kotoba-tech/kotoba-whisper-v2.0 | 15.4 | 15.4 | 17.4 |
kotoba-tech/kotoba-whisper-v1.1 (punctuator + stable-ts) | 13.7 | 11.2 | 17.4 |
kotoba-tech/kotoba-whisper-v1.1 (punctuator) | 13.9 | 11.4 | 18 |
kotoba-tech/kotoba-whisper-v1.1 (stable-ts) | 15.7 | 15 | 17.7 |
kotoba-tech/kotoba-whisper-v1.0 | 15.6 | 15.2 | 17.8 |
japanese-asr/distil-whisper-large-v3-ja-reazonspeech-medium | 17 | 18.4 | 20.2 |
japanese-asr/distil-whisper-large-v3-ja-reazonspeech-small | 34.4 | 43.2 | 44.2 |
japanese-asr/distil-whisper-large-v3-ja-reazonspeech-tiny | 93.7 | 95.1 | 95.6 |
openai/whisper-large-v3 | 12.9 | 13.4 | 20.6 |
openai/whisper-large-v2 | 13.5 | 10.6 | 34.4 |
openai/whisper-large | 14 | 11.2 | 40.4 |
openai/whisper-medium | 15.4 | 13 | 39 |
openai/whisper-base | 31.6 | 26.4 | 74.5 |
openai/whisper-small | 19.8 | 18.8 | 47 |
openai/whisper-tiny | 61.3 | 39.4 | 156.5 |
reazon-research/reazonspeech-nemo-v2 | 12.6 | 10.6 | 15.4 |
- Raw WER
model | japanese-asr/ja_asr.common_voice_8_0 | japanese-asr/ja_asr.jsut_basic5000 | japanese-asr/ja_asr.reazonspeech_test |
---|---|---|---|
kotoba-tech/kotoba-whisper-v2.1 (punctuator + stable-ts) | 87.2 | 85.9 | 81.1 |
kotoba-tech/kotoba-whisper-v2.1 (punctuator) | 87.2 | 86.3 | 81.2 |
kotoba-tech/kotoba-whisper-v2.1 (stable-ts) | 99.8 | 99.5 | 91.4 |
kotoba-tech/kotoba-whisper-v2.0 | 99.6 | 99.4 | 93.4 |
kotoba-tech/kotoba-whisper-v1.1 (punctuator + stable-ts) | 87.2 | 86.2 | 81.1 |
kotoba-tech/kotoba-whisper-v1.1 (punctuator) | 87.3 | 86.4 | 81.3 |
kotoba-tech/kotoba-whisper-v1.1 (stable-ts) | 99.6 | 99.4 | 91.8 |
kotoba-tech/kotoba-whisper-v1.0 | 99.7 | 99.4 | 93.2 |
japanese-asr/distil-whisper-large-v3-ja-reazonspeech-medium | 99.9 | 99.9 | 94.2 |
japanese-asr/distil-whisper-large-v3-ja-reazonspeech-small | 100 | 100 | 97.1 |
japanese-asr/distil-whisper-large-v3-ja-reazonspeech-tiny | 100 | 100 | 100 |
openai/whisper-large-v3 | 91.1 | 98.5 | 92.6 |
openai/whisper-large-v2 | 89.2 | 87.3 | 97.5 |
openai/whisper-large | 91.4 | 88.3 | 98.5 |
openai/whisper-medium | 91.7 | 93.3 | 98 |
openai/whisper-base | 97.5 | 97.8 | 99 |
openai/whisper-small | 95.3 | 98.5 | 98.9 |
openai/whisper-tiny | 99.2 | 99.6 | 99.6 |
reazon-research/reazonspeech-nemo-v2 | 85.7 | 91.5 | 73.4 |