diff --git a/egs/aishell/ASR/conformer_ctc/decode.py b/egs/aishell/ASR/conformer_ctc/decode.py index 2cb476e208..8dd6725365 100755 --- a/egs/aishell/ASR/conformer_ctc/decode.py +++ b/egs/aishell/ASR/conformer_ctc/decode.py @@ -366,13 +366,14 @@ def decode_dataset( num_cuts = 0 - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" + # try: + # num_batches = len(dl) + # except TypeError: + # num_batches = "?" results = defaultdict(list) for batch_idx, batch in enumerate(dl): + batch = batch[0] texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] @@ -399,9 +400,8 @@ def decode_dataset( num_cuts += len(batch["supervisions"]["text"]) if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + # batch_str = f"{batch_idx}/{num_batches}" + logging.info(f"batch {batch_idx}, cuts processed until now is {num_cuts}") return results @@ -547,20 +547,19 @@ def main(): test_sets = ["test"] test_dls = [test_dl] + # for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + lexicon=lexicon, + sos_id=sos_id, + eos_id=eos_id, + ) - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - HLG=HLG, - H=H, - lexicon=lexicon, - sos_id=sos_id, - eos_id=eos_id, - ) - - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results(params=params, test_set_name=test_sets[0], results_dict=results_dict) logging.info("Done!") diff --git a/egs/aishell/ASR/conformer_ctc/train.py b/egs/aishell/ASR/conformer_ctc/train.py index c2cbe6e3b0..df52cffca3 100755 --- a/egs/aishell/ASR/conformer_ctc/train.py +++ b/egs/aishell/ASR/conformer_ctc/train.py @@ -22,9 +22,9 @@ from shutil import copyfile from typing import Optional, Tuple +import os import k2 import torch -import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import AishellAsrDataModule from conformer import Conformer @@ -543,13 +543,9 @@ def train_one_epoch( params.best_train_loss = params.train_loss -def run(rank, world_size, args): +def run(world_size, args): """ Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. world_size: Number of GPUs for DDP training. args: @@ -560,13 +556,14 @@ def run(rank, world_size, args): fix_random_seed(params.seed) if world_size > 1: - setup_dist(rank, world_size, params.master_port) - + setup_dist(use_ddp_launch=True, master_addr=params.master_port) setup_logger(f"{params.exp_dir}/log/log-train") logging.info("Training started") - logging.info(params) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + if local_rank == 0: + logging.info(params) - if args.tensorboard and rank == 0: + if args.tensorboard and local_rank == 0: tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") else: tb_writer = None @@ -577,7 +574,7 @@ def run(rank, world_size, args): device = torch.device("cpu") if torch.cuda.is_available(): - device = torch.device("cuda", rank) + device = torch.device("cuda", local_rank) graph_compiler = CharCtcTrainingGraphCompiler( lexicon=lexicon, @@ -603,7 +600,8 @@ def run(rank, world_size, args): model.to(device) if world_size > 1: - model = DDP(model, device_ids=[rank]) + torch.distributed.barrier() # Ensure all processes have the same model parameters + model = DDP(model, device_ids=[local_rank]) optimizer = Noam( model.parameters(), @@ -629,7 +627,7 @@ def run(rank, world_size, args): tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - if rank == 0: + if local_rank == 0: logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) params.cur_epoch = epoch @@ -644,12 +642,14 @@ def run(rank, world_size, args): tb_writer=tb_writer, world_size=world_size, ) + if world_size > 1: + torch.distributed.barrier() save_checkpoint( params=params, model=model, optimizer=optimizer, - rank=rank, + rank=local_rank, ) logging.info("Done!") @@ -668,10 +668,7 @@ def main(): world_size = args.world_size assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) + run(world_size=world_size, args=args) torch.set_num_threads(1) diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py index aacbd153de..0571dddb78 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -23,6 +23,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional +from lhotse.cut import MonoCut from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy from lhotse.dataset import ( CutConcatenate, @@ -180,7 +181,34 @@ def add_arguments(cls, parser: argparse.ArgumentParser): help="When enabled, select noise from MUSAN and mix it" "with training dataset. ", ) + def to_dict(self, obj): + """ + Recursively convert an object and its nested objects to dictionaries. + """ + if isinstance(obj, (str, int, float, bool, type(None))): + return obj + elif isinstance(obj, list): + return [to_dict(item) for item in obj] + elif isinstance(obj, dict): + return {key: to_dict(value) for key, value in obj.items()} + elif hasattr(obj, '__dict__'): + return {key: to_dict(value) for key, value in obj.__dict__.items()} + else: + raise TypeError(f"Unsupported type: {type(obj)}") + def my_collate_fn(self, batch): + """ + Convert MonoCut to dict. + """ + return_batch = [] + for item in batch: + if isinstance(item, MonoCut): + processed_item = self.to_dict(item) + return_batch.append(processed_item) + elif isinstance(item, dict): + return_batch.append(item) + return return_batch + def train_dataloaders( self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None ) -> DataLoader: @@ -354,9 +382,10 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: ) test_dl = DataLoader( test, - batch_size=None, + batch_size=100, # specified to some value sampler=sampler, - num_workers=self.args.num_workers, + num_workers=4, # if larger, it will be more time-consuming for decoding, may stuck + collate_fn=self.my_collate_fn ) return test_dl