From 8da9acd7e1cc46a0f62c6198495b2e217de7e195 Mon Sep 17 00:00:00 2001 From: JinZr Date: Mon, 21 Oct 2024 17:10:40 +0800 Subject: [PATCH] minor updates --- .../TTS/local/prepare_tokens_libritts.py | 10 ++ egs/libritts/TTS/vits/train.py | 71 +++++---- egs/libritts/TTS/vits/tts_datamodule.py | 144 ++++++++++++++---- 3 files changed, 170 insertions(+), 55 deletions(-) diff --git a/egs/libritts/TTS/local/prepare_tokens_libritts.py b/egs/libritts/TTS/local/prepare_tokens_libritts.py index 6ac42755e1..e2f160b371 100755 --- a/egs/libritts/TTS/local/prepare_tokens_libritts.py +++ b/egs/libritts/TTS/local/prepare_tokens_libritts.py @@ -31,6 +31,14 @@ from tqdm.auto import tqdm +def remove_punc_to_upper(text: str) -> str: + text = text.replace("‘", "'") + text = text.replace("’", "'") + tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'") + s_list = [x.upper() if x in tokens else " " for x in text] + s = " ".join("".join(s_list).split()).strip() + return s + def prepare_tokens_libritts(): output_dir = Path("data/spectrogram") prefix = "libritts" @@ -60,6 +68,8 @@ def prepare_tokens_libritts(): for t in tokens_list: tokens.extend(t) cut.tokens = tokens + cut.supervisions[0].normalized_text = remove_punc_to_upper(text) + new_cuts.append(cut) new_cut_set = CutSet.from_cuts(new_cuts) diff --git a/egs/libritts/TTS/vits/train.py b/egs/libritts/TTS/vits/train.py index 7058f3d769..1d2870ed84 100755 --- a/egs/libritts/TTS/vits/train.py +++ b/egs/libritts/TTS/vits/train.py @@ -21,7 +21,7 @@ import logging from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import k2 import numpy as np @@ -29,6 +29,7 @@ import torch.multiprocessing as mp import torch.nn as nn from lhotse.cut import Cut +from lhotse.features.io import KaldiReader from lhotse.utils import fix_random_seed from tokenizer import Tokenizer from torch.cuda.amp import GradScaler, autocast @@ -331,16 +332,22 @@ def prepare_input( batch: dict, tokenizer: Tokenizer, device: torch.device, - speaker_map: Dict[str, int], + speaker_map: KaldiReader, ): """Parse batch data""" + + def parse_sids(batch: dict) -> List[str]: + return ["_".join(cut.id.split("_")[:2]) for cut in batch["cut"]] + audio = batch["audio"].to(device) features = batch["features"].to(device) audio_lens = batch["audio_lens"].to(device) features_lens = batch["features_lens"].to(device) tokens = batch["tokens"] speakers = ( - torch.Tensor([speaker_map[sid] for sid in batch["speakers"]]).int().to(device) + torch.Tensor(np.array([speaker_map.read(sid) for sid in parse_sids(batch)])) + .squeeze(1) + .to(device) ) tokens = tokenizer.tokens_to_token_ids( @@ -366,8 +373,9 @@ def train_one_epoch( scheduler_g: LRSchedulerType, scheduler_d: LRSchedulerType, train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - speaker_map: Dict[str, int], + dev_dl: torch.utils.data.DataLoader, + train_speaker_map: KaldiReader, + dev_speaker_map: KaldiReader, scaler: GradScaler, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -442,7 +450,7 @@ def save_bad_model(suffix: str = ""): tokens, tokens_lens, speakers, - ) = prepare_input(batch, tokenizer, device, speaker_map) + ) = prepare_input(batch, tokenizer, device, train_speaker_map) loss_info = MetricsTracker() loss_info["samples"] = batch_size @@ -457,7 +465,7 @@ def save_bad_model(suffix: str = ""): feats_lengths=features_lens, speech=audio, speech_lengths=audio_lens, - sids=speakers, + spembs=speakers, forward_generator=False, ) for k, v in stats_d.items(): @@ -476,7 +484,7 @@ def save_bad_model(suffix: str = ""): feats_lengths=features_lens, speech=audio, speech_lengths=audio_lens, - sids=speakers, + spembs=speakers, forward_generator=True, return_sample=params.batch_idx_train % params.log_interval == 0, ) @@ -583,8 +591,8 @@ def save_bad_model(suffix: str = ""): params=params, model=model, tokenizer=tokenizer, - valid_dl=valid_dl, - speaker_map=speaker_map, + dev_dl=dev_dl, + dev_speaker_map=dev_speaker_map, world_size=world_size, ) model.train() @@ -620,8 +628,8 @@ def compute_validation_loss( params: AttributeDict, model: Union[nn.Module, DDP], tokenizer: Tokenizer, - valid_dl: torch.utils.data.DataLoader, - speaker_map: Dict[str, int], + dev_dl: torch.utils.data.DataLoader, + dev_speaker_map: KaldiReader, world_size: int = 1, rank: int = 0, ) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]: @@ -634,7 +642,7 @@ def compute_validation_loss( returned_sample = None with torch.no_grad(): - for batch_idx, batch in enumerate(valid_dl): + for batch_idx, batch in enumerate(dev_dl): batch_size = len(batch["tokens"]) ( audio, @@ -644,7 +652,7 @@ def compute_validation_loss( tokens, tokens_lens, speakers, - ) = prepare_input(batch, tokenizer, device, speaker_map) + ) = prepare_input(batch, tokenizer, device, dev_speaker_map) loss_info = MetricsTracker() loss_info["samples"] = batch_size @@ -657,7 +665,7 @@ def compute_validation_loss( feats_lengths=features_lens, speech=audio, speech_lengths=audio_lens, - sids=speakers, + spembs=speakers, forward_generator=False, ) assert loss_d.requires_grad is False @@ -672,7 +680,7 @@ def compute_validation_loss( feats_lengths=features_lens, speech=audio, speech_lengths=audio_lens, - sids=speakers, + spembs=speakers, forward_generator=True, ) assert loss_g.requires_grad is False @@ -687,7 +695,7 @@ def compute_validation_loss( inner_model = model.module if isinstance(model, DDP) else model audio_pred, _, duration = inner_model.inference( text=tokens[0, : tokens_lens[0].item()], - sids=speakers[0], + spembs=speakers[0], ) audio_pred = audio_pred.data.cpu().numpy() audio_len_pred = ( @@ -717,7 +725,7 @@ def scan_pessimistic_batches_for_oom( tokenizer: Tokenizer, optimizer_g: torch.optim.Optimizer, optimizer_d: torch.optim.Optimizer, - speaker_map: Dict[str, int], + train_speaker_map: KaldiReader, params: AttributeDict, ): from lhotse.dataset import find_pessimistic_batches @@ -737,7 +745,7 @@ def scan_pessimistic_batches_for_oom( tokens, tokens_lens, speakers, - ) = prepare_input(batch, tokenizer, device, speaker_map) + ) = prepare_input(batch, tokenizer, device, train_speaker_map) try: # for discriminator with autocast(enabled=params.use_fp16): @@ -748,7 +756,7 @@ def scan_pessimistic_batches_for_oom( feats_lengths=features_lens, speech=audio, speech_lengths=audio_lens, - sids=speakers, + spembs=speakers, forward_generator=False, ) optimizer_d.zero_grad() @@ -762,7 +770,7 @@ def scan_pessimistic_batches_for_oom( feats_lengths=features_lens, speech=audio, speech_lengths=audio_lens, - sids=speakers, + spembs=speakers, forward_generator=True, ) optimizer_g.zero_grad() @@ -820,9 +828,12 @@ def run(rank, world_size, args): libritts = LibrittsTtsDataModule(args) - train_cuts = libritts.train_cuts() - speaker_map = libritts.speakers() - params.num_spks = len(speaker_map) + if params.full_libri: + train_cuts = libritts.train_all_shuf_cuts() + train_speaker_map = libritts.train_all_shuf_xvector() + else: + train_cuts = libritts.train_clean_460_cuts() + train_speaker_map = libritts.train_clean_460_xvector() logging.info(params) @@ -896,8 +907,9 @@ def remove_short_and_long_utt(c: Cut): train_cuts = train_cuts.filter(remove_short_and_long_utt) train_dl = libritts.train_dataloaders(train_cuts) - valid_cuts = libritts.valid_cuts() - valid_dl = libritts.valid_dataloaders(valid_cuts) + dev_clean_cuts = libritts.dev_clean_cuts() + dev_speaker_map = libritts.dev_clean_xvector() + dev_dl = libritts.dev_dataloaders(dev_clean_cuts) if not params.print_diagnostics: scan_pessimistic_batches_for_oom( @@ -906,7 +918,7 @@ def remove_short_and_long_utt(c: Cut): tokenizer=tokenizer, optimizer_g=optimizer_g, optimizer_d=optimizer_d, - speaker_map=speaker_map, + train_speaker_map=train_speaker_map, params=params, ) @@ -935,8 +947,9 @@ def remove_short_and_long_utt(c: Cut): scheduler_g=scheduler_g, scheduler_d=scheduler_d, train_dl=train_dl, - valid_dl=valid_dl, - speaker_map=speaker_map, + dev_dl=dev_dl, + train_speaker_map=train_speaker_map, + dev_speaker_map=dev_speaker_map, scaler=scaler, tb_writer=tb_writer, world_size=world_size, diff --git a/egs/libritts/TTS/vits/tts_datamodule.py b/egs/libritts/TTS/vits/tts_datamodule.py index 05350603b9..cd1379f38c 100644 --- a/egs/libritts/TTS/vits/tts_datamodule.py +++ b/egs/libritts/TTS/vits/tts_datamodule.py @@ -38,6 +38,7 @@ AudioSamples, OnTheFlyFeatures, ) +from lhotse.features.io import KaldiReader from lhotse.utils import fix_random_seed from torch.utils.data import DataLoader @@ -51,8 +52,10 @@ def __init__(self, seed: int): def __call__(self, worker_id: int): fix_random_seed(self.seed + worker_id) + LIBRITTS_SAMPLING_RATE = 24000 + class LibrittsTtsDataModule: """ DataModule for tts experiments. @@ -82,7 +85,13 @@ def add_arguments(cls, parser: argparse.ArgumentParser): "effective batch sizes, sampling strategies, applied data " "augmentations, etc.", ) - + group.add_argument( + "--full-libri", + type=str2bool, + default=False, + help="""When enabled, use the entire LibriTTS training set. + Otherwise, use the 460h clean subset.""", + ) group.add_argument( "--manifest-dir", type=Path, @@ -90,10 +99,10 @@ def add_arguments(cls, parser: argparse.ArgumentParser): help="Path to directory with train/valid/test cuts.", ) group.add_argument( - "--speakers", + "--speaker-embeds", type=Path, - default=Path("data/speakers.txt"), - help="Path to speakers.txt file.", + default=Path("exp/xvector_nnet_1a/"), + help="Path to directory with speaker embeddings.", ) group.add_argument( "--max-duration", @@ -141,7 +150,7 @@ def add_arguments(cls, parser: argparse.ArgumentParser): group.add_argument( "--return-cuts", type=str2bool, - default=False, + default=True, help="When enabled, each batch will have the " "field: batch['cut'] with the cuts that " "were used to construct it.", @@ -175,7 +184,7 @@ def train_dataloaders( """ logging.info("About to create train dataset") train = SpeechSynthesisDataset( - return_text=False, + return_text=True, return_tokens=True, return_spk_ids=True, feature_input_strategy=eval(self.args.input_strategy)(), @@ -191,7 +200,7 @@ def train_dataloaders( use_fft_mag=True, ) train = SpeechSynthesisDataset( - return_text=False, + return_text=True, return_tokens=True, return_spk_ids=True, feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), @@ -238,7 +247,7 @@ def train_dataloaders( return train_dl - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + def dev_dataloaders(self, cuts_valid: CutSet) -> DataLoader: logging.info("About to create dev dataset") if self.args.on_the_fly_feats: sampling_rate = LIBRITTS_SAMPLING_RATE @@ -249,7 +258,7 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: use_fft_mag=True, ) validate = SpeechSynthesisDataset( - return_text=False, + return_text=True, return_tokens=True, return_spk_ids=True, feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), @@ -257,7 +266,7 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: ) else: validate = SpeechSynthesisDataset( - return_text=False, + return_text=True, return_tokens=True, return_spk_ids=True, feature_input_strategy=eval(self.args.input_strategy)(), @@ -290,7 +299,7 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: use_fft_mag=True, ) test = SpeechSynthesisDataset( - return_text=False, + return_text=True, return_tokens=True, return_spk_ids=True, feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), @@ -298,7 +307,7 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: ) else: test = SpeechSynthesisDataset( - return_text=False, + return_text=True, return_tokens=True, return_spk_ids=True, feature_input_strategy=eval(self.args.input_strategy)(), @@ -319,23 +328,106 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: return test_dl @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_train.jsonl.gz") + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_with_tokens_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def train_clean_460_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100 and train-clean-360 cuts" + ) + return load_manifest_lazy( + self.args.manifest_dir + / "libritts_cuts_with_tokens_train-clean-460.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_with_tokens_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_with_tokens_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_with_tokens_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_with_tokens_test-other.jsonl.gz" + ) + + @lru_cache() + def train_all_shuf_xvector(self) -> KaldiReader: + raise NotImplementedError( + "Please implement the method to load speaker embeddings." + ) + + @lru_cache() + def train_clean_460_xvector(self) -> KaldiReader: + logging.info("About to get speaker embeddings for train-clean-460") + return KaldiReader( + str(self.args.speaker_embeds / "xvectors_train_clean_460" / "feats.scp") + ) + + @lru_cache() + def train_clean_100_xvector(self) -> KaldiReader: + raise NotImplementedError( + "Please implement the method to load speaker embeddings." + ) + + @lru_cache() + def train_clean_360_xvector(self) -> KaldiReader: + raise NotImplementedError( + "Please implement the method to load speaker embeddings." + ) + + @lru_cache() + def train_other_500_xvector(self) -> KaldiReader: + raise NotImplementedError( + "Please implement the method to load speaker embeddings." + ) + + @lru_cache() + def dev_clean_xvector(self) -> KaldiReader: + logging.info("About to get speaker embeddings for dev-clean") + return KaldiReader( + str(self.args.speaker_embeds / "xvectors_dev_clean" / "feats.scp") + ) @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get validation cuts") - return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_valid.jsonl.gz") + def dev_other_xvector(self) -> KaldiReader: + raise NotImplementedError( + "Please implement the method to load speaker embeddings." + ) @lru_cache() - def test_cuts(self) -> CutSet: - logging.info("About to get test cuts") - return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_test.jsonl.gz") + def test_clean_xvector(self) -> KaldiReader: + logging.info("About to get speaker embeddings for test-clean") + return KaldiReader( + str(self.args.speaker_embeds / "xvectors_test_clean" / "feats.scp") + ) @lru_cache() - def speakers(self) -> Dict[str, int]: - logging.info("About to get speakers") - with open(self.args.speakers) as f: - speakers = {line.strip(): i for i, line in enumerate(f)} - return speakers + def test_other_xvector(self) -> KaldiReader: + raise NotImplementedError( + "Please implement the method to load speaker embeddings." + )