diff --git a/egs/libritts/TTS/vocos/discriminators.py b/egs/libritts/TTS/vocos/discriminators.py new file mode 100644 index 0000000000..238b974239 --- /dev/null +++ b/egs/libritts/TTS/vocos/discriminators.py @@ -0,0 +1,296 @@ +from typing import List, Optional, Tuple + +import torch +from torch import nn +from torch.nn import Conv2d +from torch.nn.utils import weight_norm +from torchaudio.transforms import Spectrogram + + +class MultiPeriodDiscriminator(nn.Module): + """ + Multi-Period Discriminator module adapted from https://github.com/jik876/hifi-gan. + Additionally, it allows incorporating conditional information with a learned embeddings table. + + Args: + periods (tuple[int]): Tuple of periods for each discriminator. + num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator. + Defaults to None. + """ + + def __init__( + self, + periods: Tuple[int, ...] = (2, 3, 5, 7, 11), + num_embeddings: Optional[int] = None, + ): + super().__init__() + self.discriminators = nn.ModuleList( + [DiscriminatorP(period=p, num_embeddings=num_embeddings) for p in periods] + ) + + def forward( + self, + y: torch.Tensor, + y_hat: torch.Tensor, + bandwidth_id: Optional[torch.Tensor] = None, + ) -> Tuple[ + List[torch.Tensor], + List[torch.Tensor], + List[List[torch.Tensor]], + List[List[torch.Tensor]], + ]: + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for d in self.discriminators: + y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) + y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorP(nn.Module): + def __init__( + self, + period: int, + in_channels: int = 1, + kernel_size: int = 5, + stride: int = 3, + lrelu_slope: float = 0.1, + num_embeddings: Optional[int] = None, + ): + super().__init__() + self.period = period + self.convs = nn.ModuleList( + [ + weight_norm( + Conv2d( + in_channels, + 32, + (kernel_size, 1), + (stride, 1), + padding=(kernel_size // 2, 0), + ) + ), + weight_norm( + Conv2d( + 32, + 128, + (kernel_size, 1), + (stride, 1), + padding=(kernel_size // 2, 0), + ) + ), + weight_norm( + Conv2d( + 128, + 512, + (kernel_size, 1), + (stride, 1), + padding=(kernel_size // 2, 0), + ) + ), + weight_norm( + Conv2d( + 512, + 1024, + (kernel_size, 1), + (stride, 1), + padding=(kernel_size // 2, 0), + ) + ), + weight_norm( + Conv2d( + 1024, + 1024, + (kernel_size, 1), + (1, 1), + padding=(kernel_size // 2, 0), + ) + ), + ] + ) + if num_embeddings is not None: + self.emb = torch.nn.Embedding( + num_embeddings=num_embeddings, embedding_dim=1024 + ) + torch.nn.init.zeros_(self.emb.weight) + + self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + self.lrelu_slope = lrelu_slope + + def forward( + self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + x = x.unsqueeze(1) + fmap = [] + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = torch.nn.functional.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for i, l in enumerate(self.convs): + x = l(x) + x = torch.nn.functional.leaky_relu(x, self.lrelu_slope) + if i > 0: + fmap.append(x) + if cond_embedding_id is not None: + emb = self.emb(cond_embedding_id) + h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) + else: + h = 0 + x = self.conv_post(x) + fmap.append(x) + x += h + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiResolutionDiscriminator(nn.Module): + def __init__( + self, + fft_sizes: Tuple[int, ...] = (2048, 1024, 512), + num_embeddings: Optional[int] = None, + ): + """ + Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec. + Additionally, it allows incorporating conditional information with a learned embeddings table. + + Args: + fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512). + num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator. + Defaults to None. + """ + + super().__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorR(window_length=w, num_embeddings=num_embeddings) + for w in fft_sizes + ] + ) + + def forward( + self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None + ) -> Tuple[ + List[torch.Tensor], + List[torch.Tensor], + List[List[torch.Tensor]], + List[List[torch.Tensor]], + ]: + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for d in self.discriminators: + y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) + y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorR(nn.Module): + def __init__( + self, + window_length: int, + num_embeddings: Optional[int] = None, + channels: int = 32, + hop_factor: float = 0.25, + bands: Tuple[Tuple[float, float], ...] = ( + (0.0, 0.1), + (0.1, 0.25), + (0.25, 0.5), + (0.5, 0.75), + (0.75, 1.0), + ), + ): + super().__init__() + self.window_length = window_length + self.hop_factor = hop_factor + self.spec_fn = Spectrogram( + n_fft=window_length, + hop_length=int(window_length * hop_factor), + win_length=window_length, + power=None, + ) + n_fft = window_length // 2 + 1 + bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] + self.bands = bands + convs = lambda: nn.ModuleList( + [ + weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))), + weight_norm( + nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4)) + ), + weight_norm( + nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4)) + ), + weight_norm( + nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4)) + ), + weight_norm( + nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1)) + ), + ] + ) + self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) + + if num_embeddings is not None: + self.emb = torch.nn.Embedding( + num_embeddings=num_embeddings, embedding_dim=channels + ) + torch.nn.init.zeros_(self.emb.weight) + + self.conv_post = weight_norm( + nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)) + ) + + def spectrogram(self, x): + # Remove DC offset + x = x - x.mean(dim=-1, keepdims=True) + # Peak normalize the volume of input audio + x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9) + x = self.spec_fn(x) + x = torch.view_as_real(x) + # x = rearrange(x, "b f t c -> b c t f") + x = x.permute(0, 3, 2, 1) + # Split into bands + x_bands = [x[..., b[0] : b[1]] for b in self.bands] + return x_bands + + def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None): + x_bands = self.spectrogram(x) + fmap = [] + x = [] + for band, stack in zip(x_bands, self.band_convs): + for i, layer in enumerate(stack): + band = layer(band) + band = torch.nn.functional.leaky_relu(band, 0.1) + if i > 0: + fmap.append(band) + x.append(band) + x = torch.cat(x, dim=-1) + if cond_embedding_id is not None: + emb = self.emb(cond_embedding_id) + h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) + else: + h = 0 + x = self.conv_post(x) + fmap.append(x) + x += h + + return x, fmap diff --git a/egs/libritts/TTS/vocos/generator.py b/egs/libritts/TTS/vocos/generator.py new file mode 100644 index 0000000000..6e1dcdc4c7 --- /dev/null +++ b/egs/libritts/TTS/vocos/generator.py @@ -0,0 +1,257 @@ +import torch +from torch import nn + +from typing import Optional + + +class AdaLayerNorm(nn.Module): + """ + Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes + + Args: + num_embeddings (int): Number of embeddings. + embedding_dim (int): Dimension of the embeddings. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.dim = embedding_dim + self.scale = nn.Embedding( + num_embeddings=num_embeddings, embedding_dim=embedding_dim + ) + self.shift = nn.Embedding( + num_embeddings=num_embeddings, embedding_dim=embedding_dim + ) + torch.nn.init.ones_(self.scale.weight) + torch.nn.init.zeros_(self.shift.weight) + + def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor: + scale = self.scale(cond_embedding_id) + shift = self.shift(cond_embedding_id) + x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps) + x = x * scale + shift + return x + + +class ISTFT(nn.Module): + """ + Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with + windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. + See issue: https://github.com/pytorch/pytorch/issues/62323 + Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. + + Args: + n_fft (int): Size of Fourier transform. + hop_length (int): The distance between neighboring sliding window frames. + win_length (int): The size of window frame and STFT filter. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__( + self, n_fft: int, hop_length: int, win_length: int, padding: str = "same" + ): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + window = torch.hann_window(win_length) + self.register_buffer("window", window) + + def forward(self, spec: torch.Tensor) -> torch.Tensor: + """ + Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. + + Args: + spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, + N is the number of frequency bins, and T is the number of time frames. + + Returns: + Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. + """ + if self.padding == "center": + # Fallback to pytorch native implementation + return torch.istft( + spec, + self.n_fft, + self.hop_length, + self.win_length, + self.window, + center=True, + ) + elif self.padding == "same": + pad = (self.win_length - self.hop_length) // 2 + else: + raise ValueError("Padding must be 'center' or 'same'.") + + assert spec.dim() == 3, "Expected a 3D tensor as input" + B, N, T = spec.shape + + # Inverse FFT + ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") + ifft = ifft * self.window[None, :, None] + + # Overlap and Add + output_size = (T - 1) * self.hop_length + self.win_length + y = torch.nn.functional.fold( + ifft, + output_size=(1, output_size), + kernel_size=(1, self.win_length), + stride=(1, self.hop_length), + )[:, 0, 0, :] + + # Window envelope + window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) + window_envelope = torch.nn.functional.fold( + window_sq, + output_size=(1, output_size), + kernel_size=(1, self.win_length), + stride=(1, self.hop_length), + ).squeeze() + + # Normalize + norm_indexes = window_envelope > 1e-11 + y[:, norm_indexes] = y[:, norm_indexes] / window_envelope[norm_indexes] + + return y + + +class ConvNeXtBlock(nn.Module): + """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. + + Args: + dim (int): Number of input channels. + intermediate_dim (int): Dimensionality of the intermediate layer. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional LayerNorm. Defaults to None. + """ + + def __init__( + self, + dim: int, + intermediate_dim: int, + layer_scale_init_value: Optional[float] = None, + adanorm_num_embeddings: Optional[int] = None, + ): + super().__init__() + self.dwconv = nn.Conv1d( + dim, dim, kernel_size=7, padding=3, groups=dim + ) # depthwise conv + self.adanorm = adanorm_num_embeddings is not None + if adanorm_num_embeddings: + self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) + else: + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, intermediate_dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(intermediate_dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + + def forward( + self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None + ) -> torch.Tensor: + residual = x + x = self.dwconv(x) + x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) + if self.adanorm: + assert cond_embedding_id is not None + x = self.norm(x, cond_embedding_id) + else: + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) + + x = residual + x + return x + + +class Generator(torch.nn.Module): + def __init__( + self, + feature_dim: int = 80, + dim: int = 512, + n_fft: int = 1024, + hop_length: int = 256, + intermediate_dim: int = 1536, + num_layers: int = 8, + padding: str = "same", + layer_scale_init_value: Optional[float] = None, + adanorm_num_embeddings: Optional[int] = None, + ): + super(Generator, self).__init__() + self.feature_dim = feature_dim + self.embed = nn.Conv1d(feature_dim, dim, kernel_size=7, padding=3) + + self.adanorm = adanorm_num_embeddings is not None + if adanorm_num_embeddings: + self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) + else: + self.norm = nn.LayerNorm(dim, eps=1e-6) + + layer_scale_init_value = layer_scale_init_value or 1 / num_layers + self.convnext = nn.ModuleList( + [ + ConvNeXtBlock( + dim=dim, + intermediate_dim=intermediate_dim, + layer_scale_init_value=layer_scale_init_value, + adanorm_num_embeddings=adanorm_num_embeddings, + ) + for _ in range(num_layers) + ] + ) + + self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6) + self.apply(self._init_weights) + + self.out_proj = torch.nn.Linear(dim, n_fft + 2) + self.istft = ISTFT( + n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding + ) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv1d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + bandwidth_id = kwargs.get("bandwidth_id", None) + x = self.embed(x) + if self.adanorm: + assert bandwidth_id is not None + x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id) + else: + x = self.norm(x.transpose(1, 2)) + + x = x.transpose(1, 2) + for conv_block in self.convnext: + x = conv_block(x, cond_embedding_id=bandwidth_id) + + x = self.final_layer_norm(x.transpose(1, 2)) + + x = self.out_proj(x).transpose(1, 2) + mag, p = x.chunk(2, dim=1) + mag = torch.exp(mag) + mag = torch.clip( + mag, max=1e2 + ) # safeguard to prevent excessively large magnitudes + x = torch.cos(p) + y = torch.sin(p) + S = mag * (x + 1j * y) + audio = self.istft(S) + return audio diff --git a/egs/libritts/TTS/vocos/infer.py b/egs/libritts/TTS/vocos/infer.py new file mode 100644 index 0000000000..70e0aa6f00 --- /dev/null +++ b/egs/libritts/TTS/vocos/infer.py @@ -0,0 +1,342 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Wei Kang +# Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import logging +import math +import os +from functools import partial +from pathlib import Path + +import torch +import torch.nn as nn +from lhotse.utils import fix_random_seed +from scipy.io.wavfile import write +from train import add_model_arguments, get_model, get_params +from tts_datamodule import LJSpeechTtsDataModule + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import AttributeDict, setup_logger, str2bool + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=100, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=10, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=False, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="flow_match/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--generate-dir", + type=str, + default="generated_wavs", + help="Path name of the generated wavs", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + batch: dict, +): + """ + Args: + params: + It's the return value of :func:`get_params`. + model: + The text-to-feature neural model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + + cut_ids = [cut.id for cut in batch["cut"]] + + features = batch["features"] # (B, T, F) + utt_durations = batch["features_lens"] + + x = features.permute(0, 2, 1) # (B, F, T) + + audios = model(x.to(device)) # (B, T) + + wav_dir = f"{params.res_dir}/{params.suffix}" + os.makedirs(wav_dir, exist_ok=True) + + for i in range(audios.shape[0]): + audio = audios[i][ + : int(utt_durations[i] * params.frame_shift_ms / 1000 * 22050) + ] + audio = audio.cpu().squeeze().numpy() + write(f"{wav_dir}/{cut_ids[i]}.wav", 22050, audio) + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + test_set: str, +): + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The text-to-feature neural model. + test_set: + The name of the test_set + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + with open(f"{params.res_dir}/{test_set}.scp", "w", encoding="utf8") as f: + for batch_idx, batch in enumerate(dl): + texts = batch["text"] + cut_ids = [cut.id for cut in batch["cut"]] + + decode_one_batch( + params=params, + model=model, + batch=batch, + ) + + assert len(texts) == len(cut_ids), (len(texts), len(cut_ids)) + + for i in range(len(texts)): + f.write(f"{cut_ids[i]}\t{texts[i]}\n") + + num_cuts += len(texts) + + if batch_idx % 50 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + + +@torch.no_grad() +def main(): + parser = get_parser() + LJSpeechTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / params.generate_dir + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + params.device = device + + logging.info(f"Device: {device}") + + logging.info(params) + fix_random_seed(666) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model = model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + ljspeech = LJSpeechTtsDataModule(args) + + test_cuts = ljspeech.test_cuts() + + test_dl = ljspeech.test_dataloaders(test_cuts) + + test_sets = ["test"] + test_dls = [test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + decode_dataset( + dl=test_dl, + params=params, + model=model, + test_set=test_set, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/libritts/TTS/vocos/loss.py b/egs/libritts/TTS/vocos/loss.py new file mode 100644 index 0000000000..1092c6f3db --- /dev/null +++ b/egs/libritts/TTS/vocos/loss.py @@ -0,0 +1,133 @@ +from typing import List, Tuple + +import torch +import torchaudio +from torch import nn + +from utils import safe_log + + +class MelSpecReconstructionLoss(nn.Module): + """ + L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample + """ + + def __init__( + self, + sample_rate: int = 24000, + n_fft: int = 1024, + hop_length: int = 256, + n_mels: int = 100, + ): + super().__init__() + self.mel_spec = torchaudio.transforms.MelSpectrogram( + sample_rate=sample_rate, + n_fft=n_fft, + hop_length=hop_length, + n_mels=n_mels, + center=True, + power=1, + ) + + def forward(self, y_hat, y) -> torch.Tensor: + """ + Args: + y_hat (Tensor): Predicted audio waveform. + y (Tensor): Ground truth audio waveform. + + Returns: + Tensor: L1 loss between the mel-scaled magnitude spectrograms. + """ + mel_hat = safe_log(self.mel_spec(y_hat)) + mel = safe_log(self.mel_spec(y)) + + loss = torch.nn.functional.l1_loss(mel, mel_hat) + + return loss + + +class GeneratorLoss(nn.Module): + """ + Generator Loss module. Calculates the loss for the generator based on discriminator outputs. + """ + + def forward( + self, disc_outputs: List[torch.Tensor] + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Args: + disc_outputs (List[Tensor]): List of discriminator outputs. + + Returns: + Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from + the sub-discriminators + """ + loss = torch.zeros( + 1, device=disc_outputs[0].device, dtype=disc_outputs[0].dtype + ) + gen_losses = [] + for dg in disc_outputs: + l = torch.mean(torch.clamp(1 - dg, min=0)) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + +class DiscriminatorLoss(nn.Module): + """ + Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs. + """ + + def forward( + self, + disc_real_outputs: List[torch.Tensor], + disc_generated_outputs: List[torch.Tensor], + ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: + """ + Args: + disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples. + disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples. + + Returns: + Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from + the sub-discriminators for real outputs, and a list of + loss values for generated outputs. + """ + loss = torch.zeros( + 1, device=disc_real_outputs[0].device, dtype=disc_real_outputs[0].dtype + ) + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean(torch.clamp(1 - dr, min=0)) + g_loss = torch.mean(torch.clamp(1 + dg, min=0)) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +class FeatureMatchingLoss(nn.Module): + """ + Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators. + """ + + def forward( + self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]] + ) -> torch.Tensor: + """ + Args: + fmap_r (List[List[Tensor]]): List of feature maps from real samples. + fmap_g (List[List[Tensor]]): List of feature maps from generated samples. + + Returns: + Tensor: The calculated feature matching loss. + """ + loss = torch.zeros(1, device=fmap_r[0][0].device, dtype=fmap_r[0][0].dtype) + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss diff --git a/egs/ljspeech/TTS/vocos/models.py b/egs/libritts/TTS/vocos/model.py similarity index 76% rename from egs/ljspeech/TTS/vocos/models.py rename to egs/libritts/TTS/vocos/model.py index 5dbadbad85..30c906ef95 100644 --- a/egs/ljspeech/TTS/vocos/models.py +++ b/egs/libritts/TTS/vocos/model.py @@ -1,8 +1,7 @@ import logging import torch -from backbone import VocosBackbone -from heads import ISTFTHead from discriminators import MultiPeriodDiscriminator, MultiResolutionDiscriminator +from generator import Generator from loss import ( DiscriminatorLoss, GeneratorLoss, @@ -14,26 +13,23 @@ class Vocos(torch.nn.Module): def __init__( self, + feature_dim: int = 80, dim: int = 512, n_fft: int = 1024, hop_length: int = 256, - feature_dim: int = 80, intermediate_dim: int = 1536, num_layers: int = 8, padding: str = "same", - sample_rate: int = 22050, + sample_rate: int = 24000, ): super(Vocos, self).__init__() - self.backbone = VocosBackbone( - input_channels=feature_dim, - dim=dim, - intermediate_dim=intermediate_dim, - num_layers=num_layers, - ) - self.head = ISTFTHead( + self.generator = Generator( + feature_dim=feature_dim, dim=dim, n_fft=n_fft, hop_length=hop_length, + num_layers=num_layers, + intermediate_dim=intermediate_dim, padding=padding, ) @@ -46,6 +42,5 @@ def __init__( self.melspec_loss = MelSpecReconstructionLoss(sample_rate=sample_rate) def forward(self, features: torch.Tensor): - x = self.backbone(features) - audio_output = self.head(x) - return audio_output + audio = self.generator(features) + return audio diff --git a/egs/libritts/TTS/vocos/train.py b/egs/libritts/TTS/vocos/train.py new file mode 100755 index 0000000000..c00afecdb7 --- /dev/null +++ b/egs/libritts/TTS/vocos/train.py @@ -0,0 +1,996 @@ +#!/usr/bin/env python3 +# Copyright 2023-2024 Xiaomi Corp. (authors: Zengwei Yao, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union +import itertools +import json +import copy +import math +import os +import random +import time + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor +import torch.multiprocessing as mp +import torch.nn as nn +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import LibriTTSDataModule + +from torch.optim.lr_scheduler import ExponentialLR, LRScheduler +from torch.optim import Optimizer + +from utils import ( + load_checkpoint, + save_checkpoint, + plot_spectrogram, + get_cosine_schedule_with_warmup, +) + +from icefall import diagnostics +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + get_parameter_groups_with_lrs, +) +from model import Vocos +from lhotse import Fbank, FbankConfig + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-layers", + type=int, + default=8, + help="Number of ConvNeXt layers.", + ) + + parser.add_argument( + "--hidden-dim", + type=int, + default=512, + help="Hidden dim of ConvNeXt module.", + ) + + parser.add_argument( + "--intermediate-dim", + type=int, + default=1536, + help="Intermediate dim of ConvNeXt module.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=100, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vocos/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--learning-rate", type=float, default=0.0005, help="The learning rate." + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=500, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--mrd-loss-scale", + type=float, + default=0.1, + help="The scale of MultiResolutionDiscriminator loss.", + ) + + parser.add_argument( + "--mel-loss-scale", + type=float, + default=45, + help="The scale of melspectrogram loss.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 500, + "feature_dim": 80, + "segment_size": 16384, + "adam_b1": 0.8, + "adam_b2": 0.9, + "warmup_steps": 0, + "max_steps": 2000000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_model(params: AttributeDict) -> nn.Module: + device = params.device + model = Vocos( + feature_dim=params.feature_dim, + dim=params.hidden_dim, + n_fft=params.frame_length, + hop_length=params.frame_shift, + intermediate_dim=params.intermediate_dim, + num_layers=params.num_layers, + sample_rate=params.sampling_rate, + ).to(device) + + num_param_gen = sum([p.numel() for p in model.generator.parameters()]) + logging.info(f"Number of Generator parameters : {num_param_gen}") + num_param_mpd = sum([p.numel() for p in model.mpd.parameters()]) + logging.info(f"Number of MultiPeriodDiscriminator parameters : {num_param_mpd}") + num_param_mrd = sum([p.numel() for p in model.mrd.parameters()]) + logging.info(f"Number of MultiResolutionDiscriminator parameters : {num_param_mrd}") + logging.info( + f"Number of model parameters : {num_param_gen + num_param_mpd + num_param_mrd}" + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer_g: Optional[Optimizer] = None, + optimizer_d: Optional[Optimizer] = None, + scheduler_g: Optional[LRScheduler] = None, + scheduler_d: Optional[LRScheduler] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def compute_generator_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + features: Tensor, + audios: Tensor, +) -> Tuple[Tensor, MetricsTracker]: + device = params.device + model = model.module if isinstance(model, DDP) else model + + audios_hat = model(features) # (B, T) + + mel_loss = model.melspec_loss(audios_hat, audios) + + _, gen_score_mpd, fmap_rs_mpd, fmap_gs_mpd = model.mpd(y=audios, y_hat=audios_hat) + _, gen_score_mrd, fmap_rs_mrd, fmap_gs_mrd = model.mrd(y=audios, y_hat=audios_hat) + + loss_gen_mpd, list_loss_gen_mpd = model.gen_loss(disc_outputs=gen_score_mpd) + loss_gen_mrd, list_loss_gen_mrd = model.gen_loss(disc_outputs=gen_score_mrd) + + loss_gen_mpd = loss_gen_mpd / len(list_loss_gen_mpd) + loss_gen_mrd = loss_gen_mrd / len(list_loss_gen_mrd) + + loss_fm_mpd = model.feat_matching_loss( + fmap_r=fmap_rs_mpd, fmap_g=fmap_gs_mpd + ) / len(fmap_rs_mpd) + loss_fm_mrd = model.feat_matching_loss( + fmap_r=fmap_rs_mrd, fmap_g=fmap_gs_mrd + ) / len(fmap_rs_mrd) + + loss_gen_all = ( + loss_gen_mpd + + params.mrd_loss_scale * loss_gen_mrd + + loss_fm_mpd + + params.mrd_loss_scale * loss_fm_mrd + + params.mel_loss_scale * mel_loss + ) + + assert loss_gen_all.requires_grad == True + + info = MetricsTracker() + info["frames"] = 1 + info["loss_gen"] = loss_gen_all.detach().cpu().item() + info["loss_mel"] = mel_loss.detach().cpu().item() + info["loss_feature_mpd"] = loss_fm_mpd.detach().cpu().item() + info["loss_feature_mrd"] = loss_fm_mrd.detach().cpu().item() + info["loss_gen_mrd"] = loss_gen_mrd.detach().cpu().item() + info["loss_gen_mpd"] = loss_gen_mpd.detach().cpu().item() + + return loss_gen_all, info + + +def compute_discriminator_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + features: Tensor, + audios: Tensor, +) -> Tuple[Tensor, MetricsTracker]: + device = params.device + model = model.module if isinstance(model, DDP) else model + + with torch.no_grad(): + audios_hat = model(features) # (B, 1, T) + + real_score_mpd, gen_score_mpd, _, _ = model.mpd(y=audios, y_hat=audios_hat) + real_score_mrd, gen_score_mrd, _, _ = model.mrd(y=audios, y_hat=audios_hat) + loss_mpd, loss_mpd_real, loss_mpd_gen = model.disc_loss( + disc_real_outputs=real_score_mpd, disc_generated_outputs=gen_score_mpd + ) + loss_mrd, loss_mrd_real, loss_mrd_gen = model.disc_loss( + disc_real_outputs=real_score_mrd, disc_generated_outputs=gen_score_mrd + ) + loss_mpd /= len(loss_mpd_real) + loss_mrd /= len(loss_mrd_real) + + loss_disc_all = loss_mpd + params.mrd_loss_scale * loss_mrd + + info = MetricsTracker() + # MetricsTracker will norm the loss value with "frames", set it to 1 here to + # make tot_loss look normal. + info["frames"] = 1 + info["loss_disc"] = loss_disc_all.detach().cpu().item() + info["loss_disc_mrd"] = loss_mrd.detach().cpu().item() + info["loss_disc_mpd"] = loss_mpd.detach().cpu().item() + + for i in range(len(loss_mpd_real)): + info[f"loss_disc_mpd_period_{i+1}"] = loss_mpd_real[i] + loss_mpd_gen[i] + for i in range(len(loss_mrd_real)): + info[f"loss_disc_mrd_resolution_{i+1}"] = loss_mrd_real[i] + loss_mrd_gen[i] + + return loss_disc_all, info + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer_g: Optimizer, + optimizer_d: Optimizer, + scheduler_g: ExponentialLR, + scheduler_d: ExponentialLR, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer. + scheduler: + The learning rate scheduler, we call step() every epoch. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to track the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + params=params, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = batch["features_lens"].size(0) + + features = batch["features"].to(device) # (B, T, F) + features_lens = batch["features_lens"].to(device) + audios = batch["audio"].to(device) + + segment_frames = ( + params.segment_size - params.frame_length + ) // params.frame_shift + 1 + + start_p = random.randint(0, features_lens.min() - (segment_frames + 1)) + + features = features[:, start_p : start_p + segment_frames, :].permute( + 0, 2, 1 + ) # (B, F, T) + + audios = audios[ + :, + start_p * params.frame_shift : start_p * params.frame_shift + + params.segment_size, + ] # (B, T) + + try: + optimizer_d.zero_grad() + + loss_disc, loss_disc_info = compute_discriminator_loss( + params=params, + model=model, + features=features, + audios=audios, + ) + + loss_disc.backward() + optimizer_d.step() + + optimizer_g.zero_grad() + loss_gen, loss_gen_info = compute_generator_loss( + params=params, + model=model, + features=features, + audios=audios, + ) + + loss_gen.backward() + optimizer_g.step() + + loss_info = loss_gen_info + loss_disc_info + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_gen_info + + except Exception as e: + logging.info(f"Caught exception : {e}.") + save_bad_model() + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if params.batch_idx_train % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or ( + cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 + ): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if params.batch_idx_train % params.log_interval == 0: + cur_lr_g = max(scheduler_g.get_last_lr()) + cur_lr_d = max(scheduler_d.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " + f"loss[{loss_info}], tot_loss[{tot_loss}], " + f"cur_lr_g: {cur_lr_g:.2e}, " + f"cur_lr_d: {cur_lr_d:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate_gen", cur_lr_g, params.batch_idx_train + ) + tb_writer.add_scalar( + "train/learning_rate_disc", cur_lr_d, params.batch_idx_train + ) + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if ( + params.batch_idx_train % params.valid_interval == 0 + and not params.print_diagnostics + ): + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + valid_dl=valid_dl, + world_size=world_size, + rank=rank, + tb_writer=tb_writer, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + scheduler_g.step() + scheduler_d.step() + loss_value = tot_loss["loss_gen"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, + rank: int = 0, + tb_writer: Optional[SummaryWriter] = None, +) -> MetricsTracker: + """Run the validation process.""" + + model.eval() + torch.cuda.empty_cache() + model = model.module if isinstance(model, DDP) else model + device = next(model.parameters()).device + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + + with torch.no_grad(): + infer_time = 0 + audio_time = 0 + for batch_idx, batch in enumerate(valid_dl): + features = batch["features"] # (B, T, F) + features_lens = batch["features_lens"] + + audio_time += torch.sum(features_lens) + + x = features.permute(0, 2, 1) # (B, F, T) + y = batch["audio"].to(device) # (B, T) + + start = time.time() + y_g_hat = model(x.to(device)) # (B, T) + infer_time += time.time() - start + + if y_g_hat.size(1) > y.size(1): + y = torch.cat( + [ + y, + torch.zeros( + (y.size(0), y_g_hat.size(1) - y.size(1)), device=device + ), + ], + dim=1, + ) + else: + y = y[:, 0 : y_g_hat.size(1)] + + loss_mel_error = model.melspec_loss(y_g_hat, y) + + loss_info = MetricsTracker() + # MetricsTracker will norm the loss value with "frames", set it to 1 here to + # make tot_loss look normal. + loss_info["frames"] = 1 + loss_info["loss_mel_error"] = loss_mel_error.item() + + tot_loss = tot_loss + loss_info + + if batch_idx <= 5 and rank == 0 and tb_writer is not None: + if params.batch_idx_train == params.valid_interval: + tb_writer.add_audio( + "gt/y_{}".format(batch_idx), + y[0], + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_audio( + "generated/y_hat_{}".format(batch_idx), + y_g_hat[0], + params.batch_idx_train, + params.sampling_rate, + ) + + logging.info(f"RTF : {infer_time / (audio_time * 10 / 1000)}") + + if world_size > 1: + tot_loss.reduce(device) + + loss_value = tot_loss["loss_mel_error"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + torch.autograd.set_detect_anomaly(True) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + params.device = device + logging.info(params) + logging.info("About to create model") + + model = get_model(params) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model = model.to(device) + generator = model.generator + mrd = model.mrd + mpd = model.mpd + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer_g = torch.optim.AdamW( + generator.parameters(), + params.learning_rate, + betas=[params.adam_b1, params.adam_b2], + ) + optimizer_d = torch.optim.AdamW( + itertools.chain(mrd.parameters(), mpd.parameters()), + params.learning_rate, + betas=[params.adam_b1, params.adam_b2], + ) + + scheduler_g = get_cosine_schedule_with_warmup( + optimizer_g, + num_warmup_steps=params.warmup_steps, + num_training_steps=params.max_steps, + ) + scheduler_d = get_cosine_schedule_with_warmup( + optimizer_d, + num_warmup_steps=params.warmup_steps, + num_training_steps=params.max_steps, + ) + + if checkpoints is not None: + # load state_dict for optimizers + if "optimizer_g" in checkpoints: + logging.info("Loading generator optimizer state dict") + optimizer_g.load_state_dict(checkpoints["optimizer_g"]) + if "optimizer_d" in checkpoints: + logging.info("Loading discriminator optimizer state dict") + optimizer_d.load_state_dict(checkpoints["optimizer_d"]) + + # load state_dict for schedulers + if "scheduler_g" in checkpoints: + logging.info("Loading generator scheduler state dict") + scheduler_g.load_state_dict(checkpoints["scheduler_g"]) + if "scheduler_d" in checkpoints: + logging.info("Loading discriminator scheduler state dict") + scheduler_d.load_state_dict(checkpoints["scheduler_d"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + libritts = LibriTTSDataModule(args) + + train_cuts = libritts.train_clean_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + return False + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_dl = libritts.train_dataloaders(train_cuts) + + valid_cuts = libritts.dev_clean_cuts() + valid_dl = libritts.valid_dataloaders(valid_cuts) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"Start epoch {epoch}") + + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + params.cur_epoch = epoch + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + train_one_epoch( + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + if params.batch_idx_train % params.save_every_n == 0: + filename = params.exp_dir / f"checkpoint-{params.batch_idx_train}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + LibriTTSDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/libritts/TTS/vocos/tts_datamodule.py b/egs/libritts/TTS/vocos/tts_datamodule.py new file mode 100644 index 0000000000..a65fd28f8b --- /dev/null +++ b/egs/libritts/TTS/vocos/tts_datamodule.py @@ -0,0 +1,419 @@ +# Copyright 2021 Piotr ลปelasko +# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, + SpeechSynthesisDataset, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LibriTTSDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--return-text", + type=str2bool, + default=True, + help="Whether to return the text of the audio.", + ) + group.add_argument( + "--return-tokens", + type=str2bool, + default=False, + help="Whether the return the tokens of the text of the audio.", + ) + group.add_argument( + "--num-workers", + type=int, + default=4, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--sampling-rate", + type=int, + default=24000, + help="The sampleing rate of libritts dataset", + ) + + group.add_argument( + "--frame-shift", + type=int, + default=256, + help="Frame shift.", + ) + + group.add_argument( + "--frame-length", + type=int, + default=1024, + help="Frame shift.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + group.add_argument( + "--use-fft-mag", + type=str2bool, + default=True, + help="Whether to use magnitude of fbank, false to use power energy.", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=self.args.return_text, + return_tokens=self.args.return_tokens, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + sampling_rate = self.args.sampling_rate + config = FbankConfig( + sampling_rate=sampling_rate, + frame_length=self.args.frame_length / sampling_rate, # (in second), + frame_shift=self.args.frame_shift / sampling_rate, # (in second) + use_fft_mag=self.args.use_fft_mag, + ) + train = SpeechSynthesisDataset( + return_text=self.args.return_text, + return_tokens=self.args.return_tokens, + feature_input_strategy=OnTheFlyFeatures(Fbank(config)), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + sampling_rate = self.args.sampling_rate + config = FbankConfig( + sampling_rate=sampling_rate, + frame_length=self.args.frame_length / sampling_rate, # (in second), + frame_shift=self.args.frame_shift / sampling_rate, # (in second) + use_fft_mag=self.args.use_fft_mag, + ) + validate = SpeechSynthesisDataset( + return_text=self.args.return_text, + return_tokens=self.args.return_tokens, + feature_input_strategy=OnTheFlyFeatures(Fbank(config)), + return_cuts=self.args.return_cuts, + ) + else: + validate = SpeechSynthesisDataset( + return_text=self.args.return_text, + return_tokens=self.args.return_tokens, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create valid dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + sampling_rate = self.args.sampling_rate + config = FbankConfig( + sampling_rate=sampling_rate, + frame_length=self.args.frame_length / sampling_rate, # (in second), + frame_shift=self.args.frame_shift / sampling_rate, # (in second) + use_fft_mag=self.args.use_fft_mag, + ) + test = SpeechSynthesisDataset( + return_text=self.args.return_text, + return_tokens=self.args.return_tokens, + feature_input_strategy=OnTheFlyFeatures(Fbank(config)), + return_cuts=self.args.return_cuts, + ) + else: + test = SpeechSynthesisDataset( + return_text=self.args.return_text, + return_tokens=self.args.return_tokens, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def train_clean_cuts(self) -> CutSet: + logging.info("About to get train clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_train-clean-460.jsonl.gz" + ) + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train clean 100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_train-clean-100.jsonl.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train clean 360 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_train-clean-360.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz" + ) + + @lru_cache() + def train_cuts_finetune(self) -> CutSet: + logging.info("About to get train cuts finetune") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_train_finetune.jsonl.gz" + ) + + @lru_cache() + def valid_cuts_finetune(self) -> CutSet: + logging.info("About to get validation cuts finetune") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_valid_finetune.jsonl.gz" + ) diff --git a/egs/libritts/TTS/vocos/utils.py b/egs/libritts/TTS/vocos/utils.py new file mode 100644 index 0000000000..c0fb107331 --- /dev/null +++ b/egs/libritts/TTS/vocos/utils.py @@ -0,0 +1,219 @@ +import glob +import os +import logging +import matplotlib +import math +import torch +import torch.nn as nn +from functools import partial +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union +from torch.nn.utils import weight_norm +from torch.optim.lr_scheduler import LRScheduler +from torch.optim import Optimizer +from torch.cuda.amp import GradScaler +from lhotse.dataset.sampling.base import CutSampler +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR + + +matplotlib.use("Agg") +import matplotlib.pylab as plt + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def load_checkpoint( + filename: Path, + model: nn.Module, + model_avg: Optional[nn.Module] = None, + optimizer_g: Optional[Optimizer] = None, + optimizer_d: Optional[Optimizer] = None, + scheduler_g: Optional[LRScheduler] = None, + scheduler_d: Optional[LRScheduler] = None, + scaler: Optional[GradScaler] = None, + sampler: Optional[CutSampler] = None, + strict: bool = False, +) -> Dict[str, Any]: + logging.info(f"Loading checkpoint from {filename}") + checkpoint = torch.load(filename, map_location="cpu") + + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + + checkpoint.pop("model") + + if model_avg is not None and "model_avg" in checkpoint: + logging.info("Loading averaged model") + model_avg.load_state_dict(checkpoint["model_avg"], strict=strict) + checkpoint.pop("model_avg") + + def load(name, obj): + s = checkpoint.get(name, None) + if obj and s: + obj.load_state_dict(s) + checkpoint.pop(name) + + load("optimizer_g", optimizer_g) + load("optimizer_d", optimizer_d) + load("scheduler_g", scheduler_g) + load("scheduler_d", scheduler_d) + load("grad_scaler", scaler) + load("sampler", sampler) + + return checkpoint + + +def save_checkpoint( + filename: Path, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + params: Optional[Dict[str, Any]] = None, + optimizer_g: Optional[Optimizer] = None, + optimizer_d: Optional[Optimizer] = None, + scheduler_g: Optional[LRScheduler] = None, + scheduler_d: Optional[LRScheduler] = None, + scaler: Optional[GradScaler] = None, + sampler: Optional[CutSampler] = None, + rank: int = 0, +) -> None: + """Save training information to a file. + + Args: + filename: + The checkpoint filename. + model: + The model to be saved. We only save its `state_dict()`. + model_avg: + The stored model averaged from the start of training. + params: + User defined parameters, e.g., epoch, loss. + optimizer: + The optimizer to be saved. We only save its `state_dict()`. + scheduler: + The scheduler to be saved. We only save its `state_dict()`. + scalar: + The GradScaler to be saved. We only save its `state_dict()`. + rank: + Used in DDP. We save checkpoint only for the node whose rank is 0. + Returns: + Return None. + """ + if rank != 0: + return + + logging.info(f"Saving checkpoint to {filename}") + + if isinstance(model, DDP): + model = model.module + + checkpoint = { + "model": model.state_dict(), + "optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None, + "optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None, + "scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None, + "scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None, + "grad_scaler": scaler.state_dict() if scaler is not None else None, + "sampler": sampler.state_dict() if sampler is not None else None, + } + + if model_avg is not None: + checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict() + + if params: + for k, v in params.items(): + assert k not in checkpoint + checkpoint[k] = v + + torch.save(checkpoint, filename) + + +def _get_cosine_schedule_with_warmup_lr_lambda( + current_step: int, + *, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float, + min_lr_rate: float = 0.0, +): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps) + ) + factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) + factor = factor * (1 - min_lr_rate) + min_lr_rate + return max(0, factor) + + +def get_cosine_schedule_with_warmup( + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float = 0.5, + last_epoch: int = -1, +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`float`, *optional*, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + lr_lambda = partial( + _get_cosine_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor: + """ + Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values. + + Args: + x (Tensor): Input tensor. + clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7. + + Returns: + Tensor: Element-wise logarithm of the input tensor with clipping applied. + """ + return torch.log(torch.clip(x, min=clip_val)) diff --git a/egs/ljspeech/TTS/vocos/backbone.py b/egs/ljspeech/TTS/vocos/backbone.py deleted file mode 100644 index 168c8847b1..0000000000 --- a/egs/ljspeech/TTS/vocos/backbone.py +++ /dev/null @@ -1,127 +0,0 @@ -from typing import Optional - -import torch -from torch import nn -from torch.nn.utils import weight_norm - -from modules import ConvNeXtBlock, ResBlock1, AdaLayerNorm - - -class Backbone(nn.Module): - """Base class for the generator's backbone. It preserves the same temporal resolution across all layers.""" - - def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: - """ - Args: - x (Tensor): Input tensor of shape (B, C, L), where B is the batch size, - C denotes output features, and L is the sequence length. - - Returns: - Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length, - and H denotes the model dimension. - """ - raise NotImplementedError("Subclasses must implement the forward method.") - - -class VocosBackbone(Backbone): - """ - Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization - - Args: - input_channels (int): Number of input features channels. - dim (int): Hidden dimension of the model. - intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock. - num_layers (int): Number of ConvNeXtBlock layers. - layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`. - adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. - None means non-conditional model. Defaults to None. - """ - - def __init__( - self, - input_channels: int, - dim: int, - intermediate_dim: int, - num_layers: int, - layer_scale_init_value: Optional[float] = None, - adanorm_num_embeddings: Optional[int] = None, - ): - super().__init__() - self.input_channels = input_channels - self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3) - self.adanorm = adanorm_num_embeddings is not None - if adanorm_num_embeddings: - self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) - else: - self.norm = nn.LayerNorm(dim, eps=1e-6) - layer_scale_init_value = layer_scale_init_value or 1 / num_layers - self.convnext = nn.ModuleList( - [ - ConvNeXtBlock( - dim=dim, - intermediate_dim=intermediate_dim, - layer_scale_init_value=layer_scale_init_value, - adanorm_num_embeddings=adanorm_num_embeddings, - ) - for _ in range(num_layers) - ] - ) - self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6) - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, (nn.Conv1d, nn.Linear)): - nn.init.trunc_normal_(m.weight, std=0.02) - nn.init.constant_(m.bias, 0) - - def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: - bandwidth_id = kwargs.get("bandwidth_id", None) - x = self.embed(x) - if self.adanorm: - assert bandwidth_id is not None - x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id) - else: - x = self.norm(x.transpose(1, 2)) - x = x.transpose(1, 2) - for conv_block in self.convnext: - x = conv_block(x, cond_embedding_id=bandwidth_id) - x = self.final_layer_norm(x.transpose(1, 2)) - return x - - -class VocosResNetBackbone(Backbone): - """ - Vocos backbone module built with ResBlocks. - - Args: - input_channels (int): Number of input features channels. - dim (int): Hidden dimension of the model. - num_blocks (int): Number of ResBlock1 blocks. - layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None. - """ - - def __init__( - self, - input_channels, - dim, - num_blocks, - layer_scale_init_value=None, - ): - super().__init__() - self.input_channels = input_channels - self.embed = weight_norm( - nn.Conv1d(input_channels, dim, kernel_size=3, padding=1) - ) - layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3 - self.resnet = nn.Sequential( - *[ - ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value) - for _ in range(num_blocks) - ] - ) - - def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: - x = self.embed(x) - x = self.resnet(x) - x = x.transpose(1, 2) - return x diff --git a/egs/ljspeech/TTS/vocos/discriminators.py b/egs/ljspeech/TTS/vocos/discriminators.py deleted file mode 100644 index 6b013e392a..0000000000 --- a/egs/ljspeech/TTS/vocos/discriminators.py +++ /dev/null @@ -1,296 +0,0 @@ -from typing import List, Optional, Tuple - -import torch -from einops import rearrange -from torch import nn -from torch.nn import Conv2d -from torch.nn.utils import weight_norm -from torchaudio.transforms import Spectrogram - - -class MultiPeriodDiscriminator(nn.Module): - """ - Multi-Period Discriminator module adapted from https://github.com/jik876/hifi-gan. - Additionally, it allows incorporating conditional information with a learned embeddings table. - - Args: - periods (tuple[int]): Tuple of periods for each discriminator. - num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator. - Defaults to None. - """ - - def __init__( - self, - periods: Tuple[int, ...] = (2, 3, 5, 7, 11), - num_embeddings: Optional[int] = None, - ): - super().__init__() - self.discriminators = nn.ModuleList( - [DiscriminatorP(period=p, num_embeddings=num_embeddings) for p in periods] - ) - - def forward( - self, - y: torch.Tensor, - y_hat: torch.Tensor, - bandwidth_id: Optional[torch.Tensor] = None, - ) -> Tuple[ - List[torch.Tensor], - List[torch.Tensor], - List[List[torch.Tensor]], - List[List[torch.Tensor]], - ]: - y_d_rs = [] - y_d_gs = [] - fmap_rs = [] - fmap_gs = [] - for d in self.discriminators: - y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) - y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) - y_d_rs.append(y_d_r) - fmap_rs.append(fmap_r) - y_d_gs.append(y_d_g) - fmap_gs.append(fmap_g) - - return y_d_rs, y_d_gs, fmap_rs, fmap_gs - - -class DiscriminatorP(nn.Module): - def __init__( - self, - period: int, - in_channels: int = 1, - kernel_size: int = 5, - stride: int = 3, - lrelu_slope: float = 0.1, - num_embeddings: Optional[int] = None, - ): - super().__init__() - self.period = period - self.convs = nn.ModuleList( - [ - weight_norm( - Conv2d( - in_channels, - 32, - (kernel_size, 1), - (stride, 1), - padding=(kernel_size // 2, 0), - ) - ), - weight_norm( - Conv2d( - 32, - 128, - (kernel_size, 1), - (stride, 1), - padding=(kernel_size // 2, 0), - ) - ), - weight_norm( - Conv2d( - 128, - 512, - (kernel_size, 1), - (stride, 1), - padding=(kernel_size // 2, 0), - ) - ), - weight_norm( - Conv2d( - 512, - 1024, - (kernel_size, 1), - (stride, 1), - padding=(kernel_size // 2, 0), - ) - ), - weight_norm( - Conv2d( - 1024, - 1024, - (kernel_size, 1), - (1, 1), - padding=(kernel_size // 2, 0), - ) - ), - ] - ) - if num_embeddings is not None: - self.emb = torch.nn.Embedding( - num_embeddings=num_embeddings, embedding_dim=1024 - ) - torch.nn.init.zeros_(self.emb.weight) - - self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) - self.lrelu_slope = lrelu_slope - - def forward( - self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, List[torch.Tensor]]: - x = x.unsqueeze(1) - fmap = [] - # 1d to 2d - b, c, t = x.shape - if t % self.period != 0: # pad first - n_pad = self.period - (t % self.period) - x = torch.nn.functional.pad(x, (0, n_pad), "reflect") - t = t + n_pad - x = x.view(b, c, t // self.period, self.period) - - for i, l in enumerate(self.convs): - x = l(x) - x = torch.nn.functional.leaky_relu(x, self.lrelu_slope) - if i > 0: - fmap.append(x) - if cond_embedding_id is not None: - emb = self.emb(cond_embedding_id) - h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) - else: - h = 0 - x = self.conv_post(x) - fmap.append(x) - x += h - x = torch.flatten(x, 1, -1) - - return x, fmap - - -class MultiResolutionDiscriminator(nn.Module): - def __init__( - self, - fft_sizes: Tuple[int, ...] = (2048, 1024, 512), - num_embeddings: Optional[int] = None, - ): - """ - Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec. - Additionally, it allows incorporating conditional information with a learned embeddings table. - - Args: - fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512). - num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator. - Defaults to None. - """ - - super().__init__() - self.discriminators = nn.ModuleList( - [ - DiscriminatorR(window_length=w, num_embeddings=num_embeddings) - for w in fft_sizes - ] - ) - - def forward( - self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None - ) -> Tuple[ - List[torch.Tensor], - List[torch.Tensor], - List[List[torch.Tensor]], - List[List[torch.Tensor]], - ]: - y_d_rs = [] - y_d_gs = [] - fmap_rs = [] - fmap_gs = [] - - for d in self.discriminators: - y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) - y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) - y_d_rs.append(y_d_r) - fmap_rs.append(fmap_r) - y_d_gs.append(y_d_g) - fmap_gs.append(fmap_g) - - return y_d_rs, y_d_gs, fmap_rs, fmap_gs - - -class DiscriminatorR(nn.Module): - def __init__( - self, - window_length: int, - num_embeddings: Optional[int] = None, - channels: int = 32, - hop_factor: float = 0.25, - bands: Tuple[Tuple[float, float], ...] = ( - (0.0, 0.1), - (0.1, 0.25), - (0.25, 0.5), - (0.5, 0.75), - (0.75, 1.0), - ), - ): - super().__init__() - self.window_length = window_length - self.hop_factor = hop_factor - self.spec_fn = Spectrogram( - n_fft=window_length, - hop_length=int(window_length * hop_factor), - win_length=window_length, - power=None, - ) - n_fft = window_length // 2 + 1 - bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] - self.bands = bands - convs = lambda: nn.ModuleList( - [ - weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))), - weight_norm( - nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4)) - ), - weight_norm( - nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4)) - ), - weight_norm( - nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4)) - ), - weight_norm( - nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1)) - ), - ] - ) - self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) - - if num_embeddings is not None: - self.emb = torch.nn.Embedding( - num_embeddings=num_embeddings, embedding_dim=channels - ) - torch.nn.init.zeros_(self.emb.weight) - - self.conv_post = weight_norm( - nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)) - ) - - def spectrogram(self, x): - # Remove DC offset - x = x - x.mean(dim=-1, keepdims=True) - # Peak normalize the volume of input audio - x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9) - x = self.spec_fn(x) - x = torch.view_as_real(x) - x = rearrange(x, "b f t c -> b c t f") - # Split into bands - x_bands = [x[..., b[0] : b[1]] for b in self.bands] - return x_bands - - def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None): - x_bands = self.spectrogram(x) - fmap = [] - x = [] - for band, stack in zip(x_bands, self.band_convs): - for i, layer in enumerate(stack): - band = layer(band) - band = torch.nn.functional.leaky_relu(band, 0.1) - if i > 0: - fmap.append(band) - x.append(band) - x = torch.cat(x, dim=-1) - if cond_embedding_id is not None: - emb = self.emb(cond_embedding_id) - h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) - else: - h = 0 - x = self.conv_post(x) - fmap.append(x) - x += h - - return x, fmap diff --git a/egs/ljspeech/TTS/vocos/discriminators.py b/egs/ljspeech/TTS/vocos/discriminators.py new file mode 120000 index 0000000000..d35e97e6e9 --- /dev/null +++ b/egs/ljspeech/TTS/vocos/discriminators.py @@ -0,0 +1 @@ +../../../libritts/TTS/vocos/discriminators.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/vocos/generator.py b/egs/ljspeech/TTS/vocos/generator.py new file mode 120000 index 0000000000..c5f78e2b11 --- /dev/null +++ b/egs/ljspeech/TTS/vocos/generator.py @@ -0,0 +1 @@ +../../../libritts/TTS/vocos/generator.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/vocos/heads.py b/egs/ljspeech/TTS/vocos/heads.py deleted file mode 100644 index ed4d623a8d..0000000000 --- a/egs/ljspeech/TTS/vocos/heads.py +++ /dev/null @@ -1,178 +0,0 @@ -from typing import Optional - -import torch -from torch import nn -from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz - -from spectral_ops import IMDCT, ISTFT -from modules import symexp - - -class FourierHead(nn.Module): - """Base class for inverse fourier modules.""" - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, - L is the sequence length, and H denotes the model dimension. - - Returns: - Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. - """ - raise NotImplementedError("Subclasses must implement the forward method.") - - -class ISTFTHead(FourierHead): - """ - ISTFT Head module for predicting STFT complex coefficients. - - Args: - dim (int): Hidden dimension of the model. - n_fft (int): Size of Fourier transform. - hop_length (int): The distance between neighboring sliding window frames, which should align with - the resolution of the input features. - padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". - """ - - def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): - super().__init__() - out_dim = n_fft + 2 - self.out = torch.nn.Linear(dim, out_dim) - self.istft = ISTFT( - n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Forward pass of the ISTFTHead module. - - Args: - x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, - L is the sequence length, and H denotes the model dimension. - - Returns: - Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. - """ - x = self.out(x).transpose(1, 2) - mag, p = x.chunk(2, dim=1) - mag = torch.exp(mag) - mag = torch.clip( - mag, max=1e2 - ) # safeguard to prevent excessively large magnitudes - # wrapping happens here. These two lines produce real and imaginary value - x = torch.cos(p) - y = torch.sin(p) - # recalculating phase here does not produce anything new - # only costs time - # phase = torch.atan2(y, x) - # S = mag * torch.exp(phase * 1j) - # better directly produce the complex value - S = mag * (x + 1j * y) - audio = self.istft(S) - return audio - - -class IMDCTSymExpHead(FourierHead): - """ - IMDCT Head module for predicting MDCT coefficients with symmetric exponential function - - Args: - dim (int): Hidden dimension of the model. - mdct_frame_len (int): Length of the MDCT frame. - padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". - sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized - based on perceptual scaling. Defaults to None. - clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. - """ - - def __init__( - self, - dim: int, - mdct_frame_len: int, - padding: str = "same", - sample_rate: Optional[int] = None, - clip_audio: bool = False, - ): - super().__init__() - out_dim = mdct_frame_len // 2 - self.out = nn.Linear(dim, out_dim) - self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) - self.clip_audio = clip_audio - - if sample_rate is not None: - # optionally init the last layer following mel-scale - m_max = _hz_to_mel(sample_rate // 2) - m_pts = torch.linspace(0, m_max, out_dim) - f_pts = _mel_to_hz(m_pts) - scale = 1 - (f_pts / f_pts.max()) - - with torch.no_grad(): - self.out.weight.mul_(scale.view(-1, 1)) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Forward pass of the IMDCTSymExpHead module. - - Args: - x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, - L is the sequence length, and H denotes the model dimension. - - Returns: - Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. - """ - x = self.out(x) - x = symexp(x) - x = torch.clip( - x, min=-1e2, max=1e2 - ) # safeguard to prevent excessively large magnitudes - audio = self.imdct(x) - if self.clip_audio: - audio = torch.clip(x, min=-1.0, max=1.0) - - return audio - - -class IMDCTCosHead(FourierHead): - """ - IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) ยท cos(p) - - Args: - dim (int): Hidden dimension of the model. - mdct_frame_len (int): Length of the MDCT frame. - padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". - clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. - """ - - def __init__( - self, - dim: int, - mdct_frame_len: int, - padding: str = "same", - clip_audio: bool = False, - ): - super().__init__() - self.clip_audio = clip_audio - self.out = nn.Linear(dim, mdct_frame_len) - self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Forward pass of the IMDCTCosHead module. - - Args: - x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, - L is the sequence length, and H denotes the model dimension. - - Returns: - Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. - """ - x = self.out(x) - m, p = x.chunk(2, dim=2) - m = torch.exp(m).clip( - max=1e2 - ) # safeguard to prevent excessively large magnitudes - audio = self.imdct(m * torch.cos(p)) - if self.clip_audio: - audio = torch.clip(x, min=-1.0, max=1.0) - return audio diff --git a/egs/ljspeech/TTS/vocos/loss.py b/egs/ljspeech/TTS/vocos/loss.py deleted file mode 100644 index c89d818349..0000000000 --- a/egs/ljspeech/TTS/vocos/loss.py +++ /dev/null @@ -1,133 +0,0 @@ -from typing import List, Tuple - -import torch -import torchaudio -from torch import nn - -from modules import safe_log - - -class MelSpecReconstructionLoss(nn.Module): - """ - L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample - """ - - def __init__( - self, - sample_rate: int = 24000, - n_fft: int = 1024, - hop_length: int = 256, - n_mels: int = 100, - ): - super().__init__() - self.mel_spec = torchaudio.transforms.MelSpectrogram( - sample_rate=sample_rate, - n_fft=n_fft, - hop_length=hop_length, - n_mels=n_mels, - center=True, - power=1, - ) - - def forward(self, y_hat, y) -> torch.Tensor: - """ - Args: - y_hat (Tensor): Predicted audio waveform. - y (Tensor): Ground truth audio waveform. - - Returns: - Tensor: L1 loss between the mel-scaled magnitude spectrograms. - """ - mel_hat = safe_log(self.mel_spec(y_hat)) - mel = safe_log(self.mel_spec(y)) - - loss = torch.nn.functional.l1_loss(mel, mel_hat) - - return loss - - -class GeneratorLoss(nn.Module): - """ - Generator Loss module. Calculates the loss for the generator based on discriminator outputs. - """ - - def forward( - self, disc_outputs: List[torch.Tensor] - ) -> Tuple[torch.Tensor, List[torch.Tensor]]: - """ - Args: - disc_outputs (List[Tensor]): List of discriminator outputs. - - Returns: - Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from - the sub-discriminators - """ - loss = torch.zeros( - 1, device=disc_outputs[0].device, dtype=disc_outputs[0].dtype - ) - gen_losses = [] - for dg in disc_outputs: - l = torch.mean(torch.clamp(1 - dg, min=0)) - gen_losses.append(l) - loss += l - - return loss, gen_losses - - -class DiscriminatorLoss(nn.Module): - """ - Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs. - """ - - def forward( - self, - disc_real_outputs: List[torch.Tensor], - disc_generated_outputs: List[torch.Tensor], - ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: - """ - Args: - disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples. - disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples. - - Returns: - Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from - the sub-discriminators for real outputs, and a list of - loss values for generated outputs. - """ - loss = torch.zeros( - 1, device=disc_real_outputs[0].device, dtype=disc_real_outputs[0].dtype - ) - r_losses = [] - g_losses = [] - for dr, dg in zip(disc_real_outputs, disc_generated_outputs): - r_loss = torch.mean(torch.clamp(1 - dr, min=0)) - g_loss = torch.mean(torch.clamp(1 + dg, min=0)) - loss += r_loss + g_loss - r_losses.append(r_loss.item()) - g_losses.append(g_loss.item()) - - return loss, r_losses, g_losses - - -class FeatureMatchingLoss(nn.Module): - """ - Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators. - """ - - def forward( - self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]] - ) -> torch.Tensor: - """ - Args: - fmap_r (List[List[Tensor]]): List of feature maps from real samples. - fmap_g (List[List[Tensor]]): List of feature maps from generated samples. - - Returns: - Tensor: The calculated feature matching loss. - """ - loss = torch.zeros(1, device=fmap_r[0][0].device, dtype=fmap_r[0][0].dtype) - for dr, dg in zip(fmap_r, fmap_g): - for rl, gl in zip(dr, dg): - loss += torch.mean(torch.abs(rl - gl)) - - return loss diff --git a/egs/ljspeech/TTS/vocos/loss.py b/egs/ljspeech/TTS/vocos/loss.py new file mode 120000 index 0000000000..1fda940229 --- /dev/null +++ b/egs/ljspeech/TTS/vocos/loss.py @@ -0,0 +1 @@ +../../../libritts/TTS/vocos/loss.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/vocos/model.py b/egs/ljspeech/TTS/vocos/model.py new file mode 120000 index 0000000000..0d18482486 --- /dev/null +++ b/egs/ljspeech/TTS/vocos/model.py @@ -0,0 +1 @@ +../../../libritts/TTS/vocos/model.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/vocos/modules.py b/egs/ljspeech/TTS/vocos/modules.py deleted file mode 100644 index af1d6db16e..0000000000 --- a/egs/ljspeech/TTS/vocos/modules.py +++ /dev/null @@ -1,213 +0,0 @@ -from typing import Optional, Tuple - -import torch -from torch import nn -from torch.nn.utils import weight_norm, remove_weight_norm - - -class ConvNeXtBlock(nn.Module): - """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. - - Args: - dim (int): Number of input channels. - intermediate_dim (int): Dimensionality of the intermediate layer. - layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. - Defaults to None. - adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. - None means non-conditional LayerNorm. Defaults to None. - """ - - def __init__( - self, - dim: int, - intermediate_dim: int, - layer_scale_init_value: float, - adanorm_num_embeddings: Optional[int] = None, - ): - super().__init__() - self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv - self.adanorm = adanorm_num_embeddings is not None - if adanorm_num_embeddings: - self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) - else: - self.norm = nn.LayerNorm(dim, eps=1e-6) - self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers - self.act = nn.GELU() - self.pwconv2 = nn.Linear(intermediate_dim, dim) - self.gamma = ( - nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) - if layer_scale_init_value > 0 - else None - ) - - def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor: - residual = x - x = self.dwconv(x) - x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) - if self.adanorm: - assert cond_embedding_id is not None - x = self.norm(x, cond_embedding_id) - else: - x = self.norm(x) - x = self.pwconv1(x) - x = self.act(x) - x = self.pwconv2(x) - if self.gamma is not None: - x = self.gamma * x - x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) - - x = residual + x - return x - - -class AdaLayerNorm(nn.Module): - """ - Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes - - Args: - num_embeddings (int): Number of embeddings. - embedding_dim (int): Dimension of the embeddings. - """ - - def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.dim = embedding_dim - self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) - self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) - torch.nn.init.ones_(self.scale.weight) - torch.nn.init.zeros_(self.shift.weight) - - def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor: - scale = self.scale(cond_embedding_id) - shift = self.shift(cond_embedding_id) - x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps) - x = x * scale + shift - return x - - -class ResBlock1(nn.Module): - """ - ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions, - but without upsampling layers. - - Args: - dim (int): Number of input channels. - kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3. - dilation (tuple[int], optional): Dilation factors for the dilated convolutions. - Defaults to (1, 3, 5). - lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function. - Defaults to 0.1. - layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. - Defaults to None. - """ - - def __init__( - self, - dim: int, - kernel_size: int = 3, - dilation: Tuple[int, int, int] = (1, 3, 5), - lrelu_slope: float = 0.1, - layer_scale_init_value: Optional[float] = None, - ): - super().__init__() - self.lrelu_slope = lrelu_slope - self.convs1 = nn.ModuleList( - [ - weight_norm( - nn.Conv1d( - dim, - dim, - kernel_size, - 1, - dilation=dilation[0], - padding=self.get_padding(kernel_size, dilation[0]), - ) - ), - weight_norm( - nn.Conv1d( - dim, - dim, - kernel_size, - 1, - dilation=dilation[1], - padding=self.get_padding(kernel_size, dilation[1]), - ) - ), - weight_norm( - nn.Conv1d( - dim, - dim, - kernel_size, - 1, - dilation=dilation[2], - padding=self.get_padding(kernel_size, dilation[2]), - ) - ), - ] - ) - - self.convs2 = nn.ModuleList( - [ - weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), - weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), - weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), - ] - ) - - self.gamma = nn.ParameterList( - [ - nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) - if layer_scale_init_value is not None - else None, - nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) - if layer_scale_init_value is not None - else None, - nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) - if layer_scale_init_value is not None - else None, - ] - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma): - xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope) - xt = c1(xt) - xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope) - xt = c2(xt) - if gamma is not None: - xt = gamma * xt - x = xt + x - return x - - def remove_weight_norm(self): - for l in self.convs1: - remove_weight_norm(l) - for l in self.convs2: - remove_weight_norm(l) - - @staticmethod - def get_padding(kernel_size: int, dilation: int = 1) -> int: - return int((kernel_size * dilation - dilation) / 2) - - -def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor: - """ - Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values. - - Args: - x (Tensor): Input tensor. - clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7. - - Returns: - Tensor: Element-wise logarithm of the input tensor with clipping applied. - """ - return torch.log(torch.clip(x, min=clip_val)) - - -def symlog(x: torch.Tensor) -> torch.Tensor: - return torch.sign(x) * torch.log1p(x.abs()) - - -def symexp(x: torch.Tensor) -> torch.Tensor: - return torch.sign(x) * (torch.exp(x.abs()) - 1) diff --git a/egs/ljspeech/TTS/vocos/spectral_ops.py b/egs/ljspeech/TTS/vocos/spectral_ops.py deleted file mode 100644 index c0ad35ab31..0000000000 --- a/egs/ljspeech/TTS/vocos/spectral_ops.py +++ /dev/null @@ -1,230 +0,0 @@ -import numpy as np -import scipy -import torch -from torch import nn, view_as_real, view_as_complex - - -class ISTFT(nn.Module): - """ - Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with - windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. - See issue: https://github.com/pytorch/pytorch/issues/62323 - Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. - The NOLA constraint is met as we trim padded samples anyway. - - Args: - n_fft (int): Size of Fourier transform. - hop_length (int): The distance between neighboring sliding window frames. - win_length (int): The size of window frame and STFT filter. - padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". - """ - - def __init__( - self, n_fft: int, hop_length: int, win_length: int, padding: str = "same" - ): - super().__init__() - if padding not in ["center", "same"]: - raise ValueError("Padding must be 'center' or 'same'.") - self.padding = padding - self.n_fft = n_fft - self.hop_length = hop_length - self.win_length = win_length - window = torch.hann_window(win_length) - self.register_buffer("window", window) - - def forward(self, spec: torch.Tensor) -> torch.Tensor: - """ - Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. - - Args: - spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, - N is the number of frequency bins, and T is the number of time frames. - - Returns: - Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. - """ - if self.padding == "center": - # Fallback to pytorch native implementation - return torch.istft( - spec, - self.n_fft, - self.hop_length, - self.win_length, - self.window, - center=True, - ) - elif self.padding == "same": - # return torch.istft( - # spec, - # self.n_fft, - # self.hop_length, - # self.win_length, - # self.window, - # center=False, - # ) - pad = (self.win_length - self.hop_length) // 2 - else: - raise ValueError("Padding must be 'center' or 'same'.") - - assert spec.dim() == 3, "Expected a 3D tensor as input" - B, N, T = spec.shape - - # Inverse FFT - ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") - ifft = ifft * self.window[None, :, None] - - # Overlap and Add - output_size = (T - 1) * self.hop_length + self.win_length - y = torch.nn.functional.fold( - ifft, - output_size=(1, output_size), - kernel_size=(1, self.win_length), - stride=(1, self.hop_length), - )[:, 0, 0, :] - - # Window envelope - window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) - window_envelope = torch.nn.functional.fold( - window_sq, - output_size=(1, output_size), - kernel_size=(1, self.win_length), - stride=(1, self.hop_length), - ).squeeze() - - # Normalize - norm_indexes = window_envelope > 1e-11 - - y[:, norm_indexes] = y[:, norm_indexes] / window_envelope[norm_indexes] - # assert (window_envelope > 1e-11).all() - # y = y / window_envelope - - return y - - -class MDCT(nn.Module): - """ - Modified Discrete Cosine Transform (MDCT) module. - - Args: - frame_len (int): Length of the MDCT frame. - padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". - """ - - def __init__(self, frame_len: int, padding: str = "same"): - super().__init__() - if padding not in ["center", "same"]: - raise ValueError("Padding must be 'center' or 'same'.") - self.padding = padding - self.frame_len = frame_len - N = frame_len // 2 - n0 = (N + 1) / 2 - window = torch.from_numpy(scipy.signal.cosine(frame_len)).float() - self.register_buffer("window", window) - - pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len) - post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N) - # view_as_real: NCCL Backend does not support ComplexFloat data type - # https://github.com/pytorch/pytorch/issues/71613 - self.register_buffer("pre_twiddle", view_as_real(pre_twiddle)) - self.register_buffer("post_twiddle", view_as_real(post_twiddle)) - - def forward(self, audio: torch.Tensor) -> torch.Tensor: - """ - Apply the Modified Discrete Cosine Transform (MDCT) to the input audio. - - Args: - audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size - and T is the length of the audio. - - Returns: - Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames - and N is the number of frequency bins. - """ - if self.padding == "center": - audio = torch.nn.functional.pad( - audio, (self.frame_len // 2, self.frame_len // 2) - ) - elif self.padding == "same": - # hop_length is 1/2 frame_len - audio = torch.nn.functional.pad( - audio, (self.frame_len // 4, self.frame_len // 4) - ) - else: - raise ValueError("Padding must be 'center' or 'same'.") - - x = audio.unfold(-1, self.frame_len, self.frame_len // 2) - N = self.frame_len // 2 - x = x * self.window.expand(x.shape) - X = torch.fft.fft( - x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1 - )[..., :N] - res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N) - return torch.real(res) * np.sqrt(2) - - -class IMDCT(nn.Module): - """ - Inverse Modified Discrete Cosine Transform (IMDCT) module. - - Args: - frame_len (int): Length of the MDCT frame. - padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". - """ - - def __init__(self, frame_len: int, padding: str = "same"): - super().__init__() - if padding not in ["center", "same"]: - raise ValueError("Padding must be 'center' or 'same'.") - self.padding = padding - self.frame_len = frame_len - N = frame_len // 2 - n0 = (N + 1) / 2 - window = torch.from_numpy(scipy.signal.cosine(frame_len)).float() - self.register_buffer("window", window) - - pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N) - post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2)) - self.register_buffer("pre_twiddle", view_as_real(pre_twiddle)) - self.register_buffer("post_twiddle", view_as_real(post_twiddle)) - - def forward(self, X: torch.Tensor) -> torch.Tensor: - """ - Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients. - - Args: - X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size, - L is the number of frames, and N is the number of frequency bins. - - Returns: - Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio. - """ - B, L, N = X.shape - Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device) - Y[..., :N] = X - Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,))) - y = torch.fft.ifft( - Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1 - ) - y = ( - torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape)) - * np.sqrt(N) - * np.sqrt(2) - ) - result = y * self.window.expand(y.shape) - output_size = (1, (L + 1) * N) - audio = torch.nn.functional.fold( - result.transpose(1, 2), - output_size=output_size, - kernel_size=(1, self.frame_len), - stride=(1, self.frame_len // 2), - )[:, 0, 0, :] - - if self.padding == "center": - pad = self.frame_len // 2 - elif self.padding == "same": - pad = self.frame_len // 4 - else: - raise ValueError("Padding must be 'center' or 'same'.") - - audio = audio[:, pad:-pad] - return audio diff --git a/egs/ljspeech/TTS/vocos/train.py b/egs/ljspeech/TTS/vocos/train.py index e2092096e4..51ec024efb 100755 --- a/egs/ljspeech/TTS/vocos/train.py +++ b/egs/ljspeech/TTS/vocos/train.py @@ -423,7 +423,7 @@ def compute_generator_loss( fmap_r=fmap_rs_mpd, fmap_g=fmap_gs_mpd ) / len(fmap_rs_mpd) loss_fm_mrd = model.feat_matching_loss( - fmap_r=fmap_gs_mrd, fmap_g=fmap_gs_mrd + fmap_r=fmap_rs_mrd, fmap_g=fmap_gs_mrd ) / len(fmap_rs_mrd) loss_gen_all = ( diff --git a/egs/ljspeech/TTS/vocos/utils.py b/egs/ljspeech/TTS/vocos/utils.py deleted file mode 100644 index c8132e208d..0000000000 --- a/egs/ljspeech/TTS/vocos/utils.py +++ /dev/null @@ -1,205 +0,0 @@ -import glob -import os -import logging -import matplotlib -import math -import torch -import torch.nn as nn -from functools import partial -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union -from torch.nn.utils import weight_norm -from torch.optim.lr_scheduler import LRScheduler -from torch.optim import Optimizer -from torch.cuda.amp import GradScaler -from lhotse.dataset.sampling.base import CutSampler -from torch import Tensor -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.optim import Optimizer -from torch.optim.lr_scheduler import LambdaLR - - -matplotlib.use("Agg") -import matplotlib.pylab as plt - - -def plot_spectrogram(spectrogram): - fig, ax = plt.subplots(figsize=(10, 2)) - im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") - plt.colorbar(im, ax=ax) - - fig.canvas.draw() - plt.close() - - return fig - - -def load_checkpoint( - filename: Path, - model: nn.Module, - model_avg: Optional[nn.Module] = None, - optimizer_g: Optional[Optimizer] = None, - optimizer_d: Optional[Optimizer] = None, - scheduler_g: Optional[LRScheduler] = None, - scheduler_d: Optional[LRScheduler] = None, - scaler: Optional[GradScaler] = None, - sampler: Optional[CutSampler] = None, - strict: bool = False, -) -> Dict[str, Any]: - logging.info(f"Loading checkpoint from {filename}") - checkpoint = torch.load(filename, map_location="cpu") - - if next(iter(checkpoint["model"])).startswith("module."): - logging.info("Loading checkpoint saved by DDP") - - dst_state_dict = model.state_dict() - src_state_dict = checkpoint["model"] - for key in dst_state_dict.keys(): - src_key = "{}.{}".format("module", key) - dst_state_dict[key] = src_state_dict.pop(src_key) - assert len(src_state_dict) == 0 - model.load_state_dict(dst_state_dict, strict=strict) - else: - model.load_state_dict(checkpoint["model"], strict=strict) - - checkpoint.pop("model") - - if model_avg is not None and "model_avg" in checkpoint: - logging.info("Loading averaged model") - model_avg.load_state_dict(checkpoint["model_avg"], strict=strict) - checkpoint.pop("model_avg") - - def load(name, obj): - s = checkpoint.get(name, None) - if obj and s: - obj.load_state_dict(s) - checkpoint.pop(name) - - load("optimizer_g", optimizer_g) - load("optimizer_d", optimizer_d) - load("scheduler_g", scheduler_g) - load("scheduler_d", scheduler_d) - load("grad_scaler", scaler) - load("sampler", sampler) - - return checkpoint - - -def save_checkpoint( - filename: Path, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - params: Optional[Dict[str, Any]] = None, - optimizer_g: Optional[Optimizer] = None, - optimizer_d: Optional[Optimizer] = None, - scheduler_g: Optional[LRScheduler] = None, - scheduler_d: Optional[LRScheduler] = None, - scaler: Optional[GradScaler] = None, - sampler: Optional[CutSampler] = None, - rank: int = 0, -) -> None: - """Save training information to a file. - - Args: - filename: - The checkpoint filename. - model: - The model to be saved. We only save its `state_dict()`. - model_avg: - The stored model averaged from the start of training. - params: - User defined parameters, e.g., epoch, loss. - optimizer: - The optimizer to be saved. We only save its `state_dict()`. - scheduler: - The scheduler to be saved. We only save its `state_dict()`. - scalar: - The GradScaler to be saved. We only save its `state_dict()`. - rank: - Used in DDP. We save checkpoint only for the node whose rank is 0. - Returns: - Return None. - """ - if rank != 0: - return - - logging.info(f"Saving checkpoint to {filename}") - - if isinstance(model, DDP): - model = model.module - - checkpoint = { - "model": model.state_dict(), - "optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None, - "optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None, - "scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None, - "scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None, - "grad_scaler": scaler.state_dict() if scaler is not None else None, - "sampler": sampler.state_dict() if sampler is not None else None, - } - - if model_avg is not None: - checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict() - - if params: - for k, v in params.items(): - assert k not in checkpoint - checkpoint[k] = v - - torch.save(checkpoint, filename) - - -def _get_cosine_schedule_with_warmup_lr_lambda( - current_step: int, - *, - num_warmup_steps: int, - num_training_steps: int, - num_cycles: float, - min_lr_rate: float = 0.0, -): - if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - progress = float(current_step - num_warmup_steps) / float( - max(1, num_training_steps - num_warmup_steps) - ) - factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) - factor = factor * (1 - min_lr_rate) + min_lr_rate - return max(0, factor) - - -def get_cosine_schedule_with_warmup( - optimizer: Optimizer, - num_warmup_steps: int, - num_training_steps: int, - num_cycles: float = 0.5, - last_epoch: int = -1, -): - """ - Create a schedule with a learning rate that decreases following the values of the cosine function between the - initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the - initial lr set in the optimizer. - - Args: - optimizer ([`~torch.optim.Optimizer`]): - The optimizer for which to schedule the learning rate. - num_warmup_steps (`int`): - The number of steps for the warmup phase. - num_training_steps (`int`): - The total number of training steps. - num_cycles (`float`, *optional*, defaults to 0.5): - The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 - following a half-cosine). - last_epoch (`int`, *optional*, defaults to -1): - The index of the last epoch when resuming training. - - Return: - `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. - """ - - lr_lambda = partial( - _get_cosine_schedule_with_warmup_lr_lambda, - num_warmup_steps=num_warmup_steps, - num_training_steps=num_training_steps, - num_cycles=num_cycles, - ) - return LambdaLR(optimizer, lr_lambda, last_epoch) diff --git a/egs/ljspeech/TTS/vocos/utils.py b/egs/ljspeech/TTS/vocos/utils.py new file mode 120000 index 0000000000..789a8b72b0 --- /dev/null +++ b/egs/ljspeech/TTS/vocos/utils.py @@ -0,0 +1 @@ +../../../libritts/TTS/vocos/utils.py \ No newline at end of file