From 3df16b3f2b655a9fa700af2cafd17a8c79662307 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sun, 22 Oct 2023 23:14:00 +0800 Subject: [PATCH 01/16] first commit --- .../tts/local/compute_spectrogram_ljspeech.py | 100 ++ .../tts/local/display_manifest_statistics.py | 73 ++ egs/ljspeech/tts/local/split_subsets.py | 79 ++ egs/ljspeech/tts/local/validate_manifest.py | 70 ++ egs/ljspeech/tts/prepare.sh | 77 ++ egs/ljspeech/tts/shared/parse_options.sh | 97 ++ egs/ljspeech/tts/vits/commons.py | 161 +++ egs/ljspeech/tts/vits/duration_predictor.py | 194 ++++ egs/ljspeech/tts/vits/features.py | 416 ++++++++ egs/ljspeech/tts/vits/flow.py | 311 ++++++ egs/ljspeech/tts/vits/generator.py | 524 ++++++++++ egs/ljspeech/tts/vits/hifigan.py | 933 ++++++++++++++++++ egs/ljspeech/tts/vits/loss.py | 332 +++++++ egs/ljspeech/tts/vits/models.py | 534 ++++++++++ .../tts/vits/monotonic_align/__init__.py | 81 ++ .../tts/vits/monotonic_align/core.pyx | 51 + .../tts/vits/monotonic_align/setup.py | 31 + egs/ljspeech/tts/vits/posterior_encoder.py | 117 +++ egs/ljspeech/tts/vits/residual_coupling.py | 229 +++++ egs/ljspeech/tts/vits/symbols.py | 17 + egs/ljspeech/tts/vits/text_encoder.py | 534 ++++++++++ egs/ljspeech/tts/vits/train.py | 896 +++++++++++++++++ egs/ljspeech/tts/vits/transform.py | 217 ++++ egs/ljspeech/tts/vits/tts_datamodule.py | 306 ++++++ egs/ljspeech/tts/vits/utils.py | 470 +++++++++ egs/ljspeech/tts/vits/vits.py | 567 +++++++++++ egs/ljspeech/tts/vits/wavenet.py | 349 +++++++ 27 files changed, 7766 insertions(+) create mode 100755 egs/ljspeech/tts/local/compute_spectrogram_ljspeech.py create mode 100755 egs/ljspeech/tts/local/display_manifest_statistics.py create mode 100755 egs/ljspeech/tts/local/split_subsets.py create mode 100755 egs/ljspeech/tts/local/validate_manifest.py create mode 100755 egs/ljspeech/tts/prepare.sh create mode 100755 egs/ljspeech/tts/shared/parse_options.sh create mode 100644 egs/ljspeech/tts/vits/commons.py create mode 100644 egs/ljspeech/tts/vits/duration_predictor.py create mode 100644 egs/ljspeech/tts/vits/features.py create mode 100644 egs/ljspeech/tts/vits/flow.py create mode 100644 egs/ljspeech/tts/vits/generator.py create mode 100644 egs/ljspeech/tts/vits/hifigan.py create mode 100644 egs/ljspeech/tts/vits/loss.py create mode 100644 egs/ljspeech/tts/vits/models.py create mode 100644 egs/ljspeech/tts/vits/monotonic_align/__init__.py create mode 100644 egs/ljspeech/tts/vits/monotonic_align/core.pyx create mode 100644 egs/ljspeech/tts/vits/monotonic_align/setup.py create mode 100644 egs/ljspeech/tts/vits/posterior_encoder.py create mode 100644 egs/ljspeech/tts/vits/residual_coupling.py create mode 100644 egs/ljspeech/tts/vits/symbols.py create mode 100644 egs/ljspeech/tts/vits/text_encoder.py create mode 100755 egs/ljspeech/tts/vits/train.py create mode 100644 egs/ljspeech/tts/vits/transform.py create mode 100644 egs/ljspeech/tts/vits/tts_datamodule.py create mode 100644 egs/ljspeech/tts/vits/utils.py create mode 100644 egs/ljspeech/tts/vits/vits.py create mode 100644 egs/ljspeech/tts/vits/wavenet.py diff --git a/egs/ljspeech/tts/local/compute_spectrogram_ljspeech.py b/egs/ljspeech/tts/local/compute_spectrogram_ljspeech.py new file mode 100755 index 0000000000..3603af07df --- /dev/null +++ b/egs/ljspeech/tts/local/compute_spectrogram_ljspeech.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the LJSpeech dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/spectrogram. +""" + +import logging +import os +from pathlib import Path + +import torch +from lhotse import CutSet, Spectrogram, SpectrogramConfig, LilcomChunkyWriter, load_manifest +from lhotse.audio import RecordingSet +from lhotse.supervision import SupervisionSet + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_spectrogram_ljspeech(): + src_dir = Path("data/manifests") + output_dir = Path("data/spectrogram") + num_jobs = min(4, os.cpu_count()) + + sampling_rate = 22050 + frame_length = 1024 / sampling_rate # (in second) + frame_shift = 256 / sampling_rate # (in second) + use_fft_mag = True + + prefix = "ljspeech" + suffix = "jsonl.gz" + partition = "all" + + recordings = load_manifest( + src_dir / f"{prefix}_recordings_{partition}.jsonl.gz", RecordingSet + ) + supervisions = load_manifest( + src_dir / f"{prefix}_supervisions_{partition}.jsonl.gz", SupervisionSet + ) + + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=frame_length, + frame_shift=frame_shift, + use_fft_mag=use_fft_mag, + ) + extractor = Spectrogram(config) + + with get_executor() as ex: # Initialize the executor only once. + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{partition} already exists - skipping.") + return + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=recordings, supervisions=supervisions + ) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + compute_spectrogram_ljspeech() diff --git a/egs/ljspeech/tts/local/display_manifest_statistics.py b/egs/ljspeech/tts/local/display_manifest_statistics.py new file mode 100755 index 0000000000..93f0044f0e --- /dev/null +++ b/egs/ljspeech/tts/local/display_manifest_statistics.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. + +See the function `remove_short_and_long_utt()` in vits/train.py +for usage. +""" + + +from lhotse import load_manifest_lazy + + +def main(): + path = "./data/spectrogram/ljspeech_cuts_all.jsonl.gz" + cuts = load_manifest_lazy(path) + cuts.describe() + + +if __name__ == "__main__": + main() + +""" +Cut statistics: + ╒═══════════════════════════╤══════════╕ + │ Cuts count: │ 13100 │ + ├───────────────────────────┼──────────┤ + │ Total duration (hh:mm:ss) │ 23:55:18 │ + ├───────────────────────────┼──────────┤ + │ mean │ 6.6 │ + ├───────────────────────────┼──────────┤ + │ std │ 2.2 │ + ├───────────────────────────┼──────────┤ + │ min │ 1.1 │ + ├───────────────────────────┼──────────┤ + │ 25% │ 5.0 │ + ├───────────────────────────┼──────────┤ + │ 50% │ 6.8 │ + ├───────────────────────────┼──────────┤ + │ 75% │ 8.4 │ + ├───────────────────────────┼──────────┤ + │ 99% │ 10.0 │ + ├───────────────────────────┼──────────┤ + │ 99.5% │ 10.1 │ + ├───────────────────────────┼──────────┤ + │ 99.9% │ 10.1 │ + ├───────────────────────────┼──────────┤ + │ max │ 10.1 │ + ├───────────────────────────┼──────────┤ + │ Recordings available: │ 13100 │ + ├───────────────────────────┼──────────┤ + │ Features available: │ 13100 │ + ├───────────────────────────┼──────────┤ + │ Supervisions available: │ 13100 │ + ╘═══════════════════════════╧══════════╛ +""" diff --git a/egs/ljspeech/tts/local/split_subsets.py b/egs/ljspeech/tts/local/split_subsets.py new file mode 100755 index 0000000000..328cdd6910 --- /dev/null +++ b/egs/ljspeech/tts/local/split_subsets.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script split the LJSpeech dataset cuts into three sets: + - training, 12500 + - validation, 100 + - test, 500 +The numbers are from https://arxiv.org/pdf/2106.06103.pdf + +Usage example: + python3 ./local/split_subsets.py ./data/spectrogram +""" + +import argparse +import logging +import random +from pathlib import Path + +from lhotse import load_manifest_lazy + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "manifest_dir", + type=Path, + default=Path("data/spectrogram"), + help="Path to the manifest file", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + manifest_dir = Path(args.manifest_dir) + prefix = "ljspeech" + suffix = "jsonl.gz" + all_cuts = load_manifest_lazy(manifest_dir / f"{prefix}_cuts_all.{suffix}") + + cut_ids = list(all_cuts.ids) + random.shuffle(cut_ids) + + train_cuts = all_cuts.subset(cut_ids=cut_ids[:12500]) + valid_cuts = all_cuts.subset(cut_ids=cut_ids[12500:12500 + 100]) + test_cuts = all_cuts.subset(cut_ids=cut_ids[12500 + 100:]) + assert len(train_cuts) == 12500, "expected 12500 cuts for training but got len(train_cuts)" + assert len(valid_cuts) == 100, "expected 100 cuts but for validation but got len(valid_cuts)" + assert len(test_cuts) == 500, "expected 500 cuts for test but got len(test_cuts)" + + train_cuts.to_file(manifest_dir / f"{prefix}_cuts_train.{suffix}") + valid_cuts.to_file(manifest_dir / f"{prefix}_cuts_valid.{suffix}") + test_cuts.to_file(manifest_dir / f"{prefix}_cuts_test.{suffix}") + + logging.info("Splitted into three sets: training (12500), validation (100), and test (500)") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/ljspeech/tts/local/validate_manifest.py b/egs/ljspeech/tts/local/validate_manifest.py new file mode 100755 index 0000000000..cd466303ed --- /dev/null +++ b/egs/ljspeech/tts/local/validate_manifest.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script checks the following assumptions of the generated manifest: + +- Single supervision per cut + +We will add more checks later if needed. + +Usage example: + + python3 ./local/validate_manifest.py \ + ./data/spectrogram/ljspeech_cuts_all.jsonl.gz + +""" + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest_lazy +from lhotse.dataset.speech_synthesis import validate_for_tts + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "manifest", + type=Path, + help="Path to the manifest file", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + manifest = args.manifest + logging.info(f"Validating {manifest}") + + assert manifest.is_file(), f"{manifest} does not exist" + cut_set = load_manifest_lazy(manifest) + assert isinstance(cut_set, CutSet) + + validate_for_tts(cut_set) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/ljspeech/tts/prepare.sh b/egs/ljspeech/tts/prepare.sh new file mode 100755 index 0000000000..f78964c347 --- /dev/null +++ b/egs/ljspeech/tts/prepare.sh @@ -0,0 +1,77 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +nj=1 +stage=-1 +stop_stage=100 + +# dl_dir=$PWD/download +dl_dir=/star-data/zengwei/download/ljspeech/ + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/LJSpeech, + # you can create a symlink + # + # ln -sfv /path/to/LJSpeech $dl_dir/LJSpeech + # + if [ ! -d $dl_dir/LJSpeech-1.1 ]; then + lhotse download ljspeech $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare LJSpeech manifest" + # We assume that you have downloaded the LJSpeech corpus + # to $dl_dir/LJSpeech + mkdir -p data/manifests + if [ ! -e data/manifests/.ljspeech.done ]; then + lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests + touch data/manifests/.ljspeech.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Compute spectrogram for LJSpeech" + mkdir -p data/spectrogram + if [ ! -e data/spectrogram/.ljspeech.done ]; then + ./local/compute_spectrogram_ljspeech.py + touch data/spectrogram/.ljspeech.done + fi + + if [ ! -e data/spectrogram/.ljspeech-validated.done ]; then + log "Validating data/fbank for LJSpeech" + python3 ./local/validate_manifest.py \ + data/spectrogram/ljspeech_cuts_all.jsonl.gz + touch data/spectrogram/.ljspeech-validated.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Split the LJSpeech cuts into three sets" + if [ ! -e data/spectrogram/.ljspeech_split.done ]; then + ./local/split_subsets.py data/spectrogram + touch data/spectrogram/.ljspeech_split.done + fi +fi + + diff --git a/egs/ljspeech/tts/shared/parse_options.sh b/egs/ljspeech/tts/shared/parse_options.sh new file mode 100755 index 0000000000..71fb9e5ea1 --- /dev/null +++ b/egs/ljspeech/tts/shared/parse_options.sh @@ -0,0 +1,97 @@ +#!/usr/bin/env bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); +# Arnab Ghoshal, Karel Vesely + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Parse command-line options. +# To be sourced by another script (as in ". parse_options.sh"). +# Option format is: --option-name arg +# and shell variable "option_name" gets set to value "arg." +# The exception is --help, which takes no arguments, but prints the +# $help_message variable (if defined). + + +### +### The --config file options have lower priority to command line +### options, so we need to import them first... +### + +# Now import all the configs specified by command-line, in left-to-right order +for ((argpos=1; argpos<$#; argpos++)); do + if [ "${!argpos}" == "--config" ]; then + argpos_plus1=$((argpos+1)) + config=${!argpos_plus1} + [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 + . $config # source the config file. + fi +done + + +### +### Now we process the command line options +### +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + # If the enclosing script is called with --help option, print the help + # message and exit. Scripts should put help messages in $help_message + --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; + else printf "$help_message\n" 1>&2 ; fi; + exit 0 ;; + --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" + exit 1 ;; + # If the first command-line argument begins with "--" (e.g. --foo-bar), + # then work out the variable name as $name, which will equal "foo_bar". + --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; + # Next we test whether the variable in question is undefned-- if so it's + # an invalid option and we die. Note: $0 evaluates to the name of the + # enclosing script. + # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar + # is undefined. We then have to wrap this test inside "eval" because + # foo_bar is itself inside a variable ($name). + eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + + oldval="`eval echo \\$$name`"; + # Work out whether we seem to be expecting a Boolean argument. + if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval $name=\"$2\"; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + *) break; + esac +done + + +# Check for an empty argument to the --cmd option, which can easily occur as a +# result of scripting errors. +[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; + + +true; # so this script returns exit code 0. diff --git a/egs/ljspeech/tts/vits/commons.py b/egs/ljspeech/tts/vits/commons.py new file mode 100644 index 0000000000..9ad0444b61 --- /dev/null +++ b/egs/ljspeech/tts/vits/commons.py @@ -0,0 +1,161 @@ +import math +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def intersperse(lst, item): + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +def kl_divergence(m_p, logs_p, m_q, logs_q): + """KL(P||Q)""" + kl = (logs_q - logs_p) - 0.5 + kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q) + return kl + + +def rand_gumbel(shape): + """Sample from the Gumbel distribution, protect from overflows.""" + uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 + return -torch.log(-torch.log(uniform_samples)) + + +def rand_gumbel_like(x): + g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) + return g + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def get_timing_signal_1d( + length, channels, min_timescale=1.0, max_timescale=1.0e4): + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = ( + math.log(float(max_timescale) / float(min_timescale)) / + (num_timescales - 1)) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal + + +def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return x + signal.to(dtype=x.dtype, device=x.device) + + +def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) + + +def subsequent_mask(length): + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def shift_1d(x): + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] + return x + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def generate_path(duration, mask): + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + device = duration.device + + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2,3) * mask + return path + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1. / norm_type) + return total_norm diff --git a/egs/ljspeech/tts/vits/duration_predictor.py b/egs/ljspeech/tts/vits/duration_predictor.py new file mode 100644 index 0000000000..5e8d670bdc --- /dev/null +++ b/egs/ljspeech/tts/vits/duration_predictor.py @@ -0,0 +1,194 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/duration_predictor.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Stochastic duration predictor modules in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + +import math +from typing import Optional + +import torch +import torch.nn.functional as F + +from flow import ( + ConvFlow, + DilatedDepthSeparableConv, + ElementwiseAffineFlow, + FlipFlow, + LogFlow, +) + + +class StochasticDurationPredictor(torch.nn.Module): + """Stochastic duration predictor module. + + This is a module of stochastic duration predictor described in `Conditional + Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_. + + .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End + Text-to-Speech`: https://arxiv.org/abs/2006.04558 + + """ + + def __init__( + self, + channels: int = 192, + kernel_size: int = 3, + dropout_rate: float = 0.5, + flows: int = 4, + dds_conv_layers: int = 3, + global_channels: int = -1, + ): + """Initialize StochasticDurationPredictor module. + + Args: + channels (int): Number of channels. + kernel_size (int): Kernel size. + dropout_rate (float): Dropout rate. + flows (int): Number of flows. + dds_conv_layers (int): Number of conv layers in DDS conv. + global_channels (int): Number of global conditioning channels. + + """ + super().__init__() + + self.pre = torch.nn.Conv1d(channels, channels, 1) + self.dds = DilatedDepthSeparableConv( + channels, + kernel_size, + layers=dds_conv_layers, + dropout_rate=dropout_rate, + ) + self.proj = torch.nn.Conv1d(channels, channels, 1) + + self.log_flow = LogFlow() + self.flows = torch.nn.ModuleList() + self.flows += [ElementwiseAffineFlow(2)] + for i in range(flows): + self.flows += [ + ConvFlow( + 2, + channels, + kernel_size, + layers=dds_conv_layers, + ) + ] + self.flows += [FlipFlow()] + + self.post_pre = torch.nn.Conv1d(1, channels, 1) + self.post_dds = DilatedDepthSeparableConv( + channels, + kernel_size, + layers=dds_conv_layers, + dropout_rate=dropout_rate, + ) + self.post_proj = torch.nn.Conv1d(channels, channels, 1) + self.post_flows = torch.nn.ModuleList() + self.post_flows += [ElementwiseAffineFlow(2)] + for i in range(flows): + self.post_flows += [ + ConvFlow( + 2, + channels, + kernel_size, + layers=dds_conv_layers, + ) + ] + self.post_flows += [FlipFlow()] + + if global_channels > 0: + self.global_conv = torch.nn.Conv1d(global_channels, channels, 1) + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + w: Optional[torch.Tensor] = None, + g: Optional[torch.Tensor] = None, + inverse: bool = False, + noise_scale: float = 1.0, + ) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T_text). + x_mask (Tensor): Mask tensor (B, 1, T_text). + w (Optional[Tensor]): Duration tensor (B, 1, T_text). + g (Optional[Tensor]): Global conditioning tensor (B, channels, 1) + inverse (bool): Whether to inverse the flow. + noise_scale (float): Noise scale value. + + Returns: + Tensor: If not inverse, negative log-likelihood (NLL) tensor (B,). + If inverse, log-duration tensor (B, 1, T_text). + + """ + x = x.detach() # stop gradient + x = self.pre(x) + if g is not None: + x = x + self.global_conv(g.detach()) # stop gradient + x = self.dds(x, x_mask) + x = self.proj(x) * x_mask + + if not inverse: + assert w is not None, "w must be provided." + h_w = self.post_pre(w) + h_w = self.post_dds(h_w, x_mask) + h_w = self.post_proj(h_w) * x_mask + e_q = ( + torch.randn( + w.size(0), + 2, + w.size(2), + ).to(device=x.device, dtype=x.dtype) + * x_mask + ) + z_q = e_q + logdet_tot_q = 0.0 + for flow in self.post_flows: + z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) + logdet_tot_q += logdet_q + z_u, z1 = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (w - u) * x_mask + logdet_tot_q += torch.sum( + (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2] + ) + logq = ( + torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) + - logdet_tot_q + ) + + logdet_tot = 0 + z0, logdet = self.log_flow(z0, x_mask) + logdet_tot += logdet + z = torch.cat([z0, z1], 1) + for flow in self.flows: + z, logdet = flow(z, x_mask, g=x, inverse=inverse) + logdet_tot = logdet_tot + logdet + nll = ( + torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) + - logdet_tot + ) + return nll + logq # (B,) + else: + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = ( + torch.randn( + x.size(0), + 2, + x.size(2), + ).to(device=x.device, dtype=x.dtype) + * noise_scale + ) + for flow in flows: + z = flow(z, x_mask, g=x, inverse=inverse) + z0, z1 = z.split(1, 1) + logw = z0 + return logw diff --git a/egs/ljspeech/tts/vits/features.py b/egs/ljspeech/tts/vits/features.py new file mode 100644 index 0000000000..b43c7cf46d --- /dev/null +++ b/egs/ljspeech/tts/vits/features.py @@ -0,0 +1,416 @@ +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Dict, Optional, Tuple + +import librosa +import numpy as np +import torch +from torch import nn + +from icefall.utils import make_pad_mask + + +# From https://github.com/espnet/espnet/blob/master/espnet2/layers/stft.py +class Stft(nn.Module): + def __init__( + self, + n_fft: int = 512, + win_length: int = None, + hop_length: int = 128, + window: Optional[str] = "hann", + center: bool = True, + normalized: bool = False, + onesided: bool = True, + ): + super().__init__() + self.n_fft = n_fft + if win_length is None: + self.win_length = n_fft + else: + self.win_length = win_length + self.hop_length = hop_length + self.center = center + self.normalized = normalized + self.onesided = onesided + if window is not None and not hasattr(torch, f"{window}_window"): + raise ValueError(f"{window} window is not implemented") + self.window = window + + def extra_repr(self): + return ( + f"n_fft={self.n_fft}, " + f"win_length={self.win_length}, " + f"hop_length={self.hop_length}, " + f"center={self.center}, " + f"normalized={self.normalized}, " + f"onesided={self.onesided}" + ) + + def forward( + self, input: torch.Tensor, ilens: torch.Tensor = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """STFT forward function. + + Args: + input: (Batch, Nsamples) or (Batch, Nsample, Channels) + ilens: (Batch) + Returns: + output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2) + + """ + bs = input.size(0) + if input.dim() == 3: + multi_channel = True + # input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample) + input = input.transpose(1, 2).reshape(-1, input.size(1)) + else: + multi_channel = False + + # NOTE(kamo): + # The default behaviour of torch.stft is compatible with librosa.stft + # about padding and scaling. + # Note that it's different from scipy.signal.stft + + # output: (Batch, Freq, Frames, 2=real_imag) + # or (Batch, Channel, Freq, Frames, 2=real_imag) + if self.window is not None: + window_func = getattr(torch, f"{self.window}_window") + window = window_func( + self.win_length, dtype=input.dtype, device=input.device + ) + else: + window = None + + # For the compatibility of ARM devices, which do not support + # torch.stft() due to the lack of MKL (on older pytorch versions), + # there is an alternative replacement implementation with librosa. + # Note: pytorch >= 1.10.0 now has native support for FFT and STFT + # on all cpu targets including ARM. + if input.is_cuda or torch.backends.mkl.is_available(): + stft_kwargs = dict( + n_fft=self.n_fft, + win_length=self.win_length, + hop_length=self.hop_length, + center=self.center, + window=window, + normalized=self.normalized, + onesided=self.onesided, + ) + stft_kwargs["return_complex"] = True + output = torch.stft(input, **stft_kwargs) + output = torch.view_as_real(output) + else: + if self.training: + raise NotImplementedError( + "stft is implemented with librosa on this device, which does not " + "support the training mode." + ) + + # use stft_kwargs to flexibly control different PyTorch versions' kwargs + # note: librosa does not support a win_length that is < n_ftt + # but the window can be manually padded (see below). + stft_kwargs = dict( + n_fft=self.n_fft, + win_length=self.n_fft, + hop_length=self.hop_length, + center=self.center, + window=window, + pad_mode="reflect", + ) + + if window is not None: + # pad the given window to n_fft + n_pad_left = (self.n_fft - window.shape[0]) // 2 + n_pad_right = self.n_fft - window.shape[0] - n_pad_left + stft_kwargs["window"] = torch.cat( + [torch.zeros(n_pad_left), window, torch.zeros(n_pad_right)], 0 + ).numpy() + else: + win_length = ( + self.win_length if self.win_length is not None else self.n_fft + ) + stft_kwargs["window"] = torch.ones(win_length) + + output = [] + # iterate over istances in a batch + for i, instance in enumerate(input): + stft = librosa.stft(input[i].numpy(), **stft_kwargs) + output.append(torch.tensor(np.stack([stft.real, stft.imag], -1))) + output = torch.stack(output, 0) + if not self.onesided: + len_conj = self.n_fft - output.shape[1] + conj = output[:, 1 : 1 + len_conj].flip(1) + conj[:, :, :, -1].data *= -1 + output = torch.cat([output, conj], 1) + if self.normalized: + output = output * (stft_kwargs["window"].shape[0] ** (-0.5)) + + # output: (Batch, Freq, Frames, 2=real_imag) + # -> (Batch, Frames, Freq, 2=real_imag) + output = output.transpose(1, 2) + if multi_channel: + # output: (Batch * Channel, Frames, Freq, 2=real_imag) + # -> (Batch, Frame, Channel, Freq, 2=real_imag) + output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose( + 1, 2 + ) + + if ilens is not None: + if self.center: + pad = self.n_fft // 2 + ilens = ilens + 2 * pad + + olens = ( + torch.div(ilens - self.n_fft, self.hop_length, rounding_mode="trunc") + + 1 + ) + output.masked_fill_(make_pad_mask(olens), 0.0) + else: + olens = None + + return output, olens + + +# From https://github.com/espnet/espnet/blob/master/espnet2/tts/feats_extract/linear_spectrogram.py +class LinearSpectrogram(nn.Module): + """Linear amplitude spectrogram. + + Stft -> amplitude-spec + """ + + def __init__( + self, + n_fft: int = 1024, + win_length: int = None, + hop_length: int = 256, + window: Optional[str] = "hann", + center: bool = True, + normalized: bool = False, + onesided: bool = True, + ): + super().__init__() + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.window = window + self.stft = Stft( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window=window, + center=center, + normalized=normalized, + onesided=onesided, + ) + self.n_fft = n_fft + + def output_size(self) -> int: + return self.n_fft // 2 + 1 + + def get_parameters(self) -> Dict[str, Any]: + """Return the parameters required by Vocoder.""" + return dict( + n_fft=self.n_fft, + n_shift=self.hop_length, + win_length=self.win_length, + window=self.window, + ) + + def forward( + self, input: torch.Tensor, input_lengths: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Stft: time -> time-freq + input_stft, feats_lens = self.stft(input, input_lengths) + + assert input_stft.dim() >= 4, input_stft.shape + # "2" refers to the real/imag parts of Complex + assert input_stft.shape[-1] == 2, input_stft.shape + + # STFT -> Power spectrum -> Amp spectrum + # input_stft: (..., F, 2) -> (..., F) + input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2 + input_amp = torch.sqrt(torch.clamp(input_power, min=1.0e-10)) + return input_amp, feats_lens + + +# From https://github.com/espnet/espnet/blob/master/espnet2/layers/log_mel.py +class LogMel(nn.Module): + """Convert STFT to fbank feats + + The arguments is same as librosa.filters.mel + + Args: + fs: number > 0 [scalar] sampling rate of the incoming signal + n_fft: int > 0 [scalar] number of FFT components + n_mels: int > 0 [scalar] number of Mel bands to generate + fmin: float >= 0 [scalar] lowest frequency (in Hz) + fmax: float >= 0 [scalar] highest frequency (in Hz). + If `None`, use `fmax = fs / 2.0` + htk: use HTK formula instead of Slaney + """ + + def __init__( + self, + fs: int = 16000, + n_fft: int = 512, + n_mels: int = 80, + fmin: float = None, + fmax: float = None, + htk: bool = False, + log_base: float = None, + ): + super().__init__() + + fmin = 0 if fmin is None else fmin + fmax = fs / 2 if fmax is None else fmax + _mel_options = dict( + sr=fs, + n_fft=n_fft, + n_mels=n_mels, + fmin=fmin, + fmax=fmax, + htk=htk, + ) + self.mel_options = _mel_options + self.log_base = log_base + + # Note(kamo): The mel matrix of librosa is different from kaldi. + melmat = librosa.filters.mel(**_mel_options) + # melmat: (D2, D1) -> (D1, D2) + self.register_buffer("melmat", torch.from_numpy(melmat.T).float()) + + def extra_repr(self): + return ", ".join(f"{k}={v}" for k, v in self.mel_options.items()) + + def forward( + self, + feat: torch.Tensor, + ilens: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2) + mel_feat = torch.matmul(feat, self.melmat) + mel_feat = torch.clamp(mel_feat, min=1e-10) + + if self.log_base is None: + logmel_feat = mel_feat.log() + elif self.log_base == 2.0: + logmel_feat = mel_feat.log2() + elif self.log_base == 10.0: + logmel_feat = mel_feat.log10() + else: + logmel_feat = mel_feat.log() / torch.log(self.log_base) + + # Zero padding + if ilens is not None: + logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens), 0.0) + else: + ilens = feat.new_full( + [feat.size(0)], fill_value=feat.size(1), dtype=torch.long + ) + return logmel_feat, ilens + + +# From https://github.com/espnet/espnet/blob/master/espnet2/tts/feats_extract/log_mel_fbank.py +class LogMelFbank(nn.Module): + """Conventional frontend structure for TTS. + + Stft -> amplitude-spec -> Log-Mel-Fbank + """ + + def __init__( + self, + fs: int = 16000, + n_fft: int = 1024, + win_length: int = None, + hop_length: int = 256, + window: Optional[str] = "hann", + center: bool = True, + normalized: bool = False, + onesided: bool = True, + n_mels: int = 80, + fmin: Optional[int] = 80, + fmax: Optional[int] = 7600, + htk: bool = False, + log_base: Optional[float] = 10.0, + ): + super().__init__() + + self.fs = fs + self.n_mels = n_mels + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.window = window + self.fmin = fmin + self.fmax = fmax + + self.stft = Stft( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window=window, + center=center, + normalized=normalized, + onesided=onesided, + ) + + self.logmel = LogMel( + fs=fs, + n_fft=n_fft, + n_mels=n_mels, + fmin=fmin, + fmax=fmax, + htk=htk, + log_base=log_base, + ) + + def output_size(self) -> int: + return self.n_mels + + def get_parameters(self) -> Dict[str, Any]: + """Return the parameters required by Vocoder""" + return dict( + fs=self.fs, + n_fft=self.n_fft, + n_shift=self.hop_length, + window=self.window, + n_mels=self.n_mels, + win_length=self.win_length, + fmin=self.fmin, + fmax=self.fmax, + ) + + def forward( + self, input: torch.Tensor, input_lengths: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Domain-conversion: e.g. Stft: time -> time-freq + input_stft, feats_lens = self.stft(input, input_lengths) + + assert input_stft.dim() >= 4, input_stft.shape + # "2" refers to the real/imag parts of Complex + assert input_stft.shape[-1] == 2, input_stft.shape + + # NOTE(kamo): We use different definition for log-spec between TTS and ASR + # TTS: log_10(abs(stft)) + # ASR: log_e(power(stft)) + + # input_stft: (..., F, 2) -> (..., F) + input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2 + input_amp = torch.sqrt(torch.clamp(input_power, min=1.0e-10)) + input_feats, _ = self.logmel(input_amp, feats_lens) + return input_feats, feats_lens diff --git a/egs/ljspeech/tts/vits/flow.py b/egs/ljspeech/tts/vits/flow.py new file mode 100644 index 0000000000..04fb99b427 --- /dev/null +++ b/egs/ljspeech/tts/vits/flow.py @@ -0,0 +1,311 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/flow.py +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Basic Flow modules used in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + +import math +from typing import Optional, Tuple, Union + +import torch + +from transform import piecewise_rational_quadratic_transform + + +class FlipFlow(torch.nn.Module): + """Flip flow module.""" + + def forward( + self, x: torch.Tensor, *args, inverse: bool = False, **kwargs + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T). + inverse (bool): Whether to inverse the flow. + + Returns: + Tensor: Flipped tensor (B, channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + + """ + x = torch.flip(x, [1]) + if not inverse: + logdet = x.new_zeros(x.size(0)) + return x, logdet + else: + return x + + +class LogFlow(torch.nn.Module): + """Log flow module.""" + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + inverse: bool = False, + eps: float = 1e-5, + **kwargs + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T). + x_mask (Tensor): Mask tensor (B, 1, T). + inverse (bool): Whether to inverse the flow. + eps (float): Epsilon for log. + + Returns: + Tensor: Output tensor (B, channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + + """ + if not inverse: + y = torch.log(torch.clamp_min(x, eps)) * x_mask + logdet = torch.sum(-y, [1, 2]) + return y, logdet + else: + x = torch.exp(x) * x_mask + return x + + +class ElementwiseAffineFlow(torch.nn.Module): + """Elementwise affine flow module.""" + + def __init__(self, channels: int): + """Initialize ElementwiseAffineFlow module. + + Args: + channels (int): Number of channels. + + """ + super().__init__() + self.channels = channels + self.register_parameter("m", torch.nn.Parameter(torch.zeros(channels, 1))) + self.register_parameter("logs", torch.nn.Parameter(torch.zeros(channels, 1))) + + def forward( + self, x: torch.Tensor, x_mask: torch.Tensor, inverse: bool = False, **kwargs + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T). + x_lengths (Tensor): Length tensor (B,). + inverse (bool): Whether to inverse the flow. + + Returns: + Tensor: Output tensor (B, channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + + """ + if not inverse: + y = self.m + torch.exp(self.logs) * x + y = y * x_mask + logdet = torch.sum(self.logs * x_mask, [1, 2]) + return y, logdet + else: + x = (x - self.m) * torch.exp(-self.logs) * x_mask + return x + + +class Transpose(torch.nn.Module): + """Transpose module for torch.nn.Sequential().""" + + def __init__(self, dim1: int, dim2: int): + """Initialize Transpose module.""" + super().__init__() + self.dim1 = dim1 + self.dim2 = dim2 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Transpose.""" + return x.transpose(self.dim1, self.dim2) + + +class DilatedDepthSeparableConv(torch.nn.Module): + """Dilated depth-separable conv module.""" + + def __init__( + self, + channels: int, + kernel_size: int, + layers: int, + dropout_rate: float = 0.0, + eps: float = 1e-5, + ): + """Initialize DilatedDepthSeparableConv module. + + Args: + channels (int): Number of channels. + kernel_size (int): Kernel size. + layers (int): Number of layers. + dropout_rate (float): Dropout rate. + eps (float): Epsilon for layer norm. + + """ + super().__init__() + + self.convs = torch.nn.ModuleList() + for i in range(layers): + dilation = kernel_size**i + padding = (kernel_size * dilation - dilation) // 2 + self.convs += [ + torch.nn.Sequential( + torch.nn.Conv1d( + channels, + channels, + kernel_size, + groups=channels, + dilation=dilation, + padding=padding, + ), + Transpose(1, 2), + torch.nn.LayerNorm( + channels, + eps=eps, + elementwise_affine=True, + ), + Transpose(1, 2), + torch.nn.GELU(), + torch.nn.Conv1d( + channels, + channels, + 1, + ), + Transpose(1, 2), + torch.nn.LayerNorm( + channels, + eps=eps, + elementwise_affine=True, + ), + Transpose(1, 2), + torch.nn.GELU(), + torch.nn.Dropout(dropout_rate), + ) + ] + + def forward( + self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T). + x_mask (Tensor): Mask tensor (B, 1, T). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + + Returns: + Tensor: Output tensor (B, channels, T). + + """ + if g is not None: + x = x + g + for f in self.convs: + y = f(x * x_mask) + x = x + y + return x * x_mask + + +class ConvFlow(torch.nn.Module): + """Convolutional flow module.""" + + def __init__( + self, + in_channels: int, + hidden_channels: int, + kernel_size: int, + layers: int, + bins: int = 10, + tail_bound: float = 5.0, + ): + """Initialize ConvFlow module. + + Args: + in_channels (int): Number of input channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size. + layers (int): Number of layers. + bins (int): Number of bins. + tail_bound (float): Tail bound value. + + """ + super().__init__() + self.half_channels = in_channels // 2 + self.hidden_channels = hidden_channels + self.bins = bins + self.tail_bound = tail_bound + + self.input_conv = torch.nn.Conv1d( + self.half_channels, + hidden_channels, + 1, + ) + self.dds_conv = DilatedDepthSeparableConv( + hidden_channels, + kernel_size, + layers, + dropout_rate=0.0, + ) + self.proj = torch.nn.Conv1d( + hidden_channels, + self.half_channels * (bins * 3 - 1), + 1, + ) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + inverse: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T). + x_mask (Tensor): Mask tensor (B,). + g (Optional[Tensor]): Global conditioning tensor (B, channels, 1). + inverse (bool): Whether to inverse the flow. + + Returns: + Tensor: Output tensor (B, channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + + """ + xa, xb = x.split(x.size(1) // 2, 1) + h = self.input_conv(xa) + h = self.dds_conv(h, x_mask, g=g) + h = self.proj(h) * x_mask # (B, half_channels * (bins * 3 - 1), T) + + b, c, t = xa.shape + # (B, half_channels, bins * 3 - 1, T) -> (B, half_channels, T, bins * 3 - 1) + h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) + + # TODO(kan-bayashi): Understand this calculation + denom = math.sqrt(self.hidden_channels) + unnorm_widths = h[..., : self.bins] / denom + unnorm_heights = h[..., self.bins : 2 * self.bins] / denom + unnorm_derivatives = h[..., 2 * self.bins :] + xb, logdet_abs = piecewise_rational_quadratic_transform( + xb, + unnorm_widths, + unnorm_heights, + unnorm_derivatives, + inverse=inverse, + tails="linear", + tail_bound=self.tail_bound, + ) + x = torch.cat([xa, xb], 1) * x_mask + logdet = torch.sum(logdet_abs * x_mask, [1, 2]) + if not inverse: + return x, logdet + else: + return x diff --git a/egs/ljspeech/tts/vits/generator.py b/egs/ljspeech/tts/vits/generator.py new file mode 100644 index 0000000000..dbf503944f --- /dev/null +++ b/egs/ljspeech/tts/vits/generator.py @@ -0,0 +1,524 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/generator.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Generator module in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + + +import math +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F + +from icefall.utils import make_pad_mask + +from duration_predictor import StochasticDurationPredictor +from hifigan import HiFiGANGenerator +from posterior_encoder import PosteriorEncoder +from residual_coupling import ResidualAffineCouplingBlock +from text_encoder import TextEncoder +from utils import get_random_segments + + +class VITSGenerator(torch.nn.Module): + """Generator module in VITS, `Conditional Variational Autoencoder + with Adversarial Learning for End-to-End Text-to-Speech`. + """ + + def __init__( + self, + vocabs: int, + aux_channels: int = 513, + hidden_channels: int = 192, + spks: Optional[int] = None, + langs: Optional[int] = None, + spk_embed_dim: Optional[int] = None, + global_channels: int = -1, + segment_size: int = 32, + text_encoder_attention_heads: int = 2, + text_encoder_ffn_expand: int = 4, + text_encoder_blocks: int = 6, + text_encoder_dropout_rate: float = 0.1, + decoder_kernel_size: int = 7, + decoder_channels: int = 512, + decoder_upsample_scales: List[int] = [8, 8, 2, 2], + decoder_upsample_kernel_sizes: List[int] = [16, 16, 4, 4], + decoder_resblock_kernel_sizes: List[int] = [3, 7, 11], + decoder_resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + use_weight_norm_in_decoder: bool = True, + posterior_encoder_kernel_size: int = 5, + posterior_encoder_layers: int = 16, + posterior_encoder_stacks: int = 1, + posterior_encoder_base_dilation: int = 1, + posterior_encoder_dropout_rate: float = 0.0, + use_weight_norm_in_posterior_encoder: bool = True, + flow_flows: int = 4, + flow_kernel_size: int = 5, + flow_base_dilation: int = 1, + flow_layers: int = 4, + flow_dropout_rate: float = 0.0, + use_weight_norm_in_flow: bool = True, + use_only_mean_in_flow: bool = True, + stochastic_duration_predictor_kernel_size: int = 3, + stochastic_duration_predictor_dropout_rate: float = 0.5, + stochastic_duration_predictor_flows: int = 4, + stochastic_duration_predictor_dds_conv_layers: int = 3, + ): + """Initialize VITS generator module. + + Args: + vocabs (int): Input vocabulary size. + aux_channels (int): Number of acoustic feature channels. + hidden_channels (int): Number of hidden channels. + spks (Optional[int]): Number of speakers. If set to > 1, assume that the + sids will be provided as the input and use sid embedding layer. + langs (Optional[int]): Number of languages. If set to > 1, assume that the + lids will be provided as the input and use sid embedding layer. + spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0, + assume that spembs will be provided as the input. + global_channels (int): Number of global conditioning channels. + segment_size (int): Segment size for decoder. + text_encoder_attention_heads (int): Number of heads in conformer block + of text encoder. + text_encoder_ffn_expand (int): Expansion ratio of FFN in conformer block + of text encoder. + text_encoder_blocks (int): Number of conformer blocks in text encoder. + text_encoder_dropout_rate (float): Dropout rate in conformer block of + text encoder. + decoder_kernel_size (int): Decoder kernel size. + decoder_channels (int): Number of decoder initial channels. + decoder_upsample_scales (List[int]): List of upsampling scales in decoder. + decoder_upsample_kernel_sizes (List[int]): List of kernel size for + upsampling layers in decoder. + decoder_resblock_kernel_sizes (List[int]): List of kernel size for resblocks + in decoder. + decoder_resblock_dilations (List[List[int]]): List of list of dilations for + resblocks in decoder. + use_weight_norm_in_decoder (bool): Whether to apply weight normalization in + decoder. + posterior_encoder_kernel_size (int): Posterior encoder kernel size. + posterior_encoder_layers (int): Number of layers of posterior encoder. + posterior_encoder_stacks (int): Number of stacks of posterior encoder. + posterior_encoder_base_dilation (int): Base dilation of posterior encoder. + posterior_encoder_dropout_rate (float): Dropout rate for posterior encoder. + use_weight_norm_in_posterior_encoder (bool): Whether to apply weight + normalization in posterior encoder. + flow_flows (int): Number of flows in flow. + flow_kernel_size (int): Kernel size in flow. + flow_base_dilation (int): Base dilation in flow. + flow_layers (int): Number of layers in flow. + flow_dropout_rate (float): Dropout rate in flow + use_weight_norm_in_flow (bool): Whether to apply weight normalization in + flow. + use_only_mean_in_flow (bool): Whether to use only mean in flow. + stochastic_duration_predictor_kernel_size (int): Kernel size in stochastic + duration predictor. + stochastic_duration_predictor_dropout_rate (float): Dropout rate in + stochastic duration predictor. + stochastic_duration_predictor_flows (int): Number of flows in stochastic + duration predictor. + stochastic_duration_predictor_dds_conv_layers (int): Number of DDS conv + layers in stochastic duration predictor. + + """ + super().__init__() + self.segment_size = segment_size + self.text_encoder = TextEncoder( + vocabs=vocabs, + d_model=hidden_channels, + num_heads=text_encoder_attention_heads, + dim_feedforward=hidden_channels * text_encoder_ffn_expand, + num_layers=text_encoder_blocks, + dropout=text_encoder_dropout_rate, + ) + self.decoder = HiFiGANGenerator( + in_channels=hidden_channels, + out_channels=1, + channels=decoder_channels, + global_channels=global_channels, + kernel_size=decoder_kernel_size, + upsample_scales=decoder_upsample_scales, + upsample_kernel_sizes=decoder_upsample_kernel_sizes, + resblock_kernel_sizes=decoder_resblock_kernel_sizes, + resblock_dilations=decoder_resblock_dilations, + use_weight_norm=use_weight_norm_in_decoder, + ) + self.posterior_encoder = PosteriorEncoder( + in_channels=aux_channels, + out_channels=hidden_channels, + hidden_channels=hidden_channels, + kernel_size=posterior_encoder_kernel_size, + layers=posterior_encoder_layers, + stacks=posterior_encoder_stacks, + base_dilation=posterior_encoder_base_dilation, + global_channels=global_channels, + dropout_rate=posterior_encoder_dropout_rate, + use_weight_norm=use_weight_norm_in_posterior_encoder, + ) + self.flow = ResidualAffineCouplingBlock( + in_channels=hidden_channels, + hidden_channels=hidden_channels, + flows=flow_flows, + kernel_size=flow_kernel_size, + base_dilation=flow_base_dilation, + layers=flow_layers, + global_channels=global_channels, + dropout_rate=flow_dropout_rate, + use_weight_norm=use_weight_norm_in_flow, + use_only_mean=use_only_mean_in_flow, + ) + # TODO(kan-bayashi): Add deterministic version as an option + self.duration_predictor = StochasticDurationPredictor( + channels=hidden_channels, + kernel_size=stochastic_duration_predictor_kernel_size, + dropout_rate=stochastic_duration_predictor_dropout_rate, + flows=stochastic_duration_predictor_flows, + dds_conv_layers=stochastic_duration_predictor_dds_conv_layers, + global_channels=global_channels, + ) + + self.upsample_factor = int(np.prod(decoder_upsample_scales)) + self.spks = None + if spks is not None and spks > 1: + assert global_channels > 0 + self.spks = spks + self.global_emb = torch.nn.Embedding(spks, global_channels) + self.spk_embed_dim = None + if spk_embed_dim is not None and spk_embed_dim > 0: + assert global_channels > 0 + self.spk_embed_dim = spk_embed_dim + self.spemb_proj = torch.nn.Linear(spk_embed_dim, global_channels) + self.langs = None + if langs is not None and langs > 1: + assert global_channels > 0 + self.langs = langs + self.lang_emb = torch.nn.Embedding(langs, global_channels) + + # delayed import + from monotonic_align import maximum_path + + self.maximum_path = maximum_path + + def forward( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: torch.Tensor, + feats_lengths: torch.Tensor, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ], + ]: + """Calculate forward propagation. + + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, aux_channels, T_feats). + feats_lengths (Tensor): Feature length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + + Returns: + Tensor: Waveform tensor (B, 1, segment_size * upsample_factor). + Tensor: Duration negative log-likelihood (NLL) tensor (B,). + Tensor: Monotonic attention weight tensor (B, 1, T_feats, T_text). + Tensor: Segments start index tensor (B,). + Tensor: Text mask tensor (B, 1, T_text). + Tensor: Feature mask tensor (B, 1, T_feats). + tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + - Tensor: Posterior encoder hidden representation (B, H, T_feats). + - Tensor: Flow hidden representation (B, H, T_feats). + - Tensor: Expanded text encoder projected mean (B, H, T_feats). + - Tensor: Expanded text encoder projected scale (B, H, T_feats). + - Tensor: Posterior encoder projected mean (B, H, T_feats). + - Tensor: Posterior encoder projected scale (B, H, T_feats). + + """ + # forward text encoder + x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths) + + # calculate global conditioning + g = None + if self.spks is not None: + # speaker one-hot vector embedding: (B, global_channels, 1) + g = self.global_emb(sids.view(-1)).unsqueeze(-1) + if self.spk_embed_dim is not None: + # pretreined speaker embedding, e.g., X-vector (B, global_channels, 1) + g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1) + if g is None: + g = g_ + else: + g = g + g_ + if self.langs is not None: + # language one-hot vector embedding: (B, global_channels, 1) + g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1) + if g is None: + g = g_ + else: + g = g + g_ + + # forward posterior encoder + z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g) + + # forward flow + z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats) + + # monotonic alignment search + with torch.no_grad(): + # negative cross-entropy + s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text) + # (B, 1, T_text) + neg_x_ent_1 = torch.sum( + -0.5 * math.log(2 * math.pi) - logs_p, + [1], + keepdim=True, + ) + # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) + neg_x_ent_2 = torch.matmul( + -0.5 * (z_p**2).transpose(1, 2), + s_p_sq_r, + ) + # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) + neg_x_ent_3 = torch.matmul( + z_p.transpose(1, 2), + (m_p * s_p_sq_r), + ) + # (B, 1, T_text) + neg_x_ent_4 = torch.sum( + -0.5 * (m_p**2) * s_p_sq_r, + [1], + keepdim=True, + ) + # (B, T_feats, T_text) + neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4 + # (B, 1, T_feats, T_text) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + # monotonic attention weight: (B, 1, T_feats, T_text) + attn = ( + self.maximum_path( + neg_x_ent, + attn_mask.squeeze(1), + ) + .unsqueeze(1) + .detach() + ) + + # forward duration predictor + w = attn.sum(2) # (B, 1, T_text) + dur_nll = self.duration_predictor(x, x_mask, w=w, g=g) + dur_nll = dur_nll / torch.sum(x_mask) + + # expand the length to match with the feature sequence + # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) + # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) + logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) + + # get random segments + z_segments, z_start_idxs = get_random_segments( + z, + feats_lengths, + self.segment_size, + ) + + # forward decoder with random segments + wav = self.decoder(z_segments, g=g) + + return ( + wav, + dur_nll, + attn, + z_start_idxs, + x_mask, + y_mask, + (z, z_p, m_p, logs_p, m_q, logs_q), + ) + + def inference( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: Optional[torch.Tensor] = None, + feats_lengths: Optional[torch.Tensor] = None, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + dur: Optional[torch.Tensor] = None, + noise_scale: float = 0.667, + noise_scale_dur: float = 0.8, + alpha: float = 1.0, + max_len: Optional[int] = None, + use_teacher_forcing: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Run inference. + + Args: + text (Tensor): Input text index tensor (B, T_text,). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, aux_channels, T_feats,). + feats_lengths (Tensor): Feature length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + dur (Optional[Tensor]): Ground-truth duration (B, T_text,). If provided, + skip the prediction of durations (i.e., teacher forcing). + noise_scale (float): Noise scale parameter for flow. + noise_scale_dur (float): Noise scale parameter for duration predictor. + alpha (float): Alpha parameter to control the speed of generated speech. + max_len (Optional[int]): Maximum length of acoustic feature sequence. + use_teacher_forcing (bool): Whether to use teacher forcing. + + Returns: + Tensor: Generated waveform tensor (B, T_wav). + Tensor: Monotonic attention weight tensor (B, T_feats, T_text). + Tensor: Duration tensor (B, T_text). + + """ + # encoder + x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths) + g = None + if self.spks is not None: + # (B, global_channels, 1) + g = self.global_emb(sids.view(-1)).unsqueeze(-1) + if self.spk_embed_dim is not None: + # (B, global_channels, 1) + g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1) + if g is None: + g = g_ + else: + g = g + g_ + if self.langs is not None: + # (B, global_channels, 1) + g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1) + if g is None: + g = g_ + else: + g = g + g_ + + if use_teacher_forcing: + # forward posterior encoder + z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g) + + # forward flow + z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats) + + # monotonic alignment search + s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text) + # (B, 1, T_text) + neg_x_ent_1 = torch.sum( + -0.5 * math.log(2 * math.pi) - logs_p, + [1], + keepdim=True, + ) + # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) + neg_x_ent_2 = torch.matmul( + -0.5 * (z_p**2).transpose(1, 2), + s_p_sq_r, + ) + # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) + neg_x_ent_3 = torch.matmul( + z_p.transpose(1, 2), + (m_p * s_p_sq_r), + ) + # (B, 1, T_text) + neg_x_ent_4 = torch.sum( + -0.5 * (m_p**2) * s_p_sq_r, + [1], + keepdim=True, + ) + # (B, T_feats, T_text) + neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4 + # (B, 1, T_feats, T_text) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + # monotonic attention weight: (B, 1, T_feats, T_text) + attn = self.maximum_path( + neg_x_ent, + attn_mask.squeeze(1), + ).unsqueeze(1) + dur = attn.sum(2) # (B, 1, T_text) + + # forward decoder with random segments + wav = self.decoder(z * y_mask, g=g) + else: + # duration + if dur is None: + logw = self.duration_predictor( + x, + x_mask, + g=g, + inverse=True, + noise_scale=noise_scale_dur, + ) + w = torch.exp(logw) * x_mask * alpha + dur = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(dur, [1, 2]), 1).long() + y_mask = (~make_pad_mask(y_lengths)).unsqueeze(1).to(text.device) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = self._generate_path(dur, attn_mask) + + # expand the length to match with the feature sequence + # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) + m_p = torch.matmul( + attn.squeeze(1), + m_p.transpose(1, 2), + ).transpose(1, 2) + # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) + logs_p = torch.matmul( + attn.squeeze(1), + logs_p.transpose(1, 2), + ).transpose(1, 2) + + # decoder + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + z = self.flow(z_p, y_mask, g=g, inverse=True) + wav = self.decoder((z * y_mask)[:, :, :max_len], g=g) + + return wav.squeeze(1), attn.squeeze(1), dur.squeeze(1) + + def _generate_path(self, dur: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Generate path a.k.a. monotonic attention. + + Args: + dur (Tensor): Duration tensor (B, 1, T_text). + mask (Tensor): Attention mask tensor (B, 1, T_feats, T_text). + + Returns: + Tensor: Path tensor (B, 1, T_feats, T_text). + + """ + b, _, t_y, t_x = mask.shape + cum_dur = torch.cumsum(dur, -1) + cum_dur_flat = cum_dur.view(b * t_x) + path = torch.arange(t_y, dtype=dur.dtype, device=dur.device) + path = path.unsqueeze(0) < cum_dur_flat.unsqueeze(1) + path = path.view(b, t_x, t_y).to(dtype=mask.dtype) + # path will be like (t_x = 3, t_y = 5): + # [[[1., 1., 0., 0., 0.], [[[1., 1., 0., 0., 0.], + # [1., 1., 1., 1., 0.], --> [0., 0., 1., 1., 0.], + # [1., 1., 1., 1., 1.]]] [0., 0., 0., 0., 1.]]] + path = path - F.pad(path, [0, 0, 1, 0, 0, 0])[:, :-1] + return path.unsqueeze(1).transpose(2, 3) * mask diff --git a/egs/ljspeech/tts/vits/hifigan.py b/egs/ljspeech/tts/vits/hifigan.py new file mode 100644 index 0000000000..a87cb2fce7 --- /dev/null +++ b/egs/ljspeech/tts/vits/hifigan.py @@ -0,0 +1,933 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/hifigan.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""HiFi-GAN Modules. + +This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. + +""" + +import copy +import logging +from typing import Any, Dict, List, Optional + +import numpy as np +import torch +import torch.nn.functional as F + + +class HiFiGANGenerator(torch.nn.Module): + """HiFiGAN generator module.""" + + def __init__( + self, + in_channels: int = 80, + out_channels: int = 1, + channels: int = 512, + global_channels: int = -1, + kernel_size: int = 7, + upsample_scales: List[int] = [8, 8, 2, 2], + upsample_kernel_sizes: List[int] = [16, 16, 4, 4], + resblock_kernel_sizes: List[int] = [3, 7, 11], + resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + use_additional_convs: bool = True, + bias: bool = True, + nonlinear_activation: str = "LeakyReLU", + nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1}, + use_weight_norm: bool = True, + ): + """Initialize HiFiGANGenerator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + channels (int): Number of hidden representation channels. + global_channels (int): Number of global conditioning channels. + kernel_size (int): Kernel size of initial and final conv layer. + upsample_scales (List[int]): List of upsampling scales. + upsample_kernel_sizes (List[int]): List of kernel sizes for upsample layers. + resblock_kernel_sizes (List[int]): List of kernel sizes for residual blocks. + resblock_dilations (List[List[int]]): List of list of dilations for residual + blocks. + use_additional_convs (bool): Whether to use additional conv layers in + residual blocks. + bias (bool): Whether to add bias parameter in convolution layers. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation + function. + use_weight_norm (bool): Whether to use weight norm. If set to true, it will + be applied to all of the conv layers. + + """ + super().__init__() + + # check hyperparameters are valid + assert kernel_size % 2 == 1, "Kernel size must be odd number." + assert len(upsample_scales) == len(upsample_kernel_sizes) + assert len(resblock_dilations) == len(resblock_kernel_sizes) + + # define modules + self.upsample_factor = int(np.prod(upsample_scales) * out_channels) + self.num_upsamples = len(upsample_kernel_sizes) + self.num_blocks = len(resblock_kernel_sizes) + self.input_conv = torch.nn.Conv1d( + in_channels, + channels, + kernel_size, + 1, + padding=(kernel_size - 1) // 2, + ) + self.upsamples = torch.nn.ModuleList() + self.blocks = torch.nn.ModuleList() + for i in range(len(upsample_kernel_sizes)): + assert upsample_kernel_sizes[i] == 2 * upsample_scales[i] + self.upsamples += [ + torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + torch.nn.ConvTranspose1d( + channels // (2**i), + channels // (2 ** (i + 1)), + upsample_kernel_sizes[i], + upsample_scales[i], + padding=upsample_scales[i] // 2 + upsample_scales[i] % 2, + output_padding=upsample_scales[i] % 2, + ), + ) + ] + for j in range(len(resblock_kernel_sizes)): + self.blocks += [ + ResidualBlock( + kernel_size=resblock_kernel_sizes[j], + channels=channels // (2 ** (i + 1)), + dilations=resblock_dilations[j], + bias=bias, + use_additional_convs=use_additional_convs, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + ) + ] + self.output_conv = torch.nn.Sequential( + # NOTE(kan-bayashi): follow official implementation but why + # using different slope parameter here? (0.1 vs. 0.01) + torch.nn.LeakyReLU(), + torch.nn.Conv1d( + channels // (2 ** (i + 1)), + out_channels, + kernel_size, + 1, + padding=(kernel_size - 1) // 2, + ), + torch.nn.Tanh(), + ) + if global_channels > 0: + self.global_conv = torch.nn.Conv1d(global_channels, channels, 1) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + # reset parameters + self.reset_parameters() + + def forward( + self, c: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Calculate forward propagation. + + Args: + c (Tensor): Input tensor (B, in_channels, T). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + + Returns: + Tensor: Output tensor (B, out_channels, T). + + """ + c = self.input_conv(c) + if g is not None: + c = c + self.global_conv(g) + for i in range(self.num_upsamples): + c = self.upsamples[i](c) + cs = 0.0 # initialize + for j in range(self.num_blocks): + cs += self.blocks[i * self.num_blocks + j](c) + c = cs / self.num_blocks + c = self.output_conv(c) + + return c + + def reset_parameters(self): + """Reset parameters. + + This initialization follows the official implementation manner. + https://github.com/jik876/hifi-gan/blob/master/models.py + + """ + + def _reset_parameters(m: torch.nn.Module): + if isinstance(m, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)): + m.weight.data.normal_(0.0, 0.01) + logging.debug(f"Reset parameters in {m}.") + + self.apply(_reset_parameters) + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m: torch.nn.Module): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv1d) or isinstance( + m, torch.nn.ConvTranspose1d + ): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def inference( + self, c: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Perform inference. + + Args: + c (torch.Tensor): Input tensor (T, in_channels). + g (Optional[Tensor]): Global conditioning tensor (global_channels, 1). + + Returns: + Tensor: Output tensor (T ** upsample_factor, out_channels). + + """ + if g is not None: + g = g.unsqueeze(0) + c = self.forward(c.transpose(1, 0).unsqueeze(0), g=g) + return c.squeeze(0).transpose(1, 0) + + +class ResidualBlock(torch.nn.Module): + """Residual block module in HiFiGAN.""" + + def __init__( + self, + kernel_size: int = 3, + channels: int = 512, + dilations: List[int] = [1, 3, 5], + bias: bool = True, + use_additional_convs: bool = True, + nonlinear_activation: str = "LeakyReLU", + nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1}, + ): + """Initialize ResidualBlock module. + + Args: + kernel_size (int): Kernel size of dilation convolution layer. + channels (int): Number of channels for convolution layer. + dilations (List[int]): List of dilation factors. + use_additional_convs (bool): Whether to use additional convolution layers. + bias (bool): Whether to add bias parameter in convolution layers. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation + function. + + """ + super().__init__() + self.use_additional_convs = use_additional_convs + self.convs1 = torch.nn.ModuleList() + if use_additional_convs: + self.convs2 = torch.nn.ModuleList() + assert kernel_size % 2 == 1, "Kernel size must be odd number." + for dilation in dilations: + self.convs1 += [ + torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation, + bias=bias, + padding=(kernel_size - 1) // 2 * dilation, + ), + ) + ] + if use_additional_convs: + self.convs2 += [ + torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + torch.nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + bias=bias, + padding=(kernel_size - 1) // 2, + ), + ) + ] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, channels, T). + + Returns: + Tensor: Output tensor (B, channels, T). + + """ + for idx in range(len(self.convs1)): + xt = self.convs1[idx](x) + if self.use_additional_convs: + xt = self.convs2[idx](xt) + x = xt + x + return x + + +class HiFiGANPeriodDiscriminator(torch.nn.Module): + """HiFiGAN period discriminator module.""" + + def __init__( + self, + in_channels: int = 1, + out_channels: int = 1, + period: int = 3, + kernel_sizes: List[int] = [5, 3], + channels: int = 32, + downsample_scales: List[int] = [3, 3, 3, 3, 1], + max_downsample_channels: int = 1024, + bias: bool = True, + nonlinear_activation: str = "LeakyReLU", + nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1}, + use_weight_norm: bool = True, + use_spectral_norm: bool = False, + ): + """Initialize HiFiGANPeriodDiscriminator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + period (int): Period. + kernel_sizes (list): Kernel sizes of initial conv layers and the final conv + layer. + channels (int): Number of initial channels. + downsample_scales (List[int]): List of downsampling scales. + max_downsample_channels (int): Number of maximum downsampling channels. + use_additional_convs (bool): Whether to use additional conv layers in + residual blocks. + bias (bool): Whether to add bias parameter in convolution layers. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation + function. + use_weight_norm (bool): Whether to use weight norm. + If set to true, it will be applied to all of the conv layers. + use_spectral_norm (bool): Whether to use spectral norm. + If set to true, it will be applied to all of the conv layers. + + """ + super().__init__() + assert len(kernel_sizes) == 2 + assert kernel_sizes[0] % 2 == 1, "Kernel size must be odd number." + assert kernel_sizes[1] % 2 == 1, "Kernel size must be odd number." + + self.period = period + self.convs = torch.nn.ModuleList() + in_chs = in_channels + out_chs = channels + for downsample_scale in downsample_scales: + self.convs += [ + torch.nn.Sequential( + torch.nn.Conv2d( + in_chs, + out_chs, + (kernel_sizes[0], 1), + (downsample_scale, 1), + padding=((kernel_sizes[0] - 1) // 2, 0), + ), + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + ) + ] + in_chs = out_chs + # NOTE(kan-bayashi): Use downsample_scale + 1? + out_chs = min(out_chs * 4, max_downsample_channels) + self.output_conv = torch.nn.Conv2d( + out_chs, + out_channels, + (kernel_sizes[1] - 1, 1), + 1, + padding=((kernel_sizes[1] - 1) // 2, 0), + ) + + if use_weight_norm and use_spectral_norm: + raise ValueError("Either use use_weight_norm or use_spectral_norm.") + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + # apply spectral norm + if use_spectral_norm: + self.apply_spectral_norm() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Calculate forward propagation. + + Args: + c (Tensor): Input tensor (B, in_channels, T). + + Returns: + list: List of each layer's tensors. + + """ + # transform 1d to 2d -> (B, C, T/P, P) + b, c, t = x.shape + if t % self.period != 0: + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t += n_pad + x = x.view(b, c, t // self.period, self.period) + + # forward conv + outs = [] + for layer in self.convs: + x = layer(x) + outs += [x] + x = self.output_conv(x) + x = torch.flatten(x, 1, -1) + outs += [x] + + return outs + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv2d): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def apply_spectral_norm(self): + """Apply spectral normalization module from all of the layers.""" + + def _apply_spectral_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv2d): + torch.nn.utils.spectral_norm(m) + logging.debug(f"Spectral norm is applied to {m}.") + + self.apply(_apply_spectral_norm) + + +class HiFiGANMultiPeriodDiscriminator(torch.nn.Module): + """HiFiGAN multi-period discriminator module.""" + + def __init__( + self, + periods: List[int] = [2, 3, 5, 7, 11], + discriminator_params: Dict[str, Any] = { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 32, + "downsample_scales": [3, 3, 3, 3, 1], + "max_downsample_channels": 1024, + "bias": True, + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + ): + """Initialize HiFiGANMultiPeriodDiscriminator module. + + Args: + periods (List[int]): List of periods. + discriminator_params (Dict[str, Any]): Parameters for hifi-gan period + discriminator module. The period parameter will be overwritten. + + """ + super().__init__() + self.discriminators = torch.nn.ModuleList() + for period in periods: + params = copy.deepcopy(discriminator_params) + params["period"] = period + self.discriminators += [HiFiGANPeriodDiscriminator(**params)] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List: List of list of each discriminator outputs, which consists of each + layer output tensors. + + """ + outs = [] + for f in self.discriminators: + outs += [f(x)] + + return outs + + +class HiFiGANScaleDiscriminator(torch.nn.Module): + """HiFi-GAN scale discriminator module.""" + + def __init__( + self, + in_channels: int = 1, + out_channels: int = 1, + kernel_sizes: List[int] = [15, 41, 5, 3], + channels: int = 128, + max_downsample_channels: int = 1024, + max_groups: int = 16, + bias: int = True, + downsample_scales: List[int] = [2, 2, 4, 4, 1], + nonlinear_activation: str = "LeakyReLU", + nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1}, + use_weight_norm: bool = True, + use_spectral_norm: bool = False, + ): + """Initilize HiFiGAN scale discriminator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_sizes (List[int]): List of four kernel sizes. The first will be used + for the first conv layer, and the second is for downsampling part, and + the remaining two are for the last two output layers. + channels (int): Initial number of channels for conv layer. + max_downsample_channels (int): Maximum number of channels for downsampling + layers. + bias (bool): Whether to add bias parameter in convolution layers. + downsample_scales (List[int]): List of downsampling scales. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation + function. + use_weight_norm (bool): Whether to use weight norm. If set to true, it will + be applied to all of the conv layers. + use_spectral_norm (bool): Whether to use spectral norm. If set to true, it + will be applied to all of the conv layers. + + """ + super().__init__() + self.layers = torch.nn.ModuleList() + + # check kernel size is valid + assert len(kernel_sizes) == 4 + for ks in kernel_sizes: + assert ks % 2 == 1 + + # add first layer + self.layers += [ + torch.nn.Sequential( + torch.nn.Conv1d( + in_channels, + channels, + # NOTE(kan-bayashi): Use always the same kernel size + kernel_sizes[0], + bias=bias, + padding=(kernel_sizes[0] - 1) // 2, + ), + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + ) + ] + + # add downsample layers + in_chs = channels + out_chs = channels + # NOTE(kan-bayashi): Remove hard coding? + groups = 4 + for downsample_scale in downsample_scales: + self.layers += [ + torch.nn.Sequential( + torch.nn.Conv1d( + in_chs, + out_chs, + kernel_size=kernel_sizes[1], + stride=downsample_scale, + padding=(kernel_sizes[1] - 1) // 2, + groups=groups, + bias=bias, + ), + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + ) + ] + in_chs = out_chs + # NOTE(kan-bayashi): Remove hard coding? + out_chs = min(in_chs * 2, max_downsample_channels) + # NOTE(kan-bayashi): Remove hard coding? + groups = min(groups * 4, max_groups) + + # add final layers + out_chs = min(in_chs * 2, max_downsample_channels) + self.layers += [ + torch.nn.Sequential( + torch.nn.Conv1d( + in_chs, + out_chs, + kernel_size=kernel_sizes[2], + stride=1, + padding=(kernel_sizes[2] - 1) // 2, + bias=bias, + ), + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + ) + ] + self.layers += [ + torch.nn.Conv1d( + out_chs, + out_channels, + kernel_size=kernel_sizes[3], + stride=1, + padding=(kernel_sizes[3] - 1) // 2, + bias=bias, + ), + ] + + if use_weight_norm and use_spectral_norm: + raise ValueError("Either use use_weight_norm or use_spectral_norm.") + + # apply weight norm + self.use_weight_norm = use_weight_norm + if use_weight_norm: + self.apply_weight_norm() + + # apply spectral norm + self.use_spectral_norm = use_spectral_norm + if use_spectral_norm: + self.apply_spectral_norm() + + # backward compatibility + self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook) + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List[Tensor]: List of output tensors of each layer. + + """ + outs = [] + for f in self.layers: + x = f(x) + outs += [x] + + return outs + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv1d): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def apply_spectral_norm(self): + """Apply spectral normalization module from all of the layers.""" + + def _apply_spectral_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv1d): + torch.nn.utils.spectral_norm(m) + logging.debug(f"Spectral norm is applied to {m}.") + + self.apply(_apply_spectral_norm) + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def remove_spectral_norm(self): + """Remove spectral normalization module from all of the layers.""" + + def _remove_spectral_norm(m): + try: + logging.debug(f"Spectral norm is removed from {m}.") + torch.nn.utils.remove_spectral_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_spectral_norm) + + def _load_state_dict_pre_hook( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + """Fix the compatibility of weight / spectral normalization issue. + + Some pretrained models are trained with configs that use weight / spectral + normalization, but actually, the norm is not applied. This causes the mismatch + of the parameters with configs. To solve this issue, when parameter mismatch + happens in loading pretrained model, we remove the norm from the current model. + + See also: + - https://github.com/espnet/espnet/pull/5240 + - https://github.com/espnet/espnet/pull/5249 + - https://github.com/kan-bayashi/ParallelWaveGAN/pull/409 + + """ + current_module_keys = [x for x in state_dict.keys() if x.startswith(prefix)] + if self.use_weight_norm and any( + [k.endswith("weight") for k in current_module_keys] + ): + logging.warning( + "It seems weight norm is not applied in the pretrained model but the" + " current model uses it. To keep the compatibility, we remove the norm" + " from the current model. This may cause unexpected behavior due to the" + " parameter mismatch in finetuning. To avoid this issue, please change" + " the following parameters in config to false:\n" + " - discriminator_params.follow_official_norm\n" + " - discriminator_params.scale_discriminator_params.use_weight_norm\n" + " - discriminator_params.scale_discriminator_params.use_spectral_norm\n" + "\n" + "See also:\n" + " - https://github.com/espnet/espnet/pull/5240\n" + " - https://github.com/espnet/espnet/pull/5249" + ) + self.remove_weight_norm() + self.use_weight_norm = False + for k in current_module_keys: + if k.endswith("weight_g") or k.endswith("weight_v"): + del state_dict[k] + + if self.use_spectral_norm and any( + [k.endswith("weight") for k in current_module_keys] + ): + logging.warning( + "It seems spectral norm is not applied in the pretrained model but the" + " current model uses it. To keep the compatibility, we remove the norm" + " from the current model. This may cause unexpected behavior due to the" + " parameter mismatch in finetuning. To avoid this issue, please change" + " the following parameters in config to false:\n" + " - discriminator_params.follow_official_norm\n" + " - discriminator_params.scale_discriminator_params.use_weight_norm\n" + " - discriminator_params.scale_discriminator_params.use_spectral_norm\n" + "\n" + "See also:\n" + " - https://github.com/espnet/espnet/pull/5240\n" + " - https://github.com/espnet/espnet/pull/5249" + ) + self.remove_spectral_norm() + self.use_spectral_norm = False + for k in current_module_keys: + if ( + k.endswith("weight_u") + or k.endswith("weight_v") + or k.endswith("weight_orig") + ): + del state_dict[k] + + +class HiFiGANMultiScaleDiscriminator(torch.nn.Module): + """HiFi-GAN multi-scale discriminator module.""" + + def __init__( + self, + scales: int = 3, + downsample_pooling: str = "AvgPool1d", + # follow the official implementation setting + downsample_pooling_params: Dict[str, Any] = { + "kernel_size": 4, + "stride": 2, + "padding": 2, + }, + discriminator_params: Dict[str, Any] = { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 128, + "max_downsample_channels": 1024, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 2, 4, 4, 1], + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + }, + follow_official_norm: bool = False, + ): + """Initilize HiFiGAN multi-scale discriminator module. + + Args: + scales (int): Number of multi-scales. + downsample_pooling (str): Pooling module name for downsampling of the + inputs. + downsample_pooling_params (Dict[str, Any]): Parameters for the above pooling + module. + discriminator_params (Dict[str, Any]): Parameters for hifi-gan scale + discriminator module. + follow_official_norm (bool): Whether to follow the norm setting of the + official implementaion. The first discriminator uses spectral norm + and the other discriminators use weight norm. + + """ + super().__init__() + self.discriminators = torch.nn.ModuleList() + + # add discriminators + for i in range(scales): + params = copy.deepcopy(discriminator_params) + if follow_official_norm: + if i == 0: + params["use_weight_norm"] = False + params["use_spectral_norm"] = True + else: + params["use_weight_norm"] = True + params["use_spectral_norm"] = False + self.discriminators += [HiFiGANScaleDiscriminator(**params)] + self.pooling = None + if scales > 1: + self.pooling = getattr(torch.nn, downsample_pooling)( + **downsample_pooling_params + ) + + def forward(self, x: torch.Tensor) -> List[List[torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List[List[torch.Tensor]]: List of list of each discriminator outputs, + which consists of eachlayer output tensors. + + """ + outs = [] + for f in self.discriminators: + outs += [f(x)] + if self.pooling is not None: + x = self.pooling(x) + + return outs + + +class HiFiGANMultiScaleMultiPeriodDiscriminator(torch.nn.Module): + """HiFi-GAN multi-scale + multi-period discriminator module.""" + + def __init__( + self, + # Multi-scale discriminator related + scales: int = 3, + scale_downsample_pooling: str = "AvgPool1d", + scale_downsample_pooling_params: Dict[str, Any] = { + "kernel_size": 4, + "stride": 2, + "padding": 2, + }, + scale_discriminator_params: Dict[str, Any] = { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 128, + "max_downsample_channels": 1024, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 2, 4, 4, 1], + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + }, + follow_official_norm: bool = True, + # Multi-period discriminator related + periods: List[int] = [2, 3, 5, 7, 11], + period_discriminator_params: Dict[str, Any] = { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 32, + "downsample_scales": [3, 3, 3, 3, 1], + "max_downsample_channels": 1024, + "bias": True, + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + ): + """Initilize HiFiGAN multi-scale + multi-period discriminator module. + + Args: + scales (int): Number of multi-scales. + scale_downsample_pooling (str): Pooling module name for downsampling of the + inputs. + scale_downsample_pooling_params (dict): Parameters for the above pooling + module. + scale_discriminator_params (dict): Parameters for hifi-gan scale + discriminator module. + follow_official_norm (bool): Whether to follow the norm setting of the + official implementaion. The first discriminator uses spectral norm and + the other discriminators use weight norm. + periods (list): List of periods. + period_discriminator_params (dict): Parameters for hifi-gan period + discriminator module. The period parameter will be overwritten. + + """ + super().__init__() + self.msd = HiFiGANMultiScaleDiscriminator( + scales=scales, + downsample_pooling=scale_downsample_pooling, + downsample_pooling_params=scale_downsample_pooling_params, + discriminator_params=scale_discriminator_params, + follow_official_norm=follow_official_norm, + ) + self.mpd = HiFiGANMultiPeriodDiscriminator( + periods=periods, + discriminator_params=period_discriminator_params, + ) + + def forward(self, x: torch.Tensor) -> List[List[torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List[List[Tensor]]: List of list of each discriminator outputs, + which consists of each layer output tensors. Multi scale and + multi period ones are concatenated. + + """ + msd_outs = self.msd(x) + mpd_outs = self.mpd(x) + return msd_outs + mpd_outs diff --git a/egs/ljspeech/tts/vits/loss.py b/egs/ljspeech/tts/vits/loss.py new file mode 100644 index 0000000000..d322f5e053 --- /dev/null +++ b/egs/ljspeech/tts/vits/loss.py @@ -0,0 +1,332 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""HiFiGAN-related loss modules. + +This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. + +""" + +from typing import List, Optional, Tuple, Union + +import torch +import torch.distributions as D +import torch.nn.functional as F + +from lhotse.features.kaldi import Wav2LogFilterBank + + +class GeneratorAdversarialLoss(torch.nn.Module): + """Generator adversarial loss module.""" + + def __init__( + self, + average_by_discriminators: bool = True, + loss_type: str = "mse", + ): + """Initialize GeneratorAversarialLoss module. + + Args: + average_by_discriminators (bool): Whether to average the loss by + the number of discriminators. + loss_type (str): Loss type, "mse" or "hinge". + + """ + super().__init__() + self.average_by_discriminators = average_by_discriminators + assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." + if loss_type == "mse": + self.criterion = self._mse_loss + else: + self.criterion = self._hinge_loss + + def forward( + self, + outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], + ) -> torch.Tensor: + """Calcualate generator adversarial loss. + + Args: + outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator + outputs, list of discriminator outputs, or list of list of discriminator + outputs.. + + Returns: + Tensor: Generator adversarial loss value. + + """ + if isinstance(outputs, (tuple, list)): + adv_loss = 0.0 + for i, outputs_ in enumerate(outputs): + if isinstance(outputs_, (tuple, list)): + # NOTE(kan-bayashi): case including feature maps + outputs_ = outputs_[-1] + adv_loss += self.criterion(outputs_) + if self.average_by_discriminators: + adv_loss /= i + 1 + else: + adv_loss = self.criterion(outputs) + + return adv_loss + + def _mse_loss(self, x): + return F.mse_loss(x, x.new_ones(x.size())) + + def _hinge_loss(self, x): + return -x.mean() + + +class DiscriminatorAdversarialLoss(torch.nn.Module): + """Discriminator adversarial loss module.""" + + def __init__( + self, + average_by_discriminators: bool = True, + loss_type: str = "mse", + ): + """Initialize DiscriminatorAversarialLoss module. + + Args: + average_by_discriminators (bool): Whether to average the loss by + the number of discriminators. + loss_type (str): Loss type, "mse" or "hinge". + + """ + super().__init__() + self.average_by_discriminators = average_by_discriminators + assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." + if loss_type == "mse": + self.fake_criterion = self._mse_fake_loss + self.real_criterion = self._mse_real_loss + else: + self.fake_criterion = self._hinge_fake_loss + self.real_criterion = self._hinge_real_loss + + def forward( + self, + outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], + outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Calcualate discriminator adversarial loss. + + Args: + outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator + outputs, list of discriminator outputs, or list of list of discriminator + outputs calculated from generator. + outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator + outputs, list of discriminator outputs, or list of list of discriminator + outputs calculated from groundtruth. + + Returns: + Tensor: Discriminator real loss value. + Tensor: Discriminator fake loss value. + + """ + if isinstance(outputs, (tuple, list)): + real_loss = 0.0 + fake_loss = 0.0 + for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)): + if isinstance(outputs_hat_, (tuple, list)): + # NOTE(kan-bayashi): case including feature maps + outputs_hat_ = outputs_hat_[-1] + outputs_ = outputs_[-1] + real_loss += self.real_criterion(outputs_) + fake_loss += self.fake_criterion(outputs_hat_) + if self.average_by_discriminators: + fake_loss /= i + 1 + real_loss /= i + 1 + else: + real_loss = self.real_criterion(outputs) + fake_loss = self.fake_criterion(outputs_hat) + + return real_loss, fake_loss + + def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor: + return F.mse_loss(x, x.new_ones(x.size())) + + def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor: + return F.mse_loss(x, x.new_zeros(x.size())) + + def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor: + return -torch.mean(torch.min(x - 1, x.new_zeros(x.size()))) + + def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor: + return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size()))) + + +class FeatureMatchLoss(torch.nn.Module): + """Feature matching loss module.""" + + def __init__( + self, + average_by_layers: bool = True, + average_by_discriminators: bool = True, + include_final_outputs: bool = False, + ): + """Initialize FeatureMatchLoss module. + + Args: + average_by_layers (bool): Whether to average the loss by the number + of layers. + average_by_discriminators (bool): Whether to average the loss by + the number of discriminators. + include_final_outputs (bool): Whether to include the final output of + each discriminator for loss calculation. + + """ + super().__init__() + self.average_by_layers = average_by_layers + self.average_by_discriminators = average_by_discriminators + self.include_final_outputs = include_final_outputs + + def forward( + self, + feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]], + feats: Union[List[List[torch.Tensor]], List[torch.Tensor]], + ) -> torch.Tensor: + """Calculate feature matching loss. + + Args: + feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of + discriminator outputs or list of discriminator outputs calcuated + from generator's outputs. + feats (Union[List[List[Tensor]], List[Tensor]]): List of list of + discriminator outputs or list of discriminator outputs calcuated + from groundtruth.. + + Returns: + Tensor: Feature matching loss value. + + """ + feat_match_loss = 0.0 + for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)): + feat_match_loss_ = 0.0 + if not self.include_final_outputs: + feats_hat_ = feats_hat_[:-1] + feats_ = feats_[:-1] + for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)): + feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach()) + if self.average_by_layers: + feat_match_loss_ /= j + 1 + feat_match_loss += feat_match_loss_ + if self.average_by_discriminators: + feat_match_loss /= i + 1 + + return feat_match_loss + + +class MelSpectrogramLoss(torch.nn.Module): + """Mel-spectrogram loss.""" + + def __init__( + self, + sampling_rate: int = 22050, + frame_length: int = 1024, # in samples + frame_shift: int = 256, # in samples + n_mels: int = 80, + use_fft_mag: bool = True, + ): + super().__init__() + self.wav_to_mel = Wav2LogFilterBank( + sampling_rate=sampling_rate, + frame_length=frame_length / sampling_rate, # in second + frame_shift=frame_shift / sampling_rate, # in second + use_fft_mag=use_fft_mag, + num_filters=n_mels, + ) + + def forward( + self, + y_hat: torch.Tensor, + y: torch.Tensor, + ) -> torch.Tensor: + """Calculate Mel-spectrogram loss. + + Args: + y_hat (Tensor): Generated waveform tensor (B, 1, T). + y (Tensor): Groundtruth waveform tensor (B, 1, T). + spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor + (B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth + waveform. + + Returns: + Tensor: Mel-spectrogram loss value. + + """ + mel_hat = self.wav_to_mel(y_hat.squeeze(1)) + mel = self.wav_to_mel(y.squeeze(1)) + mel_loss = F.l1_loss(mel_hat, mel) + + return mel_loss + + +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/loss.py + +"""VITS-related loss modules. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + + +class KLDivergenceLoss(torch.nn.Module): + """KL divergence loss.""" + + def forward( + self, + z_p: torch.Tensor, + logs_q: torch.Tensor, + m_p: torch.Tensor, + logs_p: torch.Tensor, + z_mask: torch.Tensor, + ) -> torch.Tensor: + """Calculate KL divergence loss. + + Args: + z_p (Tensor): Flow hidden representation (B, H, T_feats). + logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats). + m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats). + logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats). + z_mask (Tensor): Mask tensor (B, 1, T_feats). + + Returns: + Tensor: KL divergence loss. + + """ + z_p = z_p.float() + logs_q = logs_q.float() + m_p = m_p.float() + logs_p = logs_p.float() + z_mask = z_mask.float() + kl = logs_p - logs_q - 0.5 + kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) + kl = torch.sum(kl * z_mask) + loss = kl / torch.sum(z_mask) + + return loss + + +class KLDivergenceLossWithoutFlow(torch.nn.Module): + """KL divergence loss without flow.""" + + def forward( + self, + m_q: torch.Tensor, + logs_q: torch.Tensor, + m_p: torch.Tensor, + logs_p: torch.Tensor, + ) -> torch.Tensor: + """Calculate KL divergence loss without flow. + + Args: + m_q (Tensor): Posterior encoder projected mean (B, H, T_feats). + logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats). + m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats). + logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats). + """ + posterior_norm = D.Normal(m_q, torch.exp(logs_q)) + prior_norm = D.Normal(m_p, torch.exp(logs_p)) + loss = D.kl_divergence(posterior_norm, prior_norm).mean() + return loss diff --git a/egs/ljspeech/tts/vits/models.py b/egs/ljspeech/tts/vits/models.py new file mode 100644 index 0000000000..f5acdeb2be --- /dev/null +++ b/egs/ljspeech/tts/vits/models.py @@ -0,0 +1,534 @@ +import copy +import math +import torch +from torch import nn +from torch.nn import functional as F + +import commons +import modules +import attentions +import monotonic_align + +from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm +from commons import init_weights, get_padding + + +class StochasticDurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): + super().__init__() + filter_channels = in_channels # it needs to be removed from future version. + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.log_flow = modules.Log() + self.flows = nn.ModuleList() + self.flows.append(modules.ElementwiseAffine(2)) + for i in range(n_flows): + self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.flows.append(modules.Flip()) + + self.post_pre = nn.Conv1d(1, filter_channels, 1) + self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) + self.post_flows = nn.ModuleList() + self.post_flows.append(modules.ElementwiseAffine(2)) + for i in range(4): + self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.post_flows.append(modules.Flip()) + + self.pre = nn.Conv1d(in_channels, filter_channels, 1) + self.proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, filter_channels, 1) + + def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): + x = torch.detach(x) + x = self.pre(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.convs(x, x_mask) + x = self.proj(x) * x_mask + + if not reverse: + flows = self.flows + assert w is not None + + logdet_tot_q = 0 + h_w = self.post_pre(w) + h_w = self.post_convs(h_w, x_mask) + h_w = self.post_proj(h_w) * x_mask + e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask + z_q = e_q + for flow in self.post_flows: + z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) + logdet_tot_q += logdet_q + z_u, z1 = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (w - u) * x_mask + logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2]) + logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q + + logdet_tot = 0 + z0, logdet = self.log_flow(z0, x_mask) + logdet_tot += logdet + z = torch.cat([z0, z1], 1) + for flow in flows: + z, logdet = flow(z, x_mask, g=x, reverse=reverse) + logdet_tot = logdet_tot + logdet + nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot + return nll + logq # [b] + else: + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale + for flow in flows: + z = flow(z, x_mask, g=x, reverse=reverse) + z0, z1 = torch.split(z, [1, 1], 1) + logw = z0 + return logw + + +class DurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0): + super().__init__() + + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.gin_channels = gin_channels + + self.drop = nn.Dropout(p_dropout) + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2) + self.norm_1 = modules.LayerNorm(filter_channels) + self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2) + self.norm_2 = modules.LayerNorm(filter_channels) + self.proj = nn.Conv1d(filter_channels, 1, 1) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, in_channels, 1) + + def forward(self, x, x_mask, g=None): + x = torch.detach(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + + +class TextEncoder(nn.Module): + def __init__(self, + n_vocab, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout): + super().__init__() + self.n_vocab = n_vocab + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.emb = nn.Embedding(n_vocab, hidden_channels) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) + + self.encoder = attentions.Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths): + x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + + x = self.encoder(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + + m, logs = torch.split(stats, self.out_channels, dim=1) + return x, m, logs, x_mask + + +class ResidualCouplingBlock(nn.Module): + def __init__(self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + n_flows=4, + gin_channels=0): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True)) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class PosteriorEncoder(nn.Module): + def __init__(self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + return z, m, logs, x_mask + + +class Generator(torch.nn.Module): + def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) + resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append(weight_norm( + ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), + k, u, padding=(k-u)//2))) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel//(2**(i+1)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward(self, x, g=None): + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, modules.LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i*self.num_kernels+j](x) + else: + xs += self.resblocks[i*self.num_kernels+j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminator, self).__init__() + periods = [2,3,5,7,11] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + + +class SynthesizerTrn(nn.Module): + """ + Synthesizer for Training + """ + + def __init__(self, + n_vocab, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + n_speakers=0, + gin_channels=0, + use_sdp=True, + **kwargs): + + super().__init__() + self.n_vocab = n_vocab + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.n_speakers = n_speakers + self.gin_channels = gin_channels + + self.use_sdp = use_sdp + + self.enc_p = TextEncoder(n_vocab, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) + self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) + self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) + + if use_sdp: + self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels) + else: + self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels) + + if n_speakers > 1: + self.emb_g = nn.Embedding(n_speakers, gin_channels) + + def forward(self, x, x_lengths, y, y_lengths, sid=None): + + x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) + if self.n_speakers > 0: + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + else: + g = None + + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) + z_p = self.flow(z, y_mask, g=g) + + with torch.no_grad(): + # negative cross-entropy + s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t] + neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t_s] + neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).transpose(1, 2), s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] + neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] + neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s] + neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 + + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach() + + w = attn.sum(2) + if self.use_sdp: + l_length = self.dp(x, x_mask, w, g=g) + l_length = l_length / torch.sum(x_mask) + else: + logw_ = torch.log(w + 1e-6) * x_mask + logw = self.dp(x, x_mask, g=g) + l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging + + # expand prior + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) + logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) + + z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size) + o = self.dec(z_slice, g=g) + return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) + + def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None): + x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) + if self.n_speakers > 0: + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + else: + g = None + + if self.use_sdp: + logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) + else: + logw = self.dp(x, x_mask, g=g) + w = torch.exp(logw) * x_mask * length_scale + w_ceil = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = commons.generate_path(w_ceil, attn_mask) + + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] + logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] + + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + z = self.flow(z_p, y_mask, g=g, reverse=True) + o = self.dec((z * y_mask)[:,:,:max_len], g=g) + return o, attn, y_mask, (z, z_p, m_p, logs_p) + + def voice_conversion(self, y, y_lengths, sid_src, sid_tgt): + assert self.n_speakers > 0, "n_speakers have to be larger than 0." + g_src = self.emb_g(sid_src).unsqueeze(-1) + g_tgt = self.emb_g(sid_tgt).unsqueeze(-1) + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src) + z_p = self.flow(z, y_mask, g=g_src) + z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) + o_hat = self.dec(z_hat * y_mask, g=g_tgt) + return o_hat, y_mask, (z, z_p, z_hat) + diff --git a/egs/ljspeech/tts/vits/monotonic_align/__init__.py b/egs/ljspeech/tts/vits/monotonic_align/__init__.py new file mode 100644 index 0000000000..2b35654f51 --- /dev/null +++ b/egs/ljspeech/tts/vits/monotonic_align/__init__.py @@ -0,0 +1,81 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/__init__.py + +"""Maximum path calculation module. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + +import warnings + +import numpy as np +import torch +from numba import njit, prange + +try: + from .core import maximum_path_c + + is_cython_avalable = True +except ImportError: + is_cython_avalable = False + warnings.warn( + "Cython version is not available. Fallback to 'EXPERIMETAL' numba version. " + "If you want to use the cython version, please build it as follows: " + "`cd espnet2/gan_tts/vits/monotonic_align; python setup.py build_ext --inplace`" + ) + + +def maximum_path(neg_x_ent: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: + """Calculate maximum path. + + Args: + neg_x_ent (Tensor): Negative X entropy tensor (B, T_feats, T_text). + attn_mask (Tensor): Attention mask (B, T_feats, T_text). + + Returns: + Tensor: Maximum path tensor (B, T_feats, T_text). + + """ + device, dtype = neg_x_ent.device, neg_x_ent.dtype + neg_x_ent = neg_x_ent.cpu().numpy().astype(np.float32) + path = np.zeros(neg_x_ent.shape, dtype=np.int32) + t_t_max = attn_mask.sum(1)[:, 0].cpu().numpy().astype(np.int32) + t_s_max = attn_mask.sum(2)[:, 0].cpu().numpy().astype(np.int32) + if is_cython_avalable: + maximum_path_c(path, neg_x_ent, t_t_max, t_s_max) + else: + maximum_path_numba(path, neg_x_ent, t_t_max, t_s_max) + + return torch.from_numpy(path).to(device=device, dtype=dtype) + + +@njit +def maximum_path_each_numba(path, value, t_y, t_x, max_neg_val=-np.inf): + """Calculate a single maximum path with numba.""" + index = t_x - 1 + for y in range(t_y): + for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + if x == y: + v_cur = max_neg_val + else: + v_cur = value[y - 1, x] + if x == 0: + if y == 0: + v_prev = 0.0 + else: + v_prev = max_neg_val + else: + v_prev = value[y - 1, x - 1] + value[y, x] += max(v_prev, v_cur) + + for y in range(t_y - 1, -1, -1): + path[y, index] = 1 + if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]): + index = index - 1 + + +@njit(parallel=True) +def maximum_path_numba(paths, values, t_ys, t_xs): + """Calculate batch maximum path with numba.""" + for i in prange(paths.shape[0]): + maximum_path_each_numba(paths[i], values[i], t_ys[i], t_xs[i]) diff --git a/egs/ljspeech/tts/vits/monotonic_align/core.pyx b/egs/ljspeech/tts/vits/monotonic_align/core.pyx new file mode 100644 index 0000000000..c02c2d02e7 --- /dev/null +++ b/egs/ljspeech/tts/vits/monotonic_align/core.pyx @@ -0,0 +1,51 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/core.pyx + +"""Maximum path calculation module with cython optimization. + +This code is copied from https://github.com/jaywalnut310/vits and modifed code format. + +""" + +cimport cython + +from cython.parallel import prange + + +@cython.boundscheck(False) +@cython.wraparound(False) +cdef void maximum_path_each(int[:, ::1] path, float[:, ::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil: + cdef int x + cdef int y + cdef float v_prev + cdef float v_cur + cdef float tmp + cdef int index = t_x - 1 + + for y in range(t_y): + for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + if x == y: + v_cur = max_neg_val + else: + v_cur = value[y - 1, x] + if x == 0: + if y == 0: + v_prev = 0.0 + else: + v_prev = max_neg_val + else: + v_prev = value[y - 1, x - 1] + value[y, x] += max(v_prev, v_cur) + + for y in range(t_y - 1, -1, -1): + path[y, index] = 1 + if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]): + index = index - 1 + + +@cython.boundscheck(False) +@cython.wraparound(False) +cpdef void maximum_path_c(int[:, :, ::1] paths, float[:, :, ::1] values, int[::1] t_ys, int[::1] t_xs) nogil: + cdef int b = paths.shape[0] + cdef int i + for i in prange(b, nogil=True): + maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i]) diff --git a/egs/ljspeech/tts/vits/monotonic_align/setup.py b/egs/ljspeech/tts/vits/monotonic_align/setup.py new file mode 100644 index 0000000000..33d75e1765 --- /dev/null +++ b/egs/ljspeech/tts/vits/monotonic_align/setup.py @@ -0,0 +1,31 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/setup.py +"""Setup cython code.""" + +from Cython.Build import cythonize +from setuptools import Extension, setup +from setuptools.command.build_ext import build_ext as _build_ext + + +class build_ext(_build_ext): + """Overwrite build_ext.""" + + def finalize_options(self): + """Prevent numpy from thinking it is still in its setup process.""" + _build_ext.finalize_options(self) + __builtins__.__NUMPY_SETUP__ = False + import numpy + + self.include_dirs.append(numpy.get_include()) + + +exts = [ + Extension( + name="core", + sources=["core.pyx"], + ) +] +setup( + name="monotonic_align", + ext_modules=cythonize(exts, language_level=3), + cmdclass={"build_ext": build_ext}, +) diff --git a/egs/ljspeech/tts/vits/posterior_encoder.py b/egs/ljspeech/tts/vits/posterior_encoder.py new file mode 100644 index 0000000000..c78fd647fe --- /dev/null +++ b/egs/ljspeech/tts/vits/posterior_encoder.py @@ -0,0 +1,117 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/posterior_encoder.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Posterior encoder module in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + +from typing import Optional, Tuple + +import torch + +from icefall.utils import make_pad_mask +from wavenet import WaveNet, Conv1d + + +class PosteriorEncoder(torch.nn.Module): + """Posterior encoder module in VITS. + + This is a module of posterior encoder described in `Conditional Variational + Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_. + + .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End + Text-to-Speech`: https://arxiv.org/abs/2006.04558 + """ + + def __init__( + self, + in_channels: int = 513, + out_channels: int = 192, + hidden_channels: int = 192, + kernel_size: int = 5, + layers: int = 16, + stacks: int = 1, + base_dilation: int = 1, + global_channels: int = -1, + dropout_rate: float = 0.0, + bias: bool = True, + use_weight_norm: bool = True, + ): + """Initilialize PosteriorEncoder module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size in WaveNet. + layers (int): Number of layers of WaveNet. + stacks (int): Number of repeat stacking of WaveNet. + base_dilation (int): Base dilation factor. + global_channels (int): Number of global conditioning channels. + dropout_rate (float): Dropout rate. + bias (bool): Whether to use bias parameters in conv. + use_weight_norm (bool): Whether to apply weight norm. + + """ + super().__init__() + + # define modules + self.input_conv = Conv1d(in_channels, hidden_channels, 1) + self.encoder = WaveNet( + in_channels=-1, + out_channels=-1, + kernel_size=kernel_size, + layers=layers, + stacks=stacks, + base_dilation=base_dilation, + residual_channels=hidden_channels, + aux_channels=-1, + gate_channels=hidden_channels * 2, + skip_channels=hidden_channels, + global_channels=global_channels, + dropout_rate=dropout_rate, + bias=bias, + use_weight_norm=use_weight_norm, + use_first_conv=False, + use_last_conv=False, + scale_residual=False, + scale_skip_connect=True, + ) + self.proj = Conv1d(hidden_channels, out_channels * 2, 1) + + def forward( + self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T_feats). + x_lengths (Tensor): Length tensor (B,). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + + Returns: + Tensor: Encoded hidden representation tensor (B, out_channels, T_feats). + Tensor: Projected mean tensor (B, out_channels, T_feats). + Tensor: Projected scale tensor (B, out_channels, T_feats). + Tensor: Mask tensor for input tensor (B, 1, T_feats). + + """ + x_mask = ( + (~make_pad_mask(x_lengths)) + .unsqueeze(1) + .to( + dtype=x.dtype, + device=x.device, + ) + ) + x = self.input_conv(x) * x_mask + x = self.encoder(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = stats.split(stats.size(1) // 2, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + + return z, m, logs, x_mask diff --git a/egs/ljspeech/tts/vits/residual_coupling.py b/egs/ljspeech/tts/vits/residual_coupling.py new file mode 100644 index 0000000000..48e7483164 --- /dev/null +++ b/egs/ljspeech/tts/vits/residual_coupling.py @@ -0,0 +1,229 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/residual_coupling.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Residual affine coupling modules in VITS. + +This code is based on https://github.com/jaywalnut310/vits. + +""" + +from typing import Optional, Tuple, Union + +import torch + +from flow import FlipFlow +from wavenet import WaveNet + + +class ResidualAffineCouplingBlock(torch.nn.Module): + """Residual affine coupling block module. + + This is a module of residual affine coupling block, which used as "Flow" in + `Conditional Variational Autoencoder with Adversarial Learning for End-to-End + Text-to-Speech`_. + + .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End + Text-to-Speech`: https://arxiv.org/abs/2006.04558 + + """ + + def __init__( + self, + in_channels: int = 192, + hidden_channels: int = 192, + flows: int = 4, + kernel_size: int = 5, + base_dilation: int = 1, + layers: int = 4, + global_channels: int = -1, + dropout_rate: float = 0.0, + use_weight_norm: bool = True, + bias: bool = True, + use_only_mean: bool = True, + ): + """Initilize ResidualAffineCouplingBlock module. + + Args: + in_channels (int): Number of input channels. + hidden_channels (int): Number of hidden channels. + flows (int): Number of flows. + kernel_size (int): Kernel size for WaveNet. + base_dilation (int): Base dilation factor for WaveNet. + layers (int): Number of layers of WaveNet. + stacks (int): Number of stacks of WaveNet. + global_channels (int): Number of global channels. + dropout_rate (float): Dropout rate. + use_weight_norm (bool): Whether to use weight normalization in WaveNet. + bias (bool): Whether to use bias paramters in WaveNet. + use_only_mean (bool): Whether to estimate only mean. + + """ + super().__init__() + + self.flows = torch.nn.ModuleList() + for i in range(flows): + self.flows += [ + ResidualAffineCouplingLayer( + in_channels=in_channels, + hidden_channels=hidden_channels, + kernel_size=kernel_size, + base_dilation=base_dilation, + layers=layers, + stacks=1, + global_channels=global_channels, + dropout_rate=dropout_rate, + use_weight_norm=use_weight_norm, + bias=bias, + use_only_mean=use_only_mean, + ) + ] + self.flows += [FlipFlow()] + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + inverse: bool = False, + ) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T). + x_lengths (Tensor): Length tensor (B,). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + inverse (bool): Whether to inverse the flow. + + Returns: + Tensor: Output tensor (B, in_channels, T). + + """ + if not inverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, inverse=inverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, inverse=inverse) + return x + + +class ResidualAffineCouplingLayer(torch.nn.Module): + """Residual affine coupling layer.""" + + def __init__( + self, + in_channels: int = 192, + hidden_channels: int = 192, + kernel_size: int = 5, + base_dilation: int = 1, + layers: int = 5, + stacks: int = 1, + global_channels: int = -1, + dropout_rate: float = 0.0, + use_weight_norm: bool = True, + bias: bool = True, + use_only_mean: bool = True, + ): + """Initialzie ResidualAffineCouplingLayer module. + + Args: + in_channels (int): Number of input channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size for WaveNet. + base_dilation (int): Base dilation factor for WaveNet. + layers (int): Number of layers of WaveNet. + stacks (int): Number of stacks of WaveNet. + global_channels (int): Number of global channels. + dropout_rate (float): Dropout rate. + use_weight_norm (bool): Whether to use weight normalization in WaveNet. + bias (bool): Whether to use bias paramters in WaveNet. + use_only_mean (bool): Whether to estimate only mean. + + """ + assert in_channels % 2 == 0, "in_channels should be divisible by 2" + super().__init__() + self.half_channels = in_channels // 2 + self.use_only_mean = use_only_mean + + # define modules + self.input_conv = torch.nn.Conv1d( + self.half_channels, + hidden_channels, + 1, + ) + self.encoder = WaveNet( + in_channels=-1, + out_channels=-1, + kernel_size=kernel_size, + layers=layers, + stacks=stacks, + base_dilation=base_dilation, + residual_channels=hidden_channels, + aux_channels=-1, + gate_channels=hidden_channels * 2, + skip_channels=hidden_channels, + global_channels=global_channels, + dropout_rate=dropout_rate, + bias=bias, + use_weight_norm=use_weight_norm, + use_first_conv=False, + use_last_conv=False, + scale_residual=False, + scale_skip_connect=True, + ) + if use_only_mean: + self.proj = torch.nn.Conv1d( + hidden_channels, + self.half_channels, + 1, + ) + else: + self.proj = torch.nn.Conv1d( + hidden_channels, + self.half_channels * 2, + 1, + ) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + inverse: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T). + x_lengths (Tensor): Length tensor (B,). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + inverse (bool): Whether to inverse the flow. + + Returns: + Tensor: Output tensor (B, in_channels, T). + Tensor: Log-determinant tensor for NLL (B,) if not inverse. + + """ + xa, xb = x.split(x.size(1) // 2, dim=1) + h = self.input_conv(xa) * x_mask + h = self.encoder(h, x_mask, g=g) + stats = self.proj(h) * x_mask + if not self.use_only_mean: + m, logs = stats.split(stats.size(1) // 2, dim=1) + else: + m = stats + logs = torch.zeros_like(m) + + if not inverse: + xb = m + xb * torch.exp(logs) * x_mask + x = torch.cat([xa, xb], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + xb = (xb - m) * torch.exp(-logs) * x_mask + x = torch.cat([xa, xb], 1) + return x diff --git a/egs/ljspeech/tts/vits/symbols.py b/egs/ljspeech/tts/vits/symbols.py new file mode 100644 index 0000000000..70c2868f4f --- /dev/null +++ b/egs/ljspeech/tts/vits/symbols.py @@ -0,0 +1,17 @@ +# https://github.com/jaywalnut310/vits/blob/main/text/symbols.py +""" from https://github.com/keithito/tacotron """ + +''' +Defines the set of symbols used in text input to the model. +''' +_pad = '_' +_punctuation = ';:,.!?¡¿—…"«»“” ' +_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' +_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" + + +# Export all symbols: +symbol_table = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) + +# Special symbol ids +SPACE_ID = symbol_table.index(" ") diff --git a/egs/ljspeech/tts/vits/text_encoder.py b/egs/ljspeech/tts/vits/text_encoder.py new file mode 100644 index 0000000000..fbf9b16a30 --- /dev/null +++ b/egs/ljspeech/tts/vits/text_encoder.py @@ -0,0 +1,534 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Text encoder module in VITS. + +This code is based on + - https://github.com/jaywalnut310/vits + - https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/text_encoder.py +""" + +import copy +import math +from typing import Optional, Tuple + +import torch +from torch import Tensor, nn + +from icefall.utils import make_pad_mask + + +class TextEncoder(torch.nn.Module): + """Text encoder module in VITS. + + This is a module of text encoder described in `Conditional Variational Autoencoder + with Adversarial Learning for End-to-End Text-to-Speech`. + """ + + def __init__( + self, + vocabs: int, + d_model: int = 192, + num_heads: int = 2, + dim_feedforward: int = 768, + num_layers: int = 6, + dropout: float = 0.1, + ): + """Initialize TextEncoder module. + + Args: + vocabs (int): Vocabulary size. + d_model (int): attention dimension + num_heads (int): number of attention heads + dim_feedforward (int): feedforward dimention + num_layers (int): number of encoder layers + dropout (float): dropout rate + """ + super().__init__() + self.d_model = d_model + + # define modules + self.emb = torch.nn.Embedding(vocabs, d_model) + torch.nn.init.normal_(self.emb.weight, 0.0, d_model**-0.5) + + self.encoder = Transformer( + d_model=d_model, + num_heads=num_heads, + dim_feedforward=dim_feedforward, + num_layers=num_layers, + dropout=dropout, + ) + + self.proj = torch.nn.Conv1d(d_model, d_model * 2, 1) + + def forward( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Calculate forward propagation. + + Args: + x (Tensor): Input index tensor (B, T_text). + x_lengths (Tensor): Length tensor (B,). + + Returns: + Tensor: Encoded hidden representation (B, attention_dim, T_text). + Tensor: Projected mean tensor (B, attention_dim, T_text). + Tensor: Projected scale tensor (B, attention_dim, T_text). + Tensor: Mask tensor for input tensor (B, 1, T_text). + + """ + # (B, T_text, embed_dim) + x = self.emb(x) * math.sqrt(self.d_model) + + assert x.size(1) == x_lengths.max().item() + + # (B, T_text) + pad_mask = make_pad_mask(x_lengths) + + # encoder assume the channel last (B, T_text, embed_dim) + x = self.encoder(x, key_padding_mask=pad_mask) + + # convert the channel first (B, embed_dim, T_text) + x = x.transpose(1, 2) + non_pad_mask = (~pad_mask).unsqueeze(1) + stats = self.proj(x) * non_pad_mask + m, logs = stats.split(stats.size(1) // 2, dim=1) + + return x, m, logs, non_pad_mask + + +class Transformer(nn.Module): + """ + Args: + d_model (int): attention dimension + num_heads (int): number of attention heads + dim_feedforward (int): feedforward dimention + num_layers (int): number of encoder layers + dropout (float): dropout rate + """ + + def __init__( + self, + d_model: int = 192, + num_heads: int = 2, + dim_feedforward: int = 768, + num_layers: int = 6, + dropout: float = 0.1, + ) -> None: + super().__init__() + + self.num_layers = num_layers + self.d_model = d_model + + self.encoder_pos = RelPositionalEncoding(d_model, dropout) + + encoder_layer = TransformerEncoderLayer( + d_model=d_model, + num_heads=num_heads, + dim_feedforward=dim_feedforward, + dropout=dropout, + ) + self.encoder = TransformerEncoder(encoder_layer, num_layers) + self.after_norm = nn.LayerNorm(d_model) + + def forward( + self, x: Tensor, key_padding_mask: Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + lengths: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + """ + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + x = self.encoder( + x, pos_emb, key_padding_mask=key_padding_mask + ) # (T, N, C) + + x = self.after_norm(x) + + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + return x + + +class TransformerEncoderLayer(nn.Module): + """ + TransformerEncoderLayer is made up of self-attn and feedforward. + + Args: + d_model: the number of expected features in the input. + num_heads: the number of heads in the multi-head attention models. + dim_feedforward: the dimension of the feed-forward network model. + dropout: the dropout value (default=0.1). + """ + + def __init__( + self, + d_model: int, + num_heads: int, + dim_feedforward: int, + dropout: float = 0.1, + ) -> None: + super(TransformerEncoderLayer, self).__init__() + + self.self_attn = RelPositionMultiheadAttention(d_model, num_heads, dropout=dropout) + + self.feed_forward = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.norm_ff = nn.LayerNorm(d_model) # for the FNN module + self.norm_mha = nn.LayerNorm(d_model) # for the MHA module + self.norm_final = nn.LayerNorm(d_model) # for the final output of the block + + self.dropout = nn.Dropout(dropout) + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the transformer encoder layer. + + Args: + src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim). + pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim). + key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len) + """ + # multi-head self-attention module + src_attn = self.self_attn( + self.norm_mha(src), + pos_emb=pos_emb, + key_padding_mask=key_padding_mask, + ) + src = src + self.dropout(src_attn) + + # feed-forward module + src = src + self.dropout(self.feed_forward(self.norm_ff(src))) + + src = self.norm_final(src) + + return src + + +class TransformerEncoder(nn.Module): + r"""TransformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the TransformerEncoderLayer class. + num_layers: the number of sub-encoder-layers in the encoder. + """ + + def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: + super().__init__() + + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim). + pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim). + key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len) + """ + output = src + + for layer_index, mod in enumerate(self.layers): + output = mod( + output, + pos_emb, + key_padding_mask=key_padding_mask, + ) + + return output + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: Tensor) -> None: + """Reset the positional encodings.""" + x_size = x.size(1) + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x_size * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vector and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). + """ + self.extend_pe(x) + x = x * self.xscale + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x.size(1) + + 1 : self.pe.size(1) // 2 # noqa E203 + + x.size(1), + ] + return self.dropout(x), self.dropout(pos_emb) + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + + # linear transformation for positional encoding. + self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + + self._reset_parameters() + + def _reset_parameters(self) -> None: + nn.init.xavier_uniform_(self.in_proj.weight) + nn.init.constant_(self.in_proj.bias, 0.0) + nn.init.constant_(self.out_proj.bias, 0.0) + + nn.init.xavier_uniform_(self.pos_bias_u) + nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x: Tensor) -> Tensor: + """Compute relative positional encoding. + + Args: + x: Input tensor (batch, head, seq_len, 2*seq_len-1). + + Returns: + Tensor: tensor of shape (batch, head, seq_len, seq_len) + """ + (batch_size, num_heads, seq_len, n) = x.shape + + assert n == 2 * seq_len - 1, f"{n} == 2 * {seq_len} - 1" + + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, seq_len, seq_len), + (batch_stride, head_stride, time_stride - n_stride, n_stride), + storage_offset=n_stride * (seq_len - 1), + ) + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Args: + x: Input tensor of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, (1, 2*seq_len-1, pos_dim) + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + Its shape is (batch_size, seq_len). + + Outputs: + A tensor of shape (seq_len, batch_size, embed_dim). + """ + seq_len, batch_size, _ = x.shape + scaling = float(self.head_dim) ** -0.5 + + q, k, v = self.in_proj(x).chunk(3, dim=-1) + + q = q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim) + k = k.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim) + v = v.contiguous().view(seq_len, batch_size * self.num_heads, self.head_dim).transpose(0, 1) + + q = q.transpose(0, 1) # (batch_size, seq_len, num_head, head_dim) + + p = self.linear_pos(pos_emb).view(pos_emb.size(0), -1, self.num_heads, self.head_dim) + # (1, 2*seq_len, num_head, head_dim) -> (1, num_head, head_dim, 2*seq_len-1) + p = p.permute(0, 2, 3, 1) + + # (batch_size, num_head, seq_len, head_dim) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + k = k.permute(1, 2, 3, 0) # (batch_size, num_head, head_dim, seq_len) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch_size, num_head, seq_len, seq_len) + + # compute matrix b and matrix d + matrix_bd = torch.matmul(q_with_bias_v, p) # (batch_size, num_head, seq_len, 2*seq_len-1) + matrix_bd = self.rel_shift(matrix_bd) # (batch_size, num_head, seq_len, seq_len) + + # (batch_size, num_head, seq_len, seq_len) + attn_output_weights = (matrix_ac + matrix_bd) * scaling + attn_output_weights = attn_output_weights.view(batch_size * self.num_heads, seq_len, seq_len) + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, seq_len) + attn_output_weights = attn_output_weights.view( + batch_size, self.num_heads, seq_len, seq_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + batch_size * self.num_heads, seq_len, seq_len + ) + + attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=self.dropout, training=self.training + ) + + # (batch_size * num_head, seq_len, head_dim) + attn_output = torch.bmm(attn_output_weights, v) + assert attn_output.shape == (batch_size * self.num_heads, seq_len, self.head_dim) + + attn_output = ( + attn_output.transpose(0, 1).contiguous().view(seq_len, batch_size, self.embed_dim) + ) + # (seq_len, batch_size, embed_dim) + attn_output = self.out_proj(attn_output) + + return attn_output + + +class Swish(nn.Module): + """Construct an Swish object.""" + + def forward(self, x: Tensor) -> Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) + + +def _test_text_encoder(): + vocabs = 500 + d_model = 192 + batch_size = 5 + seq_len = 100 + + m = TextEncoder(vocabs=vocabs, d_model=d_model) + x, m, logs, mask = m( + x=torch.randint(low=0, high=vocabs, size=(batch_size, seq_len)), + x_lengths=torch.full((batch_size,), seq_len), + ) + print(x.shape, m.shape, logs.shape, mask.shape) + + +if __name__ == "__main__": + _test_text_encoder() diff --git a/egs/ljspeech/tts/vits/train.py b/egs/ljspeech/tts/vits/train.py new file mode 100755 index 0000000000..8fd2a596ab --- /dev/null +++ b/egs/ljspeech/tts/vits/train.py @@ -0,0 +1,896 @@ +#!/usr/bin/env python3 +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Union + +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from torch.optim import Optimizer +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import LJSpeechTtsDataModule + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + setup_logger, + str2bool, +) + +from symbols import symbol_table +from utils import ( + MetricsTracker, + prepare_token_batch, + save_checkpoint, + save_checkpoint_with_global_batch_idx, +) +from vits import VITS + +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lr", type=float, default=2.0e-4, help="The base learning rate." + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + # training params + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": -1, # 0 + "log_interval": 50, + # "reset_interval": 200, + "valid_interval": 500, + "env_info": get_env_info(), + "sampling_rate": 22050, + "feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length + "vocab_size": len(symbol_table), + "mel_loss_params": { + "frame_shift": 256, + "frame_length": 1024, + "n_mels": 80, + }, + "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss + "lambda_mel": 45.0, # loss scaling coefficient for Mel loss + "lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss + "lambda_dur": 1.0, # loss scaling coefficient for duration loss + "lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, model: nn.Module +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint(filename, model=model) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def get_model(params: AttributeDict) -> nn.Module: + model = VITS( + vocab_size=params.vocab_size, + feature_dim=params.feature_dim, + sampling_rate=params.sampling_rate, + mel_loss_params=params.mel_loss_params, + lambda_adv=params.lambda_adv, + lambda_mel=params.lambda_mel, + lambda_feat_match=params.lambda_feat_match, + lambda_dur=params.lambda_dur, + lambda_kl=params.lambda_kl, + ) + return model + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + + with torch.no_grad(): + for batch_idx, batch in enumerate(valid_dl): + batch_size = len(batch["text"]) + audio = batch["audio"].to(device) + features = batch["features"].to(device) + audio_lens = batch["audio_lens"].to(device) + features_lens = batch["features_lens"].to(device) + text = batch["text"] + tokens, tokens_lens = prepare_token_batch(text) + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + + loss_info = MetricsTracker() + loss_info['samples'] = batch_size + + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=False, + ) + assert loss_d.requires_grad is False + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + + # forward generator + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=True, + ) + assert loss_g.requires_grad is False + for k, v in stats_g.items(): + loss_info[k] = v * batch_size + + # summary stats + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(device) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer_g: Optimizer, + optimizer_d: Optimizer, + scheduler_g: LRSchedulerType, + scheduler_d: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + params=params, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + + batch_size = len(batch["text"]) + audio = batch["audio"].to(device) + features = batch["features"].to(device) + audio_lens = batch["audio_lens"].to(device) + features_lens = batch["features_lens"].to(device) + text = batch["text"] + tokens, tokens_lens = prepare_token_batch(text) + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + + loss_info = MetricsTracker() + loss_info['samples'] = batch_size + + try: + with autocast(enabled=params.use_fp16): + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=False, + ) + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + # update discriminator + optimizer_d.zero_grad() + scaler.scale(loss_d).backward() + scaler.step(optimizer_d) + + with autocast(enabled=params.use_fp16): + # forward generator + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=True, + ) + for k, v in stats_g.items(): + loss_info[k] = v * batch_size + # update generator + optimizer_g.zero_grad() + scaler.scale(loss_g).backward() + scaler.step(optimizer_g) + scaler.update() + + # summary stats + # tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + tot_loss = tot_loss + loss_info + except: # noqa + save_bad_model() + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + params=params, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + # if batch_idx % 100 == 0 and params.use_fp16: + if params.batch_idx_train % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + # if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + # if batch_idx % params.log_interval == 0: + if params.batch_idx_train % params.log_interval == 0: + cur_lr_g = max(scheduler_g.get_last_lr()) + cur_lr_d = max(scheduler_d.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " + f"loss[{loss_info}], tot_loss[{tot_loss}], " + f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate_g", cur_lr_g, params.batch_idx_train + ) + tb_writer.add_scalar( + "train/learning_rate_d", cur_lr_d, params.batch_idx_train + ) + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + # if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer_g: torch.optim.Optimizer, + optimizer_d: torch.optim.Optimizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + audio = batch["audio"].to(device) + features = batch["features"].to(device) + audio_lens = batch["audio_lens"].to(device) + features_lens = batch["features_lens"].to(device) + text = batch["text"] + tokens, tokens_lens = prepare_token_batch(text) + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + try: + # for discriminator + with autocast(enabled=params.use_fp16): + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=False, + ) + optimizer_d.zero_grad() + loss_d.backward() + # for generator + with autocast(enabled=params.use_fp16): + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=True, + ) + optimizer_g.zero_grad() + loss_g.backward() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + generator = model.generator + discriminator = model.discriminator + + num_param_g = sum([p.numel() for p in generator.parameters()]) + logging.info(f"Number of parameters in generator: {num_param_g}") + num_param_d = sum([p.numel() for p in discriminator.parameters()]) + logging.info(f"Number of parameters in discriminator: {num_param_d}") + logging.info(f"Total number of parameters: {num_param_g + num_param_d}") + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer_g = torch.optim.AdamW( + generator.parameters(), + lr=params.lr, + betas=(0.8, 0.99), + eps=1e-9, + weight_decay=0, + ) + optimizer_d = torch.optim.AdamW( + discriminator.parameters(), + lr=params.lr, + betas=(0.8, 0.99), + eps=1e-9, + weight_decay=0, + ) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875) + + if checkpoints is not None: + # load state_dict for optimizers + if "optimizer_g" in checkpoints: + logging.info("Loading optimizer_g state dict") + optimizer_g.load_state_dict(checkpoints["optimizer_g"]) + if "optimizer_d" in checkpoints: + logging.info("Loading optimizer_d state dict") + optimizer_d.load_state_dict(checkpoints["optimizer_d"]) + + # load state_dict for schedulers + if "scheduler_g" in checkpoints: + logging.info("Loading scheduler_g state dict") + scheduler_g.load_state_dict(checkpoints["scheduler_g"]) + if "scheduler_d" in checkpoints: + logging.info("Loading scheduler_d state dict") + scheduler_d.load_state_dict(checkpoints["scheduler_d"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + ljspeech = LJSpeechTtsDataModule(args) + + train_cuts = ljspeech.train_cuts() + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_dl = ljspeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = ljspeech.valid_cuts() + valid_dl = ljspeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + # step per epoch + scheduler_g.step() + scheduler_d.step() + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + LJSpeechTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/ljspeech/tts/vits/transform.py b/egs/ljspeech/tts/vits/transform.py new file mode 100644 index 0000000000..6858de2ab0 --- /dev/null +++ b/egs/ljspeech/tts/vits/transform.py @@ -0,0 +1,217 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/transform.py +"""Flow-related transformation. + +This code is derived from https://github.com/bayesiains/nflows. + +""" + +import numpy as np +import torch +from torch.nn import functional as F + +DEFAULT_MIN_BIN_WIDTH = 1e-3 +DEFAULT_MIN_BIN_HEIGHT = 1e-3 +DEFAULT_MIN_DERIVATIVE = 1e-3 + + +# TODO(kan-bayashi): Documentation and type hint +def piecewise_rational_quadratic_transform( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails=None, + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if tails is None: + spline_fn = rational_quadratic_spline + spline_kwargs = {} + else: + spline_fn = unconstrained_rational_quadratic_spline + spline_kwargs = {"tails": tails, "tail_bound": tail_bound} + + outputs, logabsdet = spline_fn( + inputs=inputs, + unnormalized_widths=unnormalized_widths, + unnormalized_heights=unnormalized_heights, + unnormalized_derivatives=unnormalized_derivatives, + inverse=inverse, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + **spline_kwargs + ) + return outputs, logabsdet + + +# TODO(kan-bayashi): Documentation and type hint +def unconstrained_rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails="linear", + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) + outside_interval_mask = ~inside_interval_mask + + outputs = torch.zeros_like(inputs) + logabsdet = torch.zeros_like(inputs) + + if tails == "linear": + unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) + constant = np.log(np.exp(1 - min_derivative) - 1) + unnormalized_derivatives[..., 0] = constant + unnormalized_derivatives[..., -1] = constant + + outputs[outside_interval_mask] = inputs[outside_interval_mask] + logabsdet[outside_interval_mask] = 0 + else: + raise RuntimeError("{} tails are not implemented.".format(tails)) + + ( + outputs[inside_interval_mask], + logabsdet[inside_interval_mask], + ) = rational_quadratic_spline( + inputs=inputs[inside_interval_mask], + unnormalized_widths=unnormalized_widths[inside_interval_mask, :], + unnormalized_heights=unnormalized_heights[inside_interval_mask, :], + unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], + inverse=inverse, + left=-tail_bound, + right=tail_bound, + bottom=-tail_bound, + top=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + ) + + return outputs, logabsdet + + +# TODO(kan-bayashi): Documentation and type hint +def rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + left=0.0, + right=1.0, + bottom=0.0, + top=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if torch.min(inputs) < left or torch.max(inputs) > right: + raise ValueError("Input to a transform is not within its domain") + + num_bins = unnormalized_widths.shape[-1] + + if min_bin_width * num_bins > 1.0: + raise ValueError("Minimal bin width too large for the number of bins") + if min_bin_height * num_bins > 1.0: + raise ValueError("Minimal bin height too large for the number of bins") + + widths = F.softmax(unnormalized_widths, dim=-1) + widths = min_bin_width + (1 - min_bin_width * num_bins) * widths + cumwidths = torch.cumsum(widths, dim=-1) + cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) + cumwidths = (right - left) * cumwidths + left + cumwidths[..., 0] = left + cumwidths[..., -1] = right + widths = cumwidths[..., 1:] - cumwidths[..., :-1] + + derivatives = min_derivative + F.softplus(unnormalized_derivatives) + + heights = F.softmax(unnormalized_heights, dim=-1) + heights = min_bin_height + (1 - min_bin_height * num_bins) * heights + cumheights = torch.cumsum(heights, dim=-1) + cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) + cumheights = (top - bottom) * cumheights + bottom + cumheights[..., 0] = bottom + cumheights[..., -1] = top + heights = cumheights[..., 1:] - cumheights[..., :-1] + + if inverse: + bin_idx = _searchsorted(cumheights, inputs)[..., None] + else: + bin_idx = _searchsorted(cumwidths, inputs)[..., None] + + input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] + input_bin_widths = widths.gather(-1, bin_idx)[..., 0] + + input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] + delta = heights / widths + input_delta = delta.gather(-1, bin_idx)[..., 0] + + input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] + input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] + + input_heights = heights.gather(-1, bin_idx)[..., 0] + + if inverse: + a = (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + input_heights * (input_delta - input_derivatives) + b = input_heights * input_derivatives - (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + c = -input_delta * (inputs - input_cumheights) + + discriminant = b.pow(2) - 4 * a * c + assert (discriminant >= 0).all() + + root = (2 * c) / (-b - torch.sqrt(discriminant)) + outputs = root * input_bin_widths + input_cumwidths + + theta_one_minus_theta = root * (1 - root) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta + ) + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * root.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - root).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, -logabsdet + else: + theta = (inputs - input_cumwidths) / input_bin_widths + theta_one_minus_theta = theta * (1 - theta) + + numerator = input_heights * ( + input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta + ) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta + ) + outputs = input_cumheights + numerator / denominator + + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - theta).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, logabsdet + + +def _searchsorted(bin_locations, inputs, eps=1e-6): + bin_locations[..., -1] += eps + return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 diff --git a/egs/ljspeech/tts/vits/tts_datamodule.py b/egs/ljspeech/tts/vits/tts_datamodule.py new file mode 100644 index 0000000000..bd67aa6b13 --- /dev/null +++ b/egs/ljspeech/tts/vits/tts_datamodule.py @@ -0,0 +1,306 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + SpeechSynthesisDataset, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LJSpeechTtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/spectrogram"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_tokens=False, + feature_input_strategy=eval(self.args.input_strategy)(), + ) + + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + train = SpeechSynthesisDataset( + return_tokens=False, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + validate = SpeechSynthesisDataset( + return_tokens=False, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + ) + else: + validate = SpeechSynthesisDataset( + return_tokens=False, + feature_input_strategy=eval(self.args.input_strategy)(), + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create valid dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + test = SpeechSynthesisDataset( + return_tokens=False, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + ) + else: + test = SpeechSynthesisDataset( + return_tokens=False, + feature_input_strategy=eval(self.args.input_strategy)(), + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz" + ) + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get validation cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz" + ) + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz" + ) diff --git a/egs/ljspeech/tts/vits/utils.py b/egs/ljspeech/tts/vits/utils.py new file mode 100644 index 0000000000..0020975816 --- /dev/null +++ b/egs/ljspeech/tts/vits/utils.py @@ -0,0 +1,470 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Function to get random segments.""" + +from typing import Any, Dict, List, Optional, Tuple, Union +import collections +import logging +import re +import warnings + +import numpy as np +import torch +import torch.nn as nn +import torch.distributed as dist +from lhotse.dataset.sampling.base import CutSampler +from pathlib import Path +from phonemizer import phonemize +from symbols import symbol_table +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils.rnn import pad_sequence +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter +from unidecode import unidecode + + +def get_random_segments( + x: torch.Tensor, + x_lengths: torch.Tensor, + segment_size: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Get random segments. + + Args: + x (Tensor): Input tensor (B, C, T). + x_lengths (Tensor): Length tensor (B,). + segment_size (int): Segment size. + + Returns: + Tensor: Segmented tensor (B, C, segment_size). + Tensor: Start index tensor (B,). + + """ + b, c, t = x.size() + max_start_idx = x_lengths - segment_size + max_start_idx[max_start_idx < 0] = 0 + start_idxs = (torch.rand([b]).to(x.device) * max_start_idx).to( + dtype=torch.long, + ) + segments = get_segments(x, start_idxs, segment_size) + + return segments, start_idxs + + +def get_segments( + x: torch.Tensor, + start_idxs: torch.Tensor, + segment_size: int, +) -> torch.Tensor: + """Get segments. + + Args: + x (Tensor): Input tensor (B, C, T). + start_idxs (Tensor): Start index tensor (B,). + segment_size (int): Segment size. + + Returns: + Tensor: Segmented tensor (B, C, segment_size). + + """ + b, c, t = x.size() + segments = x.new_zeros(b, c, segment_size) + for i, start_idx in enumerate(start_idxs): + segments[i] = x[i, :, start_idx : start_idx + segment_size] + return segments + + +# https://github.com/espnet/espnet/blob/master/espnet2/torch_utils/device_funcs.py +def force_gatherable(data, device): + """Change object to gatherable in torch.nn.DataParallel recursively + + The difference from to_device() is changing to torch.Tensor if float or int + value is found. + + The restriction to the returned value in DataParallel: + The object must be + - torch.cuda.Tensor + - 1 or more dimension. 0-dimension-tensor sends warning. + or a list, tuple, dict. + + """ + if isinstance(data, dict): + return {k: force_gatherable(v, device) for k, v in data.items()} + # DataParallel can't handle NamedTuple well + elif isinstance(data, tuple) and type(data) is not tuple: + return type(data)(*[force_gatherable(o, device) for o in data]) + elif isinstance(data, (list, tuple, set)): + return type(data)(force_gatherable(v, device) for v in data) + elif isinstance(data, np.ndarray): + return force_gatherable(torch.from_numpy(data), device) + elif isinstance(data, torch.Tensor): + if data.dim() == 0: + # To 1-dim array + data = data[None] + return data.to(device) + elif isinstance(data, float): + return torch.tensor([data], dtype=torch.float, device=device) + elif isinstance(data, int): + return torch.tensor([data], dtype=torch.long, device=device) + elif data is None: + return None + else: + warnings.warn(f"{type(data)} may not be gatherable by DataParallel") + return data + + +# The following codes are based on https://github.com/jaywalnut310/vits + +# Regular expression matching whitespace: +_whitespace_re = re.compile(r'\s+') + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ + ('mrs', 'misess'), + ('mr', 'mister'), + ('dr', 'doctor'), + ('st', 'saint'), + ('co', 'company'), + ('jr', 'junior'), + ('maj', 'major'), + ('gen', 'general'), + ('drs', 'doctors'), + ('rev', 'reverend'), + ('lt', 'lieutenant'), + ('hon', 'honorable'), + ('sgt', 'sergeant'), + ('capt', 'captain'), + ('esq', 'esquire'), + ('ltd', 'limited'), + ('col', 'colonel'), + ('ft', 'fort'), +]] + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, ' ', text) + + +def convert_to_ascii(text): + return unidecode(text) + + +def text_clean(text): + '''Pipeline for English text, including abbreviation expansion. + punctuation + stress. + + Returns: + A string of phonemes. + ''' + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_abbreviations(text) + phonemes = phonemize( + text, + language='en-us', + backend='espeak', + strip=True, + preserve_punctuation=True, + with_stress=True, + ) + phonemes = collapse_whitespace(phonemes) + return phonemes + + +# Mappings from symbol to numeric ID and vice versa: +symbol_to_id = {s: i for i, s in enumerate(symbol_table)} +id_to_symbol = {i: s for i, s in enumerate(symbol_table)} + + +# def text_to_sequence(text: str) -> List[int]: +# '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. +# ''' +# cleaned_text = text_clean(text) +# sequence = [symbol_to_id[symbol] for symbol in cleaned_text] +# return sequence +# +# +# def sequence_to_text(sequence: List[int]) -> str: +# '''Converts a sequence of IDs back to a string''' +# result = ''.join(id_to_symbol[symbol_id] for symbol_id in sequence) +# return result + + +def intersperse(sequence, item=0): + result = [item] * (len(sequence) * 2 + 1) + result[1::2] = sequence + return result + + +def prepare_token_batch( + texts: List[str], + intersperse_blank: bool = True, + blank_id: int = 0, + pad_id: int = 0, +) -> torch.Tensor: + """Convert a list of text strings into a batch of symbol tokens with padding. + Args: + texts: list of text strings + intersperse_blank: whether to intersperse blank tokens in the converted token sequence. + blank_id: index of blank token + pad_id: padding index + """ + # normalize text + normalized_texts = [] + for text in texts: + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_abbreviations(text) + normalized_texts.append(text) + + # convert to phonemes + phonemes = phonemize( + normalized_texts, + language='en-us', + backend='espeak', + strip=True, + preserve_punctuation=True, + with_stress=True, + ) + + # convert to symbol ids + lengths = [] + sequences = [] + for idx, sequence in enumerate(phonemes): + try: + sequence = [symbol_to_id[symbol] for symbol in collapse_whitespace(sequence)] + except RuntimeError: + print(text[idx]) + print(normalized_texts[idx]) + if intersperse_blank: + sequence = intersperse(sequence, blank_id) + sequences.append(torch.tensor(sequence, dtype=torch.int64)) + lengths.append(len(sequence)) + + sequences = pad_sequence(sequences, batch_first=True, padding_value=pad_id) + lengths = torch.tensor(lengths, dtype=torch.int64) + return sequences, lengths + + +class MetricsTracker(collections.defaultdict): + def __init__(self): + # Passing the type 'int' to the base-class constructor + # makes undefined items default to int() which is zero. + # This class will play a role as metrics tracker. + # It can record many metrics, including but not limited to loss. + super(MetricsTracker, self).__init__(int) + + def __add__(self, other: "MetricsTracker") -> "MetricsTracker": + ans = MetricsTracker() + for k, v in self.items(): + ans[k] = v + for k, v in other.items(): + ans[k] = ans[k] + v + return ans + + def __mul__(self, alpha: float) -> "MetricsTracker": + ans = MetricsTracker() + for k, v in self.items(): + ans[k] = v * alpha + return ans + + def __str__(self) -> str: + ans = "" + for k, v in self.norm_items(): + norm_value = "%.4g" % v + ans += str(k) + "=" + str(norm_value) + ", " + samples = "%.2f" % self["samples"] + ans += "over" + str(samples) + " samples." + return ans + + def norm_items(self) -> List[Tuple[str, float]]: + """ + Returns a list of pairs, like: + [('loss_1', 0.1), ('loss_2', 0.07)] + """ + samples = self["samples"] if "samples" in self else 1 + ans = [] + for k, v in self.items(): + if k == "samples": + continue + norm_value = float(v) / samples + ans.append((k, norm_value)) + return ans + + def reduce(self, device): + """ + Reduce using torch.distributed, which I believe ensures that + all processes get the total. + """ + keys = sorted(self.keys()) + s = torch.tensor([float(self[k]) for k in keys], device=device) + dist.all_reduce(s, op=dist.ReduceOp.SUM) + for k, v in zip(keys, s.cpu().tolist()): + self[k] = v + + def write_summary( + self, + tb_writer: SummaryWriter, + prefix: str, + batch_idx: int, + ) -> None: + """Add logging information to a TensorBoard writer. + + Args: + tb_writer: a TensorBoard writer + prefix: a prefix for the name of the loss, e.g. "train/valid_", + or "train/current_" + batch_idx: The current batch index, used as the x-axis of the plot. + """ + for k, v in self.norm_items(): + tb_writer.add_scalar(prefix + k, v, batch_idx) + + +# checkpoint saving and loading +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + + +def save_checkpoint( + filename: Path, + model: Union[nn.Module, DDP], + params: Optional[Dict[str, Any]] = None, + optimizer_g: Optional[Optimizer] = None, + optimizer_d: Optional[Optimizer] = None, + scheduler_g: Optional[LRSchedulerType] = None, + scheduler_d: Optional[LRSchedulerType] = None, + scaler: Optional[GradScaler] = None, + sampler: Optional[CutSampler] = None, + rank: int = 0, +) -> None: + """Save training information to a file. + + Args: + filename: + The checkpoint filename. + model: + The model to be saved. We only save its `state_dict()`. + model_avg: + The stored model averaged from the start of training. + params: + User defined parameters, e.g., epoch, loss. + optimizer_g: + The optimizer for generator used in the training. + Its `state_dict` will be saved. + optimizer_d: + The optimizer for discriminator used in the training. + Its `state_dict` will be saved. + scheduler_g: + The learning rate scheduler for generator used in the training. + Its `state_dict` will be saved. + scheduler_d: + The learning rate scheduler for discriminator used in the training. + Its `state_dict` will be saved. + scalar: + The GradScaler to be saved. We only save its `state_dict()`. + rank: + Used in DDP. We save checkpoint only for the node whose rank is 0. + Returns: + Return None. + """ + if rank != 0: + return + + logging.info(f"Saving checkpoint to {filename}") + + if isinstance(model, DDP): + model = model.module + + checkpoint = { + "model": model.state_dict(), + "optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None, + "optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None, + "scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None, + "scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None, + "grad_scaler": scaler.state_dict() if scaler is not None else None, + "sampler": sampler.state_dict() if sampler is not None else None, + } + + if params: + for k, v in params.items(): + assert k not in checkpoint + checkpoint[k] = v + + torch.save(checkpoint, filename) + + +def save_checkpoint_with_global_batch_idx( + out_dir: Path, + global_batch_idx: int, + model: Union[nn.Module, DDP], + params: Optional[Dict[str, Any]] = None, + optimizer_g: Optional[Optimizer] = None, + optimizer_d: Optional[Optimizer] = None, + scheduler_g: Optional[LRSchedulerType] = None, + scheduler_d: Optional[LRSchedulerType] = None, + scaler: Optional[GradScaler] = None, + sampler: Optional[CutSampler] = None, + rank: int = 0, +): + """Save training info after processing given number of batches. + + Args: + out_dir: + The directory to save the checkpoint. + global_batch_idx: + The number of batches processed so far from the very start of the + training. The saved checkpoint will have the following filename: + f'out_dir / checkpoint-{global_batch_idx}.pt' + model: + The neural network model whose `state_dict` will be saved in the + checkpoint. + params: + A dict of training configurations to be saved. + optimizer_g: + The optimizer for generator used in the training. + Its `state_dict` will be saved. + optimizer_d: + The optimizer for discriminator used in the training. + Its `state_dict` will be saved. + scheduler_g: + The learning rate scheduler for generator used in the training. + Its `state_dict` will be saved. + scheduler_d: + The learning rate scheduler for discriminator used in the training. + Its `state_dict` will be saved. + scaler: + The scaler used for mix precision training. Its `state_dict` will + be saved. + sampler: + The sampler used in the training dataset. + rank: + The rank ID used in DDP training of the current node. Set it to 0 + if DDP is not used. + """ + out_dir = Path(out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + filename = out_dir / f"checkpoint-{global_batch_idx}.pt" + save_checkpoint( + filename=filename, + model=model, + params=params, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + scaler=scaler, + sampler=sampler, + rank=rank, + ) diff --git a/egs/ljspeech/tts/vits/vits.py b/egs/ljspeech/tts/vits/vits.py new file mode 100644 index 0000000000..da9d144f22 --- /dev/null +++ b/egs/ljspeech/tts/vits/vits.py @@ -0,0 +1,567 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/vits.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""VITS module for GAN-TTS task.""" + +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn +from torch.cuda.amp import autocast + +from hifigan import ( + HiFiGANMultiPeriodDiscriminator, + HiFiGANMultiScaleDiscriminator, + HiFiGANMultiScaleMultiPeriodDiscriminator, + HiFiGANPeriodDiscriminator, + HiFiGANScaleDiscriminator, +) +from loss import ( + DiscriminatorAdversarialLoss, + FeatureMatchLoss, + GeneratorAdversarialLoss, + KLDivergenceLoss, + MelSpectrogramLoss, +) +from utils import get_segments +from generator import VITSGenerator + + +AVAILABLE_GENERATERS = { + "vits_generator": VITSGenerator, +} +AVAILABLE_DISCRIMINATORS = { + "hifigan_period_discriminator": HiFiGANPeriodDiscriminator, + "hifigan_scale_discriminator": HiFiGANScaleDiscriminator, + "hifigan_multi_period_discriminator": HiFiGANMultiPeriodDiscriminator, + "hifigan_multi_scale_discriminator": HiFiGANMultiScaleDiscriminator, + "hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA +} + + +class VITS(nn.Module): + """Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech` + """ + + def __init__( + self, + # generator related + vocab_size: int, + feature_dim: int = 513, + sampling_rate: int = 22050, + generator_type: str = "vits_generator", + generator_params: Dict[str, Any] = { + "hidden_channels": 192, + "spks": None, + "langs": None, + "spk_embed_dim": None, + "global_channels": -1, + "segment_size": 32, + "text_encoder_attention_heads": 2, + "text_encoder_ffn_expand": 4, + "text_encoder_blocks": 6, + "text_encoder_dropout_rate": 0.1, + "decoder_kernel_size": 7, + "decoder_channels": 512, + "decoder_upsample_scales": [8, 8, 2, 2], + "decoder_upsample_kernel_sizes": [16, 16, 4, 4], + "decoder_resblock_kernel_sizes": [3, 7, 11], + "decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "use_weight_norm_in_decoder": True, + "posterior_encoder_kernel_size": 5, + "posterior_encoder_layers": 16, + "posterior_encoder_stacks": 1, + "posterior_encoder_base_dilation": 1, + "posterior_encoder_dropout_rate": 0.0, + "use_weight_norm_in_posterior_encoder": True, + "flow_flows": 4, + "flow_kernel_size": 5, + "flow_base_dilation": 1, + "flow_layers": 4, + "flow_dropout_rate": 0.0, + "use_weight_norm_in_flow": True, + "use_only_mean_in_flow": True, + "stochastic_duration_predictor_kernel_size": 3, + "stochastic_duration_predictor_dropout_rate": 0.5, + "stochastic_duration_predictor_flows": 4, + "stochastic_duration_predictor_dds_conv_layers": 3, + }, + # discriminator related + discriminator_type: str = "hifigan_multi_scale_multi_period_discriminator", + discriminator_params: Dict[str, Any] = { + "scales": 1, + "scale_downsample_pooling": "AvgPool1d", + "scale_downsample_pooling_params": { + "kernel_size": 4, + "stride": 2, + "padding": 2, + }, + "scale_discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 128, + "max_downsample_channels": 1024, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 2, 4, 4, 1], + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + "follow_official_norm": False, + "periods": [2, 3, 5, 7, 11], + "period_discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 32, + "downsample_scales": [3, 3, 3, 3, 1], + "max_downsample_channels": 1024, + "bias": True, + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": {"negative_slope": 0.1}, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + }, + # loss related + generator_adv_loss_params: Dict[str, Any] = { + "average_by_discriminators": False, + "loss_type": "mse", + }, + discriminator_adv_loss_params: Dict[str, Any] = { + "average_by_discriminators": False, + "loss_type": "mse", + }, + feat_match_loss_params: Dict[str, Any] = { + "average_by_discriminators": False, + "average_by_layers": False, + "include_final_outputs": True, + }, + mel_loss_params: Dict[str, Any] = { + "frame_shift": 256, + "frame_length": 1024, + "n_mels": 80, + }, + lambda_adv: float = 1.0, + lambda_mel: float = 45.0, + lambda_feat_match: float = 2.0, + lambda_dur: float = 1.0, + lambda_kl: float = 1.0, + cache_generator_outputs: bool = True, + ): + """Initialize VITS module. + + Args: + idim (int): Input vocabrary size. + odim (int): Acoustic feature dimension. The actual output channels will + be 1 since VITS is the end-to-end text-to-wave model but for the + compatibility odim is used to indicate the acoustic feature dimension. + sampling_rate (int): Sampling rate, not used for the training but it will + be referred in saving waveform during the inference. + generator_type (str): Generator type. + generator_params (Dict[str, Any]): Parameter dict for generator. + discriminator_type (str): Discriminator type. + discriminator_params (Dict[str, Any]): Parameter dict for discriminator. + generator_adv_loss_params (Dict[str, Any]): Parameter dict for generator + adversarial loss. + discriminator_adv_loss_params (Dict[str, Any]): Parameter dict for + discriminator adversarial loss. + feat_match_loss_params (Dict[str, Any]): Parameter dict for feat match loss. + mel_loss_params (Dict[str, Any]): Parameter dict for mel loss. + lambda_adv (float): Loss scaling coefficient for adversarial loss. + lambda_mel (float): Loss scaling coefficient for mel spectrogram loss. + lambda_feat_match (float): Loss scaling coefficient for feat match loss. + lambda_dur (float): Loss scaling coefficient for duration loss. + lambda_kl (float): Loss scaling coefficient for KL divergence loss. + cache_generator_outputs (bool): Whether to cache generator outputs. + + """ + super().__init__() + + # define modules + generator_class = AVAILABLE_GENERATERS[generator_type] + if generator_type == "vits_generator": + # NOTE(kan-bayashi): Update parameters for the compatibility. + # The idim and odim is automatically decided from input data, + # where idim represents #vocabularies and odim represents + # the input acoustic feature dimension. + generator_params.update(vocabs=vocab_size, aux_channels=feature_dim) + self.generator = generator_class( + **generator_params, + ) + discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type] + self.discriminator = discriminator_class( + **discriminator_params, + ) + self.generator_adv_loss = GeneratorAdversarialLoss( + **generator_adv_loss_params, + ) + self.discriminator_adv_loss = DiscriminatorAdversarialLoss( + **discriminator_adv_loss_params, + ) + self.feat_match_loss = FeatureMatchLoss( + **feat_match_loss_params, + ) + mel_loss_params.update(sampling_rate=sampling_rate) + self.mel_loss = MelSpectrogramLoss( + **mel_loss_params, + ) + self.kl_loss = KLDivergenceLoss() + + # coefficients + self.lambda_adv = lambda_adv + self.lambda_mel = lambda_mel + self.lambda_kl = lambda_kl + self.lambda_feat_match = lambda_feat_match + self.lambda_dur = lambda_dur + + # cache + self.cache_generator_outputs = cache_generator_outputs + self._cache = None + + # store sampling rate for saving wav file + # (not used for the training) + self.sampling_rate = sampling_rate + + # store parameters for test compatibility + self.spks = self.generator.spks + self.langs = self.generator.langs + self.spk_embed_dim = self.generator.spk_embed_dim + + def forward( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: torch.Tensor, + feats_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + forward_generator: bool = True, + ) -> Dict[str, Any]: + """Perform generator forward. + + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, T_feats, aux_channels). + feats_lengths (Tensor): Feature length tensor (B,). + speech (Tensor): Speech waveform tensor (B, T_wav). + speech_lengths (Tensor): Speech length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + forward_generator (bool): Whether to forward generator. + + Returns: + Dict[str, Any]: + - loss (Tensor): Loss scalar tensor. + - stats (Dict[str, float]): Statistics to be monitored. + - weight (Tensor): Weight tensor to summarize losses. + - optim_idx (int): Optimizer index (0 for G and 1 for D). + + """ + if forward_generator: + return self._forward_generator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + speech=speech, + speech_lengths=speech_lengths, + sids=sids, + spembs=spembs, + lids=lids, + ) + else: + return self._forward_discrminator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + speech=speech, + speech_lengths=speech_lengths, + sids=sids, + spembs=spembs, + lids=lids, + ) + + def _forward_generator( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: torch.Tensor, + feats_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + ) -> Dict[str, Any]: + """Perform generator forward. + + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, T_feats, aux_channels). + feats_lengths (Tensor): Feature length tensor (B,). + speech (Tensor): Speech waveform tensor (B, T_wav). + speech_lengths (Tensor): Speech length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + + Returns: + Dict[str, Any]: + * loss (Tensor): Loss scalar tensor. + * stats (Dict[str, float]): Statistics to be monitored. + * weight (Tensor): Weight tensor to summarize losses. + * optim_idx (int): Optimizer index (0 for G and 1 for D). + + """ + # setup + feats = feats.transpose(1, 2) + speech = speech.unsqueeze(1) + + # calculate generator outputs + reuse_cache = True + if not self.cache_generator_outputs or self._cache is None: + reuse_cache = False + outs = self.generator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + sids=sids, + spembs=spembs, + lids=lids, + ) + else: + outs = self._cache + + # store cache + if self.training and self.cache_generator_outputs and not reuse_cache: + self._cache = outs + + # parse outputs + speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs + _, z_p, m_p, logs_p, _, logs_q = outs_ + speech_ = get_segments( + x=speech, + start_idxs=start_idxs * self.generator.upsample_factor, + segment_size=self.generator.segment_size * self.generator.upsample_factor, + ) + + # calculate discriminator outputs + p_hat = self.discriminator(speech_hat_) + with torch.no_grad(): + # do not store discriminator gradient in generator turn + p = self.discriminator(speech_) + + # calculate losses + with autocast(enabled=False): + mel_loss = self.mel_loss(speech_hat_, speech_) + kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask) + dur_loss = torch.sum(dur_nll.float()) + adv_loss = self.generator_adv_loss(p_hat) + feat_match_loss = self.feat_match_loss(p_hat, p) + + mel_loss = mel_loss * self.lambda_mel + kl_loss = kl_loss * self.lambda_kl + dur_loss = dur_loss * self.lambda_dur + adv_loss = adv_loss * self.lambda_adv + feat_match_loss = feat_match_loss * self.lambda_feat_match + loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss + + stats = dict( + generator_loss=loss.item(), + generator_mel_loss=mel_loss.item(), + generator_kl_loss=kl_loss.item(), + generator_dur_loss=dur_loss.item(), + generator_adv_loss=adv_loss.item(), + generator_feat_match_loss=feat_match_loss.item(), + ) + + # reset cache + if reuse_cache or not self.training: + self._cache = None + + return loss, stats + + def _forward_discrminator( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + feats: torch.Tensor, + feats_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + ) -> Dict[str, Any]: + """Perform discriminator forward. + + Args: + text (Tensor): Text index tensor (B, T_text). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, T_feats, aux_channels). + feats_lengths (Tensor): Feature length tensor (B,). + speech (Tensor): Speech waveform tensor (B, T_wav). + speech_lengths (Tensor): Speech length tensor (B,). + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + + Returns: + Dict[str, Any]: + * loss (Tensor): Loss scalar tensor. + * stats (Dict[str, float]): Statistics to be monitored. + * weight (Tensor): Weight tensor to summarize losses. + * optim_idx (int): Optimizer index (0 for G and 1 for D). + + """ + # setup + feats = feats.transpose(1, 2) + speech = speech.unsqueeze(1) + + # calculate generator outputs + reuse_cache = True + if not self.cache_generator_outputs or self._cache is None: + reuse_cache = False + outs = self.generator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + sids=sids, + spembs=spembs, + lids=lids, + ) + else: + outs = self._cache + + # store cache + if self.cache_generator_outputs and not reuse_cache: + self._cache = outs + + # parse outputs + speech_hat_, _, _, start_idxs, *_ = outs + speech_ = get_segments( + x=speech, + start_idxs=start_idxs * self.generator.upsample_factor, + segment_size=self.generator.segment_size * self.generator.upsample_factor, + ) + + # calculate discriminator outputs + p_hat = self.discriminator(speech_hat_.detach()) + p = self.discriminator(speech_) + + # calculate losses + with autocast(enabled=False): + real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p) + loss = real_loss + fake_loss + + stats = dict( + discriminator_loss=loss.item(), + discriminator_real_loss=real_loss.item(), + discriminator_fake_loss=fake_loss.item(), + ) + + # reset cache + if reuse_cache or not self.training: + self._cache = None + + return loss, stats + + def inference( + self, + text: torch.Tensor, + feats: Optional[torch.Tensor] = None, + sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, + durations: Optional[torch.Tensor] = None, + noise_scale: float = 0.667, + noise_scale_dur: float = 0.8, + alpha: float = 1.0, + max_len: Optional[int] = None, + use_teacher_forcing: bool = False, + ) -> Dict[str, torch.Tensor]: + """Run inference. + + Args: + text (Tensor): Input text index tensor (T_text,). + feats (Tensor): Feature tensor (T_feats, aux_channels). + sids (Tensor): Speaker index tensor (1,). + spembs (Optional[Tensor]): Speaker embedding tensor (spk_embed_dim,). + lids (Tensor): Language index tensor (1,). + durations (Tensor): Ground-truth duration tensor (T_text,). + noise_scale (float): Noise scale value for flow. + noise_scale_dur (float): Noise scale value for duration predictor. + alpha (float): Alpha parameter to control the speed of generated speech. + max_len (Optional[int]): Maximum length. + use_teacher_forcing (bool): Whether to use teacher forcing. + + Returns: + Dict[str, Tensor]: + * wav (Tensor): Generated waveform tensor (T_wav,). + * att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text). + * duration (Tensor): Predicted duration tensor (T_text,). + + """ + # setup + text = text[None] + text_lengths = torch.tensor( + [text.size(1)], + dtype=torch.long, + device=text.device, + ) + if sids is not None: + sids = sids.view(1) + if lids is not None: + lids = lids.view(1) + if durations is not None: + durations = durations.view(1, 1, -1) + + # inference + if use_teacher_forcing: + assert feats is not None + feats = feats[None].transpose(1, 2) + feats_lengths = torch.tensor( + [feats.size(2)], + dtype=torch.long, + device=feats.device, + ) + wav, att_w, dur = self.generator.inference( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + sids=sids, + spembs=spembs, + lids=lids, + max_len=max_len, + use_teacher_forcing=use_teacher_forcing, + ) + else: + wav, att_w, dur = self.generator.inference( + text=text, + text_lengths=text_lengths, + sids=sids, + spembs=spembs, + lids=lids, + dur=durations, + noise_scale=noise_scale, + noise_scale_dur=noise_scale_dur, + alpha=alpha, + max_len=max_len, + ) + return dict(wav=wav.view(-1), att_w=att_w[0], duration=dur[0]) diff --git a/egs/ljspeech/tts/vits/wavenet.py b/egs/ljspeech/tts/vits/wavenet.py new file mode 100644 index 0000000000..cbb44a8f40 --- /dev/null +++ b/egs/ljspeech/tts/vits/wavenet.py @@ -0,0 +1,349 @@ +# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/wavenet/wavenet.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""WaveNet modules. + +This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. + +""" + +import math +import logging + +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F + + +class WaveNet(torch.nn.Module): + """WaveNet with global conditioning.""" + + def __init__( + self, + in_channels: int = 1, + out_channels: int = 1, + kernel_size: int = 3, + layers: int = 30, + stacks: int = 3, + base_dilation: int = 2, + residual_channels: int = 64, + aux_channels: int = -1, + gate_channels: int = 128, + skip_channels: int = 64, + global_channels: int = -1, + dropout_rate: float = 0.0, + bias: bool = True, + use_weight_norm: bool = True, + use_first_conv: bool = False, + use_last_conv: bool = False, + scale_residual: bool = False, + scale_skip_connect: bool = False, + ): + """Initialize WaveNet module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_size (int): Kernel size of dilated convolution. + layers (int): Number of residual block layers. + stacks (int): Number of stacks i.e., dilation cycles. + base_dilation (int): Base dilation factor. + residual_channels (int): Number of channels in residual conv. + gate_channels (int): Number of channels in gated conv. + skip_channels (int): Number of channels in skip conv. + aux_channels (int): Number of channels for local conditioning feature. + global_channels (int): Number of channels for global conditioning feature. + dropout_rate (float): Dropout rate. 0.0 means no dropout applied. + bias (bool): Whether to use bias parameter in conv layer. + use_weight_norm (bool): Whether to use weight norm. If set to true, it will + be applied to all of the conv layers. + use_first_conv (bool): Whether to use the first conv layers. + use_last_conv (bool): Whether to use the last conv layers. + scale_residual (bool): Whether to scale the residual outputs. + scale_skip_connect (bool): Whether to scale the skip connection outputs. + + """ + super().__init__() + self.layers = layers + self.stacks = stacks + self.kernel_size = kernel_size + self.base_dilation = base_dilation + self.use_first_conv = use_first_conv + self.use_last_conv = use_last_conv + self.scale_skip_connect = scale_skip_connect + + # check the number of layers and stacks + assert layers % stacks == 0 + layers_per_stack = layers // stacks + + # define first convolution + if self.use_first_conv: + self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True) + + # define residual blocks + self.conv_layers = torch.nn.ModuleList() + for layer in range(layers): + dilation = base_dilation ** (layer % layers_per_stack) + conv = ResidualBlock( + kernel_size=kernel_size, + residual_channels=residual_channels, + gate_channels=gate_channels, + skip_channels=skip_channels, + aux_channels=aux_channels, + global_channels=global_channels, + dilation=dilation, + dropout_rate=dropout_rate, + bias=bias, + scale_residual=scale_residual, + ) + self.conv_layers += [conv] + + # define output layers + if self.use_last_conv: + self.last_conv = torch.nn.Sequential( + torch.nn.ReLU(inplace=True), + Conv1d1x1(skip_channels, skip_channels, bias=True), + torch.nn.ReLU(inplace=True), + Conv1d1x1(skip_channels, out_channels, bias=True), + ) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + def forward( + self, + x: torch.Tensor, + x_mask: Optional[torch.Tensor] = None, + c: Optional[torch.Tensor] = None, + g: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T) if use_first_conv else + (B, residual_channels, T). + x_mask (Optional[Tensor]): Mask tensor (B, 1, T). + c (Optional[Tensor]): Local conditioning features (B, aux_channels, T). + g (Optional[Tensor]): Global conditioning features (B, global_channels, 1). + + Returns: + Tensor: Output tensor (B, out_channels, T) if use_last_conv else + (B, residual_channels, T). + + """ + # encode to hidden representation + if self.use_first_conv: + x = self.first_conv(x) + + # residual block + skips = 0.0 + for f in self.conv_layers: + x, h = f(x, x_mask=x_mask, c=c, g=g) + skips = skips + h + x = skips + if self.scale_skip_connect: + x = x * math.sqrt(1.0 / len(self.conv_layers)) + + # apply final layers + if self.use_last_conv: + x = self.last_conv(x) + + return x + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m: torch.nn.Module): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m: torch.nn.Module): + if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + @staticmethod + def _get_receptive_field_size( + layers: int, + stacks: int, + kernel_size: int, + base_dilation: int, + ) -> int: + assert layers % stacks == 0 + layers_per_cycle = layers // stacks + dilations = [base_dilation ** (i % layers_per_cycle) for i in range(layers)] + return (kernel_size - 1) * sum(dilations) + 1 + + @property + def receptive_field_size(self) -> int: + """Return receptive field size.""" + return self._get_receptive_field_size( + self.layers, self.stacks, self.kernel_size, self.base_dilation + ) + + +class Conv1d(torch.nn.Conv1d): + """Conv1d module with customized initialization.""" + + def __init__(self, *args, **kwargs): + """Initialize Conv1d module.""" + super().__init__(*args, **kwargs) + + def reset_parameters(self): + """Reset parameters.""" + torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu") + if self.bias is not None: + torch.nn.init.constant_(self.bias, 0.0) + + +class Conv1d1x1(Conv1d): + """1x1 Conv1d with customized initialization.""" + + def __init__(self, in_channels: int, out_channels: int, bias: bool): + """Initialize 1x1 Conv1d module.""" + super().__init__( + in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias + ) + + +class ResidualBlock(torch.nn.Module): + """Residual block module in WaveNet.""" + + def __init__( + self, + kernel_size: int = 3, + residual_channels: int = 64, + gate_channels: int = 128, + skip_channels: int = 64, + aux_channels: int = 80, + global_channels: int = -1, + dropout_rate: float = 0.0, + dilation: int = 1, + bias: bool = True, + scale_residual: bool = False, + ): + """Initialize ResidualBlock module. + + Args: + kernel_size (int): Kernel size of dilation convolution layer. + residual_channels (int): Number of channels for residual connection. + skip_channels (int): Number of channels for skip connection. + aux_channels (int): Number of local conditioning channels. + dropout (float): Dropout probability. + dilation (int): Dilation factor. + bias (bool): Whether to add bias parameter in convolution layers. + scale_residual (bool): Whether to scale the residual outputs. + + """ + super().__init__() + self.dropout_rate = dropout_rate + self.residual_channels = residual_channels + self.skip_channels = skip_channels + self.scale_residual = scale_residual + + # check + assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." + assert gate_channels % 2 == 0 + + # dilation conv + padding = (kernel_size - 1) // 2 * dilation + self.conv = Conv1d( + residual_channels, + gate_channels, + kernel_size, + padding=padding, + dilation=dilation, + bias=bias, + ) + + # local conditioning + if aux_channels > 0: + self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False) + else: + self.conv1x1_aux = None + + # global conditioning + if global_channels > 0: + self.conv1x1_glo = Conv1d1x1(global_channels, gate_channels, bias=False) + else: + self.conv1x1_glo = None + + # conv output is split into two groups + gate_out_channels = gate_channels // 2 + + # NOTE(kan-bayashi): concat two convs into a single conv for the efficiency + # (integrate res 1x1 + skip 1x1 convs) + self.conv1x1_out = Conv1d1x1( + gate_out_channels, residual_channels + skip_channels, bias=bias + ) + + def forward( + self, + x: torch.Tensor, + x_mask: Optional[torch.Tensor] = None, + c: Optional[torch.Tensor] = None, + g: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, residual_channels, T). + x_mask Optional[torch.Tensor]: Mask tensor (B, 1, T). + c (Optional[Tensor]): Local conditioning tensor (B, aux_channels, T). + g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1). + + Returns: + Tensor: Output tensor for residual connection (B, residual_channels, T). + Tensor: Output tensor for skip connection (B, skip_channels, T). + + """ + residual = x + x = F.dropout(x, p=self.dropout_rate, training=self.training) + x = self.conv(x) + + # split into two part for gated activation + splitdim = 1 + xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim) + + # local conditioning + if c is not None: + c = self.conv1x1_aux(c) + ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) + xa, xb = xa + ca, xb + cb + + # global conditioning + if g is not None: + g = self.conv1x1_glo(g) + ga, gb = g.split(g.size(splitdim) // 2, dim=splitdim) + xa, xb = xa + ga, xb + gb + + x = torch.tanh(xa) * torch.sigmoid(xb) + + # residual + skip 1x1 conv + x = self.conv1x1_out(x) + if x_mask is not None: + x = x * x_mask + + # split integrated conv results + x, s = x.split([self.residual_channels, self.skip_channels], dim=1) + + # for residual connection + x = x + residual + if self.scale_residual: + x = x * math.sqrt(0.5) + + return x, s From b719581e2f6b8d0cc549617424e7e42a2651fee6 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sat, 28 Oct 2023 21:16:43 +0800 Subject: [PATCH 02/16] replace phonimizer with g2p --- egs/ljspeech/tts/local/prepare_token_file.py | 116 ++++++ egs/ljspeech/tts/local/split_subsets.py | 3 +- egs/ljspeech/tts/prepare.sh | 45 ++- egs/ljspeech/tts/vits/generator.py | 4 +- egs/ljspeech/tts/vits/infer.py | 366 +++++++++++++++++++ egs/ljspeech/tts/vits/loss.py | 6 +- egs/ljspeech/tts/vits/tokenizer.py | 80 ++++ egs/ljspeech/tts/vits/train.py | 273 +++++++++----- egs/ljspeech/tts/vits/tts_datamodule.py | 15 +- egs/ljspeech/tts/vits/utils.py | 96 +++-- egs/ljspeech/tts/vits/vits.py | 59 ++- 11 files changed, 936 insertions(+), 127 deletions(-) create mode 100755 egs/ljspeech/tts/local/prepare_token_file.py create mode 100755 egs/ljspeech/tts/vits/infer.py create mode 100644 egs/ljspeech/tts/vits/tokenizer.py diff --git a/egs/ljspeech/tts/local/prepare_token_file.py b/egs/ljspeech/tts/local/prepare_token_file.py new file mode 100755 index 0000000000..17a5588992 --- /dev/null +++ b/egs/ljspeech/tts/local/prepare_token_file.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file reads the texts in given manifest and generate the file that maps tokens to IDs. +""" + +import argparse +import logging +from collections import Counter +from pathlib import Path +from typing import Dict + +import g2p_en +import tacotron_cleaner.cleaners +from lhotse import load_manifest + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--manifest-file", + type=Path, + default=Path("data/spectrogram/ljspeech_cuts_train.jsonl.gz"), + help="Path to the manifest file", + ) + + parser.add_argument( + "--tokens", + type=Path, + default=Path("data/tokens.txt"), + help="Path to the tokens", + ) + + return parser.parse_args() + + +def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: + """Write a symbol to ID mapping to a file. + + Note: + No need to implement `read_mapping` as it can be done + through :func:`k2.SymbolTable.from_file`. + + Args: + filename: + Filename to save the mapping. + sym2id: + A dict mapping symbols to IDs. + Returns: + Return None. + """ + with open(filename, "w", encoding="utf-8") as f: + for sym, i in sym2id.items(): + f.write(f"{sym} {i}\n") + + +def get_token2id(manifest_file: Path) -> Dict[str, int]: + """Return a dict that maps token to IDs.""" + extra_tokens = { + "": 0, # blank + "": 1, # sos and eos symbols. + "": 2, # OOV + } + cut_set = load_manifest(manifest_file) + g2p = g2p_en.G2p() + counter = Counter() + + for cut in cut_set: + # Each cut only contain one supervision + assert len(cut.supervisions) == 1, len(cut.supervisions) + text = cut.supervisions[0].normalized_text + # Text normalization + text = tacotron_cleaner.cleaners.custom_english_cleaners(text) + # Convert to phonemes + tokens = g2p(text) + for t in tokens: + counter[t] += 1 + + # Sort by the number of occurrences in descending order + tokens_and_counts = sorted(counter.items(), key=lambda x: -x[1]) + + for token, idx in extra_tokens.items(): + tokens_and_counts.insert(idx, (token, None)) + + token2id: Dict[str, int] = {token: i for i, (token, count) in enumerate(tokens_and_counts)} + return token2id + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_args() + manifest_file = Path(args.manifest_file) + out_file = Path(args.tokens) + + token2id = get_token2id(manifest_file) + write_mapping(out_file, token2id) diff --git a/egs/ljspeech/tts/local/split_subsets.py b/egs/ljspeech/tts/local/split_subsets.py index 328cdd6910..b2afca9712 100755 --- a/egs/ljspeech/tts/local/split_subsets.py +++ b/egs/ljspeech/tts/local/split_subsets.py @@ -52,7 +52,8 @@ def main(): manifest_dir = Path(args.manifest_dir) prefix = "ljspeech" suffix = "jsonl.gz" - all_cuts = load_manifest_lazy(manifest_dir / f"{prefix}_cuts_all.{suffix}") + # all_cuts = load_manifest_lazy(manifest_dir / f"{prefix}_cuts_all.{suffix}") + all_cuts = load_manifest_lazy(manifest_dir / f"{prefix}_cuts_all_phonemized.{suffix}") cut_ids = list(all_cuts.ids) random.shuffle(cut_ids) diff --git a/egs/ljspeech/tts/prepare.sh b/egs/ljspeech/tts/prepare.sh index f78964c347..4f4685951a 100755 --- a/egs/ljspeech/tts/prepare.sh +++ b/egs/ljspeech/tts/prepare.sh @@ -66,11 +66,50 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi fi +# if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then +# log "Stage 3: Phonemize the transcripts for LJSpeech" +# if [ ! -e data/spectrogram/.ljspeech_phonemized.done ]; then +# ./local/phonemize_text.py data/spectrogram +# touch data/spectrogram/.ljspeech_phonemized.done +# fi +# fi + +# if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then +# log "Stage 4: Split the LJSpeech cuts into three sets" +# if [ ! -e data/spectrogram/.ljspeech_split.done ]; then +# ./local/split_subsets.py data/spectrogram +# touch data/spectrogram/.ljspeech_split.done +# fi +# fi + if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Split the LJSpeech cuts into three sets" + log "Stage 3: Split the LJSpeech cuts into train, valid and test sets" if [ ! -e data/spectrogram/.ljspeech_split.done ]; then - ./local/split_subsets.py data/spectrogram - touch data/spectrogram/.ljspeech_split.done + lhotse subset --last 600 \ + data/spectrogram/ljspeech_cuts_all.jsonl.gz \ + data/spectrogram/ljspeech_cuts_validtest.jsonl.gz + lhotse subset --first 100 \ + data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \ + data/spectrogram/ljspeech_cuts_valid.jsonl.gz + lhotse subset --last 500 \ + data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \ + data/spectrogram/ljspeech_cuts_test.jsonl.gz + rm data/spectrogram/ljspeech_cuts_validtest.jsonl.gz + + n=$(( $(gunzip -c data/spectrogram/ljspeech_cuts_all.jsonl.gz | wc -l) - 600 )) + lhotse subset --first $n \ + data/spectrogram/ljspeech_cuts_all.jsonl.gz \ + data/spectrogram/ljspeech_cuts_train.jsonl.gz + touch data/spectrogram/.ljspeech_split.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Generate token file" + if [ ! -e data/tokens.txt ]; then + ./local/prepare_token_file.py \ + --manifest-file data/spectrogram/ljspeech_cuts_train.jsonl.gz \ + --tokens data/tokens.txt fi fi diff --git a/egs/ljspeech/tts/vits/generator.py b/egs/ljspeech/tts/vits/generator.py index dbf503944f..a74440c958 100644 --- a/egs/ljspeech/tts/vits/generator.py +++ b/egs/ljspeech/tts/vits/generator.py @@ -515,10 +515,12 @@ def _generate_path(self, dur: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: cum_dur_flat = cum_dur.view(b * t_x) path = torch.arange(t_y, dtype=dur.dtype, device=dur.device) path = path.unsqueeze(0) < cum_dur_flat.unsqueeze(1) - path = path.view(b, t_x, t_y).to(dtype=mask.dtype) + # path = path.view(b, t_x, t_y).to(dtype=mask.dtype) + path = path.view(b, t_x, t_y).to(dtype=torch.float) # path will be like (t_x = 3, t_y = 5): # [[[1., 1., 0., 0., 0.], [[[1., 1., 0., 0., 0.], # [1., 1., 1., 1., 0.], --> [0., 0., 1., 1., 0.], # [1., 1., 1., 1., 1.]]] [0., 0., 0., 0., 1.]]] path = path - F.pad(path, [0, 0, 1, 0, 0, 0])[:, :-1] + # path = path.to(dtype=mask.dtype) return path.unsqueeze(1).transpose(2, 3) * mask diff --git a/egs/ljspeech/tts/vits/infer.py b/egs/ljspeech/tts/vits/infer.py new file mode 100755 index 0000000000..89fc729626 --- /dev/null +++ b/egs/ljspeech/tts/vits/infer.py @@ -0,0 +1,366 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torchaudio + +from train2 import get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) +from tts_datamodule import LJSpeechTtsDataModule +from utils import prepare_token_batch + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + return parser + + +def infer_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + # Background worker save audios to disk. + def _save_worker( + batch_size: int, + cut_ids: List[str], + audio: torch.Tensor, + audio_pred: torch.Tensor, + audio_lens: List[int], + audio_lens_pred: List[int], + ): + for i in range(batch_size): + torchaudio.save( + str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"), + audio[i:i + 1, :audio_lens[i]], + sample_rate=params.sampling_rate, + ) + torchaudio.save( + str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"), + audio_pred[i:i + 1, :audio_lens_pred[i]], + sample_rate=params.sampling_rate, + ) + + device = next(model.parameters()).device + num_cuts = 0 + log_interval = 10 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + futures = [] + with ThreadPoolExecutor(max_workers=1) as executor: + # We only want one background worker so that serialization is deterministic. + for batch_idx, batch in enumerate(dl): + batch_size = len(batch["text"]) + text = batch["text"] + tokens, tokens_lens = prepare_token_batch(text) + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + + audio = batch["audio"] + audio_lens = batch["audio_lens"].tolist() + cut_ids = [cut.id for cut in batch["cut"]] + + audio_pred, _, durations = model.inference_batch(text=tokens, text_lengths=tokens_lens) + audio_pred = audio_pred.detach().cpu() + # convert to samples + audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist() + + # import pdb + # pdb.set_trace() + + futures.append( + executor.submit( + _save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred + ) + ) + + num_cuts += batch_size + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + # return results + for f in futures: + f.result() + + +@torch.no_grad() +def main(): + parser = get_parser() + LJSpeechTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + params.res_dir = params.exp_dir / "infer" / params.suffix + params.save_wav_dir = params.res_dir / "wav" + params.save_wav_dir.mkdir(parents=True, exist_ok=True) + + setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") + logging.info("Infer started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + ljspeech = LJSpeechTtsDataModule(args) + + test_cuts = ljspeech.test_cuts() + test_dl = ljspeech.test_dataloaders(test_cuts) + + infer_dataset( + dl=test_dl, + params=params, + model=model, + ) + + # save_results( + # params=params, + # test_set_name=test_set, + # results_dict=results_dict, + # ) + + logging.info("Done!") + + +# torch.set_num_threads(1) +# torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/ljspeech/tts/vits/loss.py b/egs/ljspeech/tts/vits/loss.py index d322f5e053..0d27af6435 100644 --- a/egs/ljspeech/tts/vits/loss.py +++ b/egs/ljspeech/tts/vits/loss.py @@ -241,7 +241,8 @@ def forward( self, y_hat: torch.Tensor, y: torch.Tensor, - ) -> torch.Tensor: + return_mel: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]: """Calculate Mel-spectrogram loss. Args: @@ -259,6 +260,9 @@ def forward( mel = self.wav_to_mel(y.squeeze(1)) mel_loss = F.l1_loss(mel_hat, mel) + if return_mel: + return mel_loss, (mel_hat, mel) + return mel_loss diff --git a/egs/ljspeech/tts/vits/tokenizer.py b/egs/ljspeech/tts/vits/tokenizer.py new file mode 100644 index 0000000000..5a513a0d98 --- /dev/null +++ b/egs/ljspeech/tts/vits/tokenizer.py @@ -0,0 +1,80 @@ +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List + +import g2p_en +import tacotron_cleaner.cleaners + +from utils import intersperse + + +class Tokenizer(object): + def __init__(self, tokens: str): + """ + Args: + tokens: the file that maps tokens to ids + """ + # Parse token file + self.token2id: Dict[str, int] = {} + with open(tokens, "r", encoding="utf-8") as f: + for line in f.readlines(): + info = line.rstrip().split() + if len(info) == 1: + # case of space + token = " " + id = int(info[0]) + else: + token, id = info[0], int(info[1]) + self.token2id[token] = id + + self.blank_id = self.token2id[""] + self.oov_id = self.token2id[""] + self.vocab_size = len(self.token2id) + + self.g2p = g2p_en.G2p() + + def texts_to_token_ids(self, texts: List[str], intersperse_blank: bool = True): + """ + Args: + texts: + A list of transcripts. + intersperse_blank: + Whether to intersperse blanks in the token sequence. + + Returns: + Return a list of token id list [utterance][token_id] + """ + token_ids_list = [] + + for text in texts: + # Text normalization + text = tacotron_cleaner.cleaners.custom_english_cleaners(text) + # Convert to phonemes + tokens = self.g2p(text) + token_ids = [] + for t in tokens: + if t in self.token2id: + token_ids.append(self.token2id[t]) + else: + token_ids.append(self.oov_id) + + if intersperse_blank: + token_ids = intersperse(token_ids, self.blank_id) + + token_ids_list.append(token_ids) + + return token_ids_list diff --git a/egs/ljspeech/tts/vits/train.py b/egs/ljspeech/tts/vits/train.py index 8fd2a596ab..01cd6137ef 100755 --- a/egs/ljspeech/tts/vits/train.py +++ b/egs/ljspeech/tts/vits/train.py @@ -1,10 +1,32 @@ #!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import argparse import logging from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Union +import k2 import torch import torch.multiprocessing as mp import torch.nn as nn @@ -27,10 +49,10 @@ str2bool, ) -from symbols import symbol_table +from tokenizer import Tokenizer from utils import ( MetricsTracker, - prepare_token_batch, + plot_feature, save_checkpoint, save_checkpoint_with_global_batch_idx, ) @@ -101,6 +123,13 @@ def get_parser(): """, ) + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to tokens.txt.""", + ) + parser.add_argument( "--lr", type=float, default=2.0e-4, help="The base learning rate." ) @@ -213,16 +242,16 @@ def get_params() -> AttributeDict: "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": -1, # 0 - "log_interval": 50, + "log_interval": 10, + "draw_interval": 500, # "reset_interval": 200, - "valid_interval": 500, + "valid_interval": 200, "env_info": get_env_info(), "sampling_rate": 22050, + "frame_shift": 256, + "frame_length": 1024, "feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length - "vocab_size": len(symbol_table), "mel_loss_params": { - "frame_shift": 256, - "frame_length": 1024, "n_mels": 80, }, "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss @@ -287,11 +316,16 @@ def load_checkpoint_if_available( def get_model(params: AttributeDict) -> nn.Module: + mel_loss_params = params.mel_loss_params + mel_loss_params.update( + frame_length=params.frame_length, + frame_shift=params.frame_shift, + ) model = VITS( vocab_size=params.vocab_size, feature_dim=params.feature_dim, sampling_rate=params.sampling_rate, - mel_loss_params=params.mel_loss_params, + mel_loss_params=mel_loss_params, lambda_adv=params.lambda_adv, lambda_mel=params.lambda_mel, lambda_feat_match=params.lambda_feat_match, @@ -301,79 +335,30 @@ def get_model(params: AttributeDict) -> nn.Module: return model -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - - # used to summary the stats over iterations - tot_loss = MetricsTracker() - - with torch.no_grad(): - for batch_idx, batch in enumerate(valid_dl): - batch_size = len(batch["text"]) - audio = batch["audio"].to(device) - features = batch["features"].to(device) - audio_lens = batch["audio_lens"].to(device) - features_lens = batch["features_lens"].to(device) - text = batch["text"] - tokens, tokens_lens = prepare_token_batch(text) - tokens = tokens.to(device) - tokens_lens = tokens_lens.to(device) - - loss_info = MetricsTracker() - loss_info['samples'] = batch_size - - # forward discriminator - loss_d, stats_d = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - forward_generator=False, - ) - assert loss_d.requires_grad is False - for k, v in stats_d.items(): - loss_info[k] = v * batch_size - - # forward generator - loss_g, stats_g = model( - text=tokens, - text_lengths=tokens_lens, - feats=features, - feats_lengths=features_lens, - speech=audio, - speech_lengths=audio_lens, - forward_generator=True, - ) - assert loss_g.requires_grad is False - for k, v in stats_g.items(): - loss_info[k] = v * batch_size - - # summary stats - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(device) +def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): + """Parse batch data""" + audio = batch["audio"].to(device) + features = batch["features"].to(device) + audio_lens = batch["audio_lens"].to(device) + features_lens = batch["features_lens"].to(device) + text = batch["text"] - loss_value = tot_loss["generator_loss"] / tot_loss["samples"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value + tokens = tokenizer.texts_to_token_ids(text) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + # a tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) - return tot_loss + return audio, audio_lens, features, features_lens, tokens, tokens_lens def train_one_epoch( params: AttributeDict, model: Union[nn.Module, DDP], + tokenizer: Tokenizer, optimizer_g: Optimizer, optimizer_d: Optimizer, scheduler_g: LRSchedulerType, @@ -442,18 +427,13 @@ def save_bad_model(suffix: str = ""): params.batch_idx_train += 1 batch_size = len(batch["text"]) - audio = batch["audio"].to(device) - features = batch["features"].to(device) - audio_lens = batch["audio_lens"].to(device) - features_lens = batch["features_lens"].to(device) - text = batch["text"] - tokens, tokens_lens = prepare_token_batch(text) - tokens = tokens.to(device) - tokens_lens = tokens_lens.to(device) + audio, audio_lens, features, features_lens, tokens, tokens_lens = \ + prepare_input(batch, tokenizer, device) loss_info = MetricsTracker() loss_info['samples'] = batch_size + return_sample = params.batch_idx_train % params.log_interval == 0 try: with autocast(enabled=params.use_fp16): # forward discriminator @@ -483,9 +463,13 @@ def save_bad_model(suffix: str = ""): speech=audio, speech_lengths=audio_lens, forward_generator=True, + return_sample=return_sample, ) for k, v in stats_g.items(): - loss_info[k] = v * batch_size + if "return_sample" not in k: + loss_info[k] = v * batch_size + if return_sample: + speech_hat_, speech_, mel_hat_, mel_ = stats_g["return_sample"] # update generator optimizer_g.zero_grad() scaler.scale(loss_g).backward() @@ -577,13 +561,27 @@ def save_bad_model(suffix: str = ""): tb_writer.add_scalar( "train/grad_scale", cur_grad_scale, params.batch_idx_train ) + if return_sample: + tb_writer.add_audio( + "train/speech_hat_", speech_hat_, params.batch_idx_train, params.sampling_rate + ) + tb_writer.add_audio( + "train/speech_", speech_, params.batch_idx_train, params.sampling_rate + ) + tb_writer.add_image( + "train/mel_hat_", plot_feature(mel_hat_), params.batch_idx_train, dataformats='HWC' + ) + tb_writer.add_image( + "train/mel_", plot_feature(mel_), params.batch_idx_train, dataformats='HWC' + ) # if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics: logging.info("Computing validation loss") - valid_info = compute_validation_loss( + valid_info, (speech_hat, speech) = compute_validation_loss( params=params, model=model, + tokenizer=tokenizer, valid_dl=valid_dl, world_size=world_size, ) @@ -596,6 +594,12 @@ def save_bad_model(suffix: str = ""): valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train ) + tb_writer.add_audio( + "train/valdi_speech_hat", speech_hat, params.batch_idx_train, params.sampling_rate + ) + tb_writer.add_audio( + "train/valdi_speech", speech, params.batch_idx_train, params.sampling_rate + ) loss_value = tot_loss["generator_loss"] / tot_loss["samples"] params.train_loss = loss_value @@ -604,9 +608,87 @@ def save_bad_model(suffix: str = ""): params.best_train_loss = params.train_loss +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, + rank: int = 0, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + return_sample = None + + with torch.no_grad(): + for batch_idx, batch in enumerate(valid_dl): + batch_size = len(batch["text"]) + audio, audio_lens, features, features_lens, tokens, tokens_lens = \ + prepare_input(batch, tokenizer, device) + + loss_info = MetricsTracker() + loss_info['samples'] = batch_size + + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=False, + ) + assert loss_d.requires_grad is False + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + + # forward generator + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + forward_generator=True, + ) + assert loss_g.requires_grad is False + for k, v in stats_g.items(): + loss_info[k] = v * batch_size + + # summary stats + tot_loss = tot_loss + loss_info + + # infer for first batch: + if batch_idx == 0 and rank == 0: + inner_model = model.module if isinstance(model, DDP) else model + audio_pred, _, duration = inner_model.inference(text=tokens[0, :tokens_lens[0].item()]) + audio_pred = audio_pred.data.cpu().numpy() + audio_len_pred = (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() + assert audio_len_pred == len(audio_pred), (audio_len_pred, len(audio_pred)) + audio_gt = audio[0, :audio_lens[0].item()].data.cpu().numpy() + return_sample = (audio_pred, audio_gt) + + if world_size > 1: + tot_loss.reduce(device) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss, return_sample + + def scan_pessimistic_batches_for_oom( model: Union[nn.Module, DDP], train_dl: torch.utils.data.DataLoader, + tokenizer: Tokenizer, optimizer_g: torch.optim.Optimizer, optimizer_d: torch.optim.Optimizer, params: AttributeDict, @@ -620,14 +702,8 @@ def scan_pessimistic_batches_for_oom( batches, crit_values = find_pessimistic_batches(train_dl.sampler) for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] - audio = batch["audio"].to(device) - features = batch["features"].to(device) - audio_lens = batch["audio_lens"].to(device) - features_lens = batch["features_lens"].to(device) - text = batch["text"] - tokens, tokens_lens = prepare_token_batch(text) - tokens = tokens.to(device) - tokens_lens = tokens_lens.to(device) + audio, audio_lens, features, features_lens, tokens, tokens_lens = \ + prepare_input(batch, tokenizer, device) try: # for discriminator with autocast(enabled=params.use_fp16): @@ -702,6 +778,11 @@ def run(rank, world_size, args): device = torch.device("cuda", rank) logging.info(f"Device: {device}") + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.blank_id + params.oov_id = tokenizer.oov_id + params.vocab_size = tokenizer.vocab_size + logging.info(params) logging.info("About to create model") @@ -728,14 +809,14 @@ def run(rank, world_size, args): lr=params.lr, betas=(0.8, 0.99), eps=1e-9, - weight_decay=0, + # weight_decay=0, ) optimizer_d = torch.optim.AdamW( discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9, - weight_decay=0, + # weight_decay=0, ) scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875) @@ -804,6 +885,7 @@ def remove_short_and_long_utt(c: Cut): scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl, + tokenizer=tokenizer, optimizer_g=optimizer_g, optimizer_d=optimizer_d, params=params, @@ -815,6 +897,8 @@ def remove_short_and_long_utt(c: Cut): scaler.load_state_dict(checkpoints["grad_scaler"]) for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"Start epoch {epoch}") + fix_random_seed(params.seed + epoch - 1) train_dl.sampler.set_epoch(epoch - 1) @@ -826,6 +910,7 @@ def remove_short_and_long_utt(c: Cut): train_one_epoch( params=params, model=model, + tokenizer=tokenizer, optimizer_g=optimizer_g, optimizer_d=optimizer_d, scheduler_g=scheduler_g, diff --git a/egs/ljspeech/tts/vits/tts_datamodule.py b/egs/ljspeech/tts/vits/tts_datamodule.py index bd67aa6b13..40e9c19ddf 100644 --- a/egs/ljspeech/tts/vits/tts_datamodule.py +++ b/egs/ljspeech/tts/vits/tts_datamodule.py @@ -131,7 +131,14 @@ def add_arguments(cls, parser: argparse.ArgumentParser): default=True, help="Whether to drop last batch. Used by sampler.", ) - + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) group.add_argument( "--num-workers", type=int, @@ -163,6 +170,7 @@ def train_dataloaders( train = SpeechSynthesisDataset( return_tokens=False, feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, ) if self.args.on_the_fly_feats: @@ -176,6 +184,7 @@ def train_dataloaders( train = SpeechSynthesisDataset( return_tokens=False, feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, ) if self.args.bucketing_sampler: @@ -229,11 +238,13 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: validate = SpeechSynthesisDataset( return_tokens=False, feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, ) else: validate = SpeechSynthesisDataset( return_tokens=False, feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, ) valid_sampler = DynamicBucketingSampler( cuts_valid, @@ -264,11 +275,13 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: test = SpeechSynthesisDataset( return_tokens=False, feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, ) else: test = SpeechSynthesisDataset( return_tokens=False, feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, ) test_sampler = DynamicBucketingSampler( cuts, diff --git a/egs/ljspeech/tts/vits/utils.py b/egs/ljspeech/tts/vits/utils.py index 0020975816..582856eee0 100644 --- a/egs/ljspeech/tts/vits/utils.py +++ b/egs/ljspeech/tts/vits/utils.py @@ -211,6 +211,7 @@ def intersperse(sequence, item=0): def prepare_token_batch( texts: List[str], + phonemes: Optional[List[str]] = None, intersperse_blank: bool = True, blank_id: int = 0, pad_id: int = 0, @@ -222,41 +223,50 @@ def prepare_token_batch( blank_id: index of blank token pad_id: padding index """ - # normalize text - normalized_texts = [] - for text in texts: - text = convert_to_ascii(text) - text = lowercase(text) - text = expand_abbreviations(text) - normalized_texts.append(text) - - # convert to phonemes - phonemes = phonemize( - normalized_texts, - language='en-us', - backend='espeak', - strip=True, - preserve_punctuation=True, - with_stress=True, - ) + if phonemes is None: + # normalize text + normalized_texts = [] + for text in texts: + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_abbreviations(text) + normalized_texts.append(text) + + # convert to phonemes + phonemes = phonemize( + normalized_texts, + language='en-us', + backend='espeak', + strip=True, + preserve_punctuation=True, + with_stress=True, + ) + phonemes = [collapse_whitespace(sequence) for sequence in phonemes] # convert to symbol ids lengths = [] sequences = [] + skip = False for idx, sequence in enumerate(phonemes): try: - sequence = [symbol_to_id[symbol] for symbol in collapse_whitespace(sequence)] - except RuntimeError: - print(text[idx]) - print(normalized_texts[idx]) + sequence = [symbol_to_id[symbol] for symbol in sequence] + except Exception: + # print(texts[idx]) + # print(normalized_texts[idx]) + print(phonemes[idx]) + skip = True if intersperse_blank: sequence = intersperse(sequence, blank_id) - sequences.append(torch.tensor(sequence, dtype=torch.int64)) + try: + sequences.append(torch.tensor(sequence, dtype=torch.int64)) + except Exception: + print(sequence) + skip = True lengths.append(len(sequence)) sequences = pad_sequence(sequences, batch_first=True, padding_value=pad_id) lengths = torch.tensor(lengths, dtype=torch.int64) - return sequences, lengths + return sequences, lengths, skip class MetricsTracker(collections.defaultdict): @@ -287,7 +297,7 @@ def __str__(self) -> str: norm_value = "%.4g" % v ans += str(k) + "=" + str(norm_value) + ", " samples = "%.2f" % self["samples"] - ans += "over" + str(samples) + " samples." + ans += "over " + str(samples) + " samples." return ans def norm_items(self) -> List[Tuple[str, float]]: @@ -468,3 +478,41 @@ def save_checkpoint_with_global_batch_idx( sampler=sampler, rank=rank, ) + + +# def plot_feature(feature): +# """ +# Display the feature matrix as an image. Requires matplotlib to be installed. +# """ +# import matplotlib.pyplot as plt +# +# feature = np.flip(feature.transpose(1, 0), 0) +# return plt.matshow(feature) + +MATPLOTLIB_FLAG = False + + +def plot_feature(spectrogram): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger('matplotlib') + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", + interpolation='none') + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data diff --git a/egs/ljspeech/tts/vits/vits.py b/egs/ljspeech/tts/vits/vits.py index da9d144f22..441e915df6 100644 --- a/egs/ljspeech/tts/vits/vits.py +++ b/egs/ljspeech/tts/vits/vits.py @@ -241,6 +241,7 @@ def forward( feats_lengths: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, + return_sample: bool = False, sids: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, @@ -276,6 +277,7 @@ def forward( feats_lengths=feats_lengths, speech=speech, speech_lengths=speech_lengths, + return_sample=return_sample, sids=sids, spembs=spembs, lids=lids, @@ -301,6 +303,7 @@ def _forward_generator( feats_lengths: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, + return_sample: bool = False, sids: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, @@ -367,7 +370,12 @@ def _forward_generator( # calculate losses with autocast(enabled=False): - mel_loss = self.mel_loss(speech_hat_, speech_) + if not return_sample: + mel_loss = self.mel_loss(speech_hat_, speech_) + else: + mel_loss, (mel_hat_, mel_) = self.mel_loss( + speech_hat_, speech_, return_mel=True + ) kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask) dur_loss = torch.sum(dur_nll.float()) adv_loss = self.generator_adv_loss(p_hat) @@ -389,6 +397,14 @@ def _forward_generator( generator_feat_match_loss=feat_match_loss.item(), ) + if return_sample: + stats["return_sample"] = ( + speech_hat_[0].data.cpu().numpy(), + speech_[0].data.cpu().numpy(), + mel_hat_[0].data.cpu().numpy(), + mel_[0].data.cpu().numpy(), + ) + # reset cache if reuse_cache or not self.training: self._cache = None @@ -564,4 +580,43 @@ def inference( alpha=alpha, max_len=max_len, ) - return dict(wav=wav.view(-1), att_w=att_w[0], duration=dur[0]) + return wav.view(-1), att_w[0], dur[0] + + def inference_batch( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + durations: Optional[torch.Tensor] = None, + noise_scale: float = 0.667, + noise_scale_dur: float = 0.8, + alpha: float = 1.0, + max_len: Optional[int] = None, + use_teacher_forcing: bool = False, + ) -> Dict[str, torch.Tensor]: + """Run inference. + + Args: + text (Tensor): Input text index tensor (B, T_text). + text_lengths (Tensor): Input text index tensor (B,). + noise_scale (float): Noise scale value for flow. + noise_scale_dur (float): Noise scale value for duration predictor. + alpha (float): Alpha parameter to control the speed of generated speech. + max_len (Optional[int]): Maximum length. + + Returns: + Dict[str, Tensor]: + * wav (Tensor): Generated waveform tensor (B, T_wav). + * att_w (Tensor): Monotonic attention weight tensor (B, T_feats, T_text). + * duration (Tensor): Predicted duration tensor (B, T_text). + + """ + # inference + wav, att_w, dur = self.generator.inference( + text=text, + text_lengths=text_lengths, + noise_scale=noise_scale, + noise_scale_dur=noise_scale_dur, + alpha=alpha, + max_len=max_len, + ) + return wav, att_w, dur From 8d09f8e6bfa642eae3635104d1884014ad165d23 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sun, 5 Nov 2023 18:25:47 +0800 Subject: [PATCH 03/16] use Conformer as text encoder --- egs/ljspeech/tts/vits/generator.py | 3 + egs/ljspeech/tts/vits/infer.py | 27 +++++- egs/ljspeech/tts/vits/text_encoder.py | 116 +++++++++++++++++++++++++- egs/ljspeech/tts/vits/vits.py | 1 + 4 files changed, 143 insertions(+), 4 deletions(-) diff --git a/egs/ljspeech/tts/vits/generator.py b/egs/ljspeech/tts/vits/generator.py index a74440c958..fc0d45cfd6 100644 --- a/egs/ljspeech/tts/vits/generator.py +++ b/egs/ljspeech/tts/vits/generator.py @@ -44,6 +44,7 @@ def __init__( segment_size: int = 32, text_encoder_attention_heads: int = 2, text_encoder_ffn_expand: int = 4, + text_encoder_cnn_module_kernel: int = 5, text_encoder_blocks: int = 6, text_encoder_dropout_rate: float = 0.1, decoder_kernel_size: int = 7, @@ -89,6 +90,7 @@ def __init__( of text encoder. text_encoder_ffn_expand (int): Expansion ratio of FFN in conformer block of text encoder. + text_encoder_cnn_module_kernel (int): Convolution kernel size in text encoder. text_encoder_blocks (int): Number of conformer blocks in text encoder. text_encoder_dropout_rate (float): Dropout rate in conformer block of text encoder. @@ -135,6 +137,7 @@ def __init__( d_model=hidden_channels, num_heads=text_encoder_attention_heads, dim_feedforward=hidden_channels * text_encoder_ffn_expand, + cnn_module_kernel=text_encoder_cnn_module_kernel, num_layers=text_encoder_blocks, dropout=text_encoder_dropout_rate, ) diff --git a/egs/ljspeech/tts/vits/infer.py b/egs/ljspeech/tts/vits/infer.py index 89fc729626..623cc3ec9a 100755 --- a/egs/ljspeech/tts/vits/infer.py +++ b/egs/ljspeech/tts/vits/infer.py @@ -103,11 +103,13 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple +import k2 import torch import torch.nn as nn import torchaudio -from train2 import get_model, get_params +from train import get_model, get_params, prepare_input +from tokenizer import Tokenizer from icefall.checkpoint import ( average_checkpoints, @@ -124,7 +126,6 @@ write_error_stats, ) from tts_datamodule import LJSpeechTtsDataModule -from utils import prepare_token_batch LOG_EPS = math.log(1e-10) @@ -169,6 +170,13 @@ def get_parser(): help="The experiment dir", ) + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to tokens.txt.""", + ) + return parser @@ -176,6 +184,7 @@ def infer_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, model: nn.Module, + tokenizer: Tokenizer, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -236,10 +245,16 @@ def _save_worker( # We only want one background worker so that serialization is deterministic. for batch_idx, batch in enumerate(dl): batch_size = len(batch["text"]) + text = batch["text"] - tokens, tokens_lens = prepare_token_batch(text) + tokens = tokenizer.texts_to_token_ids(text) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] tokens = tokens.to(device) tokens_lens = tokens_lens.to(device) + # a tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) audio = batch["audio"] audio_lens = batch["audio_lens"].tolist() @@ -296,6 +311,11 @@ def main(): if torch.cuda.is_available(): device = torch.device("cuda", 0) + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.blank_id + params.oov_id = tokenizer.oov_id + params.vocab_size = tokenizer.vocab_size + logging.info(f"Device: {device}") logging.info(params) @@ -348,6 +368,7 @@ def main(): dl=test_dl, params=params, model=model, + tokenizer=tokenizer, ) # save_results( diff --git a/egs/ljspeech/tts/vits/text_encoder.py b/egs/ljspeech/tts/vits/text_encoder.py index fbf9b16a30..9ba8e1768c 100644 --- a/egs/ljspeech/tts/vits/text_encoder.py +++ b/egs/ljspeech/tts/vits/text_encoder.py @@ -45,6 +45,7 @@ def __init__( d_model: int = 192, num_heads: int = 2, dim_feedforward: int = 768, + cnn_module_kernel: int = 5, num_layers: int = 6, dropout: float = 0.1, ): @@ -55,6 +56,7 @@ def __init__( d_model (int): attention dimension num_heads (int): number of attention heads dim_feedforward (int): feedforward dimention + cnn_module_kernel (int): convolution kernel size num_layers (int): number of encoder layers dropout (float): dropout rate """ @@ -69,6 +71,7 @@ def __init__( d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward, + cnn_module_kernel=cnn_module_kernel, num_layers=num_layers, dropout=dropout, ) @@ -119,6 +122,7 @@ class Transformer(nn.Module): d_model (int): attention dimension num_heads (int): number of attention heads dim_feedforward (int): feedforward dimention + cnn_module_kernel (int): convolution kernel size num_layers (int): number of encoder layers dropout (float): dropout rate """ @@ -128,6 +132,7 @@ def __init__( d_model: int = 192, num_heads: int = 2, dim_feedforward: int = 768, + cnn_module_kernel: int = 5, num_layers: int = 6, dropout: float = 0.1, ) -> None: @@ -142,6 +147,7 @@ def __init__( d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward, + cnn_module_kernel=cnn_module_kernel, dropout=dropout, ) self.encoder = TransformerEncoder(encoder_layer, num_layers) @@ -187,12 +193,22 @@ def __init__( d_model: int, num_heads: int, dim_feedforward: int, + cnn_module_kernel: int, dropout: float = 0.1, ) -> None: super(TransformerEncoderLayer, self).__init__() + self.feed_forward_macaron = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + self.self_attn = RelPositionMultiheadAttention(d_model, num_heads, dropout=dropout) + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), Swish(), @@ -200,10 +216,13 @@ def __init__( nn.Linear(dim_feedforward, d_model), ) - self.norm_ff = nn.LayerNorm(d_model) # for the FNN module + self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module + self.norm_conv = nn.LayerNorm(d_model) # for the CNN module self.norm_final = nn.LayerNorm(d_model) # for the final output of the block + self.norm_ff = nn.LayerNorm(d_model) # for the FNN module + self.ff_scale = 0.5 self.dropout = nn.Dropout(dropout) def forward( @@ -220,6 +239,9 @@ def forward( pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim). key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len) """ + # macaron style feed-forward module + src = src + self.ff_scale * self.dropout(self.feed_forward_macaron(self.norm_ff_macaron(src))) + # multi-head self-attention module src_attn = self.self_attn( self.norm_mha(src), @@ -228,6 +250,9 @@ def forward( ) src = src + self.dropout(src_attn) + # convolution module + src = src + self.dropout(self.conv_module(self.norm_conv(src))) + # feed-forward module src = src + self.dropout(self.feed_forward(self.norm_ff(src))) @@ -508,6 +533,95 @@ def forward( return attn_output +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + """ + + def __init__( + self, + channels: int, + kernel_size: int, + bias: bool = True, + ) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + + padding = (kernel_size - 1) // 2 + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + bias=bias, + ) + self.norm = nn.LayerNorm(channels) + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = Swish() + + def forward( + self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + if src_key_padding_mask is not None: + x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + x = self.depthwise_conv(x) + # x is (batch, channels, time) + x = x.permute(0, 2, 1) + x = self.norm(x) + x = x.permute(0, 2, 1) + + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) + + class Swish(nn.Module): """Construct an Swish object.""" diff --git a/egs/ljspeech/tts/vits/vits.py b/egs/ljspeech/tts/vits/vits.py index 441e915df6..27d9b4c7a1 100644 --- a/egs/ljspeech/tts/vits/vits.py +++ b/egs/ljspeech/tts/vits/vits.py @@ -61,6 +61,7 @@ def __init__( "segment_size": 32, "text_encoder_attention_heads": 2, "text_encoder_ffn_expand": 4, + "text_encoder_cnn_module_kernel": 5, "text_encoder_blocks": 6, "text_encoder_dropout_rate": 0.1, "decoder_kernel_size": 7, From 04c6ecbaa1fcf41d78398707c438556a8f766c9c Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sun, 5 Nov 2023 22:47:04 +0800 Subject: [PATCH 04/16] modify training script, clean codes --- egs/ljspeech/tts/local/split_subsets.py | 80 --- egs/ljspeech/tts/prepare.sh | 20 +- egs/ljspeech/tts/vits/commons.py | 161 ------ egs/ljspeech/tts/vits/duration_predictor.py | 2 +- egs/ljspeech/tts/vits/features.py | 416 --------------- egs/ljspeech/tts/vits/flow.py | 3 +- egs/ljspeech/tts/vits/generator.py | 2 +- egs/ljspeech/tts/vits/hifigan.py | 2 +- egs/ljspeech/tts/vits/infer.py | 207 +------- egs/ljspeech/tts/vits/loss.py | 6 +- egs/ljspeech/tts/vits/models.py | 534 -------------------- egs/ljspeech/tts/vits/posterior_encoder.py | 2 +- egs/ljspeech/tts/vits/residual_coupling.py | 2 +- egs/ljspeech/tts/vits/symbols.py | 17 - egs/ljspeech/tts/vits/text_encoder.py | 2 + egs/ljspeech/tts/vits/tokenizer.py | 1 - egs/ljspeech/tts/vits/train.py | 232 +++------ egs/ljspeech/tts/vits/transform.py | 3 +- egs/ljspeech/tts/vits/utils.py | 347 ++----------- egs/ljspeech/tts/vits/vits.py | 60 +-- egs/ljspeech/tts/vits/wavenet.py | 2 +- 21 files changed, 183 insertions(+), 1918 deletions(-) delete mode 100755 egs/ljspeech/tts/local/split_subsets.py delete mode 100644 egs/ljspeech/tts/vits/commons.py delete mode 100644 egs/ljspeech/tts/vits/features.py delete mode 100644 egs/ljspeech/tts/vits/models.py delete mode 100644 egs/ljspeech/tts/vits/symbols.py diff --git a/egs/ljspeech/tts/local/split_subsets.py b/egs/ljspeech/tts/local/split_subsets.py deleted file mode 100755 index b2afca9712..0000000000 --- a/egs/ljspeech/tts/local/split_subsets.py +++ /dev/null @@ -1,80 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script split the LJSpeech dataset cuts into three sets: - - training, 12500 - - validation, 100 - - test, 500 -The numbers are from https://arxiv.org/pdf/2106.06103.pdf - -Usage example: - python3 ./local/split_subsets.py ./data/spectrogram -""" - -import argparse -import logging -import random -from pathlib import Path - -from lhotse import load_manifest_lazy - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "manifest_dir", - type=Path, - default=Path("data/spectrogram"), - help="Path to the manifest file", - ) - - return parser.parse_args() - - -def main(): - args = get_args() - - manifest_dir = Path(args.manifest_dir) - prefix = "ljspeech" - suffix = "jsonl.gz" - # all_cuts = load_manifest_lazy(manifest_dir / f"{prefix}_cuts_all.{suffix}") - all_cuts = load_manifest_lazy(manifest_dir / f"{prefix}_cuts_all_phonemized.{suffix}") - - cut_ids = list(all_cuts.ids) - random.shuffle(cut_ids) - - train_cuts = all_cuts.subset(cut_ids=cut_ids[:12500]) - valid_cuts = all_cuts.subset(cut_ids=cut_ids[12500:12500 + 100]) - test_cuts = all_cuts.subset(cut_ids=cut_ids[12500 + 100:]) - assert len(train_cuts) == 12500, "expected 12500 cuts for training but got len(train_cuts)" - assert len(valid_cuts) == 100, "expected 100 cuts but for validation but got len(valid_cuts)" - assert len(test_cuts) == 500, "expected 500 cuts for test but got len(test_cuts)" - - train_cuts.to_file(manifest_dir / f"{prefix}_cuts_train.{suffix}") - valid_cuts.to_file(manifest_dir / f"{prefix}_cuts_valid.{suffix}") - test_cuts.to_file(manifest_dir / f"{prefix}_cuts_test.{suffix}") - - logging.info("Splitted into three sets: training (12500), validation (100), and test (500)") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/ljspeech/tts/prepare.sh b/egs/ljspeech/tts/prepare.sh index 4f4685951a..613eb37d8d 100755 --- a/egs/ljspeech/tts/prepare.sh +++ b/egs/ljspeech/tts/prepare.sh @@ -9,8 +9,7 @@ nj=1 stage=-1 stop_stage=100 -# dl_dir=$PWD/download -dl_dir=/star-data/zengwei/download/ljspeech/ +dl_dir=$PWD/download . shared/parse_options.sh || exit 1 @@ -66,22 +65,6 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi fi -# if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then -# log "Stage 3: Phonemize the transcripts for LJSpeech" -# if [ ! -e data/spectrogram/.ljspeech_phonemized.done ]; then -# ./local/phonemize_text.py data/spectrogram -# touch data/spectrogram/.ljspeech_phonemized.done -# fi -# fi - -# if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then -# log "Stage 4: Split the LJSpeech cuts into three sets" -# if [ ! -e data/spectrogram/.ljspeech_split.done ]; then -# ./local/split_subsets.py data/spectrogram -# touch data/spectrogram/.ljspeech_split.done -# fi -# fi - if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Split the LJSpeech cuts into train, valid and test sets" if [ ! -e data/spectrogram/.ljspeech_split.done ]; then @@ -94,6 +77,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then lhotse subset --last 500 \ data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \ data/spectrogram/ljspeech_cuts_test.jsonl.gz + rm data/spectrogram/ljspeech_cuts_validtest.jsonl.gz n=$(( $(gunzip -c data/spectrogram/ljspeech_cuts_all.jsonl.gz | wc -l) - 600 )) diff --git a/egs/ljspeech/tts/vits/commons.py b/egs/ljspeech/tts/vits/commons.py deleted file mode 100644 index 9ad0444b61..0000000000 --- a/egs/ljspeech/tts/vits/commons.py +++ /dev/null @@ -1,161 +0,0 @@ -import math -import numpy as np -import torch -from torch import nn -from torch.nn import functional as F - - -def init_weights(m, mean=0.0, std=0.01): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - m.weight.data.normal_(mean, std) - - -def get_padding(kernel_size, dilation=1): - return int((kernel_size*dilation - dilation)/2) - - -def convert_pad_shape(pad_shape): - l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape - - -def intersperse(lst, item): - result = [item] * (len(lst) * 2 + 1) - result[1::2] = lst - return result - - -def kl_divergence(m_p, logs_p, m_q, logs_q): - """KL(P||Q)""" - kl = (logs_q - logs_p) - 0.5 - kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q) - return kl - - -def rand_gumbel(shape): - """Sample from the Gumbel distribution, protect from overflows.""" - uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 - return -torch.log(-torch.log(uniform_samples)) - - -def rand_gumbel_like(x): - g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) - return g - - -def slice_segments(x, ids_str, segment_size=4): - ret = torch.zeros_like(x[:, :, :segment_size]) - for i in range(x.size(0)): - idx_str = ids_str[i] - idx_end = idx_str + segment_size - ret[i] = x[i, :, idx_str:idx_end] - return ret - - -def rand_slice_segments(x, x_lengths=None, segment_size=4): - b, d, t = x.size() - if x_lengths is None: - x_lengths = t - ids_str_max = x_lengths - segment_size + 1 - ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) - ret = slice_segments(x, ids_str, segment_size) - return ret, ids_str - - -def get_timing_signal_1d( - length, channels, min_timescale=1.0, max_timescale=1.0e4): - position = torch.arange(length, dtype=torch.float) - num_timescales = channels // 2 - log_timescale_increment = ( - math.log(float(max_timescale) / float(min_timescale)) / - (num_timescales - 1)) - inv_timescales = min_timescale * torch.exp( - torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) - scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) - signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) - signal = F.pad(signal, [0, 0, 0, channels % 2]) - signal = signal.view(1, channels, length) - return signal - - -def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): - b, channels, length = x.size() - signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) - return x + signal.to(dtype=x.dtype, device=x.device) - - -def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): - b, channels, length = x.size() - signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) - return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) - - -def subsequent_mask(length): - mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) - return mask - - -@torch.jit.script -def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): - n_channels_int = n_channels[0] - in_act = input_a + input_b - t_act = torch.tanh(in_act[:, :n_channels_int, :]) - s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) - acts = t_act * s_act - return acts - - -def convert_pad_shape(pad_shape): - l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape - - -def shift_1d(x): - x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] - return x - - -def sequence_mask(length, max_length=None): - if max_length is None: - max_length = length.max() - x = torch.arange(max_length, dtype=length.dtype, device=length.device) - return x.unsqueeze(0) < length.unsqueeze(1) - - -def generate_path(duration, mask): - """ - duration: [b, 1, t_x] - mask: [b, 1, t_y, t_x] - """ - device = duration.device - - b, _, t_y, t_x = mask.shape - cum_duration = torch.cumsum(duration, -1) - - cum_duration_flat = cum_duration.view(b * t_x) - path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) - path = path.view(b, t_x, t_y) - path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] - path = path.unsqueeze(1).transpose(2,3) * mask - return path - - -def clip_grad_value_(parameters, clip_value, norm_type=2): - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - parameters = list(filter(lambda p: p.grad is not None, parameters)) - norm_type = float(norm_type) - if clip_value is not None: - clip_value = float(clip_value) - - total_norm = 0 - for p in parameters: - param_norm = p.grad.data.norm(norm_type) - total_norm += param_norm.item() ** norm_type - if clip_value is not None: - p.grad.data.clamp_(min=-clip_value, max=clip_value) - total_norm = total_norm ** (1. / norm_type) - return total_norm diff --git a/egs/ljspeech/tts/vits/duration_predictor.py b/egs/ljspeech/tts/vits/duration_predictor.py index 5e8d670bdc..c29a28479a 100644 --- a/egs/ljspeech/tts/vits/duration_predictor.py +++ b/egs/ljspeech/tts/vits/duration_predictor.py @@ -1,4 +1,4 @@ -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/duration_predictor.py +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/duration_predictor.py # Copyright 2021 Tomoki Hayashi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) diff --git a/egs/ljspeech/tts/vits/features.py b/egs/ljspeech/tts/vits/features.py deleted file mode 100644 index b43c7cf46d..0000000000 --- a/egs/ljspeech/tts/vits/features.py +++ /dev/null @@ -1,416 +0,0 @@ -# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Any, Dict, Optional, Tuple - -import librosa -import numpy as np -import torch -from torch import nn - -from icefall.utils import make_pad_mask - - -# From https://github.com/espnet/espnet/blob/master/espnet2/layers/stft.py -class Stft(nn.Module): - def __init__( - self, - n_fft: int = 512, - win_length: int = None, - hop_length: int = 128, - window: Optional[str] = "hann", - center: bool = True, - normalized: bool = False, - onesided: bool = True, - ): - super().__init__() - self.n_fft = n_fft - if win_length is None: - self.win_length = n_fft - else: - self.win_length = win_length - self.hop_length = hop_length - self.center = center - self.normalized = normalized - self.onesided = onesided - if window is not None and not hasattr(torch, f"{window}_window"): - raise ValueError(f"{window} window is not implemented") - self.window = window - - def extra_repr(self): - return ( - f"n_fft={self.n_fft}, " - f"win_length={self.win_length}, " - f"hop_length={self.hop_length}, " - f"center={self.center}, " - f"normalized={self.normalized}, " - f"onesided={self.onesided}" - ) - - def forward( - self, input: torch.Tensor, ilens: torch.Tensor = None - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """STFT forward function. - - Args: - input: (Batch, Nsamples) or (Batch, Nsample, Channels) - ilens: (Batch) - Returns: - output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2) - - """ - bs = input.size(0) - if input.dim() == 3: - multi_channel = True - # input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample) - input = input.transpose(1, 2).reshape(-1, input.size(1)) - else: - multi_channel = False - - # NOTE(kamo): - # The default behaviour of torch.stft is compatible with librosa.stft - # about padding and scaling. - # Note that it's different from scipy.signal.stft - - # output: (Batch, Freq, Frames, 2=real_imag) - # or (Batch, Channel, Freq, Frames, 2=real_imag) - if self.window is not None: - window_func = getattr(torch, f"{self.window}_window") - window = window_func( - self.win_length, dtype=input.dtype, device=input.device - ) - else: - window = None - - # For the compatibility of ARM devices, which do not support - # torch.stft() due to the lack of MKL (on older pytorch versions), - # there is an alternative replacement implementation with librosa. - # Note: pytorch >= 1.10.0 now has native support for FFT and STFT - # on all cpu targets including ARM. - if input.is_cuda or torch.backends.mkl.is_available(): - stft_kwargs = dict( - n_fft=self.n_fft, - win_length=self.win_length, - hop_length=self.hop_length, - center=self.center, - window=window, - normalized=self.normalized, - onesided=self.onesided, - ) - stft_kwargs["return_complex"] = True - output = torch.stft(input, **stft_kwargs) - output = torch.view_as_real(output) - else: - if self.training: - raise NotImplementedError( - "stft is implemented with librosa on this device, which does not " - "support the training mode." - ) - - # use stft_kwargs to flexibly control different PyTorch versions' kwargs - # note: librosa does not support a win_length that is < n_ftt - # but the window can be manually padded (see below). - stft_kwargs = dict( - n_fft=self.n_fft, - win_length=self.n_fft, - hop_length=self.hop_length, - center=self.center, - window=window, - pad_mode="reflect", - ) - - if window is not None: - # pad the given window to n_fft - n_pad_left = (self.n_fft - window.shape[0]) // 2 - n_pad_right = self.n_fft - window.shape[0] - n_pad_left - stft_kwargs["window"] = torch.cat( - [torch.zeros(n_pad_left), window, torch.zeros(n_pad_right)], 0 - ).numpy() - else: - win_length = ( - self.win_length if self.win_length is not None else self.n_fft - ) - stft_kwargs["window"] = torch.ones(win_length) - - output = [] - # iterate over istances in a batch - for i, instance in enumerate(input): - stft = librosa.stft(input[i].numpy(), **stft_kwargs) - output.append(torch.tensor(np.stack([stft.real, stft.imag], -1))) - output = torch.stack(output, 0) - if not self.onesided: - len_conj = self.n_fft - output.shape[1] - conj = output[:, 1 : 1 + len_conj].flip(1) - conj[:, :, :, -1].data *= -1 - output = torch.cat([output, conj], 1) - if self.normalized: - output = output * (stft_kwargs["window"].shape[0] ** (-0.5)) - - # output: (Batch, Freq, Frames, 2=real_imag) - # -> (Batch, Frames, Freq, 2=real_imag) - output = output.transpose(1, 2) - if multi_channel: - # output: (Batch * Channel, Frames, Freq, 2=real_imag) - # -> (Batch, Frame, Channel, Freq, 2=real_imag) - output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose( - 1, 2 - ) - - if ilens is not None: - if self.center: - pad = self.n_fft // 2 - ilens = ilens + 2 * pad - - olens = ( - torch.div(ilens - self.n_fft, self.hop_length, rounding_mode="trunc") - + 1 - ) - output.masked_fill_(make_pad_mask(olens), 0.0) - else: - olens = None - - return output, olens - - -# From https://github.com/espnet/espnet/blob/master/espnet2/tts/feats_extract/linear_spectrogram.py -class LinearSpectrogram(nn.Module): - """Linear amplitude spectrogram. - - Stft -> amplitude-spec - """ - - def __init__( - self, - n_fft: int = 1024, - win_length: int = None, - hop_length: int = 256, - window: Optional[str] = "hann", - center: bool = True, - normalized: bool = False, - onesided: bool = True, - ): - super().__init__() - self.n_fft = n_fft - self.hop_length = hop_length - self.win_length = win_length - self.window = window - self.stft = Stft( - n_fft=n_fft, - win_length=win_length, - hop_length=hop_length, - window=window, - center=center, - normalized=normalized, - onesided=onesided, - ) - self.n_fft = n_fft - - def output_size(self) -> int: - return self.n_fft // 2 + 1 - - def get_parameters(self) -> Dict[str, Any]: - """Return the parameters required by Vocoder.""" - return dict( - n_fft=self.n_fft, - n_shift=self.hop_length, - win_length=self.win_length, - window=self.window, - ) - - def forward( - self, input: torch.Tensor, input_lengths: torch.Tensor = None - ) -> Tuple[torch.Tensor, torch.Tensor]: - # 1. Stft: time -> time-freq - input_stft, feats_lens = self.stft(input, input_lengths) - - assert input_stft.dim() >= 4, input_stft.shape - # "2" refers to the real/imag parts of Complex - assert input_stft.shape[-1] == 2, input_stft.shape - - # STFT -> Power spectrum -> Amp spectrum - # input_stft: (..., F, 2) -> (..., F) - input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2 - input_amp = torch.sqrt(torch.clamp(input_power, min=1.0e-10)) - return input_amp, feats_lens - - -# From https://github.com/espnet/espnet/blob/master/espnet2/layers/log_mel.py -class LogMel(nn.Module): - """Convert STFT to fbank feats - - The arguments is same as librosa.filters.mel - - Args: - fs: number > 0 [scalar] sampling rate of the incoming signal - n_fft: int > 0 [scalar] number of FFT components - n_mels: int > 0 [scalar] number of Mel bands to generate - fmin: float >= 0 [scalar] lowest frequency (in Hz) - fmax: float >= 0 [scalar] highest frequency (in Hz). - If `None`, use `fmax = fs / 2.0` - htk: use HTK formula instead of Slaney - """ - - def __init__( - self, - fs: int = 16000, - n_fft: int = 512, - n_mels: int = 80, - fmin: float = None, - fmax: float = None, - htk: bool = False, - log_base: float = None, - ): - super().__init__() - - fmin = 0 if fmin is None else fmin - fmax = fs / 2 if fmax is None else fmax - _mel_options = dict( - sr=fs, - n_fft=n_fft, - n_mels=n_mels, - fmin=fmin, - fmax=fmax, - htk=htk, - ) - self.mel_options = _mel_options - self.log_base = log_base - - # Note(kamo): The mel matrix of librosa is different from kaldi. - melmat = librosa.filters.mel(**_mel_options) - # melmat: (D2, D1) -> (D1, D2) - self.register_buffer("melmat", torch.from_numpy(melmat.T).float()) - - def extra_repr(self): - return ", ".join(f"{k}={v}" for k, v in self.mel_options.items()) - - def forward( - self, - feat: torch.Tensor, - ilens: torch.Tensor = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2) - mel_feat = torch.matmul(feat, self.melmat) - mel_feat = torch.clamp(mel_feat, min=1e-10) - - if self.log_base is None: - logmel_feat = mel_feat.log() - elif self.log_base == 2.0: - logmel_feat = mel_feat.log2() - elif self.log_base == 10.0: - logmel_feat = mel_feat.log10() - else: - logmel_feat = mel_feat.log() / torch.log(self.log_base) - - # Zero padding - if ilens is not None: - logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens), 0.0) - else: - ilens = feat.new_full( - [feat.size(0)], fill_value=feat.size(1), dtype=torch.long - ) - return logmel_feat, ilens - - -# From https://github.com/espnet/espnet/blob/master/espnet2/tts/feats_extract/log_mel_fbank.py -class LogMelFbank(nn.Module): - """Conventional frontend structure for TTS. - - Stft -> amplitude-spec -> Log-Mel-Fbank - """ - - def __init__( - self, - fs: int = 16000, - n_fft: int = 1024, - win_length: int = None, - hop_length: int = 256, - window: Optional[str] = "hann", - center: bool = True, - normalized: bool = False, - onesided: bool = True, - n_mels: int = 80, - fmin: Optional[int] = 80, - fmax: Optional[int] = 7600, - htk: bool = False, - log_base: Optional[float] = 10.0, - ): - super().__init__() - - self.fs = fs - self.n_mels = n_mels - self.n_fft = n_fft - self.hop_length = hop_length - self.win_length = win_length - self.window = window - self.fmin = fmin - self.fmax = fmax - - self.stft = Stft( - n_fft=n_fft, - win_length=win_length, - hop_length=hop_length, - window=window, - center=center, - normalized=normalized, - onesided=onesided, - ) - - self.logmel = LogMel( - fs=fs, - n_fft=n_fft, - n_mels=n_mels, - fmin=fmin, - fmax=fmax, - htk=htk, - log_base=log_base, - ) - - def output_size(self) -> int: - return self.n_mels - - def get_parameters(self) -> Dict[str, Any]: - """Return the parameters required by Vocoder""" - return dict( - fs=self.fs, - n_fft=self.n_fft, - n_shift=self.hop_length, - window=self.window, - n_mels=self.n_mels, - win_length=self.win_length, - fmin=self.fmin, - fmax=self.fmax, - ) - - def forward( - self, input: torch.Tensor, input_lengths: torch.Tensor = None - ) -> Tuple[torch.Tensor, torch.Tensor]: - # 1. Domain-conversion: e.g. Stft: time -> time-freq - input_stft, feats_lens = self.stft(input, input_lengths) - - assert input_stft.dim() >= 4, input_stft.shape - # "2" refers to the real/imag parts of Complex - assert input_stft.shape[-1] == 2, input_stft.shape - - # NOTE(kamo): We use different definition for log-spec between TTS and ASR - # TTS: log_10(abs(stft)) - # ASR: log_e(power(stft)) - - # input_stft: (..., F, 2) -> (..., F) - input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2 - input_amp = torch.sqrt(torch.clamp(input_power, min=1.0e-10)) - input_feats, _ = self.logmel(input_amp, feats_lens) - return input_feats, feats_lens diff --git a/egs/ljspeech/tts/vits/flow.py b/egs/ljspeech/tts/vits/flow.py index 04fb99b427..206bd5e3e5 100644 --- a/egs/ljspeech/tts/vits/flow.py +++ b/egs/ljspeech/tts/vits/flow.py @@ -1,4 +1,5 @@ -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/flow.py +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/flow.py + # Copyright 2021 Tomoki Hayashi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) diff --git a/egs/ljspeech/tts/vits/generator.py b/egs/ljspeech/tts/vits/generator.py index fc0d45cfd6..664d8064f5 100644 --- a/egs/ljspeech/tts/vits/generator.py +++ b/egs/ljspeech/tts/vits/generator.py @@ -1,4 +1,4 @@ -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/generator.py +# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/generator.py # Copyright 2021 Tomoki Hayashi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) diff --git a/egs/ljspeech/tts/vits/hifigan.py b/egs/ljspeech/tts/vits/hifigan.py index a87cb2fce7..589ac30f60 100644 --- a/egs/ljspeech/tts/vits/hifigan.py +++ b/egs/ljspeech/tts/vits/hifigan.py @@ -1,4 +1,4 @@ -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/hifigan.py +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/hifigan.py # Copyright 2021 Tomoki Hayashi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) diff --git a/egs/ljspeech/tts/vits/infer.py b/egs/ljspeech/tts/vits/infer.py index 623cc3ec9a..f971f85ffd 100755 --- a/egs/ljspeech/tts/vits/infer.py +++ b/egs/ljspeech/tts/vits/infer.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 # -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -17,118 +16,34 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Usage: -(1) greedy search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method greedy_search +This script performs model inference on test set. -(2) beam search (not recommended) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(5) fast beam search (nbest) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./zipformer/decode.py \ - --epoch 28 \ - --avg 15 \ +Usage: +./vits/infer.py \ + --epoch 1000 \ --exp-dir ./zipformer/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 + --max-duration 500 """ import argparse import logging -import math -import os -from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import List import k2 import torch import torch.nn as nn import torchaudio -from train import get_model, get_params, prepare_input +from train import get_model, get_params from tokenizer import Tokenizer -from icefall.checkpoint import ( - average_checkpoints, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - make_pad_mask, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict, setup_logger from tts_datamodule import LJSpeechTtsDataModule -LOG_EPS = math.log(1e-10) - def get_parser(): parser = argparse.ArgumentParser( @@ -138,35 +53,16 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=30, + default=1000, help="""It specifies the checkpoint to use for decoding. Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. """, ) - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - parser.add_argument( "--exp-dir", type=str, - default="zipformer/exp", + default="vits/exp", help="The experiment dir", ) @@ -174,7 +70,7 @@ def get_parser(): "--tokens", type=str, default="data/tokens.txt", - help="""Path to tokens.txt.""", + help="""Path to vocabulary.""", ) return parser @@ -185,8 +81,9 @@ def infer_dataset( params: AttributeDict, model: nn.Module, tokenizer: Tokenizer, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: +) -> None: """Decode dataset. + The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. Args: dl: @@ -195,20 +92,8 @@ def infer_dataset( It is returned by :func:`get_params`. model: The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding-method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. + tokenizer: + Used to convert text to phonemes. """ # Background worker save audios to disk. def _save_worker( @@ -233,7 +118,7 @@ def _save_worker( device = next(model.parameters()).device num_cuts = 0 - log_interval = 10 + log_interval = 5 try: num_batches = len(dl) @@ -242,7 +127,6 @@ def _save_worker( futures = [] with ThreadPoolExecutor(max_workers=1) as executor: - # We only want one background worker so that serialization is deterministic. for batch_idx, batch in enumerate(dl): batch_size = len(batch["text"]) @@ -253,7 +137,7 @@ def _save_worker( tokens_lens = row_splits[1:] - row_splits[:-1] tokens = tokens.to(device) tokens_lens = tokens_lens.to(device) - # a tensor of shape (B, T) + # tensor of shape (B, T) tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id) audio = batch["audio"] @@ -265,9 +149,6 @@ def _save_worker( # convert to samples audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist() - # import pdb - # pdb.set_trace() - futures.append( executor.submit( _save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred @@ -295,10 +176,7 @@ def main(): params = get_params() params.update(vars(args)) - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + params.suffix = f"epoch-{params.epoch}" params.res_dir = params.exp_dir / "infer" / params.suffix params.save_wav_dir = params.res_dir / "wav" @@ -322,40 +200,16 @@ def main(): logging.info("About to create model") model = get_model(params) - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) model.to(device) model.eval() - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") + num_param_g = sum([p.numel() for p in model.generator.parameters()]) + logging.info(f"Number of parameters in generator: {num_param_g}") + num_param_d = sum([p.numel() for p in model.discriminator.parameters()]) + logging.info(f"Number of parameters in discriminator: {num_param_d}") + logging.info(f"Total number of parameters: {num_param_g + num_param_d}") # we need cut ids to display recognition results. args.return_cuts = True @@ -371,17 +225,8 @@ def main(): tokenizer=tokenizer, ) - # save_results( - # params=params, - # test_set_name=test_set, - # results_dict=results_dict, - # ) - logging.info("Done!") -# torch.set_num_threads(1) -# torch.set_num_interop_threads(1) - if __name__ == "__main__": main() diff --git a/egs/ljspeech/tts/vits/loss.py b/egs/ljspeech/tts/vits/loss.py index 0d27af6435..21aaad6e75 100644 --- a/egs/ljspeech/tts/vits/loss.py +++ b/egs/ljspeech/tts/vits/loss.py @@ -1,4 +1,4 @@ -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py # Copyright 2021 Tomoki Hayashi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) @@ -9,7 +9,7 @@ """ -from typing import List, Optional, Tuple, Union +from typing import List, Tuple, Union import torch import torch.distributions as D @@ -266,7 +266,7 @@ def forward( return mel_loss -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/loss.py +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/loss.py """VITS-related loss modules. diff --git a/egs/ljspeech/tts/vits/models.py b/egs/ljspeech/tts/vits/models.py deleted file mode 100644 index f5acdeb2be..0000000000 --- a/egs/ljspeech/tts/vits/models.py +++ /dev/null @@ -1,534 +0,0 @@ -import copy -import math -import torch -from torch import nn -from torch.nn import functional as F - -import commons -import modules -import attentions -import monotonic_align - -from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d -from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm -from commons import init_weights, get_padding - - -class StochasticDurationPredictor(nn.Module): - def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): - super().__init__() - filter_channels = in_channels # it needs to be removed from future version. - self.in_channels = in_channels - self.filter_channels = filter_channels - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.n_flows = n_flows - self.gin_channels = gin_channels - - self.log_flow = modules.Log() - self.flows = nn.ModuleList() - self.flows.append(modules.ElementwiseAffine(2)) - for i in range(n_flows): - self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) - self.flows.append(modules.Flip()) - - self.post_pre = nn.Conv1d(1, filter_channels, 1) - self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) - self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) - self.post_flows = nn.ModuleList() - self.post_flows.append(modules.ElementwiseAffine(2)) - for i in range(4): - self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) - self.post_flows.append(modules.Flip()) - - self.pre = nn.Conv1d(in_channels, filter_channels, 1) - self.proj = nn.Conv1d(filter_channels, filter_channels, 1) - self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) - if gin_channels != 0: - self.cond = nn.Conv1d(gin_channels, filter_channels, 1) - - def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): - x = torch.detach(x) - x = self.pre(x) - if g is not None: - g = torch.detach(g) - x = x + self.cond(g) - x = self.convs(x, x_mask) - x = self.proj(x) * x_mask - - if not reverse: - flows = self.flows - assert w is not None - - logdet_tot_q = 0 - h_w = self.post_pre(w) - h_w = self.post_convs(h_w, x_mask) - h_w = self.post_proj(h_w) * x_mask - e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask - z_q = e_q - for flow in self.post_flows: - z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) - logdet_tot_q += logdet_q - z_u, z1 = torch.split(z_q, [1, 1], 1) - u = torch.sigmoid(z_u) * x_mask - z0 = (w - u) * x_mask - logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2]) - logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q - - logdet_tot = 0 - z0, logdet = self.log_flow(z0, x_mask) - logdet_tot += logdet - z = torch.cat([z0, z1], 1) - for flow in flows: - z, logdet = flow(z, x_mask, g=x, reverse=reverse) - logdet_tot = logdet_tot + logdet - nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot - return nll + logq # [b] - else: - flows = list(reversed(self.flows)) - flows = flows[:-2] + [flows[-1]] # remove a useless vflow - z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale - for flow in flows: - z = flow(z, x_mask, g=x, reverse=reverse) - z0, z1 = torch.split(z, [1, 1], 1) - logw = z0 - return logw - - -class DurationPredictor(nn.Module): - def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0): - super().__init__() - - self.in_channels = in_channels - self.filter_channels = filter_channels - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.gin_channels = gin_channels - - self.drop = nn.Dropout(p_dropout) - self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2) - self.norm_1 = modules.LayerNorm(filter_channels) - self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2) - self.norm_2 = modules.LayerNorm(filter_channels) - self.proj = nn.Conv1d(filter_channels, 1, 1) - - if gin_channels != 0: - self.cond = nn.Conv1d(gin_channels, in_channels, 1) - - def forward(self, x, x_mask, g=None): - x = torch.detach(x) - if g is not None: - g = torch.detach(g) - x = x + self.cond(g) - x = self.conv_1(x * x_mask) - x = torch.relu(x) - x = self.norm_1(x) - x = self.drop(x) - x = self.conv_2(x * x_mask) - x = torch.relu(x) - x = self.norm_2(x) - x = self.drop(x) - x = self.proj(x * x_mask) - return x * x_mask - - -class TextEncoder(nn.Module): - def __init__(self, - n_vocab, - out_channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout): - super().__init__() - self.n_vocab = n_vocab - self.out_channels = out_channels - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = p_dropout - - self.emb = nn.Embedding(n_vocab, hidden_channels) - nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) - - self.encoder = attentions.Encoder( - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout) - self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1) - - def forward(self, x, x_lengths): - x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] - x = torch.transpose(x, 1, -1) # [b, h, t] - x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) - - x = self.encoder(x * x_mask, x_mask) - stats = self.proj(x) * x_mask - - m, logs = torch.split(stats, self.out_channels, dim=1) - return x, m, logs, x_mask - - -class ResidualCouplingBlock(nn.Module): - def __init__(self, - channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - n_flows=4, - gin_channels=0): - super().__init__() - self.channels = channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.n_flows = n_flows - self.gin_channels = gin_channels - - self.flows = nn.ModuleList() - for i in range(n_flows): - self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True)) - self.flows.append(modules.Flip()) - - def forward(self, x, x_mask, g=None, reverse=False): - if not reverse: - for flow in self.flows: - x, _ = flow(x, x_mask, g=g, reverse=reverse) - else: - for flow in reversed(self.flows): - x = flow(x, x_mask, g=g, reverse=reverse) - return x - - -class PosteriorEncoder(nn.Module): - def __init__(self, - in_channels, - out_channels, - hidden_channels, - kernel_size, - dilation_rate, - n_layers, - gin_channels=0): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.hidden_channels = hidden_channels - self.kernel_size = kernel_size - self.dilation_rate = dilation_rate - self.n_layers = n_layers - self.gin_channels = gin_channels - - self.pre = nn.Conv1d(in_channels, hidden_channels, 1) - self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) - self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - - def forward(self, x, x_lengths, g=None): - x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) - x = self.pre(x) * x_mask - x = self.enc(x, x_mask, g=g) - stats = self.proj(x) * x_mask - m, logs = torch.split(stats, self.out_channels, dim=1) - z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask - return z, m, logs, x_mask - - -class Generator(torch.nn.Module): - def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): - super(Generator, self).__init__() - self.num_kernels = len(resblock_kernel_sizes) - self.num_upsamples = len(upsample_rates) - self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) - resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2 - - self.ups = nn.ModuleList() - for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): - self.ups.append(weight_norm( - ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), - k, u, padding=(k-u)//2))) - - self.resblocks = nn.ModuleList() - for i in range(len(self.ups)): - ch = upsample_initial_channel//(2**(i+1)) - for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): - self.resblocks.append(resblock(ch, k, d)) - - self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) - self.ups.apply(init_weights) - - if gin_channels != 0: - self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) - - def forward(self, x, g=None): - x = self.conv_pre(x) - if g is not None: - x = x + self.cond(g) - - for i in range(self.num_upsamples): - x = F.leaky_relu(x, modules.LRELU_SLOPE) - x = self.ups[i](x) - xs = None - for j in range(self.num_kernels): - if xs is None: - xs = self.resblocks[i*self.num_kernels+j](x) - else: - xs += self.resblocks[i*self.num_kernels+j](x) - x = xs / self.num_kernels - x = F.leaky_relu(x) - x = self.conv_post(x) - x = torch.tanh(x) - - return x - - def remove_weight_norm(self): - print('Removing weight norm...') - for l in self.ups: - remove_weight_norm(l) - for l in self.resblocks: - l.remove_weight_norm() - - -class DiscriminatorP(torch.nn.Module): - def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): - super(DiscriminatorP, self).__init__() - self.period = period - self.use_spectral_norm = use_spectral_norm - norm_f = weight_norm if use_spectral_norm == False else spectral_norm - self.convs = nn.ModuleList([ - norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), - ]) - self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) - - def forward(self, x): - fmap = [] - - # 1d to 2d - b, c, t = x.shape - if t % self.period != 0: # pad first - n_pad = self.period - (t % self.period) - x = F.pad(x, (0, n_pad), "reflect") - t = t + n_pad - x = x.view(b, c, t // self.period, self.period) - - for l in self.convs: - x = l(x) - x = F.leaky_relu(x, modules.LRELU_SLOPE) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - -class DiscriminatorS(torch.nn.Module): - def __init__(self, use_spectral_norm=False): - super(DiscriminatorS, self).__init__() - norm_f = weight_norm if use_spectral_norm == False else spectral_norm - self.convs = nn.ModuleList([ - norm_f(Conv1d(1, 16, 15, 1, padding=7)), - norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), - norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), - norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), - norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), - norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), - ]) - self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) - - def forward(self, x): - fmap = [] - - for l in self.convs: - x = l(x) - x = F.leaky_relu(x, modules.LRELU_SLOPE) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - -class MultiPeriodDiscriminator(torch.nn.Module): - def __init__(self, use_spectral_norm=False): - super(MultiPeriodDiscriminator, self).__init__() - periods = [2,3,5,7,11] - - discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] - discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] - self.discriminators = nn.ModuleList(discs) - - def forward(self, y, y_hat): - y_d_rs = [] - y_d_gs = [] - fmap_rs = [] - fmap_gs = [] - for i, d in enumerate(self.discriminators): - y_d_r, fmap_r = d(y) - y_d_g, fmap_g = d(y_hat) - y_d_rs.append(y_d_r) - y_d_gs.append(y_d_g) - fmap_rs.append(fmap_r) - fmap_gs.append(fmap_g) - - return y_d_rs, y_d_gs, fmap_rs, fmap_gs - - - -class SynthesizerTrn(nn.Module): - """ - Synthesizer for Training - """ - - def __init__(self, - n_vocab, - spec_channels, - segment_size, - inter_channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout, - resblock, - resblock_kernel_sizes, - resblock_dilation_sizes, - upsample_rates, - upsample_initial_channel, - upsample_kernel_sizes, - n_speakers=0, - gin_channels=0, - use_sdp=True, - **kwargs): - - super().__init__() - self.n_vocab = n_vocab - self.spec_channels = spec_channels - self.inter_channels = inter_channels - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = p_dropout - self.resblock = resblock - self.resblock_kernel_sizes = resblock_kernel_sizes - self.resblock_dilation_sizes = resblock_dilation_sizes - self.upsample_rates = upsample_rates - self.upsample_initial_channel = upsample_initial_channel - self.upsample_kernel_sizes = upsample_kernel_sizes - self.segment_size = segment_size - self.n_speakers = n_speakers - self.gin_channels = gin_channels - - self.use_sdp = use_sdp - - self.enc_p = TextEncoder(n_vocab, - inter_channels, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size, - p_dropout) - self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) - self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) - self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) - - if use_sdp: - self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels) - else: - self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels) - - if n_speakers > 1: - self.emb_g = nn.Embedding(n_speakers, gin_channels) - - def forward(self, x, x_lengths, y, y_lengths, sid=None): - - x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) - if self.n_speakers > 0: - g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] - else: - g = None - - z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) - z_p = self.flow(z, y_mask, g=g) - - with torch.no_grad(): - # negative cross-entropy - s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t] - neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t_s] - neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).transpose(1, 2), s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] - neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] - neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s] - neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 - - attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) - attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach() - - w = attn.sum(2) - if self.use_sdp: - l_length = self.dp(x, x_mask, w, g=g) - l_length = l_length / torch.sum(x_mask) - else: - logw_ = torch.log(w + 1e-6) * x_mask - logw = self.dp(x, x_mask, g=g) - l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging - - # expand prior - m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) - logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) - - z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size) - o = self.dec(z_slice, g=g) - return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) - - def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None): - x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) - if self.n_speakers > 0: - g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] - else: - g = None - - if self.use_sdp: - logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) - else: - logw = self.dp(x, x_mask, g=g) - w = torch.exp(logw) * x_mask * length_scale - w_ceil = torch.ceil(w) - y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() - y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype) - attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) - attn = commons.generate_path(w_ceil, attn_mask) - - m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] - logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] - - z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale - z = self.flow(z_p, y_mask, g=g, reverse=True) - o = self.dec((z * y_mask)[:,:,:max_len], g=g) - return o, attn, y_mask, (z, z_p, m_p, logs_p) - - def voice_conversion(self, y, y_lengths, sid_src, sid_tgt): - assert self.n_speakers > 0, "n_speakers have to be larger than 0." - g_src = self.emb_g(sid_src).unsqueeze(-1) - g_tgt = self.emb_g(sid_tgt).unsqueeze(-1) - z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src) - z_p = self.flow(z, y_mask, g=g_src) - z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) - o_hat = self.dec(z_hat * y_mask, g=g_tgt) - return o_hat, y_mask, (z, z_p, z_hat) - diff --git a/egs/ljspeech/tts/vits/posterior_encoder.py b/egs/ljspeech/tts/vits/posterior_encoder.py index c78fd647fe..6b8a5be52f 100644 --- a/egs/ljspeech/tts/vits/posterior_encoder.py +++ b/egs/ljspeech/tts/vits/posterior_encoder.py @@ -1,4 +1,4 @@ -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/posterior_encoder.py +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/posterior_encoder.py # Copyright 2021 Tomoki Hayashi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) diff --git a/egs/ljspeech/tts/vits/residual_coupling.py b/egs/ljspeech/tts/vits/residual_coupling.py index 48e7483164..2d6807cb7c 100644 --- a/egs/ljspeech/tts/vits/residual_coupling.py +++ b/egs/ljspeech/tts/vits/residual_coupling.py @@ -1,4 +1,4 @@ -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/residual_coupling.py +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/residual_coupling.py # Copyright 2021 Tomoki Hayashi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) diff --git a/egs/ljspeech/tts/vits/symbols.py b/egs/ljspeech/tts/vits/symbols.py deleted file mode 100644 index 70c2868f4f..0000000000 --- a/egs/ljspeech/tts/vits/symbols.py +++ /dev/null @@ -1,17 +0,0 @@ -# https://github.com/jaywalnut310/vits/blob/main/text/symbols.py -""" from https://github.com/keithito/tacotron """ - -''' -Defines the set of symbols used in text input to the model. -''' -_pad = '_' -_punctuation = ';:,.!?¡¿—…"«»“” ' -_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' -_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" - - -# Export all symbols: -symbol_table = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) - -# Special symbol ids -SPACE_ID = symbol_table.index(" ") diff --git a/egs/ljspeech/tts/vits/text_encoder.py b/egs/ljspeech/tts/vits/text_encoder.py index 9ba8e1768c..419fd6162e 100644 --- a/egs/ljspeech/tts/vits/text_encoder.py +++ b/egs/ljspeech/tts/vits/text_encoder.py @@ -20,6 +20,7 @@ This code is based on - https://github.com/jaywalnut310/vits - https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/text_encoder.py + - https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/transducer_stateless/conformer.py """ import copy @@ -67,6 +68,7 @@ def __init__( self.emb = torch.nn.Embedding(vocabs, d_model) torch.nn.init.normal_(self.emb.weight, 0.0, d_model**-0.5) + # We use conformer as text encoder self.encoder = Transformer( d_model=d_model, num_heads=num_heads, diff --git a/egs/ljspeech/tts/vits/tokenizer.py b/egs/ljspeech/tts/vits/tokenizer.py index 5a513a0d98..8a61511ef5 100644 --- a/egs/ljspeech/tts/vits/tokenizer.py +++ b/egs/ljspeech/tts/vits/tokenizer.py @@ -18,7 +18,6 @@ import g2p_en import tacotron_cleaner.cleaners - from utils import intersperse diff --git a/egs/ljspeech/tts/vits/train.py b/egs/ljspeech/tts/vits/train.py index 01cd6137ef..c8df3c5d0a 100755 --- a/egs/ljspeech/tts/vits/train.py +++ b/egs/ljspeech/tts/vits/train.py @@ -1,9 +1,5 @@ #!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo, -# Zengwei Yao, -# Daniel Povey) +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -22,9 +18,10 @@ import argparse import logging +import numpy as np from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union import k2 import torch @@ -36,26 +33,17 @@ from torch.cuda.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from tts_datamodule import LJSpeechTtsDataModule from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import load_checkpoint from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, setup_logger, str2bool from tokenizer import Tokenizer -from utils import ( - MetricsTracker, - plot_feature, - save_checkpoint, - save_checkpoint_with_global_batch_idx, -) +from tts_datamodule import LJSpeechTtsDataModule +from utils import MetricsTracker, plot_feature, save_checkpoint from vits import VITS LRSchedulerType = torch.optim.lr_scheduler._LRScheduler @@ -90,7 +78,7 @@ def get_parser(): parser.add_argument( "--num-epochs", type=int, - default=30, + default=1000, help="Number of epochs to train.", ) @@ -104,15 +92,6 @@ def get_parser(): """, ) - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - parser.add_argument( "--exp-dir", type=str, @@ -127,7 +106,7 @@ def get_parser(): "--tokens", type=str, default="data/tokens.txt", - help="""Path to tokens.txt.""", + help="""Path to vocabulary.""", ) parser.add_argument( @@ -158,24 +137,11 @@ def get_parser(): parser.add_argument( "--save-every-n", type=int, - default=4000, - help="""Save checkpoint after processing this number of batches" + default=20, + help="""Save checkpoint after processing this number of epochs" periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 1. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. + params.cur_epoch % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/epoch-{params.cur_epoch}.pt' """, ) @@ -218,8 +184,6 @@ def get_params() -> AttributeDict: - log_interval: Print training loss if batch_idx % log_interval` is 0 - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - valid_interval: Run validation if batch_idx % valid_interval is 0 - feature_dim: The model input dim. It has to match the one used @@ -242,18 +206,14 @@ def get_params() -> AttributeDict: "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": -1, # 0 - "log_interval": 10, - "draw_interval": 500, - # "reset_interval": 200, + "log_interval": 50, "valid_interval": 200, "env_info": get_env_info(), "sampling_rate": 22050, "frame_shift": 256, "frame_length": 1024, "feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length - "mel_loss_params": { - "n_mels": 80, - }, + "n_mels": 80, "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss "lambda_mel": 45.0, # loss scaling coefficient for Mel loss "lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss @@ -270,9 +230,7 @@ def load_checkpoint_if_available( ) -> Optional[Dict[str, Any]]: """Load checkpoint from file. - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from + If params.start_epoch is larger than 1, it will load the checkpoint from `params.start_epoch - 1`. Apart from loading state dict for `model` and `optimizer` it also updates @@ -287,9 +245,7 @@ def load_checkpoint_if_available( Returns: Return a dict containing previously saved training info. """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: + if params.start_epoch > 1: filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" else: return None @@ -308,19 +264,15 @@ def load_checkpoint_if_available( for k in keys: params[k] = saved_params[k] - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - return saved_params def get_model(params: AttributeDict) -> nn.Module: - mel_loss_params = params.mel_loss_params - mel_loss_params.update( - frame_length=params.frame_length, - frame_shift=params.frame_shift, - ) + mel_loss_params = { + "n_mels": params.n_mels, + "frame_length": params.frame_length, + "frame_shift": params.frame_shift, + } model = VITS( vocab_size=params.vocab_size, feature_dim=params.feature_dim, @@ -381,18 +333,22 @@ def train_one_epoch( It is returned by :func:`get_params`. model: The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. + tokenizer: + Used to convert text to phonemes. + optimizer_g: + The optimizer for generator. + optimizer_d: + The optimizer for discriminator. + scheduler_g: + The learning rate scheduler for generator, we call step() every epoch. + scheduler_d: + The learning rate scheduler for discriminator, we call step() every epoch. train_dl: Dataloader for the training dataset. valid_dl: Dataloader for the validation dataset. scaler: The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. tb_writer: Writer to write log messages to tensorboard. world_size: @@ -404,7 +360,7 @@ def train_one_epoch( model.train() device = model.device if isinstance(model, DDP) else next(model.parameters()).device - # used to summary the stats over iterations + # used to summary the stats over iterations in one epoch tot_loss = MetricsTracker() saved_bad_model = False @@ -433,7 +389,6 @@ def save_bad_model(suffix: str = ""): loss_info = MetricsTracker() loss_info['samples'] = batch_size - return_sample = params.batch_idx_train % params.log_interval == 0 try: with autocast(enabled=params.use_fp16): # forward discriminator @@ -463,13 +418,11 @@ def save_bad_model(suffix: str = ""): speech=audio, speech_lengths=audio_lens, forward_generator=True, - return_sample=return_sample, + return_sample=params.batch_idx_train % params.log_interval == 0, ) for k, v in stats_g.items(): - if "return_sample" not in k: + if "returned_sample" not in k: loss_info[k] = v * batch_size - if return_sample: - speech_hat_, speech_, mel_hat_, mel_ = stats_g["return_sample"] # update generator optimizer_g.zero_grad() scaler.scale(loss_g).backward() @@ -477,7 +430,6 @@ def save_bad_model(suffix: str = ""): scaler.update() # summary stats - # tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info tot_loss = tot_loss + loss_info except: # noqa save_bad_model() @@ -486,37 +438,12 @@ def save_bad_model(suffix: str = ""): if params.print_diagnostics and batch_idx == 5: return - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - params=params, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - scheduler_g=scheduler_g, - scheduler_d=scheduler_d, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - # if batch_idx % 100 == 0 and params.use_fp16: if params.batch_idx_train % 100 == 0 and params.use_fp16: # If the grad scale was less than 1, try increasing it. The _growth_interval # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - # if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: @@ -530,7 +457,6 @@ def save_bad_model(suffix: str = ""): f"grad_scale is too small, exiting: {cur_grad_scale}" ) - # if batch_idx % params.log_interval == 0: if params.batch_idx_train % params.log_interval == 0: cur_lr_g = max(scheduler_g.get_last_lr()) cur_lr_d = max(scheduler_d.get_last_lr()) @@ -561,7 +487,8 @@ def save_bad_model(suffix: str = ""): tb_writer.add_scalar( "train/grad_scale", cur_grad_scale, params.batch_idx_train ) - if return_sample: + if "returned_sample" in stats_g: + speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] tb_writer.add_audio( "train/speech_hat_", speech_hat_, params.batch_idx_train, params.sampling_rate ) @@ -575,7 +502,6 @@ def save_bad_model(suffix: str = ""): "train/mel_", plot_feature(mel_), params.batch_idx_train, dataformats='HWC' ) - # if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics: logging.info("Computing validation loss") valid_info, (speech_hat, speech) = compute_validation_loss( @@ -615,14 +541,14 @@ def compute_validation_loss( valid_dl: torch.utils.data.DataLoader, world_size: int = 1, rank: int = 0, -) -> MetricsTracker: +) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]: """Run the validation process.""" model.eval() device = model.device if isinstance(model, DDP) else next(model.parameters()).device # used to summary the stats over iterations tot_loss = MetricsTracker() - return_sample = None + returned_sample = None with torch.no_grad(): for batch_idx, batch in enumerate(valid_dl): @@ -667,12 +593,14 @@ def compute_validation_loss( # infer for first batch: if batch_idx == 0 and rank == 0: inner_model = model.module if isinstance(model, DDP) else model - audio_pred, _, duration = inner_model.inference(text=tokens[0, :tokens_lens[0].item()]) + audio_pred, _, duration = inner_model.inference( + text=tokens[0, :tokens_lens[0].item()] + ) audio_pred = audio_pred.data.cpu().numpy() audio_len_pred = (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() assert audio_len_pred == len(audio_pred), (audio_len_pred, len(audio_pred)) audio_gt = audio[0, :audio_lens[0].item()].data.cpu().numpy() - return_sample = (audio_pred, audio_gt) + returned_sample = (audio_pred, audio_gt) if world_size > 1: tot_loss.reduce(device) @@ -682,7 +610,7 @@ def compute_validation_loss( params.best_valid_epoch = params.cur_epoch params.best_valid_loss = loss_value - return tot_loss, return_sample + return tot_loss, returned_sample def scan_pessimistic_batches_for_oom( @@ -805,18 +733,10 @@ def run(rank, world_size, args): model = DDP(model, device_ids=[rank], find_unused_parameters=True) optimizer_g = torch.optim.AdamW( - generator.parameters(), - lr=params.lr, - betas=(0.8, 0.99), - eps=1e-9, - # weight_decay=0, + generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 ) optimizer_d = torch.optim.AdamW( - discriminator.parameters(), - lr=params.lr, - betas=(0.8, 0.99), - eps=1e-9, - # weight_decay=0, + discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 ) scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875) @@ -852,16 +772,8 @@ def run(rank, world_size, args): train_cuts = ljspeech.train_cuts() - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds - # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold @@ -870,13 +782,10 @@ def remove_short_and_long_utt(c: Cut): # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" # ) return False - return True train_cuts = train_cuts.filter(remove_short_and_long_utt) - train_dl = ljspeech.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) + train_dl = ljspeech.train_dataloaders(train_cuts) valid_cuts = ljspeech.valid_cuts() valid_dl = ljspeech.valid_dataloaders(valid_cuts) @@ -902,11 +811,11 @@ def remove_short_and_long_utt(c: Cut): fix_random_seed(params.seed + epoch - 1) train_dl.sampler.set_epoch(epoch - 1) + params.cur_epoch = epoch + if tb_writer is not None: tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - params.cur_epoch = epoch - train_one_epoch( params=params, model=model, @@ -927,27 +836,28 @@ def remove_short_and_long_utt(c: Cut): diagnostic.print_diagnostics() break - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint( - filename=filename, - params=params, - model=model, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - scheduler_g=scheduler_g, - scheduler_d=scheduler_d, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - if rank == 0: - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) + if epoch % params.save_every_n == 0: + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) # step per epoch scheduler_g.step() diff --git a/egs/ljspeech/tts/vits/transform.py b/egs/ljspeech/tts/vits/transform.py index 6858de2ab0..c20d13130a 100644 --- a/egs/ljspeech/tts/vits/transform.py +++ b/egs/ljspeech/tts/vits/transform.py @@ -1,4 +1,5 @@ -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/transform.py +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/transform.py + """Flow-related transformation. This code is derived from https://github.com/bayesiains/nflows. diff --git a/egs/ljspeech/tts/vits/utils.py b/egs/ljspeech/tts/vits/utils.py index 582856eee0..2a3dae9007 100644 --- a/egs/ljspeech/tts/vits/utils.py +++ b/egs/ljspeech/tts/vits/utils.py @@ -1,32 +1,35 @@ -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py - -# Copyright 2021 Tomoki Hayashi -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -"""Function to get random segments.""" - +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Any, Dict, List, Optional, Tuple, Union import collections import logging -import re -import warnings -import numpy as np import torch import torch.nn as nn import torch.distributed as dist from lhotse.dataset.sampling.base import CutSampler from pathlib import Path -from phonemizer import phonemize -from symbols import symbol_table from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils.rnn import pad_sequence from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter -from unidecode import unidecode +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py def get_random_segments( x: torch.Tensor, x_lengths: torch.Tensor, @@ -55,6 +58,7 @@ def get_random_segments( return segments, start_idxs +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py def get_segments( x: torch.Tensor, start_idxs: torch.Tensor, @@ -78,195 +82,41 @@ def get_segments( return segments -# https://github.com/espnet/espnet/blob/master/espnet2/torch_utils/device_funcs.py -def force_gatherable(data, device): - """Change object to gatherable in torch.nn.DataParallel recursively - - The difference from to_device() is changing to torch.Tensor if float or int - value is found. - - The restriction to the returned value in DataParallel: - The object must be - - torch.cuda.Tensor - - 1 or more dimension. 0-dimension-tensor sends warning. - or a list, tuple, dict. - - """ - if isinstance(data, dict): - return {k: force_gatherable(v, device) for k, v in data.items()} - # DataParallel can't handle NamedTuple well - elif isinstance(data, tuple) and type(data) is not tuple: - return type(data)(*[force_gatherable(o, device) for o in data]) - elif isinstance(data, (list, tuple, set)): - return type(data)(force_gatherable(v, device) for v in data) - elif isinstance(data, np.ndarray): - return force_gatherable(torch.from_numpy(data), device) - elif isinstance(data, torch.Tensor): - if data.dim() == 0: - # To 1-dim array - data = data[None] - return data.to(device) - elif isinstance(data, float): - return torch.tensor([data], dtype=torch.float, device=device) - elif isinstance(data, int): - return torch.tensor([data], dtype=torch.long, device=device) - elif data is None: - return None - else: - warnings.warn(f"{type(data)} may not be gatherable by DataParallel") - return data - - -# The following codes are based on https://github.com/jaywalnut310/vits - -# Regular expression matching whitespace: -_whitespace_re = re.compile(r'\s+') - -# List of (regular expression, replacement) pairs for abbreviations: -_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ - ('mrs', 'misess'), - ('mr', 'mister'), - ('dr', 'doctor'), - ('st', 'saint'), - ('co', 'company'), - ('jr', 'junior'), - ('maj', 'major'), - ('gen', 'general'), - ('drs', 'doctors'), - ('rev', 'reverend'), - ('lt', 'lieutenant'), - ('hon', 'honorable'), - ('sgt', 'sergeant'), - ('capt', 'captain'), - ('esq', 'esquire'), - ('ltd', 'limited'), - ('col', 'colonel'), - ('ft', 'fort'), -]] - - -def expand_abbreviations(text): - for regex, replacement in _abbreviations: - text = re.sub(regex, replacement, text) - return text - - -def lowercase(text): - return text.lower() - - -def collapse_whitespace(text): - return re.sub(_whitespace_re, ' ', text) - - -def convert_to_ascii(text): - return unidecode(text) - - -def text_clean(text): - '''Pipeline for English text, including abbreviation expansion. + punctuation + stress. - - Returns: - A string of phonemes. - ''' - text = convert_to_ascii(text) - text = lowercase(text) - text = expand_abbreviations(text) - phonemes = phonemize( - text, - language='en-us', - backend='espeak', - strip=True, - preserve_punctuation=True, - with_stress=True, - ) - phonemes = collapse_whitespace(phonemes) - return phonemes - - -# Mappings from symbol to numeric ID and vice versa: -symbol_to_id = {s: i for i, s in enumerate(symbol_table)} -id_to_symbol = {i: s for i, s in enumerate(symbol_table)} - - -# def text_to_sequence(text: str) -> List[int]: -# '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. -# ''' -# cleaned_text = text_clean(text) -# sequence = [symbol_to_id[symbol] for symbol in cleaned_text] -# return sequence -# -# -# def sequence_to_text(sequence: List[int]) -> str: -# '''Converts a sequence of IDs back to a string''' -# result = ''.join(id_to_symbol[symbol_id] for symbol_id in sequence) -# return result - - +# from https://github.com/jaywalnut310/vit://github.com/jaywalnut310/vits/blob/main/commons.py def intersperse(sequence, item=0): result = [item] * (len(sequence) * 2 + 1) result[1::2] = sequence return result -def prepare_token_batch( - texts: List[str], - phonemes: Optional[List[str]] = None, - intersperse_blank: bool = True, - blank_id: int = 0, - pad_id: int = 0, -) -> torch.Tensor: - """Convert a list of text strings into a batch of symbol tokens with padding. - Args: - texts: list of text strings - intersperse_blank: whether to intersperse blank tokens in the converted token sequence. - blank_id: index of blank token - pad_id: padding index - """ - if phonemes is None: - # normalize text - normalized_texts = [] - for text in texts: - text = convert_to_ascii(text) - text = lowercase(text) - text = expand_abbreviations(text) - normalized_texts.append(text) - - # convert to phonemes - phonemes = phonemize( - normalized_texts, - language='en-us', - backend='espeak', - strip=True, - preserve_punctuation=True, - with_stress=True, - ) - phonemes = [collapse_whitespace(sequence) for sequence in phonemes] - - # convert to symbol ids - lengths = [] - sequences = [] - skip = False - for idx, sequence in enumerate(phonemes): - try: - sequence = [symbol_to_id[symbol] for symbol in sequence] - except Exception: - # print(texts[idx]) - # print(normalized_texts[idx]) - print(phonemes[idx]) - skip = True - if intersperse_blank: - sequence = intersperse(sequence, blank_id) - try: - sequences.append(torch.tensor(sequence, dtype=torch.int64)) - except Exception: - print(sequence) - skip = True - lengths.append(len(sequence)) - - sequences = pad_sequence(sequences, batch_first=True, padding_value=pad_id) - lengths = torch.tensor(lengths, dtype=torch.int64) - return sequences, lengths, skip +# from https://github.com/jaywalnut310/vits/blob/main/utils.py +MATPLOTLIB_FLAG = False + + +def plot_feature(spectrogram): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger('matplotlib') + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", + interpolation='none') + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data class MetricsTracker(collections.defaultdict): @@ -413,106 +263,3 @@ def save_checkpoint( checkpoint[k] = v torch.save(checkpoint, filename) - - -def save_checkpoint_with_global_batch_idx( - out_dir: Path, - global_batch_idx: int, - model: Union[nn.Module, DDP], - params: Optional[Dict[str, Any]] = None, - optimizer_g: Optional[Optimizer] = None, - optimizer_d: Optional[Optimizer] = None, - scheduler_g: Optional[LRSchedulerType] = None, - scheduler_d: Optional[LRSchedulerType] = None, - scaler: Optional[GradScaler] = None, - sampler: Optional[CutSampler] = None, - rank: int = 0, -): - """Save training info after processing given number of batches. - - Args: - out_dir: - The directory to save the checkpoint. - global_batch_idx: - The number of batches processed so far from the very start of the - training. The saved checkpoint will have the following filename: - f'out_dir / checkpoint-{global_batch_idx}.pt' - model: - The neural network model whose `state_dict` will be saved in the - checkpoint. - params: - A dict of training configurations to be saved. - optimizer_g: - The optimizer for generator used in the training. - Its `state_dict` will be saved. - optimizer_d: - The optimizer for discriminator used in the training. - Its `state_dict` will be saved. - scheduler_g: - The learning rate scheduler for generator used in the training. - Its `state_dict` will be saved. - scheduler_d: - The learning rate scheduler for discriminator used in the training. - Its `state_dict` will be saved. - scaler: - The scaler used for mix precision training. Its `state_dict` will - be saved. - sampler: - The sampler used in the training dataset. - rank: - The rank ID used in DDP training of the current node. Set it to 0 - if DDP is not used. - """ - out_dir = Path(out_dir) - out_dir.mkdir(parents=True, exist_ok=True) - filename = out_dir / f"checkpoint-{global_batch_idx}.pt" - save_checkpoint( - filename=filename, - model=model, - params=params, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - scheduler_g=scheduler_g, - scheduler_d=scheduler_d, - scaler=scaler, - sampler=sampler, - rank=rank, - ) - - -# def plot_feature(feature): -# """ -# Display the feature matrix as an image. Requires matplotlib to be installed. -# """ -# import matplotlib.pyplot as plt -# -# feature = np.flip(feature.transpose(1, 0), 0) -# return plt.matshow(feature) - -MATPLOTLIB_FLAG = False - - -def plot_feature(spectrogram): - global MATPLOTLIB_FLAG - if not MATPLOTLIB_FLAG: - import matplotlib - matplotlib.use("Agg") - MATPLOTLIB_FLAG = True - mpl_logger = logging.getLogger('matplotlib') - mpl_logger.setLevel(logging.WARNING) - import matplotlib.pylab as plt - import numpy as np - - fig, ax = plt.subplots(figsize=(10, 2)) - im = ax.imshow(spectrogram, aspect="auto", origin="lower", - interpolation='none') - plt.colorbar(im, ax=ax) - plt.xlabel("Frames") - plt.ylabel("Channels") - plt.tight_layout() - - fig.canvas.draw() - data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') - data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) - plt.close() - return data diff --git a/egs/ljspeech/tts/vits/vits.py b/egs/ljspeech/tts/vits/vits.py index 27d9b4c7a1..aa26a012d8 100644 --- a/egs/ljspeech/tts/vits/vits.py +++ b/egs/ljspeech/tts/vits/vits.py @@ -1,11 +1,11 @@ -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/vits.py +# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/vits.py # Copyright 2021 Tomoki Hayashi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) """VITS module for GAN-TTS task.""" -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple import torch import torch.nn as nn @@ -247,7 +247,7 @@ def forward( spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, forward_generator: bool = True, - ) -> Dict[str, Any]: + ) -> Tuple[torch.Tensor, Dict[str, Any]]: """Perform generator forward. Args: @@ -263,12 +263,8 @@ def forward( forward_generator (bool): Whether to forward generator. Returns: - Dict[str, Any]: - - loss (Tensor): Loss scalar tensor. - - stats (Dict[str, float]): Statistics to be monitored. - - weight (Tensor): Weight tensor to summarize losses. - - optim_idx (int): Optimizer index (0 for G and 1 for D). - + - loss (Tensor): Loss scalar tensor. + - stats (Dict[str, float]): Statistics to be monitored. """ if forward_generator: return self._forward_generator( @@ -308,7 +304,7 @@ def _forward_generator( sids: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, - ) -> Dict[str, Any]: + ) -> Tuple[torch.Tensor, Dict[str, Any]]: """Perform generator forward. Args: @@ -323,12 +319,8 @@ def _forward_generator( lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). Returns: - Dict[str, Any]: - * loss (Tensor): Loss scalar tensor. - * stats (Dict[str, float]): Statistics to be monitored. - * weight (Tensor): Weight tensor to summarize losses. - * optim_idx (int): Optimizer index (0 for G and 1 for D). - + * loss (Tensor): Loss scalar tensor. + * stats (Dict[str, float]): Statistics to be monitored. """ # setup feats = feats.transpose(1, 2) @@ -399,7 +391,7 @@ def _forward_generator( ) if return_sample: - stats["return_sample"] = ( + stats["returned_sample"] = ( speech_hat_[0].data.cpu().numpy(), speech_[0].data.cpu().numpy(), mel_hat_[0].data.cpu().numpy(), @@ -423,7 +415,7 @@ def _forward_discrminator( sids: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, - ) -> Dict[str, Any]: + ) -> Tuple[torch.Tensor, Dict[str, Any]]: """Perform discriminator forward. Args: @@ -438,12 +430,8 @@ def _forward_discrminator( lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). Returns: - Dict[str, Any]: - * loss (Tensor): Loss scalar tensor. - * stats (Dict[str, float]): Statistics to be monitored. - * weight (Tensor): Weight tensor to summarize losses. - * optim_idx (int): Optimizer index (0 for G and 1 for D). - + * loss (Tensor): Loss scalar tensor. + * stats (Dict[str, float]): Statistics to be monitored. """ # setup feats = feats.transpose(1, 2) @@ -511,8 +499,8 @@ def inference( alpha: float = 1.0, max_len: Optional[int] = None, use_teacher_forcing: bool = False, - ) -> Dict[str, torch.Tensor]: - """Run inference. + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Run inference for single sample. Args: text (Tensor): Input text index tensor (T_text,). @@ -528,11 +516,9 @@ def inference( use_teacher_forcing (bool): Whether to use teacher forcing. Returns: - Dict[str, Tensor]: - * wav (Tensor): Generated waveform tensor (T_wav,). - * att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text). - * duration (Tensor): Predicted duration tensor (T_text,). - + * wav (Tensor): Generated waveform tensor (T_wav,). + * att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text). + * duration (Tensor): Predicted duration tensor (T_text,). """ # setup text = text[None] @@ -593,8 +579,8 @@ def inference_batch( alpha: float = 1.0, max_len: Optional[int] = None, use_teacher_forcing: bool = False, - ) -> Dict[str, torch.Tensor]: - """Run inference. + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Run inference for one batch. Args: text (Tensor): Input text index tensor (B, T_text). @@ -605,11 +591,9 @@ def inference_batch( max_len (Optional[int]): Maximum length. Returns: - Dict[str, Tensor]: - * wav (Tensor): Generated waveform tensor (B, T_wav). - * att_w (Tensor): Monotonic attention weight tensor (B, T_feats, T_text). - * duration (Tensor): Predicted duration tensor (B, T_text). - + * wav (Tensor): Generated waveform tensor (B, T_wav). + * att_w (Tensor): Monotonic attention weight tensor (B, T_feats, T_text). + * duration (Tensor): Predicted duration tensor (B, T_text). """ # inference wav, att_w, dur = self.generator.inference( diff --git a/egs/ljspeech/tts/vits/wavenet.py b/egs/ljspeech/tts/vits/wavenet.py index cbb44a8f40..fbe1be52b0 100644 --- a/egs/ljspeech/tts/vits/wavenet.py +++ b/egs/ljspeech/tts/vits/wavenet.py @@ -1,4 +1,4 @@ -# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/wavenet/wavenet.py +# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/wavenet/wavenet.py # Copyright 2021 Tomoki Hayashi # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) From fc359be29d0d2456ea08d39198fd6cb2e1ec4d01 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Mon, 6 Nov 2023 10:31:17 +0800 Subject: [PATCH 05/16] minor fix --- egs/ljspeech/tts/vits/train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/egs/ljspeech/tts/vits/train.py b/egs/ljspeech/tts/vits/train.py index c8df3c5d0a..1a2c934fe2 100755 --- a/egs/ljspeech/tts/vits/train.py +++ b/egs/ljspeech/tts/vits/train.py @@ -141,7 +141,9 @@ def get_parser(): help="""Save checkpoint after processing this number of epochs" periodically. We save checkpoint to exp-dir/ whenever params.cur_epoch % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/epoch-{params.cur_epoch}.pt' + has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. + Since it will take around 1000 epochs, we suggest using a large + save_every_n to save disk space. """, ) @@ -836,7 +838,7 @@ def remove_short_and_long_utt(c: Cut): diagnostic.print_diagnostics() break - if epoch % params.save_every_n == 0: + if epoch % params.save_every_n == 0 or epoch == params.num_epochs: filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" save_checkpoint( filename=filename, From cd59a69957cf48acf2499ab402ee7d56b9e5ff98 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Mon, 6 Nov 2023 11:06:42 +0800 Subject: [PATCH 06/16] rename directory --- egs/ljspeech/{tts => TTS}/local/compute_spectrogram_ljspeech.py | 0 egs/ljspeech/{tts => TTS}/local/display_manifest_statistics.py | 0 egs/ljspeech/{tts => TTS}/local/prepare_token_file.py | 0 egs/ljspeech/{tts => TTS}/local/validate_manifest.py | 0 egs/ljspeech/{tts => TTS}/prepare.sh | 0 egs/ljspeech/{tts => TTS}/shared/parse_options.sh | 0 egs/ljspeech/{tts => TTS}/vits/duration_predictor.py | 0 egs/ljspeech/{tts => TTS}/vits/flow.py | 0 egs/ljspeech/{tts => TTS}/vits/generator.py | 0 egs/ljspeech/{tts => TTS}/vits/hifigan.py | 0 egs/ljspeech/{tts => TTS}/vits/infer.py | 2 +- egs/ljspeech/{tts => TTS}/vits/loss.py | 0 egs/ljspeech/{tts => TTS}/vits/monotonic_align/__init__.py | 0 egs/ljspeech/{tts => TTS}/vits/monotonic_align/core.pyx | 0 egs/ljspeech/{tts => TTS}/vits/monotonic_align/setup.py | 0 egs/ljspeech/{tts => TTS}/vits/posterior_encoder.py | 0 egs/ljspeech/{tts => TTS}/vits/residual_coupling.py | 0 egs/ljspeech/{tts => TTS}/vits/text_encoder.py | 0 egs/ljspeech/{tts => TTS}/vits/tokenizer.py | 0 egs/ljspeech/{tts => TTS}/vits/train.py | 0 egs/ljspeech/{tts => TTS}/vits/transform.py | 0 egs/ljspeech/{tts => TTS}/vits/tts_datamodule.py | 0 egs/ljspeech/{tts => TTS}/vits/utils.py | 0 egs/ljspeech/{tts => TTS}/vits/vits.py | 0 egs/ljspeech/{tts => TTS}/vits/wavenet.py | 0 25 files changed, 1 insertion(+), 1 deletion(-) rename egs/ljspeech/{tts => TTS}/local/compute_spectrogram_ljspeech.py (100%) rename egs/ljspeech/{tts => TTS}/local/display_manifest_statistics.py (100%) rename egs/ljspeech/{tts => TTS}/local/prepare_token_file.py (100%) rename egs/ljspeech/{tts => TTS}/local/validate_manifest.py (100%) rename egs/ljspeech/{tts => TTS}/prepare.sh (100%) rename egs/ljspeech/{tts => TTS}/shared/parse_options.sh (100%) rename egs/ljspeech/{tts => TTS}/vits/duration_predictor.py (100%) rename egs/ljspeech/{tts => TTS}/vits/flow.py (100%) rename egs/ljspeech/{tts => TTS}/vits/generator.py (100%) rename egs/ljspeech/{tts => TTS}/vits/hifigan.py (100%) rename egs/ljspeech/{tts => TTS}/vits/infer.py (99%) rename egs/ljspeech/{tts => TTS}/vits/loss.py (100%) rename egs/ljspeech/{tts => TTS}/vits/monotonic_align/__init__.py (100%) rename egs/ljspeech/{tts => TTS}/vits/monotonic_align/core.pyx (100%) rename egs/ljspeech/{tts => TTS}/vits/monotonic_align/setup.py (100%) rename egs/ljspeech/{tts => TTS}/vits/posterior_encoder.py (100%) rename egs/ljspeech/{tts => TTS}/vits/residual_coupling.py (100%) rename egs/ljspeech/{tts => TTS}/vits/text_encoder.py (100%) rename egs/ljspeech/{tts => TTS}/vits/tokenizer.py (100%) rename egs/ljspeech/{tts => TTS}/vits/train.py (100%) rename egs/ljspeech/{tts => TTS}/vits/transform.py (100%) rename egs/ljspeech/{tts => TTS}/vits/tts_datamodule.py (100%) rename egs/ljspeech/{tts => TTS}/vits/utils.py (100%) rename egs/ljspeech/{tts => TTS}/vits/vits.py (100%) rename egs/ljspeech/{tts => TTS}/vits/wavenet.py (100%) diff --git a/egs/ljspeech/tts/local/compute_spectrogram_ljspeech.py b/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py similarity index 100% rename from egs/ljspeech/tts/local/compute_spectrogram_ljspeech.py rename to egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py diff --git a/egs/ljspeech/tts/local/display_manifest_statistics.py b/egs/ljspeech/TTS/local/display_manifest_statistics.py similarity index 100% rename from egs/ljspeech/tts/local/display_manifest_statistics.py rename to egs/ljspeech/TTS/local/display_manifest_statistics.py diff --git a/egs/ljspeech/tts/local/prepare_token_file.py b/egs/ljspeech/TTS/local/prepare_token_file.py similarity index 100% rename from egs/ljspeech/tts/local/prepare_token_file.py rename to egs/ljspeech/TTS/local/prepare_token_file.py diff --git a/egs/ljspeech/tts/local/validate_manifest.py b/egs/ljspeech/TTS/local/validate_manifest.py similarity index 100% rename from egs/ljspeech/tts/local/validate_manifest.py rename to egs/ljspeech/TTS/local/validate_manifest.py diff --git a/egs/ljspeech/tts/prepare.sh b/egs/ljspeech/TTS/prepare.sh similarity index 100% rename from egs/ljspeech/tts/prepare.sh rename to egs/ljspeech/TTS/prepare.sh diff --git a/egs/ljspeech/tts/shared/parse_options.sh b/egs/ljspeech/TTS/shared/parse_options.sh similarity index 100% rename from egs/ljspeech/tts/shared/parse_options.sh rename to egs/ljspeech/TTS/shared/parse_options.sh diff --git a/egs/ljspeech/tts/vits/duration_predictor.py b/egs/ljspeech/TTS/vits/duration_predictor.py similarity index 100% rename from egs/ljspeech/tts/vits/duration_predictor.py rename to egs/ljspeech/TTS/vits/duration_predictor.py diff --git a/egs/ljspeech/tts/vits/flow.py b/egs/ljspeech/TTS/vits/flow.py similarity index 100% rename from egs/ljspeech/tts/vits/flow.py rename to egs/ljspeech/TTS/vits/flow.py diff --git a/egs/ljspeech/tts/vits/generator.py b/egs/ljspeech/TTS/vits/generator.py similarity index 100% rename from egs/ljspeech/tts/vits/generator.py rename to egs/ljspeech/TTS/vits/generator.py diff --git a/egs/ljspeech/tts/vits/hifigan.py b/egs/ljspeech/TTS/vits/hifigan.py similarity index 100% rename from egs/ljspeech/tts/vits/hifigan.py rename to egs/ljspeech/TTS/vits/hifigan.py diff --git a/egs/ljspeech/tts/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py similarity index 99% rename from egs/ljspeech/tts/vits/infer.py rename to egs/ljspeech/TTS/vits/infer.py index f971f85ffd..4917a7ee99 100755 --- a/egs/ljspeech/tts/vits/infer.py +++ b/egs/ljspeech/TTS/vits/infer.py @@ -21,7 +21,7 @@ Usage: ./vits/infer.py \ --epoch 1000 \ - --exp-dir ./zipformer/exp \ + --exp-dir ./vits/exp \ --max-duration 500 """ diff --git a/egs/ljspeech/tts/vits/loss.py b/egs/ljspeech/TTS/vits/loss.py similarity index 100% rename from egs/ljspeech/tts/vits/loss.py rename to egs/ljspeech/TTS/vits/loss.py diff --git a/egs/ljspeech/tts/vits/monotonic_align/__init__.py b/egs/ljspeech/TTS/vits/monotonic_align/__init__.py similarity index 100% rename from egs/ljspeech/tts/vits/monotonic_align/__init__.py rename to egs/ljspeech/TTS/vits/monotonic_align/__init__.py diff --git a/egs/ljspeech/tts/vits/monotonic_align/core.pyx b/egs/ljspeech/TTS/vits/monotonic_align/core.pyx similarity index 100% rename from egs/ljspeech/tts/vits/monotonic_align/core.pyx rename to egs/ljspeech/TTS/vits/monotonic_align/core.pyx diff --git a/egs/ljspeech/tts/vits/monotonic_align/setup.py b/egs/ljspeech/TTS/vits/monotonic_align/setup.py similarity index 100% rename from egs/ljspeech/tts/vits/monotonic_align/setup.py rename to egs/ljspeech/TTS/vits/monotonic_align/setup.py diff --git a/egs/ljspeech/tts/vits/posterior_encoder.py b/egs/ljspeech/TTS/vits/posterior_encoder.py similarity index 100% rename from egs/ljspeech/tts/vits/posterior_encoder.py rename to egs/ljspeech/TTS/vits/posterior_encoder.py diff --git a/egs/ljspeech/tts/vits/residual_coupling.py b/egs/ljspeech/TTS/vits/residual_coupling.py similarity index 100% rename from egs/ljspeech/tts/vits/residual_coupling.py rename to egs/ljspeech/TTS/vits/residual_coupling.py diff --git a/egs/ljspeech/tts/vits/text_encoder.py b/egs/ljspeech/TTS/vits/text_encoder.py similarity index 100% rename from egs/ljspeech/tts/vits/text_encoder.py rename to egs/ljspeech/TTS/vits/text_encoder.py diff --git a/egs/ljspeech/tts/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py similarity index 100% rename from egs/ljspeech/tts/vits/tokenizer.py rename to egs/ljspeech/TTS/vits/tokenizer.py diff --git a/egs/ljspeech/tts/vits/train.py b/egs/ljspeech/TTS/vits/train.py similarity index 100% rename from egs/ljspeech/tts/vits/train.py rename to egs/ljspeech/TTS/vits/train.py diff --git a/egs/ljspeech/tts/vits/transform.py b/egs/ljspeech/TTS/vits/transform.py similarity index 100% rename from egs/ljspeech/tts/vits/transform.py rename to egs/ljspeech/TTS/vits/transform.py diff --git a/egs/ljspeech/tts/vits/tts_datamodule.py b/egs/ljspeech/TTS/vits/tts_datamodule.py similarity index 100% rename from egs/ljspeech/tts/vits/tts_datamodule.py rename to egs/ljspeech/TTS/vits/tts_datamodule.py diff --git a/egs/ljspeech/tts/vits/utils.py b/egs/ljspeech/TTS/vits/utils.py similarity index 100% rename from egs/ljspeech/tts/vits/utils.py rename to egs/ljspeech/TTS/vits/utils.py diff --git a/egs/ljspeech/tts/vits/vits.py b/egs/ljspeech/TTS/vits/vits.py similarity index 100% rename from egs/ljspeech/tts/vits/vits.py rename to egs/ljspeech/TTS/vits/vits.py diff --git a/egs/ljspeech/tts/vits/wavenet.py b/egs/ljspeech/TTS/vits/wavenet.py similarity index 100% rename from egs/ljspeech/tts/vits/wavenet.py rename to egs/ljspeech/TTS/vits/wavenet.py From f55e80a7c5c8215552511593d80e25084604c7d6 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Mon, 6 Nov 2023 15:05:49 +0800 Subject: [PATCH 07/16] minor fixes --- .../TTS/local/compute_spectrogram_ljspeech.py | 4 +- egs/ljspeech/TTS/local/prepare_token_file.py | 18 ++-- egs/ljspeech/TTS/local/validate_manifest.py | 2 +- egs/ljspeech/TTS/prepare.sh | 15 ++- egs/ljspeech/TTS/shared/parse_options.sh | 98 +------------------ 5 files changed, 24 insertions(+), 113 deletions(-) mode change 100755 => 120000 egs/ljspeech/TTS/shared/parse_options.sh diff --git a/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py b/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py index 3603af07df..edb22b276c 100755 --- a/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py +++ b/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py @@ -21,7 +21,7 @@ This file computes fbank features of the LJSpeech dataset. It looks for manifests in the directory data/manifests. -The generated fbank features are saved in data/spectrogram. +The generated spectrogram features are saved in data/spectrogram. """ import logging @@ -75,7 +75,7 @@ def compute_spectrogram_ljspeech(): with get_executor() as ex: # Initialize the executor only once. cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" if (output_dir / cuts_filename).is_file(): - logging.info(f"{partition} already exists - skipping.") + logging.info(f"{cuts_filename} already exists - skipping.") return logging.info(f"Processing {partition}") cut_set = CutSet.from_manifests( diff --git a/egs/ljspeech/TTS/local/prepare_token_file.py b/egs/ljspeech/TTS/local/prepare_token_file.py index 17a5588992..167b73f2e5 100755 --- a/egs/ljspeech/TTS/local/prepare_token_file.py +++ b/egs/ljspeech/TTS/local/prepare_token_file.py @@ -17,7 +17,7 @@ """ -This file reads the texts in given manifest and generate the file that maps tokens to IDs. +This file reads the texts in given manifest and generates the file that maps tokens to IDs. """ import argparse @@ -73,11 +73,11 @@ def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: def get_token2id(manifest_file: Path) -> Dict[str, int]: """Return a dict that maps token to IDs.""" - extra_tokens = { - "": 0, # blank - "": 1, # sos and eos symbols. - "": 2, # OOV - } + extra_tokens = [ + ("", None), # 0 for blank + ("", None), # 1 for sos and eos symbols. + ("", None), # 2 for OOV + ] cut_set = load_manifest(manifest_file) g2p = g2p_en.G2p() counter = Counter() @@ -96,10 +96,10 @@ def get_token2id(manifest_file: Path) -> Dict[str, int]: # Sort by the number of occurrences in descending order tokens_and_counts = sorted(counter.items(), key=lambda x: -x[1]) - for token, idx in extra_tokens.items(): - tokens_and_counts.insert(idx, (token, None)) + tokens_and_counts = extra_tokens + tokens_and_counts + + token2id: Dict[str, int] = {token: i for i, (token, _) in enumerate(tokens_and_counts)} - token2id: Dict[str, int] = {token: i for i, (token, count) in enumerate(tokens_and_counts)} return token2id diff --git a/egs/ljspeech/TTS/local/validate_manifest.py b/egs/ljspeech/TTS/local/validate_manifest.py index cd466303ed..68159ae036 100755 --- a/egs/ljspeech/TTS/local/validate_manifest.py +++ b/egs/ljspeech/TTS/local/validate_manifest.py @@ -57,7 +57,7 @@ def main(): assert manifest.is_file(), f"{manifest} does not exist" cut_set = load_manifest_lazy(manifest) - assert isinstance(cut_set, CutSet) + assert isinstance(cut_set, CutSet), type(cut_set) validate_for_tts(cut_set) diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh index 613eb37d8d..396d91b597 100755 --- a/egs/ljspeech/TTS/prepare.sh +++ b/egs/ljspeech/TTS/prepare.sh @@ -28,10 +28,13 @@ log "dl_dir: $dl_dir" if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "Stage 0: Download data" - # If you have pre-downloaded it to /path/to/LJSpeech, - # you can create a symlink + # The directory $dl_dir/LJSpeech-1.1 will contain: + # - wavs, which contains the audio files + # - metadata.csv, which provides the transcript text for each audio clip + + # If you have pre-downloaded it to /path/to/LJSpeech-1.1, you can create a symlink # - # ln -sfv /path/to/LJSpeech $dl_dir/LJSpeech + # ln -sfv /path/to/LJSpeech-1.1 $dl_dir/LJSpeech-1.1 # if [ ! -d $dl_dir/LJSpeech-1.1 ]; then lhotse download ljspeech $dl_dir @@ -58,7 +61,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi if [ ! -e data/spectrogram/.ljspeech-validated.done ]; then - log "Validating data/fbank for LJSpeech" + log "Validating data/spectrogram for LJSpeech" python3 ./local/validate_manifest.py \ data/spectrogram/ljspeech_cuts_all.jsonl.gz touch data/spectrogram/.ljspeech-validated.done @@ -90,6 +93,10 @@ fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "Stage 4: Generate token file" + # We assume you have installed g2p_en and espnet_tts_frontend. + # If not, please install them with: + # - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p + # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ if [ ! -e data/tokens.txt ]; then ./local/prepare_token_file.py \ --manifest-file data/spectrogram/ljspeech_cuts_train.jsonl.gz \ diff --git a/egs/ljspeech/TTS/shared/parse_options.sh b/egs/ljspeech/TTS/shared/parse_options.sh deleted file mode 100755 index 71fb9e5ea1..0000000000 --- a/egs/ljspeech/TTS/shared/parse_options.sh +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/env bash - -# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); -# Arnab Ghoshal, Karel Vesely - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED -# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, -# MERCHANTABLITY OR NON-INFRINGEMENT. -# See the Apache 2 License for the specific language governing permissions and -# limitations under the License. - - -# Parse command-line options. -# To be sourced by another script (as in ". parse_options.sh"). -# Option format is: --option-name arg -# and shell variable "option_name" gets set to value "arg." -# The exception is --help, which takes no arguments, but prints the -# $help_message variable (if defined). - - -### -### The --config file options have lower priority to command line -### options, so we need to import them first... -### - -# Now import all the configs specified by command-line, in left-to-right order -for ((argpos=1; argpos<$#; argpos++)); do - if [ "${!argpos}" == "--config" ]; then - argpos_plus1=$((argpos+1)) - config=${!argpos_plus1} - [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 - . $config # source the config file. - fi -done - - -### -### Now we process the command line options -### -while true; do - [ -z "${1:-}" ] && break; # break if there are no arguments - case "$1" in - # If the enclosing script is called with --help option, print the help - # message and exit. Scripts should put help messages in $help_message - --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; - else printf "$help_message\n" 1>&2 ; fi; - exit 0 ;; - --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" - exit 1 ;; - # If the first command-line argument begins with "--" (e.g. --foo-bar), - # then work out the variable name as $name, which will equal "foo_bar". - --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; - # Next we test whether the variable in question is undefned-- if so it's - # an invalid option and we die. Note: $0 evaluates to the name of the - # enclosing script. - # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar - # is undefined. We then have to wrap this test inside "eval" because - # foo_bar is itself inside a variable ($name). - eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; - - oldval="`eval echo \\$$name`"; - # Work out whether we seem to be expecting a Boolean argument. - if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then - was_bool=true; - else - was_bool=false; - fi - - # Set the variable to the right value-- the escaped quotes make it work if - # the option had spaces, like --cmd "queue.pl -sync y" - eval $name=\"$2\"; - - # Check that Boolean-valued arguments are really Boolean. - if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then - echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 - exit 1; - fi - shift 2; - ;; - *) break; - esac -done - - -# Check for an empty argument to the --cmd option, which can easily occur as a -# result of scripting errors. -[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; - - -true; # so this script returns exit code 0. diff --git a/egs/ljspeech/TTS/shared/parse_options.sh b/egs/ljspeech/TTS/shared/parse_options.sh new file mode 120000 index 0000000000..e4665e7de2 --- /dev/null +++ b/egs/ljspeech/TTS/shared/parse_options.sh @@ -0,0 +1 @@ +../../../librispeech/ASR/shared/parse_options.sh \ No newline at end of file From 8791a4efb090faca1648ae8d8b8c6ddcb6ee9a0f Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Mon, 13 Nov 2023 21:56:18 +0800 Subject: [PATCH 08/16] convert text to tokens in data preparation stage --- .../TTS/local/compute_spectrogram_ljspeech.py | 4 +- egs/ljspeech/TTS/local/prepare_token_file.py | 30 +++------ .../TTS/local/prepare_tokens_ljspeech.py | 63 +++++++++++++++++++ egs/ljspeech/TTS/prepare.sh | 16 ++++- egs/ljspeech/TTS/vits/infer.py | 6 +- egs/ljspeech/TTS/vits/tokenizer.py | 27 ++++++++ egs/ljspeech/TTS/vits/train.py | 8 +-- egs/ljspeech/TTS/vits/tts_datamodule.py | 24 +++++-- 8 files changed, 139 insertions(+), 39 deletions(-) create mode 100755 egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py diff --git a/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py b/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py index edb22b276c..eacf0df57f 100755 --- a/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py +++ b/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py @@ -58,10 +58,10 @@ def compute_spectrogram_ljspeech(): partition = "all" recordings = load_manifest( - src_dir / f"{prefix}_recordings_{partition}.jsonl.gz", RecordingSet + src_dir / f"{prefix}_recordings_{partition}.{suffix}", RecordingSet ) supervisions = load_manifest( - src_dir / f"{prefix}_supervisions_{partition}.jsonl.gz", SupervisionSet + src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet ) config = SpectrogramConfig( diff --git a/egs/ljspeech/TTS/local/prepare_token_file.py b/egs/ljspeech/TTS/local/prepare_token_file.py index 167b73f2e5..007bb299bf 100755 --- a/egs/ljspeech/TTS/local/prepare_token_file.py +++ b/egs/ljspeech/TTS/local/prepare_token_file.py @@ -22,12 +22,9 @@ import argparse import logging -from collections import Counter from pathlib import Path from typing import Dict -import g2p_en -import tacotron_cleaner.cleaners from lhotse import load_manifest @@ -74,32 +71,23 @@ def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: def get_token2id(manifest_file: Path) -> Dict[str, int]: """Return a dict that maps token to IDs.""" extra_tokens = [ - ("", None), # 0 for blank - ("", None), # 1 for sos and eos symbols. - ("", None), # 2 for OOV + "", # 0 for blank + "", # 1 for sos and eos symbols. + "" # 2 for OOV ] + all_tokens = set() + cut_set = load_manifest(manifest_file) - g2p = g2p_en.G2p() - counter = Counter() for cut in cut_set: # Each cut only contain one supervision assert len(cut.supervisions) == 1, len(cut.supervisions) - text = cut.supervisions[0].normalized_text - # Text normalization - text = tacotron_cleaner.cleaners.custom_english_cleaners(text) - # Convert to phonemes - tokens = g2p(text) - for t in tokens: - counter[t] += 1 - - # Sort by the number of occurrences in descending order - tokens_and_counts = sorted(counter.items(), key=lambda x: -x[1]) - - tokens_and_counts = extra_tokens + tokens_and_counts + for t in cut.tokens: + all_tokens.add(t) - token2id: Dict[str, int] = {token: i for i, (token, _) in enumerate(tokens_and_counts)} + all_tokens = extra_tokens + list(all_tokens) + token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)} return token2id diff --git a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py new file mode 100755 index 0000000000..f7fa7e2d26 --- /dev/null +++ b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file reads the texts in given manifest and save the new cuts with phoneme tokens. +""" + +import logging +from pathlib import Path + +import g2p_en +import tacotron_cleaner.cleaners +from lhotse import CutSet, load_manifest + + +def prepare_tokens_ljspeech(): + output_dir = Path("data/spectrogram") + prefix = "ljspeech" + suffix = "jsonl.gz" + partition = "all" + + cut_set = load_manifest( + output_dir / f"{prefix}_cuts_{partition}.{suffix}" + ) + g2p = g2p_en.G2p() + + new_cuts = [] + for cut in cut_set: + # Each cut only contains one supervision + assert len(cut.supervisions) == 1, len(cut.supervisions) + text = cut.supervisions[0].normalized_text + # Text normalization + text = tacotron_cleaner.cleaners.custom_english_cleaners(text) + # Convert to phonemes + cut.tokens = g2p(text) + new_cuts.append(cut) + + new_cut_set = CutSet.from_cuts(new_cuts) + new_cut_set.to_file( + output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}" + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + prepare_tokens_ljspeech() diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh index 396d91b597..8ee40896e1 100755 --- a/egs/ljspeech/TTS/prepare.sh +++ b/egs/ljspeech/TTS/prepare.sh @@ -69,7 +69,17 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Split the LJSpeech cuts into train, valid and test sets" + log "Stage 3: Prepare phoneme tokens for LJSpeech" + if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then + ./local/prepare_tokens_ljspeech.py + mv data/spectrogram/ljspeech_cuts_with_tokens_all.jsonl.gz \ + data/spectrogram/ljspeech_cuts_all.jsonl.gz + touch data/spectrogram/.ljspeech_with_token.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Split the LJSpeech cuts into train, valid and test sets" if [ ! -e data/spectrogram/.ljspeech_split.done ]; then lhotse subset --last 600 \ data/spectrogram/ljspeech_cuts_all.jsonl.gz \ @@ -91,8 +101,8 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then fi fi -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Generate token file" +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Generate token file" # We assume you have installed g2p_en and espnet_tts_frontend. # If not, please install them with: # - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p diff --git a/egs/ljspeech/TTS/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py index 4917a7ee99..a7c4a4c096 100755 --- a/egs/ljspeech/TTS/vits/infer.py +++ b/egs/ljspeech/TTS/vits/infer.py @@ -128,10 +128,10 @@ def _save_worker( futures = [] with ThreadPoolExecutor(max_workers=1) as executor: for batch_idx, batch in enumerate(dl): - batch_size = len(batch["text"]) + batch_size = len(batch["tokens"]) - text = batch["text"] - tokens = tokenizer.texts_to_token_ids(text) + tokens = batch["tokens"] + tokens = tokenizer.tokens_to_token_ids(tokens) tokens = k2.RaggedTensor(tokens) row_splits = tokens.shape.row_splits(1) tokens_lens = row_splits[1:] - row_splits[:-1] diff --git a/egs/ljspeech/TTS/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py index 8a61511ef5..0678b26fe0 100644 --- a/egs/ljspeech/TTS/vits/tokenizer.py +++ b/egs/ljspeech/TTS/vits/tokenizer.py @@ -77,3 +77,30 @@ def texts_to_token_ids(self, texts: List[str], intersperse_blank: bool = True): token_ids_list.append(token_ids) return token_ids_list + + def tokens_to_token_ids(self, tokens_list: List[str], intersperse_blank: bool = True): + """ + Args: + tokens_list: + A list of token list, each corresponding to one utterance. + intersperse_blank: + Whether to intersperse blanks in the token sequence. + + Returns: + Return a list of token id list [utterance][token_id] + """ + token_ids_list = [] + + for tokens in tokens_list: + token_ids = [] + for t in tokens: + if t in self.token2id: + token_ids.append(self.token2id[t]) + else: + token_ids.append(self.oov_id) + + if intersperse_blank: + token_ids = intersperse(token_ids, self.blank_id) + token_ids_list.append(token_ids) + + return token_ids_list diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py index 1a2c934fe2..eb43a4cc93 100755 --- a/egs/ljspeech/TTS/vits/train.py +++ b/egs/ljspeech/TTS/vits/train.py @@ -295,9 +295,9 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): features = batch["features"].to(device) audio_lens = batch["audio_lens"].to(device) features_lens = batch["features_lens"].to(device) - text = batch["text"] + tokens = batch["tokens"] - tokens = tokenizer.texts_to_token_ids(text) + tokens = tokenizer.tokens_to_token_ids(tokens) tokens = k2.RaggedTensor(tokens) row_splits = tokens.shape.row_splits(1) tokens_lens = row_splits[1:] - row_splits[:-1] @@ -384,7 +384,7 @@ def save_bad_model(suffix: str = ""): for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 - batch_size = len(batch["text"]) + batch_size = len(batch["tokens"]) audio, audio_lens, features, features_lens, tokens, tokens_lens = \ prepare_input(batch, tokenizer, device) @@ -554,7 +554,7 @@ def compute_validation_loss( with torch.no_grad(): for batch_idx, batch in enumerate(valid_dl): - batch_size = len(batch["text"]) + batch_size = len(batch["tokens"]) audio, audio_lens, features, features_lens, tokens, tokens_lens = \ prepare_input(batch, tokenizer, device) diff --git a/egs/ljspeech/TTS/vits/tts_datamodule.py b/egs/ljspeech/TTS/vits/tts_datamodule.py index 40e9c19ddf..f27676670b 100644 --- a/egs/ljspeech/TTS/vits/tts_datamodule.py +++ b/egs/ljspeech/TTS/vits/tts_datamodule.py @@ -168,7 +168,9 @@ def train_dataloaders( """ logging.info("About to create train dataset") train = SpeechSynthesisDataset( - return_tokens=False, + return_token_ids=False, + return_text=False, + return_tokens=True, feature_input_strategy=eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, ) @@ -182,7 +184,9 @@ def train_dataloaders( use_fft_mag=True, ) train = SpeechSynthesisDataset( - return_tokens=False, + return_token_ids=False, + return_text=False, + return_tokens=True, feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), return_cuts=self.args.return_cuts, ) @@ -236,13 +240,17 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: use_fft_mag=True, ) validate = SpeechSynthesisDataset( - return_tokens=False, + return_token_ids=False, + return_text=False, + return_tokens=True, feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), return_cuts=self.args.return_cuts, ) else: validate = SpeechSynthesisDataset( - return_tokens=False, + return_token_ids=False, + return_text=False, + return_tokens=True, feature_input_strategy=eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, ) @@ -273,13 +281,17 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: use_fft_mag=True, ) test = SpeechSynthesisDataset( - return_tokens=False, + return_token_ids=False, + return_text=False, + return_tokens=True, feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), return_cuts=self.args.return_cuts, ) else: test = SpeechSynthesisDataset( - return_tokens=False, + return_token_ids=False, + return_text=False, + return_tokens=True, feature_input_strategy=eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, ) From 32931b78851b60ed4a8665c9ae7f40fa289de54f Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Tue, 21 Nov 2023 16:33:54 +0800 Subject: [PATCH 09/16] fix tts_datamodule.py --- egs/ljspeech/TTS/vits/tts_datamodule.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/egs/ljspeech/TTS/vits/tts_datamodule.py b/egs/ljspeech/TTS/vits/tts_datamodule.py index f27676670b..0fcbb92c16 100644 --- a/egs/ljspeech/TTS/vits/tts_datamodule.py +++ b/egs/ljspeech/TTS/vits/tts_datamodule.py @@ -168,7 +168,6 @@ def train_dataloaders( """ logging.info("About to create train dataset") train = SpeechSynthesisDataset( - return_token_ids=False, return_text=False, return_tokens=True, feature_input_strategy=eval(self.args.input_strategy)(), @@ -184,7 +183,6 @@ def train_dataloaders( use_fft_mag=True, ) train = SpeechSynthesisDataset( - return_token_ids=False, return_text=False, return_tokens=True, feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), @@ -240,7 +238,6 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: use_fft_mag=True, ) validate = SpeechSynthesisDataset( - return_token_ids=False, return_text=False, return_tokens=True, feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), @@ -248,7 +245,6 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: ) else: validate = SpeechSynthesisDataset( - return_token_ids=False, return_text=False, return_tokens=True, feature_input_strategy=eval(self.args.input_strategy)(), @@ -281,7 +277,6 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: use_fft_mag=True, ) test = SpeechSynthesisDataset( - return_token_ids=False, return_text=False, return_tokens=True, feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), @@ -289,7 +284,6 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: ) else: test = SpeechSynthesisDataset( - return_token_ids=False, return_text=False, return_tokens=True, feature_input_strategy=eval(self.args.input_strategy)(), From 1ed6b4e14399312b9542a8002ce806916293d9fd Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Tue, 21 Nov 2023 16:59:56 +0800 Subject: [PATCH 10/16] minor fix --- egs/ljspeech/TTS/vits/infer.py | 1 + egs/ljspeech/TTS/vits/vits.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/egs/ljspeech/TTS/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py index a7c4a4c096..91a35e3602 100755 --- a/egs/ljspeech/TTS/vits/infer.py +++ b/egs/ljspeech/TTS/vits/infer.py @@ -225,6 +225,7 @@ def main(): tokenizer=tokenizer, ) + logging.info(f"Wav files are saved to {params.save_wav_dir}") logging.info("Done!") diff --git a/egs/ljspeech/TTS/vits/vits.py b/egs/ljspeech/TTS/vits/vits.py index aa26a012d8..d5e20a5787 100644 --- a/egs/ljspeech/TTS/vits/vits.py +++ b/egs/ljspeech/TTS/vits/vits.py @@ -573,6 +573,7 @@ def inference_batch( self, text: torch.Tensor, text_lengths: torch.Tensor, + sids: Optional[torch.Tensor] = None, durations: Optional[torch.Tensor] = None, noise_scale: float = 0.667, noise_scale_dur: float = 0.8, @@ -585,6 +586,7 @@ def inference_batch( Args: text (Tensor): Input text index tensor (B, T_text). text_lengths (Tensor): Input text index tensor (B,). + sids (Tensor): Speaker index tensor (B,). noise_scale (float): Noise scale value for flow. noise_scale_dur (float): Noise scale value for duration predictor. alpha (float): Alpha parameter to control the speed of generated speech. @@ -599,6 +601,7 @@ def inference_batch( wav, att_w, dur = self.generator.inference( text=text, text_lengths=text_lengths, + sids=sids, noise_scale=noise_scale, noise_scale_dur=noise_scale_dur, alpha=alpha, From a983dcd469ac23bba82ad13ce764c994bf65d81c Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Thu, 23 Nov 2023 20:46:34 +0800 Subject: [PATCH 11/16] support onnx export and testing the exported onnx model --- egs/ljspeech/TTS/vits/export-onnx.py | 261 ++++++++++++++++++++++++++ egs/ljspeech/TTS/vits/generator.py | 2 + egs/ljspeech/TTS/vits/test_onnx.py | 123 ++++++++++++ egs/ljspeech/TTS/vits/text_encoder.py | 38 ++-- 4 files changed, 411 insertions(+), 13 deletions(-) create mode 100755 egs/ljspeech/TTS/vits/export-onnx.py create mode 100755 egs/ljspeech/TTS/vits/test_onnx.py diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py new file mode 100755 index 0000000000..154de4bf42 --- /dev/null +++ b/egs/ljspeech/TTS/vits/export-onnx.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script exports a VITS model from PyTorch to ONNX. + +Export the model to ONNX: +./vits/export-onnx.py \ + --epoch 1000 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + +It will generate two files inside vits/exp: + - vits-epoch-1000.onnx + - vits-epoch-1000.int8.onnx (quantizated model) + +See ./test_onnx.py for how to use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Tuple + +import onnx +import torch +import torch.nn as nn +from onnxruntime.quantization import QuantType, quantize_dynamic +from tokenizer import Tokenizer +from train import get_model, get_params + +from icefall.checkpoint import load_checkpoint + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=1000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxModel(nn.Module): + """A wrapper for VITS generator.""" + + def __init__(self, model: nn.Module): + """ + Args: + model: + A VITS generator. + frame_shift: + The frame shift in samples. + """ + super().__init__() + self.model = model + + def forward( + self, + tokens: torch.Tensor, + tokens_lens: torch.Tensor, + noise_scale: float = 0.667, + noise_scale_dur: float = 0.8, + alpha: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of VITS.inference_batch + + Args: + tokens: + Input text token indexes (1, T_text) + tokens_lens: + Number of tokens of shape (1,) + noise_scale (float): + Noise scale parameter for flow. + noise_scale_dur (float): + Noise scale parameter for duration predictor. + alpha (float): + Alpha parameter to control the speed of generated speech. + + Returns: + Return a tuple containing: + - audio, generated wavform tensor, (B, T_wav) + """ + audio, _, _ = self.model.inference( + text=tokens, + text_lengths=tokens_lens, + noise_scale=noise_scale, + noise_scale_dur=noise_scale_dur, + alpha=alpha, + ) + return audio + + +def export_model_onnx( + model: nn.Module, + model_filename: str, + opset_version: int = 11, +) -> None: + """Export the given generator model to ONNX format. + The exported model has one input: + + - tokens, a tensor of shape (1, T_text); dtype is torch.int64 + + and it has one output: + + - audio, a tensor of shape (1, T'); dtype is torch.float32 + + Args: + model: + The VITS generator. + model_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64) + tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) + noise_scale = torch.tensor([1], dtype=torch.float32) + noise_scale_dur = torch.tensor([1], dtype=torch.float32) + alpha = torch.tensor([1], dtype=torch.float32) + + torch.onnx.export( + model, + (tokens, tokens_lens, noise_scale, noise_scale_dur, alpha), + model_filename, + verbose=False, + opset_version=opset_version, + input_names=["tokens", "tokens_lens", "noise_scale", "noise_scale_dur", "alpha"], + output_names=["audio"], + dynamic_axes={ + "tokens": {0: "N", 1: "T"}, + "tokens_lens": {0: "N"}, + "audio": {0: "N", 1: "T"}, + }, + ) + + meta_data = { + "model_type": "VITS", + "version": "1", + "model_author": "k2-fsa", + "comment": "VITS generator", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=model_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.blank_id + params.oov_id = tokenizer.oov_id + params.vocab_size = tokenizer.vocab_size + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + + model = model.generator + model.to("cpu") + model.eval() + + model = OnnxModel(model=model) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"generator parameters: {num_param}") + + suffix = f"epoch-{params.epoch}" + + opset_version = 13 + + logging.info("Exporting encoder") + model_filename = params.exp_dir / f"vits-{suffix}.onnx" + export_model_onnx( + model, + model_filename, + opset_version=opset_version, + ) + logging.info(f"Exported generator to {model_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx" + quantize_dynamic( + model_input=model_filename, + model_output=model_filename_int8, + weight_type=QuantType.QUInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/ljspeech/TTS/vits/generator.py b/egs/ljspeech/TTS/vits/generator.py index 664d8064f5..efb0e254cf 100644 --- a/egs/ljspeech/TTS/vits/generator.py +++ b/egs/ljspeech/TTS/vits/generator.py @@ -403,6 +403,7 @@ def inference( """ # encoder x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths) + x_mask = x_mask.to(x.dtype) g = None if self.spks is not None: # (B, global_channels, 1) @@ -480,6 +481,7 @@ def inference( dur = torch.ceil(w) y_lengths = torch.clamp_min(torch.sum(dur, [1, 2]), 1).long() y_mask = (~make_pad_mask(y_lengths)).unsqueeze(1).to(text.device) + y_mask = y_mask.to(x.dtype) attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) attn = self._generate_path(dur, attn_mask) diff --git a/egs/ljspeech/TTS/vits/test_onnx.py b/egs/ljspeech/TTS/vits/test_onnx.py new file mode 100755 index 0000000000..8acca7c026 --- /dev/null +++ b/egs/ljspeech/TTS/vits/test_onnx.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script is used to test the exported onnx model by vits/export-onnx.py + +Use the onnx model to generate a wav: +./vits/test_onnx.py \ + --model-filename vits/exp/vits-epoch-1000.onnx \ + --tokens data/tokens.txt +""" + + +import argparse +import logging +import onnxruntime as ort +import torch +import torchaudio + +from tokenizer import Tokenizer + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the onnx model.", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + return parser + + +class OnnxModel: + def __init__(self, model_filename: str): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 4 + + self.session_opts = session_opts + + self.model = ort.InferenceSession( + model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") + + def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor: + """ + Args: + tokens: + A 1-D tensor of shape (1, T) + Returns: + A tensor of shape (1, T') + """ + noise_scale = torch.tensor([0.667], dtype=torch.float32) + noise_scale_dur = torch.tensor([0.8], dtype=torch.float32) + alpha = torch.tensor([1.0], dtype=torch.float32) + + out = self.model.run( + [ + self.model.get_outputs()[0].name, + ], + { + self.model.get_inputs()[0].name: tokens.numpy(), + self.model.get_inputs()[1].name: tokens_lens.numpy(), + self.model.get_inputs()[2].name: noise_scale.numpy(), + self.model.get_inputs()[3].name: noise_scale_dur.numpy(), + self.model.get_inputs()[4].name: alpha.numpy(), + }, + )[0] + return torch.from_numpy(out) + + +def main(): + args = get_parser().parse_args() + + tokenizer = Tokenizer(args.tokens) + + logging.info("About to create onnx model") + model = OnnxModel(args.model_filename) + + text = "I went there to see the land, the people and how their system works, end quote." + tokens = tokenizer.texts_to_token_ids([text]) + tokens = torch.tensor(tokens) # (1, T) + tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T) + audio = model(tokens, tokens_lens) # (1, T') + + torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050) + logging.info("Saved to test_onnx.wav") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/ljspeech/TTS/vits/text_encoder.py b/egs/ljspeech/TTS/vits/text_encoder.py index 419fd6162e..9f337e45bb 100644 --- a/egs/ljspeech/TTS/vits/text_encoder.py +++ b/egs/ljspeech/TTS/vits/text_encoder.py @@ -30,7 +30,7 @@ import torch from torch import Tensor, nn -from icefall.utils import make_pad_mask +from icefall.utils import is_jit_tracing, make_pad_mask class TextEncoder(torch.nn.Module): @@ -440,18 +440,30 @@ def rel_shift(self, x: Tensor) -> Tensor: """ (batch_size, num_heads, seq_len, n) = x.shape - assert n == 2 * seq_len - 1, f"{n} == 2 * {seq_len} - 1" - - # Note: TorchScript requires explicit arg for stride() - batch_stride = x.stride(0) - head_stride = x.stride(1) - time_stride = x.stride(2) - n_stride = x.stride(3) - return x.as_strided( - (batch_size, num_heads, seq_len, seq_len), - (batch_stride, head_stride, time_stride - n_stride, n_stride), - storage_offset=n_stride * (seq_len - 1), - ) + if not is_jit_tracing(): + assert n == 2 * seq_len - 1, f"{n} == 2 * {seq_len} - 1" + + if is_jit_tracing(): + rows = torch.arange(start=seq_len - 1, end=-1, step=-1) + cols = torch.arange(seq_len) + rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) + indexes = rows + cols + + x = x.reshape(-1, n) + x = torch.gather(x, dim=1, index=indexes) + x = x.reshape(batch_size, num_heads, seq_len, seq_len) + return x + else: + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, seq_len, seq_len), + (batch_stride, head_stride, time_stride - n_stride, n_stride), + storage_offset=n_stride * (seq_len - 1), + ) def forward( self, From 5ab142842e782a8e51e3731b3e2028ad89103787 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Tue, 28 Nov 2023 10:50:07 +0800 Subject: [PATCH 12/16] add doc --- docs/source/recipes/TTS/index.rst | 7 ++ docs/source/recipes/TTS/ljspeech/vits.rst | 106 ++++++++++++++++++++++ docs/source/recipes/index.rst | 3 +- 3 files changed, 115 insertions(+), 1 deletion(-) create mode 100644 docs/source/recipes/TTS/index.rst create mode 100644 docs/source/recipes/TTS/ljspeech/vits.rst diff --git a/docs/source/recipes/TTS/index.rst b/docs/source/recipes/TTS/index.rst new file mode 100644 index 0000000000..aa891c072e --- /dev/null +++ b/docs/source/recipes/TTS/index.rst @@ -0,0 +1,7 @@ +TTS +====== + +.. toctree:: + :maxdepth: 2 + + ljspeech/vits diff --git a/docs/source/recipes/TTS/ljspeech/vits.rst b/docs/source/recipes/TTS/ljspeech/vits.rst new file mode 100644 index 0000000000..0f0d97a9e8 --- /dev/null +++ b/docs/source/recipes/TTS/ljspeech/vits.rst @@ -0,0 +1,106 @@ +VITS +=============== + +This tutorial shows you how to train an VITS model +with the `LJSpeech `_ dataset. + +.. note:: + + The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech `_ + + +Data preparation +---------------- + +.. code-block:: bash + + $ cd egs/ljspeech/TTS + $ ./prepare.sh + +To run stage 1 to stage 5, use + +.. code-block:: bash + + $ ./prepare.sh --stage 1 --stop_stage 5 + + +Build Monotonic Alignment Search +-------------------------------- + +.. code-block:: bash + + $ cd vits/monotonic_align + $ python setup.py build_ext --inplace + $ cd ../../ + + +Training +-------- + +.. code-block:: bash + + $ export CUDA_VISIBLE_DEVICES="0,1,2,3" + $ ./vits/train.py \ + --world-size 4 \ + --num-epochs 1000 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + --max-duration 500 + +.. note:: + + You can adjust the hyper-parameters to control the size of the VITS model and + the training configurations. For more details, please run ``./vits/train.py --help``. + +.. note:: + + The training can take a long time (usually a couple of days). + +Training logs, checkpoints and tensorboard logs are saved in ``vits/exp``. + + +Inference +--------- + +The inference part uses checkpoints saved by the training part, so you have to run the +training part first. It will save the ground-truth and generated wavs to the directory +``vits/exp/infer/epoch-*/wav``, e.g., ``vits/exp/infer/epoch-1000/wav``. + +.. code-block:: bash + + $ export CUDA_VISIBLE_DEVICES="0" + $ ./vits/infer.py \ + --epoch 1000 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + --max-duration 500 + +.. note:: + + For more details, please run ``./vits/infer.py --help``. + + +Export models +------------- + +Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``: +``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``. + +.. code-block:: bash + + $ ./vits/export-onnx.py \ + --epoch 1000 \ + --exp-dir vits/exp \ + --tokens data/tokens.txt + +You can test the exported ONNX model with: + +.. code-block:: bash + + $ ./vits/test_onnx.py \ + --model-filename vits/exp/vits-epoch-1000.onnx \ + --tokens data/tokens.txt + + diff --git a/docs/source/recipes/index.rst b/docs/source/recipes/index.rst index 7265e1cf62..8df61f0d08 100644 --- a/docs/source/recipes/index.rst +++ b/docs/source/recipes/index.rst @@ -2,7 +2,7 @@ Recipes ======= This page contains various recipes in ``icefall``. -Currently, only speech recognition recipes are provided. +Currently, we provide recipes for speech recognition, language model, and speech synthesis. We may add recipes for other tasks as well in the future. @@ -16,3 +16,4 @@ We may add recipes for other tasks as well in the future. Non-streaming-ASR/index Streaming-ASR/index RNN-LM/index + TTS/index From 0030d1b766dca94d0fde7a80c4ba685f5f43518f Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Wed, 29 Nov 2023 19:05:36 +0800 Subject: [PATCH 13/16] add README.md --- egs/ljspeech/TTS/vits/README.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 egs/ljspeech/TTS/vits/README.md diff --git a/egs/ljspeech/TTS/vits/README.md b/egs/ljspeech/TTS/vits/README.md new file mode 100644 index 0000000000..45b5445160 --- /dev/null +++ b/egs/ljspeech/TTS/vits/README.md @@ -0,0 +1 @@ +See https://k2-fsa.github.io/icefall/recipes/TTS/ljspeech/vits.html for detailed tutorials. From 70343a81f439a6f75ee2648f3139c80865c5437f Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Wed, 29 Nov 2023 19:57:16 +0800 Subject: [PATCH 14/16] fix style --- .flake8 | 2 +- docs/source/recipes/TTS/ljspeech/vits.rst | 6 ++++++ egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py | 8 +++++++- egs/ljspeech/TTS/local/prepare_token_file.py | 2 +- egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py | 8 ++------ egs/ljspeech/TTS/vits/README.md | 2 ++ 6 files changed, 19 insertions(+), 9 deletions(-) diff --git a/.flake8 b/.flake8 index 410cb54822..cf276d0ba4 100644 --- a/.flake8 +++ b/.flake8 @@ -15,7 +15,7 @@ per-file-ignores = egs/librispeech/ASR/zipformer_mmi/*.py: E501, E203 egs/librispeech/ASR/zipformer/*.py: E501, E203 egs/librispeech/ASR/RESULTS.md: E999, - + egs/ljspeech/TTS/vits/*.py: E501, E203 # invalid escape sequence (cause by tex formular), W605 icefall/utils.py: E501, W605 diff --git a/docs/source/recipes/TTS/ljspeech/vits.rst b/docs/source/recipes/TTS/ljspeech/vits.rst index 0f0d97a9e8..535d8999f4 100644 --- a/docs/source/recipes/TTS/ljspeech/vits.rst +++ b/docs/source/recipes/TTS/ljspeech/vits.rst @@ -103,4 +103,10 @@ You can test the exported ONNX model with: --model-filename vits/exp/vits-epoch-1000.onnx \ --tokens data/tokens.txt +Download pretrained models +-------------------------- +If you don't want to train from scratch, you can download the pretrained models +by visiting the following link: + + - ``_ diff --git a/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py b/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py index eacf0df57f..97c9008fc8 100755 --- a/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py +++ b/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py @@ -29,7 +29,13 @@ from pathlib import Path import torch -from lhotse import CutSet, Spectrogram, SpectrogramConfig, LilcomChunkyWriter, load_manifest +from lhotse import ( + CutSet, + LilcomChunkyWriter, + Spectrogram, + SpectrogramConfig, + load_manifest, +) from lhotse.audio import RecordingSet from lhotse.supervision import SupervisionSet diff --git a/egs/ljspeech/TTS/local/prepare_token_file.py b/egs/ljspeech/TTS/local/prepare_token_file.py index 007bb299bf..df976804ab 100755 --- a/egs/ljspeech/TTS/local/prepare_token_file.py +++ b/egs/ljspeech/TTS/local/prepare_token_file.py @@ -73,7 +73,7 @@ def get_token2id(manifest_file: Path) -> Dict[str, int]: extra_tokens = [ "", # 0 for blank "", # 1 for sos and eos symbols. - "" # 2 for OOV + "", # 2 for OOV ] all_tokens = set() diff --git a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py index f7fa7e2d26..fcd0137a08 100755 --- a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py +++ b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py @@ -34,9 +34,7 @@ def prepare_tokens_ljspeech(): suffix = "jsonl.gz" partition = "all" - cut_set = load_manifest( - output_dir / f"{prefix}_cuts_{partition}.{suffix}" - ) + cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}") g2p = g2p_en.G2p() new_cuts = [] @@ -51,9 +49,7 @@ def prepare_tokens_ljspeech(): new_cuts.append(cut) new_cut_set = CutSet.from_cuts(new_cuts) - new_cut_set.to_file( - output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}" - ) + new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}") if __name__ == "__main__": diff --git a/egs/ljspeech/TTS/vits/README.md b/egs/ljspeech/TTS/vits/README.md index 45b5445160..1141326b96 100644 --- a/egs/ljspeech/TTS/vits/README.md +++ b/egs/ljspeech/TTS/vits/README.md @@ -1 +1,3 @@ See https://k2-fsa.github.io/icefall/recipes/TTS/ljspeech/vits.html for detailed tutorials. + +Training logs, Tensorboard logs, and checkpoints are uploaded to https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2023-11-29. From 0f721040489a60379b3d4f3a250a31974a008f2a Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Wed, 29 Nov 2023 20:02:58 +0800 Subject: [PATCH 15/16] modify pyproject.toml --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index c40143fb93..435256416c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,4 +14,5 @@ exclude = ''' | icefall\/diagnostics\.py | icefall\/profiler\.py | egs\/librispeech\/ASR\/zipformer + | egs\/ljspeech\/TTS\/vits ''' From 24587e3a723f20f0653fc46840b00beb5b3788d2 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Wed, 29 Nov 2023 20:11:48 +0800 Subject: [PATCH 16/16] minor fix --- docs/source/recipes/TTS/ljspeech/vits.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/recipes/TTS/ljspeech/vits.rst b/docs/source/recipes/TTS/ljspeech/vits.rst index 535d8999f4..385fd3c705 100644 --- a/docs/source/recipes/TTS/ljspeech/vits.rst +++ b/docs/source/recipes/TTS/ljspeech/vits.rst @@ -103,6 +103,7 @@ You can test the exported ONNX model with: --model-filename vits/exp/vits-epoch-1000.onnx \ --tokens data/tokens.txt + Download pretrained models --------------------------