diff --git a/README.md b/README.md index 31e514606c..81cfc03ce7 100644 --- a/README.md +++ b/README.md @@ -375,7 +375,7 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [libricss]: egs/libricss/SURT [libriheavy]: egs/libriheavy/ASR [mgb2]: egs/mgb2/ASR -[peoplespeech]: egs/peoplespeech/ASR +[peoplespeech]: egs/peoples_speech/ASR [spgispeech]: egs/spgispeech/ASR [voxpopuli]: egs/voxpopuli/ASR [xbmu-amdo31]: egs/xbmu-amdo31/ASR diff --git a/egs/audioset/AT/RESULTS.md b/egs/audioset/AT/RESULTS.md index 0128b70184..36613db031 100644 --- a/egs/audioset/AT/RESULTS.md +++ b/egs/audioset/AT/RESULTS.md @@ -35,16 +35,40 @@ python zipformer/train.py \ --master-port 13455 ``` +We recommend that you train the model with weighted sampler, as the model converges +faster with better performance: + +| Model | mAP | +| ------ | ------- | +| Zipformer-AT, train with weighted sampler | 46.6 | + The evaluation command is: ```bash -python zipformer/evaluate.py \ - --epoch 32 \ - --avg 8 \ - --exp-dir zipformer/exp_at_as_full \ - --max-duration 500 +export CUDA_VISIBLE_DEVICES="4,5,6,7" +subset=full +weighted_sampler=1 +bucket_sampler=0 +lr_epochs=15 + +python zipformer/train.py \ + --world-size 4 \ + --audioset-subset $subset \ + --num-epochs 120 \ + --start-epoch 1 \ + --use-fp16 1 \ + --num-events 527 \ + --lr-epochs $lr_epochs \ + --exp-dir zipformer/exp_AS_${subset}_weighted_sampler${weighted_sampler} \ + --weighted-sampler $weighted_sampler \ + --bucketing-sampler $bucket_sampler \ + --max-duration 1000 \ + --enable-musan True \ + --master-port 13452 ``` +The command for evaluation is the same. The pre-trained model can be downloaded from https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-M-weighted-sampler + #### small-scaled model, number of model parameters: 22125218, i.e., 22.13 M @@ -92,4 +116,4 @@ python zipformer/evaluate.py \ --encoder-unmasked-dim 192,192,192,192,192,192 \ --exp-dir zipformer/exp_small_at_as_full \ --max-duration 500 -``` \ No newline at end of file +``` diff --git a/egs/audioset/AT/local/compute_weight.py b/egs/audioset/AT/local/compute_weight.py new file mode 100644 index 0000000000..a0deddc0c9 --- /dev/null +++ b/egs/audioset/AT/local/compute_weight.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file generates the manifest and computes the fbank features for AudioSet +dataset. The generated manifests and features are stored in data/fbank. +""" + +import argparse + +import lhotse +from lhotse import load_manifest + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--input-manifest", type=str, default="data/fbank/cuts_audioset_full.jsonl.gz" + ) + + parser.add_argument( + "--output", + type=str, + required=True, + ) + return parser + + +def main(): + # Reference: https://github.com/YuanGongND/ast/blob/master/egs/audioset/gen_weight_file.py + parser = get_parser() + args = parser.parse_args() + + cuts = load_manifest(args.input_manifest) + + print(f"A total of {len(cuts)} cuts.") + + label_count = [0] * 527 # a total of 527 classes + for c in cuts: + audio_event = c.supervisions[0].audio_event + labels = list(map(int, audio_event.split(";"))) + for label in labels: + label_count[label] += 1 + + with open(args.output, "w") as f: + for c in cuts: + audio_event = c.supervisions[0].audio_event + labels = list(map(int, audio_event.split(";"))) + weight = 0 + for label in labels: + weight += 1000 / (label_count[label] + 0.01) + f.write(f"{c.id} {weight}\n") + + +if __name__ == "__main__": + main() diff --git a/egs/audioset/AT/prepare.sh b/egs/audioset/AT/prepare.sh index f7f73a008c..8beaf2d86a 100755 --- a/egs/audioset/AT/prepare.sh +++ b/egs/audioset/AT/prepare.sh @@ -10,6 +10,7 @@ stage=-1 stop_stage=4 dl_dir=$PWD/download +fbank_dir=data/fbank # we assume that you have your downloaded the AudioSet and placed # it under $dl_dir/audioset, the folder structure should look like @@ -49,7 +50,6 @@ fi if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "Stage 0: Construct the audioset manifest and compute the fbank features for balanced set" - fbank_dir=data/fbank if [! -e $fbank_dir/.balanced.done]; then python local/generate_audioset_manifest.py \ --dataset-dir $dl_dir/audioset \ @@ -102,3 +102,14 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then touch data/fbank/.musan.done fi fi + +# The following stages are required to do weighted-sampling training +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Prepare for weighted-sampling training" + if [ ! -e $fbank_dir/cuts_audioset_full.jsonl.gz ]; then + lhotse combine $fbank_dir/cuts_audioset_balanced.jsonl.gz $fbank_dir/cuts_audioset_unbalanced.jsonl.gz $fbank_dir/cuts_audioset_full.jsonl.gz + fi + python ./local/compute_weight.py \ + --input-manifest $fbank_dir/cuts_audioset_full.jsonl.gz \ + --output $fbank_dir/sampling_weights_full.txt +fi diff --git a/egs/audioset/AT/zipformer/at_datamodule.py b/egs/audioset/AT/zipformer/at_datamodule.py index ac8671fa61..b7df015390 100644 --- a/egs/audioset/AT/zipformer/at_datamodule.py +++ b/egs/audioset/AT/zipformer/at_datamodule.py @@ -31,6 +31,7 @@ PrecomputedFeatures, SimpleCutSampler, SpecAugment, + WeightedSimpleCutSampler, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples AudioSamples, @@ -99,6 +100,20 @@ def add_arguments(cls, parser: argparse.ArgumentParser): help="Maximum pooled recordings duration (seconds) in a " "single batch. You can reduce it if it causes CUDA OOM.", ) + group.add_argument( + "--weighted-sampler", + type=str2bool, + default=False, + help="When enabled, samples are drawn from by their weights. " + "It cannot be used together with bucketing sampler", + ) + group.add_argument( + "--num-samples", + type=int, + default=200000, + help="The number of samples to be drawn in each epoch. Only be used" + "for weighed sampler", + ) group.add_argument( "--bucketing-sampler", type=str2bool, @@ -295,6 +310,9 @@ def train_dataloaders( ) if self.args.bucketing_sampler: + assert ( + not self.args.weighted_sampler + ), "weighted sampling is not supported in bucket sampler" logging.info("Using DynamicBucketingSampler.") train_sampler = DynamicBucketingSampler( cuts_train, @@ -304,13 +322,26 @@ def train_dataloaders( drop_last=self.args.drop_last, ) else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - drop_last=self.args.drop_last, - ) + if self.args.weighted_sampler: + # assert self.args.audioset_subset == "full", "Only use weighted sampling for full audioset" + logging.info("Using weighted SimpleCutSampler") + weights = self.audioset_sampling_weights() + train_sampler = WeightedSimpleCutSampler( + cuts_train, + weights, + num_samples=self.args.num_samples, + max_duration=self.args.max_duration, + shuffle=False, # do not support shuffle + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + drop_last=self.args.drop_last, + ) logging.info("About to create train dataloader") if sampler_state_dict is not None: @@ -373,11 +404,9 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") test = AudioTaggingDataset( - input_strategy=( - OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)() - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, ) sampler = DynamicBucketingSampler( @@ -397,21 +426,30 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: @lru_cache() def audioset_train_cuts(self) -> CutSet: logging.info("About to get the audioset training cuts.") - balanced_cuts = load_manifest_lazy( - self.args.manifest_dir / "cuts_audioset_balanced.jsonl.gz" - ) - if self.args.audioset_subset == "full": - unbalanced_cuts = load_manifest_lazy( - self.args.manifest_dir / "cuts_audioset_unbalanced.jsonl.gz" - ) - cuts = CutSet.mux( - balanced_cuts, - unbalanced_cuts, - weights=[20000, 2000000], - stop_early=True, + if not self.args.weighted_sampler: + balanced_cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_audioset_balanced.jsonl.gz" ) + if self.args.audioset_subset == "full": + unbalanced_cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_audioset_unbalanced.jsonl.gz" + ) + cuts = CutSet.mux( + balanced_cuts, + unbalanced_cuts, + weights=[20000, 2000000], + stop_early=True, + ) + else: + cuts = balanced_cuts else: - cuts = balanced_cuts + # assert self.args.audioset_subset == "full", "Only do weighted sampling for full AudioSet" + cuts = load_manifest( + self.args.manifest_dir + / f"cuts_audioset_{self.args.audioset_subset}.jsonl.gz" + ) + logging.info(f"Get {len(cuts)} cuts in total.") + return cuts @lru_cache() @@ -420,3 +458,22 @@ def audioset_eval_cuts(self) -> CutSet: return load_manifest_lazy( self.args.manifest_dir / "cuts_audioset_eval.jsonl.gz" ) + + @lru_cache() + def audioset_sampling_weights(self): + logging.info( + f"About to get the sampling weight for {self.args.audioset_subset} in AudioSet" + ) + weights = [] + with open( + self.args.manifest_dir / f"sample_weights_{self.args.audioset_subset}.txt", + "r", + ) as f: + while True: + line = f.readline() + if not line: + break + weight = float(line.split()[1]) + weights.append(weight) + logging.info(f"Get the sampling weight for {len(weights)} cuts") + return weights diff --git a/egs/audioset/AT/zipformer/train.py b/egs/audioset/AT/zipformer/train.py index 2d193030a8..67c7033642 100644 --- a/egs/audioset/AT/zipformer/train.py +++ b/egs/audioset/AT/zipformer/train.py @@ -789,12 +789,14 @@ def save_bad_model(suffix: str = ""): rank=0, ) + num_samples = 0 for batch_idx, batch in enumerate(train_dl): if batch_idx % 10 == 0: set_batch_count(model, get_adjusted_batch_count(params)) params.batch_idx_train += 1 batch_size = batch["inputs"].size(0) + num_samples += batch_size try: with torch.cuda.amp.autocast(enabled=params.use_fp16): @@ -919,6 +921,12 @@ def save_bad_model(suffix: str = ""): tb_writer, "train/valid_", params.batch_idx_train ) + if num_samples > params.num_samples: + logging.info( + f"Number of training samples exceeds {params.num_samples} in this epoch, move on to next epoch" + ) + break + loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value if params.train_loss < params.best_train_loss: @@ -1032,7 +1040,8 @@ def remove_short_and_long_utt(c: Cut): return True - train_cuts = train_cuts.filter(remove_short_and_long_utt) + if not params.weighted_sampler: + train_cuts = train_cuts.filter(remove_short_and_long_utt) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint diff --git a/egs/libriheavy/ASR/local/prepare_manifest.py b/egs/libriheavy/ASR/local/prepare_manifest.py index 42f392cae4..d7e184d863 100755 --- a/egs/libriheavy/ASR/local/prepare_manifest.py +++ b/egs/libriheavy/ASR/local/prepare_manifest.py @@ -29,17 +29,21 @@ def simple_cleanup(text: str) -> str: # Assign text of the supervisions and remove unnecessary entries. def main(): - assert len(sys.argv) == 3, "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR" + assert ( + len(sys.argv) == 4 + ), "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR KEEP_CUSTOM_FIELDS" fname = Path(sys.argv[1]).name oname = Path(sys.argv[2]) / fname + keep_custom_fields = bool(sys.argv[3]) with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout: for line in fin: cut = json.loads(line) cut["supervisions"][0]["text"] = simple_cleanup( cut["supervisions"][0]["custom"]["texts"][0] ) - del cut["supervisions"][0]["custom"] - del cut["custom"] + if not keep_custom_fields: + del cut["supervisions"][0]["custom"] + del cut["custom"] fout.write((json.dumps(cut) + "\n").encode()) diff --git a/egs/libriheavy/ASR/prepare.sh b/egs/libriheavy/ASR/prepare.sh index b0736c98ba..366a1459f4 100755 --- a/egs/libriheavy/ASR/prepare.sh +++ b/egs/libriheavy/ASR/prepare.sh @@ -29,6 +29,11 @@ export CUDA_VISIBLE_DEVICES="" # - speech dl_dir=$PWD/download +# If you want to do PromptASR experiments, please set it to True +# as this will keep the texts and pre_text information required for +# the training of PromptASR. +keep_custom_fields=False + . shared/parse_options.sh || exit 1 # vocab size for sentence piece models. @@ -134,7 +139,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then for subset in small medium large dev test_clean test_other; do if [ ! -e $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then log "Prepare manifest for subset : ${subset}" - ./local/prepare_manifest.py $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz $manifests_dir + ./local/prepare_manifest.py $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz $manifests_dir $keep_custom_fields fi done fi diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 66b147764b..bc7d8a5efb 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -307,6 +307,23 @@ done To decode with external language models, please refer to the documentation [here](https://k2-fsa.github.io/icefall/decoding-with-langugage-models/index.html). +We also support training Zipformer with AMP+bf16 format (requires bf16 support). See [here](https://github.com/k2-fsa/icefall/pull/1700) for more details and pre-trained models. **The same command can be used for decoding and exporting the model.** + +The amp+bf16 training command is: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 0 \ + --use-bf16 1 \ + --exp-dir zipformer/exp_amp_bf16 \ + --causal 0 \ + --full-libri 1 \ + --max-duration 1000 +``` + ##### small-scaled model, number of model parameters: 23285615, i.e., 23.3 M The tensorboard log can be found at diff --git a/egs/librispeech/ASR/zipformer/ctc_decode.py b/egs/librispeech/ASR/zipformer/ctc_decode.py index 435a79e7fc..9db4299592 100755 --- a/egs/librispeech/ASR/zipformer/ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/ctc_decode.py @@ -120,6 +120,7 @@ import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule +from lhotse import set_caching_enabled from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( @@ -296,6 +297,13 @@ def get_parser(): """, ) + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""" + ) + add_model_arguments(parser) return parser @@ -455,7 +463,7 @@ def decode_one_batch( # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] hyps = [s.split() for s in hyps] key = "ctc-decoding" - return {key: hyps} + return {key: hyps} # note: returns words if params.decoding_method == "attention-decoder-rescoring-no-ngram": best_path_dict = rescore_with_attention_decoder_no_ngram( @@ -492,7 +500,7 @@ def decode_one_batch( ) hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] - key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + key = f"oracle_{params.num_paths}_nbest-scale-{params.nbest_scale}" # noqa return {key: hyps} if params.decoding_method in ["1best", "nbest"]: @@ -500,7 +508,7 @@ def decode_one_batch( best_path = one_best_decoding( lattice=lattice, use_double_scores=params.use_double_scores ) - key = "no_rescore" + key = "no-rescore" else: best_path = nbest_decoding( lattice=lattice, @@ -508,11 +516,11 @@ def decode_one_batch( use_double_scores=params.use_double_scores, nbest_scale=params.nbest_scale, ) - key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + key = f"no-rescore_nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] - return {key: hyps} + return {key: hyps} # note: returns BPE tokens assert params.decoding_method in [ "nbest-rescoring", @@ -646,7 +654,27 @@ def decode_dataset( return results -def save_results( +def save_asr_output( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + """ + Save text produced by ASR. + """ + for key, results in results_dict.items(): + + recogs_filename = ( + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + ) + + results = sorted(results) + store_transcripts(filename=recogs_filename, texts=results) + + logging.info(f"The transcripts are stored in {recogs_filename}") + + +def save_wer_results( params: AttributeDict, test_set_name: str, results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], @@ -661,32 +689,30 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - if enable_log: - logging.info(f"The transcripts are stored in {recog_path}") - # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}-{key}", results) + with open(errs_filename, "w", encoding="utf8") as fd: + wer = write_error_stats( + fd, f"{test_set_name}_{key}", results, enable_log=enable_log + ) test_set_wers[key] = wer - if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + logging.info(f"Wrote detailed error stats to {errs_filename}") test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) + + wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) + print(f"{key}\t{val}", file=fd) - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) + s += f"{key}\t{val}{note}\n" note = "" logging.info(s) @@ -705,6 +731,9 @@ def main(): params.update(get_decoding_params()) params.update(vars(args)) + # enable AudioCache + set_caching_enabled(True) # lhotse + assert params.decoding_method in ( "ctc-greedy-search", "ctc-decoding", @@ -719,9 +748,9 @@ def main(): params.res_dir = params.exp_dir / params.decoding_method if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" + params.suffix = f"iter-{params.iter}_avg-{params.avg}" else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" if params.causal: assert ( @@ -730,11 +759,11 @@ def main(): assert ( "," not in params.left_context_frames ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" if params.use_averaged_model: - params.suffix += "-use-averaged-model" + params.suffix += "_use-averaged-model" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -940,12 +969,19 @@ def main(): G=G, ) - save_results( + save_asr_output( params=params, test_set_name=test_set, results_dict=results_dict, ) + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + logging.info("Done!") diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index df2d555a09..cbfb3728e6 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -121,6 +121,7 @@ modified_beam_search_lm_shallow_fusion, modified_beam_search_LODR, ) +from lhotse import set_caching_enabled from train import add_model_arguments, get_model, get_params from icefall import ContextGraph, LmScorer, NgramLm @@ -369,6 +370,14 @@ def get_parser(): modified_beam_search_LODR. """, ) + + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""", + ) + add_model_arguments(parser) return parser @@ -590,21 +599,23 @@ def decode_one_batch( ) hyps.append(sp.decode(hyp).split()) + # prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" ) + prefix = f"{params.decoding_method}" if params.decoding_method == "greedy_search": return {"greedy_search": hyps} elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" + prefix += f"_beam-{params.beam}" + prefix += f"_max-contexts-{params.max_contexts}" + prefix += f"_max-states-{params.max_states}" if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" + prefix += f"_num-paths-{params.num_paths}" + prefix += f"_nbest-scale-{params.nbest_scale}" if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + prefix += f"_ngram-lm-scale-{params.ngram_lm_scale}" - return {key: hyps} + return {prefix: hyps} elif "modified_beam_search" in params.decoding_method: - prefix = f"beam_size_{params.beam_size}" + prefix += f"_beam-size-{params.beam_size}" if params.decoding_method in ( "modified_beam_search_lm_rescore", "modified_beam_search_lm_rescore_LODR", @@ -617,10 +628,11 @@ def decode_one_batch( return ans else: if params.has_contexts: - prefix += f"-context-score-{params.context_score}" + prefix += f"_context-score-{params.context_score}" return {prefix: hyps} else: - return {f"beam_size_{params.beam_size}": hyps} + prefix += f"_beam-size-{params.beam_size}" + return {prefix: hyps} def decode_dataset( @@ -707,46 +719,58 @@ def decode_dataset( return results -def save_results( +def save_asr_output( params: AttributeDict, test_set_name: str, results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): - test_set_wers = dict() + """ + Save text produced by ASR. + """ for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) + + recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") + store_transcripts(filename=recogs_filename, texts=results) + + logging.info(f"The transcripts are stored in {recogs_filename}") + +def save_wer_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]], +): + """ + Save WER and per-utterance word alignments. + """ + test_set_wers = dict() + for key, results in results_dict.items(): # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w", encoding="utf8") as fd: wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True + fd, f"{test_set_name}-{key}", results, enable_log=True ) test_set_wers[key] = wer - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info(f"Wrote detailed error stats to {errs_filename}") test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) + + wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) + print(f"{key}\t{val}", file=fd) - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) + s += f"{key}\t{val}{note}\n" note = "" logging.info(s) @@ -762,6 +786,9 @@ def main(): params = get_params() params.update(vars(args)) + # enable AudioCache + set_caching_enabled(True) # lhotse + assert params.decoding_method in ( "greedy_search", "beam_search", @@ -783,9 +810,9 @@ def main(): params.has_contexts = False if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" + params.suffix = f"iter-{params.iter}_avg-{params.avg}" else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" if params.causal: assert ( @@ -794,20 +821,20 @@ def main(): assert ( "," not in params.left_context_frames ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" + params.suffix += f"_beam-{params.beam}" + params.suffix += f"_max-contexts-{params.max_contexts}" + params.suffix += f"_max-states-{params.max_states}" if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" + params.suffix += f"_nbest-scale-{params.nbest_scale}" + params.suffix += f"_num-paths-{params.num_paths}" if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + params.suffix += f"_ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += f"__{params.decoding_method}__beam-size-{params.beam_size}" if params.decoding_method in ( "modified_beam_search", "modified_beam_search_LODR", @@ -815,19 +842,19 @@ def main(): if params.has_contexts: params.suffix += f"-context-score-{params.context_score}" else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + params.suffix += f"_context-{params.context_size}" + params.suffix += f"_max-sym-per-frame-{params.max_sym_per_frame}" if params.use_shallow_fusion: - params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + params.suffix += f"_{params.lm_type}-lm-scale-{params.lm_scale}" if "LODR" in params.decoding_method: params.suffix += ( - f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" ) if params.use_averaged_model: - params.suffix += "-use-averaged-model" + params.suffix += "_use-averaged-model" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -1038,12 +1065,19 @@ def main(): ngram_lm_scale=ngram_lm_scale, ) - save_results( + save_asr_output( params=params, test_set_name=test_set, results_dict=results_dict, ) + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + logging.info("Done!") diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py index ed8a0ef0fd..ca3cbf0d59 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-onnx.py @@ -218,7 +218,7 @@ def forward( - encoder_out_lens, A 1-D tensor of shape (N,) """ x, x_lens = self.encoder_embed(x, x_lens) - src_key_padding_mask = make_pad_mask(x_lens) + src_key_padding_mask = make_pad_mask(x_lens, x.shape[1]) x = x.permute(1, 0, 2) encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) encoder_out = encoder_out.permute(1, 0, 2) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 164cc7bfd0..2a40b8d643 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -297,7 +297,7 @@ def forward(ctx, x: Tensor, dim: int): # (presumably) that op does not support float16, and autocast # is enabled. if torch.is_autocast_enabled(): - ans = ans.to(torch.float16) + ans = ans.to(torch.get_autocast_gpu_dtype()) ctx.save_for_backward(ans) ctx.x_dtype = x.dtype ctx.dim = dim @@ -1234,7 +1234,7 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor) -> Tensor: requires_grad = x.requires_grad - if x.dtype == torch.float16: + if x.dtype == torch.float16 or x.dtype == torch.bfloat16: x = x.to(torch.float32) s = torch.sigmoid(x - 1.0) @@ -1346,7 +1346,7 @@ class SwooshLFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor) -> Tensor: requires_grad = x.requires_grad - if x.dtype == torch.float16: + if x.dtype == torch.float16 or x.dtype == torch.bfloat16: x = x.to(torch.float32) zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) @@ -1379,7 +1379,7 @@ def forward(ctx, x: Tensor) -> Tensor: d_int = d_scaled.to(torch.uint8) ctx.save_for_backward(d_int) if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) + y = y.to(torch.get_autocast_gpu_dtype()) return y @staticmethod @@ -1425,7 +1425,7 @@ class SwooshRFunction(torch.autograd.Function): def forward(ctx, x: Tensor) -> Tensor: requires_grad = x.requires_grad - if x.dtype == torch.float16: + if x.dtype == torch.float16 or x.dtype == torch.bfloat16: x = x.to(torch.float32) zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) @@ -1455,7 +1455,7 @@ def forward(ctx, x: Tensor) -> Tensor: d_int = d_scaled.to(torch.uint8) ctx.save_for_backward(d_int) if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) + y = y.to(torch.get_autocast_gpu_dtype()) return y @staticmethod diff --git a/egs/librispeech/ASR/zipformer/streaming_decode.py b/egs/librispeech/ASR/zipformer/streaming_decode.py index 360523b8eb..ebcafbf873 100755 --- a/egs/librispeech/ASR/zipformer/streaming_decode.py +++ b/egs/librispeech/ASR/zipformer/streaming_decode.py @@ -43,7 +43,7 @@ from asr_datamodule import LibriSpeechAsrDataModule from decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet +from lhotse import CutSet, set_caching_enabled from streaming_beam_search import ( fast_beam_search_one_best, greedy_search, @@ -76,6 +76,13 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter ) + parser.add_argument( + "--label", + type=str, + default="", + help="""Extra label of the decoding run.""", + ) + parser.add_argument( "--epoch", type=int, @@ -188,6 +195,14 @@ def get_parser(): help="The number of streams that can be decoded parallel.", ) + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""" + ) + + add_model_arguments(parser) return parser @@ -640,46 +655,60 @@ def decode_dataset( return {key: decode_results} -def save_results( +def save_asr_output( params: AttributeDict, test_set_name: str, results_dict: Dict[str, List[Tuple[List[str], List[str]]]], ): - test_set_wers = dict() + """ + Save text produced by ASR. + """ for key, results in results_dict.items(): - recog_path = ( + recogs_filename = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") + store_transcripts(filename=recogs_filename, texts=results) + logging.info(f"The transcripts are stored in {recogs_filename}") + +def save_wer_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + """ + Save WER and per-utterance word alignments. + """ + test_set_wers = dict() + for key, results in results_dict.items(): # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" ) - with open(errs_filename, "w") as f: + with open(errs_filename, "w", encoding="utf8") as fd: wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True + fd, f"{test_set_name}-{key}", results, enable_log=True ) test_set_wers[key] = wer - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info(f"Wrote detailed error stats to {errs_filename}") test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( + + wer_filename = ( params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) + print(f"{key}\t{val}", file=fd) - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) + s += f"{key}\t{val}{note}\n" note = "" logging.info(s) @@ -694,6 +723,9 @@ def main(): params = get_params() params.update(vars(args)) + # enable AudioCache + set_caching_enabled(True) # lhotse + params.res_dir = params.exp_dir / "streaming" / params.decoding_method if params.iter > 0: @@ -706,18 +738,21 @@ def main(): assert ( "," not in params.left_context_frames ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" # for fast_beam_search if params.decoding_method == "fast_beam_search": - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" + params.suffix += f"_beam-{params.beam}" + params.suffix += f"_max-contexts-{params.max_contexts}" + params.suffix += f"_max-states-{params.max_states}" if params.use_averaged_model: params.suffix += "-use-averaged-model" + if params.label: + params.suffix += f"-{params.label}" + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -845,12 +880,21 @@ def main(): decoding_graph=decoding_graph, ) - save_results( + + save_asr_output( params=params, test_set_name=test_set, results_dict=results_dict, ) + + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + logging.info("Done!") diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 9b6f4a93aa..9c1c7f5a78 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -521,6 +521,13 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--use-bf16", + type=str2bool, + default=False, + help="Whether to use bf16 in AMP.", + ) + add_model_arguments(parser) return parser @@ -1027,7 +1034,9 @@ def save_bad_model(suffix: str = ""): batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.cuda.amp.autocast( + enabled=params.use_autocast, dtype=params.dtype + ): loss, loss_info = compute_loss( params=params, model=model, @@ -1047,9 +1056,7 @@ def save_bad_model(suffix: str = ""): scaler.update() optimizer.zero_grad() except Exception as e: - logging.info( - f"Caught exception: {e}." - ) + logging.info(f"Caught exception: {e}.") save_bad_model() display_and_save_batch(batch, params=params, sp=sp) raise @@ -1090,7 +1097,7 @@ def save_bad_model(suffix: str = ""): rank=rank, ) - if batch_idx % 100 == 0 and params.use_fp16: + if batch_idx % 100 == 0 and params.use_autocast: # If the grad scale was less than 1, try increasing it. The _growth_interval # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. @@ -1109,14 +1116,14 @@ def save_bad_model(suffix: str = ""): if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0 logging.info( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + + (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "") ) if tb_writer is not None: @@ -1128,7 +1135,7 @@ def save_bad_model(suffix: str = ""): tb_writer, "train/current_", params.batch_idx_train ) tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: + if params.use_autocast: tb_writer.add_scalar( "train/grad_scale", cur_grad_scale, params.batch_idx_train ) @@ -1204,9 +1211,25 @@ def run(rank, world_size, args): params.ctc_loss_scale = 1.0 else: assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, ( - params.ctc_loss_scale, params.attention_decoder_loss_scale + params.ctc_loss_scale, + params.attention_decoder_loss_scale, ) + if params.use_bf16: # amp + bf16 + assert torch.cuda.is_bf16_supported(), "Your GPU does not support bf16!" + assert not params.use_fp16, "You can only use either fp16 or bf16" + params.dtype = torch.bfloat16 + params.use_autocast = True + elif params.use_fp16: # amp + fp16 + params.dtype = torch.float16 + params.use_autocast = True + else: # fp32 + params.dtype = torch.float32 + params.use_autocast = False + + logging.info(f"Using dtype={params.dtype}") + logging.info(f"Use AMP={params.use_autocast}") + logging.info(params) logging.info("About to create model") @@ -1339,7 +1362,7 @@ def remove_short_and_long_utt(c: Cut): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1439,7 +1462,9 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.cuda.amp.autocast( + enabled=params.use_autocast, dtype=params.dtype + ): loss, _ = compute_loss( params=params, model=model,