Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modified aishell/ASR/conformer_ctc/train.py, which implemented multi-machine DDP. #1845

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 19 additions & 20 deletions egs/aishell/ASR/conformer_ctc/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,13 +366,14 @@ def decode_dataset(

num_cuts = 0

try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
# try:
# num_batches = len(dl)
# except TypeError:
# num_batches = "?"

results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
batch = batch[0]
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]

Expand All @@ -399,9 +400,8 @@ def decode_dataset(
num_cuts += len(batch["supervisions"]["text"])

if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"

logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
# batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_idx}, cuts processed until now is {num_cuts}")
return results


Expand Down Expand Up @@ -547,20 +547,19 @@ def main():

test_sets = ["test"]
test_dls = [test_dl]
# for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
HLG=HLG,
H=H,
lexicon=lexicon,
sos_id=sos_id,
eos_id=eos_id,
)

for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
HLG=HLG,
H=H,
lexicon=lexicon,
sos_id=sos_id,
eos_id=eos_id,
)

save_results(params=params, test_set_name=test_set, results_dict=results_dict)
save_results(params=params, test_set_name=test_sets[0], results_dict=results_dict)

logging.info("Done!")

Expand Down
33 changes: 15 additions & 18 deletions egs/aishell/ASR/conformer_ctc/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from shutil import copyfile
from typing import Optional, Tuple

import os
import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
from conformer import Conformer
Expand Down Expand Up @@ -543,13 +543,9 @@ def train_one_epoch(
params.best_train_loss = params.train_loss


def run(rank, world_size, args):
def run(world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
Expand All @@ -560,13 +556,14 @@ def run(rank, world_size, args):

fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)

setup_dist(use_ddp_launch=True, master_addr=params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
logging.info(params)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
if local_rank == 0:
logging.info(params)

if args.tensorboard and rank == 0:
if args.tensorboard and local_rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
Expand All @@ -577,7 +574,7 @@ def run(rank, world_size, args):

device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
device = torch.device("cuda", local_rank)

graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
Expand All @@ -603,7 +600,8 @@ def run(rank, world_size, args):

model.to(device)
if world_size > 1:
model = DDP(model, device_ids=[rank])
torch.distributed.barrier() # Ensure all processes have the same model parameters
model = DDP(model, device_ids=[local_rank])

optimizer = Noam(
model.parameters(),
Expand All @@ -629,7 +627,7 @@ def run(rank, world_size, args):
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)

if rank == 0:
if local_rank == 0:
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))

params.cur_epoch = epoch
Expand All @@ -644,12 +642,14 @@ def run(rank, world_size, args):
tb_writer=tb_writer,
world_size=world_size,
)
if world_size > 1:
torch.distributed.barrier()

save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
rank=rank,
rank=local_rank,
)

logging.info("Done!")
Expand All @@ -668,10 +668,7 @@ def main():

world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
run(world_size=world_size, args=args)


torch.set_num_threads(1)
Expand Down
33 changes: 31 additions & 2 deletions egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pathlib import Path
from typing import Any, Dict, List, Optional

from lhotse.cut import MonoCut
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import (
CutConcatenate,
Expand Down Expand Up @@ -180,7 +181,34 @@ def add_arguments(cls, parser: argparse.ArgumentParser):
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
def to_dict(self, obj):
"""
Recursively convert an object and its nested objects to dictionaries.
"""
if isinstance(obj, (str, int, float, bool, type(None))):
return obj
elif isinstance(obj, list):
return [to_dict(item) for item in obj]
elif isinstance(obj, dict):
return {key: to_dict(value) for key, value in obj.items()}
elif hasattr(obj, '__dict__'):
return {key: to_dict(value) for key, value in obj.__dict__.items()}
else:
raise TypeError(f"Unsupported type: {type(obj)}")

def my_collate_fn(self, batch):
"""
Convert MonoCut to dict.
"""
return_batch = []
for item in batch:
if isinstance(item, MonoCut):
processed_item = self.to_dict(item)
return_batch.append(processed_item)
elif isinstance(item, dict):
return_batch.append(item)
return return_batch

def train_dataloaders(
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
) -> DataLoader:
Expand Down Expand Up @@ -354,9 +382,10 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
)
test_dl = DataLoader(
test,
batch_size=None,
batch_size=100, # specified to some value
sampler=sampler,
num_workers=self.args.num_workers,
num_workers=4, # if larger, it will be more time-consuming for decoding, may stuck
collate_fn=self.my_collate_fn
)
return test_dl

Expand Down
Loading