diff --git a/.flake8 b/.flake8
index 410cb54822..cf276d0ba4 100644
--- a/.flake8
+++ b/.flake8
@@ -15,7 +15,7 @@ per-file-ignores =
egs/librispeech/ASR/zipformer_mmi/*.py: E501, E203
egs/librispeech/ASR/zipformer/*.py: E501, E203
egs/librispeech/ASR/RESULTS.md: E999,
-
+ egs/ljspeech/TTS/vits/*.py: E501, E203
# invalid escape sequence (cause by tex formular), W605
icefall/utils.py: E501, W605
diff --git a/docs/source/recipes/TTS/index.rst b/docs/source/recipes/TTS/index.rst
new file mode 100644
index 0000000000..aa891c072e
--- /dev/null
+++ b/docs/source/recipes/TTS/index.rst
@@ -0,0 +1,7 @@
+TTS
+======
+
+.. toctree::
+ :maxdepth: 2
+
+ ljspeech/vits
diff --git a/docs/source/recipes/TTS/ljspeech/vits.rst b/docs/source/recipes/TTS/ljspeech/vits.rst
new file mode 100644
index 0000000000..385fd3c705
--- /dev/null
+++ b/docs/source/recipes/TTS/ljspeech/vits.rst
@@ -0,0 +1,113 @@
+VITS
+===============
+
+This tutorial shows you how to train an VITS model
+with the `LJSpeech `_ dataset.
+
+.. note::
+
+ The VITS paper: `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech `_
+
+
+Data preparation
+----------------
+
+.. code-block:: bash
+
+ $ cd egs/ljspeech/TTS
+ $ ./prepare.sh
+
+To run stage 1 to stage 5, use
+
+.. code-block:: bash
+
+ $ ./prepare.sh --stage 1 --stop_stage 5
+
+
+Build Monotonic Alignment Search
+--------------------------------
+
+.. code-block:: bash
+
+ $ cd vits/monotonic_align
+ $ python setup.py build_ext --inplace
+ $ cd ../../
+
+
+Training
+--------
+
+.. code-block:: bash
+
+ $ export CUDA_VISIBLE_DEVICES="0,1,2,3"
+ $ ./vits/train.py \
+ --world-size 4 \
+ --num-epochs 1000 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir vits/exp \
+ --tokens data/tokens.txt
+ --max-duration 500
+
+.. note::
+
+ You can adjust the hyper-parameters to control the size of the VITS model and
+ the training configurations. For more details, please run ``./vits/train.py --help``.
+
+.. note::
+
+ The training can take a long time (usually a couple of days).
+
+Training logs, checkpoints and tensorboard logs are saved in ``vits/exp``.
+
+
+Inference
+---------
+
+The inference part uses checkpoints saved by the training part, so you have to run the
+training part first. It will save the ground-truth and generated wavs to the directory
+``vits/exp/infer/epoch-*/wav``, e.g., ``vits/exp/infer/epoch-1000/wav``.
+
+.. code-block:: bash
+
+ $ export CUDA_VISIBLE_DEVICES="0"
+ $ ./vits/infer.py \
+ --epoch 1000 \
+ --exp-dir vits/exp \
+ --tokens data/tokens.txt
+ --max-duration 500
+
+.. note::
+
+ For more details, please run ``./vits/infer.py --help``.
+
+
+Export models
+-------------
+
+Currently we only support ONNX model exporting. It will generate two files in the given ``exp-dir``:
+``vits-epoch-*.onnx`` and ``vits-epoch-*.int8.onnx``.
+
+.. code-block:: bash
+
+ $ ./vits/export-onnx.py \
+ --epoch 1000 \
+ --exp-dir vits/exp \
+ --tokens data/tokens.txt
+
+You can test the exported ONNX model with:
+
+.. code-block:: bash
+
+ $ ./vits/test_onnx.py \
+ --model-filename vits/exp/vits-epoch-1000.onnx \
+ --tokens data/tokens.txt
+
+
+Download pretrained models
+--------------------------
+
+If you don't want to train from scratch, you can download the pretrained models
+by visiting the following link:
+
+ - ``_
diff --git a/docs/source/recipes/index.rst b/docs/source/recipes/index.rst
index 7265e1cf62..8df61f0d08 100644
--- a/docs/source/recipes/index.rst
+++ b/docs/source/recipes/index.rst
@@ -2,7 +2,7 @@ Recipes
=======
This page contains various recipes in ``icefall``.
-Currently, only speech recognition recipes are provided.
+Currently, we provide recipes for speech recognition, language model, and speech synthesis.
We may add recipes for other tasks as well in the future.
@@ -16,3 +16,4 @@ We may add recipes for other tasks as well in the future.
Non-streaming-ASR/index
Streaming-ASR/index
RNN-LM/index
+ TTS/index
diff --git a/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py b/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py
new file mode 100755
index 0000000000..97c9008fc8
--- /dev/null
+++ b/egs/ljspeech/TTS/local/compute_spectrogram_ljspeech.py
@@ -0,0 +1,106 @@
+#!/usr/bin/env python3
+# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the LJSpeech dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated spectrogram features are saved in data/spectrogram.
+"""
+
+import logging
+import os
+from pathlib import Path
+
+import torch
+from lhotse import (
+ CutSet,
+ LilcomChunkyWriter,
+ Spectrogram,
+ SpectrogramConfig,
+ load_manifest,
+)
+from lhotse.audio import RecordingSet
+from lhotse.supervision import SupervisionSet
+
+from icefall.utils import get_executor
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def compute_spectrogram_ljspeech():
+ src_dir = Path("data/manifests")
+ output_dir = Path("data/spectrogram")
+ num_jobs = min(4, os.cpu_count())
+
+ sampling_rate = 22050
+ frame_length = 1024 / sampling_rate # (in second)
+ frame_shift = 256 / sampling_rate # (in second)
+ use_fft_mag = True
+
+ prefix = "ljspeech"
+ suffix = "jsonl.gz"
+ partition = "all"
+
+ recordings = load_manifest(
+ src_dir / f"{prefix}_recordings_{partition}.{suffix}", RecordingSet
+ )
+ supervisions = load_manifest(
+ src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet
+ )
+
+ config = SpectrogramConfig(
+ sampling_rate=sampling_rate,
+ frame_length=frame_length,
+ frame_shift=frame_shift,
+ use_fft_mag=use_fft_mag,
+ )
+ extractor = Spectrogram(config)
+
+ with get_executor() as ex: # Initialize the executor only once.
+ cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
+ if (output_dir / cuts_filename).is_file():
+ logging.info(f"{cuts_filename} already exists - skipping.")
+ return
+ logging.info(f"Processing {partition}")
+ cut_set = CutSet.from_manifests(
+ recordings=recordings, supervisions=supervisions
+ )
+
+ cut_set = cut_set.compute_and_store_features(
+ extractor=extractor,
+ storage_path=f"{output_dir}/{prefix}_feats_{partition}",
+ # when an executor is specified, make more partitions
+ num_jobs=num_jobs if ex is None else 80,
+ executor=ex,
+ storage_type=LilcomChunkyWriter,
+ )
+ cut_set.to_file(output_dir / cuts_filename)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ compute_spectrogram_ljspeech()
diff --git a/egs/ljspeech/TTS/local/display_manifest_statistics.py b/egs/ljspeech/TTS/local/display_manifest_statistics.py
new file mode 100755
index 0000000000..93f0044f0e
--- /dev/null
+++ b/egs/ljspeech/TTS/local/display_manifest_statistics.py
@@ -0,0 +1,73 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This file displays duration statistics of utterances in a manifest.
+You can use the displayed value to choose minimum/maximum duration
+to remove short and long utterances during the training.
+
+See the function `remove_short_and_long_utt()` in vits/train.py
+for usage.
+"""
+
+
+from lhotse import load_manifest_lazy
+
+
+def main():
+ path = "./data/spectrogram/ljspeech_cuts_all.jsonl.gz"
+ cuts = load_manifest_lazy(path)
+ cuts.describe()
+
+
+if __name__ == "__main__":
+ main()
+
+"""
+Cut statistics:
+ ╒═══════════════════════════╤══════════╕
+ │ Cuts count: │ 13100 │
+ ├───────────────────────────┼──────────┤
+ │ Total duration (hh:mm:ss) │ 23:55:18 │
+ ├───────────────────────────┼──────────┤
+ │ mean │ 6.6 │
+ ├───────────────────────────┼──────────┤
+ │ std │ 2.2 │
+ ├───────────────────────────┼──────────┤
+ │ min │ 1.1 │
+ ├───────────────────────────┼──────────┤
+ │ 25% │ 5.0 │
+ ├───────────────────────────┼──────────┤
+ │ 50% │ 6.8 │
+ ├───────────────────────────┼──────────┤
+ │ 75% │ 8.4 │
+ ├───────────────────────────┼──────────┤
+ │ 99% │ 10.0 │
+ ├───────────────────────────┼──────────┤
+ │ 99.5% │ 10.1 │
+ ├───────────────────────────┼──────────┤
+ │ 99.9% │ 10.1 │
+ ├───────────────────────────┼──────────┤
+ │ max │ 10.1 │
+ ├───────────────────────────┼──────────┤
+ │ Recordings available: │ 13100 │
+ ├───────────────────────────┼──────────┤
+ │ Features available: │ 13100 │
+ ├───────────────────────────┼──────────┤
+ │ Supervisions available: │ 13100 │
+ ╘═══════════════════════════╧══════════╛
+"""
diff --git a/egs/ljspeech/TTS/local/prepare_token_file.py b/egs/ljspeech/TTS/local/prepare_token_file.py
new file mode 100755
index 0000000000..df976804ab
--- /dev/null
+++ b/egs/ljspeech/TTS/local/prepare_token_file.py
@@ -0,0 +1,104 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file reads the texts in given manifest and generates the file that maps tokens to IDs.
+"""
+
+import argparse
+import logging
+from pathlib import Path
+from typing import Dict
+
+from lhotse import load_manifest
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--manifest-file",
+ type=Path,
+ default=Path("data/spectrogram/ljspeech_cuts_train.jsonl.gz"),
+ help="Path to the manifest file",
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=Path,
+ default=Path("data/tokens.txt"),
+ help="Path to the tokens",
+ )
+
+ return parser.parse_args()
+
+
+def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
+ """Write a symbol to ID mapping to a file.
+
+ Note:
+ No need to implement `read_mapping` as it can be done
+ through :func:`k2.SymbolTable.from_file`.
+
+ Args:
+ filename:
+ Filename to save the mapping.
+ sym2id:
+ A dict mapping symbols to IDs.
+ Returns:
+ Return None.
+ """
+ with open(filename, "w", encoding="utf-8") as f:
+ for sym, i in sym2id.items():
+ f.write(f"{sym} {i}\n")
+
+
+def get_token2id(manifest_file: Path) -> Dict[str, int]:
+ """Return a dict that maps token to IDs."""
+ extra_tokens = [
+ "", # 0 for blank
+ "", # 1 for sos and eos symbols.
+ "", # 2 for OOV
+ ]
+ all_tokens = set()
+
+ cut_set = load_manifest(manifest_file)
+
+ for cut in cut_set:
+ # Each cut only contain one supervision
+ assert len(cut.supervisions) == 1, len(cut.supervisions)
+ for t in cut.tokens:
+ all_tokens.add(t)
+
+ all_tokens = extra_tokens + list(all_tokens)
+
+ token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)}
+ return token2id
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ args = get_args()
+ manifest_file = Path(args.manifest_file)
+ out_file = Path(args.tokens)
+
+ token2id = get_token2id(manifest_file)
+ write_mapping(out_file, token2id)
diff --git a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py
new file mode 100755
index 0000000000..fcd0137a08
--- /dev/null
+++ b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py
@@ -0,0 +1,59 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file reads the texts in given manifest and save the new cuts with phoneme tokens.
+"""
+
+import logging
+from pathlib import Path
+
+import g2p_en
+import tacotron_cleaner.cleaners
+from lhotse import CutSet, load_manifest
+
+
+def prepare_tokens_ljspeech():
+ output_dir = Path("data/spectrogram")
+ prefix = "ljspeech"
+ suffix = "jsonl.gz"
+ partition = "all"
+
+ cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
+ g2p = g2p_en.G2p()
+
+ new_cuts = []
+ for cut in cut_set:
+ # Each cut only contains one supervision
+ assert len(cut.supervisions) == 1, len(cut.supervisions)
+ text = cut.supervisions[0].normalized_text
+ # Text normalization
+ text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
+ # Convert to phonemes
+ cut.tokens = g2p(text)
+ new_cuts.append(cut)
+
+ new_cut_set = CutSet.from_cuts(new_cuts)
+ new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ prepare_tokens_ljspeech()
diff --git a/egs/ljspeech/TTS/local/validate_manifest.py b/egs/ljspeech/TTS/local/validate_manifest.py
new file mode 100755
index 0000000000..68159ae036
--- /dev/null
+++ b/egs/ljspeech/TTS/local/validate_manifest.py
@@ -0,0 +1,70 @@
+#!/usr/bin/env python3
+# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script checks the following assumptions of the generated manifest:
+
+- Single supervision per cut
+
+We will add more checks later if needed.
+
+Usage example:
+
+ python3 ./local/validate_manifest.py \
+ ./data/spectrogram/ljspeech_cuts_all.jsonl.gz
+
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+from lhotse import CutSet, load_manifest_lazy
+from lhotse.dataset.speech_synthesis import validate_for_tts
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "manifest",
+ type=Path,
+ help="Path to the manifest file",
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ args = get_args()
+
+ manifest = args.manifest
+ logging.info(f"Validating {manifest}")
+
+ assert manifest.is_file(), f"{manifest} does not exist"
+ cut_set = load_manifest_lazy(manifest)
+ assert isinstance(cut_set, CutSet), type(cut_set)
+
+ validate_for_tts(cut_set)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ main()
diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh
new file mode 100755
index 0000000000..8ee40896e1
--- /dev/null
+++ b/egs/ljspeech/TTS/prepare.sh
@@ -0,0 +1,117 @@
+#!/usr/bin/env bash
+
+# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
+export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+
+set -eou pipefail
+
+nj=1
+stage=-1
+stop_stage=100
+
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Download data"
+
+ # The directory $dl_dir/LJSpeech-1.1 will contain:
+ # - wavs, which contains the audio files
+ # - metadata.csv, which provides the transcript text for each audio clip
+
+ # If you have pre-downloaded it to /path/to/LJSpeech-1.1, you can create a symlink
+ #
+ # ln -sfv /path/to/LJSpeech-1.1 $dl_dir/LJSpeech-1.1
+ #
+ if [ ! -d $dl_dir/LJSpeech-1.1 ]; then
+ lhotse download ljspeech $dl_dir
+ fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare LJSpeech manifest"
+ # We assume that you have downloaded the LJSpeech corpus
+ # to $dl_dir/LJSpeech
+ mkdir -p data/manifests
+ if [ ! -e data/manifests/.ljspeech.done ]; then
+ lhotse prepare ljspeech $dl_dir/LJSpeech-1.1 data/manifests
+ touch data/manifests/.ljspeech.done
+ fi
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Compute spectrogram for LJSpeech"
+ mkdir -p data/spectrogram
+ if [ ! -e data/spectrogram/.ljspeech.done ]; then
+ ./local/compute_spectrogram_ljspeech.py
+ touch data/spectrogram/.ljspeech.done
+ fi
+
+ if [ ! -e data/spectrogram/.ljspeech-validated.done ]; then
+ log "Validating data/spectrogram for LJSpeech"
+ python3 ./local/validate_manifest.py \
+ data/spectrogram/ljspeech_cuts_all.jsonl.gz
+ touch data/spectrogram/.ljspeech-validated.done
+ fi
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 3: Prepare phoneme tokens for LJSpeech"
+ if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then
+ ./local/prepare_tokens_ljspeech.py
+ mv data/spectrogram/ljspeech_cuts_with_tokens_all.jsonl.gz \
+ data/spectrogram/ljspeech_cuts_all.jsonl.gz
+ touch data/spectrogram/.ljspeech_with_token.done
+ fi
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Split the LJSpeech cuts into train, valid and test sets"
+ if [ ! -e data/spectrogram/.ljspeech_split.done ]; then
+ lhotse subset --last 600 \
+ data/spectrogram/ljspeech_cuts_all.jsonl.gz \
+ data/spectrogram/ljspeech_cuts_validtest.jsonl.gz
+ lhotse subset --first 100 \
+ data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \
+ data/spectrogram/ljspeech_cuts_valid.jsonl.gz
+ lhotse subset --last 500 \
+ data/spectrogram/ljspeech_cuts_validtest.jsonl.gz \
+ data/spectrogram/ljspeech_cuts_test.jsonl.gz
+
+ rm data/spectrogram/ljspeech_cuts_validtest.jsonl.gz
+
+ n=$(( $(gunzip -c data/spectrogram/ljspeech_cuts_all.jsonl.gz | wc -l) - 600 ))
+ lhotse subset --first $n \
+ data/spectrogram/ljspeech_cuts_all.jsonl.gz \
+ data/spectrogram/ljspeech_cuts_train.jsonl.gz
+ touch data/spectrogram/.ljspeech_split.done
+ fi
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 5: Generate token file"
+ # We assume you have installed g2p_en and espnet_tts_frontend.
+ # If not, please install them with:
+ # - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p
+ # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
+ if [ ! -e data/tokens.txt ]; then
+ ./local/prepare_token_file.py \
+ --manifest-file data/spectrogram/ljspeech_cuts_train.jsonl.gz \
+ --tokens data/tokens.txt
+ fi
+fi
+
+
diff --git a/egs/ljspeech/TTS/shared/parse_options.sh b/egs/ljspeech/TTS/shared/parse_options.sh
new file mode 120000
index 0000000000..e4665e7de2
--- /dev/null
+++ b/egs/ljspeech/TTS/shared/parse_options.sh
@@ -0,0 +1 @@
+../../../librispeech/ASR/shared/parse_options.sh
\ No newline at end of file
diff --git a/egs/ljspeech/TTS/vits/README.md b/egs/ljspeech/TTS/vits/README.md
new file mode 100644
index 0000000000..1141326b96
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/README.md
@@ -0,0 +1,3 @@
+See https://k2-fsa.github.io/icefall/recipes/TTS/ljspeech/vits.html for detailed tutorials.
+
+Training logs, Tensorboard logs, and checkpoints are uploaded to https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2023-11-29.
diff --git a/egs/ljspeech/TTS/vits/duration_predictor.py b/egs/ljspeech/TTS/vits/duration_predictor.py
new file mode 100644
index 0000000000..c29a28479a
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/duration_predictor.py
@@ -0,0 +1,194 @@
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/duration_predictor.py
+
+# Copyright 2021 Tomoki Hayashi
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Stochastic duration predictor modules in VITS.
+
+This code is based on https://github.com/jaywalnut310/vits.
+
+"""
+
+import math
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+
+from flow import (
+ ConvFlow,
+ DilatedDepthSeparableConv,
+ ElementwiseAffineFlow,
+ FlipFlow,
+ LogFlow,
+)
+
+
+class StochasticDurationPredictor(torch.nn.Module):
+ """Stochastic duration predictor module.
+
+ This is a module of stochastic duration predictor described in `Conditional
+ Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_.
+
+ .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
+ Text-to-Speech`: https://arxiv.org/abs/2006.04558
+
+ """
+
+ def __init__(
+ self,
+ channels: int = 192,
+ kernel_size: int = 3,
+ dropout_rate: float = 0.5,
+ flows: int = 4,
+ dds_conv_layers: int = 3,
+ global_channels: int = -1,
+ ):
+ """Initialize StochasticDurationPredictor module.
+
+ Args:
+ channels (int): Number of channels.
+ kernel_size (int): Kernel size.
+ dropout_rate (float): Dropout rate.
+ flows (int): Number of flows.
+ dds_conv_layers (int): Number of conv layers in DDS conv.
+ global_channels (int): Number of global conditioning channels.
+
+ """
+ super().__init__()
+
+ self.pre = torch.nn.Conv1d(channels, channels, 1)
+ self.dds = DilatedDepthSeparableConv(
+ channels,
+ kernel_size,
+ layers=dds_conv_layers,
+ dropout_rate=dropout_rate,
+ )
+ self.proj = torch.nn.Conv1d(channels, channels, 1)
+
+ self.log_flow = LogFlow()
+ self.flows = torch.nn.ModuleList()
+ self.flows += [ElementwiseAffineFlow(2)]
+ for i in range(flows):
+ self.flows += [
+ ConvFlow(
+ 2,
+ channels,
+ kernel_size,
+ layers=dds_conv_layers,
+ )
+ ]
+ self.flows += [FlipFlow()]
+
+ self.post_pre = torch.nn.Conv1d(1, channels, 1)
+ self.post_dds = DilatedDepthSeparableConv(
+ channels,
+ kernel_size,
+ layers=dds_conv_layers,
+ dropout_rate=dropout_rate,
+ )
+ self.post_proj = torch.nn.Conv1d(channels, channels, 1)
+ self.post_flows = torch.nn.ModuleList()
+ self.post_flows += [ElementwiseAffineFlow(2)]
+ for i in range(flows):
+ self.post_flows += [
+ ConvFlow(
+ 2,
+ channels,
+ kernel_size,
+ layers=dds_conv_layers,
+ )
+ ]
+ self.post_flows += [FlipFlow()]
+
+ if global_channels > 0:
+ self.global_conv = torch.nn.Conv1d(global_channels, channels, 1)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ w: Optional[torch.Tensor] = None,
+ g: Optional[torch.Tensor] = None,
+ inverse: bool = False,
+ noise_scale: float = 1.0,
+ ) -> torch.Tensor:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, channels, T_text).
+ x_mask (Tensor): Mask tensor (B, 1, T_text).
+ w (Optional[Tensor]): Duration tensor (B, 1, T_text).
+ g (Optional[Tensor]): Global conditioning tensor (B, channels, 1)
+ inverse (bool): Whether to inverse the flow.
+ noise_scale (float): Noise scale value.
+
+ Returns:
+ Tensor: If not inverse, negative log-likelihood (NLL) tensor (B,).
+ If inverse, log-duration tensor (B, 1, T_text).
+
+ """
+ x = x.detach() # stop gradient
+ x = self.pre(x)
+ if g is not None:
+ x = x + self.global_conv(g.detach()) # stop gradient
+ x = self.dds(x, x_mask)
+ x = self.proj(x) * x_mask
+
+ if not inverse:
+ assert w is not None, "w must be provided."
+ h_w = self.post_pre(w)
+ h_w = self.post_dds(h_w, x_mask)
+ h_w = self.post_proj(h_w) * x_mask
+ e_q = (
+ torch.randn(
+ w.size(0),
+ 2,
+ w.size(2),
+ ).to(device=x.device, dtype=x.dtype)
+ * x_mask
+ )
+ z_q = e_q
+ logdet_tot_q = 0.0
+ for flow in self.post_flows:
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
+ logdet_tot_q += logdet_q
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
+ u = torch.sigmoid(z_u) * x_mask
+ z0 = (w - u) * x_mask
+ logdet_tot_q += torch.sum(
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
+ )
+ logq = (
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
+ - logdet_tot_q
+ )
+
+ logdet_tot = 0
+ z0, logdet = self.log_flow(z0, x_mask)
+ logdet_tot += logdet
+ z = torch.cat([z0, z1], 1)
+ for flow in self.flows:
+ z, logdet = flow(z, x_mask, g=x, inverse=inverse)
+ logdet_tot = logdet_tot + logdet
+ nll = (
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
+ - logdet_tot
+ )
+ return nll + logq # (B,)
+ else:
+ flows = list(reversed(self.flows))
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
+ z = (
+ torch.randn(
+ x.size(0),
+ 2,
+ x.size(2),
+ ).to(device=x.device, dtype=x.dtype)
+ * noise_scale
+ )
+ for flow in flows:
+ z = flow(z, x_mask, g=x, inverse=inverse)
+ z0, z1 = z.split(1, 1)
+ logw = z0
+ return logw
diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py
new file mode 100755
index 0000000000..154de4bf42
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/export-onnx.py
@@ -0,0 +1,261 @@
+#!/usr/bin/env python3
+#
+# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script exports a VITS model from PyTorch to ONNX.
+
+Export the model to ONNX:
+./vits/export-onnx.py \
+ --epoch 1000 \
+ --exp-dir vits/exp \
+ --tokens data/tokens.txt
+
+It will generate two files inside vits/exp:
+ - vits-epoch-1000.onnx
+ - vits-epoch-1000.int8.onnx (quantizated model)
+
+See ./test_onnx.py for how to use the exported ONNX models.
+"""
+
+import argparse
+import logging
+from pathlib import Path
+from typing import Dict, Tuple
+
+import onnx
+import torch
+import torch.nn as nn
+from onnxruntime.quantization import QuantType, quantize_dynamic
+from tokenizer import Tokenizer
+from train import get_model, get_params
+
+from icefall.checkpoint import load_checkpoint
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=1000,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="vits/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/tokens.txt",
+ help="""Path to vocabulary.""",
+ )
+
+ return parser
+
+
+def add_meta_data(filename: str, meta_data: Dict[str, str]):
+ """Add meta data to an ONNX model. It is changed in-place.
+
+ Args:
+ filename:
+ Filename of the ONNX model to be changed.
+ meta_data:
+ Key-value pairs.
+ """
+ model = onnx.load(filename)
+ for key, value in meta_data.items():
+ meta = model.metadata_props.add()
+ meta.key = key
+ meta.value = value
+
+ onnx.save(model, filename)
+
+
+class OnnxModel(nn.Module):
+ """A wrapper for VITS generator."""
+
+ def __init__(self, model: nn.Module):
+ """
+ Args:
+ model:
+ A VITS generator.
+ frame_shift:
+ The frame shift in samples.
+ """
+ super().__init__()
+ self.model = model
+
+ def forward(
+ self,
+ tokens: torch.Tensor,
+ tokens_lens: torch.Tensor,
+ noise_scale: float = 0.667,
+ noise_scale_dur: float = 0.8,
+ alpha: float = 1.0,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Please see the help information of VITS.inference_batch
+
+ Args:
+ tokens:
+ Input text token indexes (1, T_text)
+ tokens_lens:
+ Number of tokens of shape (1,)
+ noise_scale (float):
+ Noise scale parameter for flow.
+ noise_scale_dur (float):
+ Noise scale parameter for duration predictor.
+ alpha (float):
+ Alpha parameter to control the speed of generated speech.
+
+ Returns:
+ Return a tuple containing:
+ - audio, generated wavform tensor, (B, T_wav)
+ """
+ audio, _, _ = self.model.inference(
+ text=tokens,
+ text_lengths=tokens_lens,
+ noise_scale=noise_scale,
+ noise_scale_dur=noise_scale_dur,
+ alpha=alpha,
+ )
+ return audio
+
+
+def export_model_onnx(
+ model: nn.Module,
+ model_filename: str,
+ opset_version: int = 11,
+) -> None:
+ """Export the given generator model to ONNX format.
+ The exported model has one input:
+
+ - tokens, a tensor of shape (1, T_text); dtype is torch.int64
+
+ and it has one output:
+
+ - audio, a tensor of shape (1, T'); dtype is torch.float32
+
+ Args:
+ model:
+ The VITS generator.
+ model_filename:
+ The filename to save the exported ONNX model.
+ opset_version:
+ The opset version to use.
+ """
+ tokens = torch.randint(low=0, high=79, size=(1, 13), dtype=torch.int64)
+ tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64)
+ noise_scale = torch.tensor([1], dtype=torch.float32)
+ noise_scale_dur = torch.tensor([1], dtype=torch.float32)
+ alpha = torch.tensor([1], dtype=torch.float32)
+
+ torch.onnx.export(
+ model,
+ (tokens, tokens_lens, noise_scale, noise_scale_dur, alpha),
+ model_filename,
+ verbose=False,
+ opset_version=opset_version,
+ input_names=["tokens", "tokens_lens", "noise_scale", "noise_scale_dur", "alpha"],
+ output_names=["audio"],
+ dynamic_axes={
+ "tokens": {0: "N", 1: "T"},
+ "tokens_lens": {0: "N"},
+ "audio": {0: "N", 1: "T"},
+ },
+ )
+
+ meta_data = {
+ "model_type": "VITS",
+ "version": "1",
+ "model_author": "k2-fsa",
+ "comment": "VITS generator",
+ }
+ logging.info(f"meta_data: {meta_data}")
+
+ add_meta_data(filename=model_filename, meta_data=meta_data)
+
+
+@torch.no_grad()
+def main():
+ args = get_parser().parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ tokenizer = Tokenizer(params.tokens)
+ params.blank_id = tokenizer.blank_id
+ params.oov_id = tokenizer.oov_id
+ params.vocab_size = tokenizer.vocab_size
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+
+ model = model.generator
+ model.to("cpu")
+ model.eval()
+
+ model = OnnxModel(model=model)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"generator parameters: {num_param}")
+
+ suffix = f"epoch-{params.epoch}"
+
+ opset_version = 13
+
+ logging.info("Exporting encoder")
+ model_filename = params.exp_dir / f"vits-{suffix}.onnx"
+ export_model_onnx(
+ model,
+ model_filename,
+ opset_version=opset_version,
+ )
+ logging.info(f"Exported generator to {model_filename}")
+
+ # Generate int8 quantization models
+ # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
+
+ logging.info("Generate int8 quantization models")
+
+ model_filename_int8 = params.exp_dir / f"vits-{suffix}.int8.onnx"
+ quantize_dynamic(
+ model_input=model_filename,
+ model_output=model_filename_int8,
+ weight_type=QuantType.QUInt8,
+ )
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/ljspeech/TTS/vits/flow.py b/egs/ljspeech/TTS/vits/flow.py
new file mode 100644
index 0000000000..206bd5e3e5
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/flow.py
@@ -0,0 +1,312 @@
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/flow.py
+
+# Copyright 2021 Tomoki Hayashi
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Basic Flow modules used in VITS.
+
+This code is based on https://github.com/jaywalnut310/vits.
+
+"""
+
+import math
+from typing import Optional, Tuple, Union
+
+import torch
+
+from transform import piecewise_rational_quadratic_transform
+
+
+class FlipFlow(torch.nn.Module):
+ """Flip flow module."""
+
+ def forward(
+ self, x: torch.Tensor, *args, inverse: bool = False, **kwargs
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, channels, T).
+ inverse (bool): Whether to inverse the flow.
+
+ Returns:
+ Tensor: Flipped tensor (B, channels, T).
+ Tensor: Log-determinant tensor for NLL (B,) if not inverse.
+
+ """
+ x = torch.flip(x, [1])
+ if not inverse:
+ logdet = x.new_zeros(x.size(0))
+ return x, logdet
+ else:
+ return x
+
+
+class LogFlow(torch.nn.Module):
+ """Log flow module."""
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ inverse: bool = False,
+ eps: float = 1e-5,
+ **kwargs
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, channels, T).
+ x_mask (Tensor): Mask tensor (B, 1, T).
+ inverse (bool): Whether to inverse the flow.
+ eps (float): Epsilon for log.
+
+ Returns:
+ Tensor: Output tensor (B, channels, T).
+ Tensor: Log-determinant tensor for NLL (B,) if not inverse.
+
+ """
+ if not inverse:
+ y = torch.log(torch.clamp_min(x, eps)) * x_mask
+ logdet = torch.sum(-y, [1, 2])
+ return y, logdet
+ else:
+ x = torch.exp(x) * x_mask
+ return x
+
+
+class ElementwiseAffineFlow(torch.nn.Module):
+ """Elementwise affine flow module."""
+
+ def __init__(self, channels: int):
+ """Initialize ElementwiseAffineFlow module.
+
+ Args:
+ channels (int): Number of channels.
+
+ """
+ super().__init__()
+ self.channels = channels
+ self.register_parameter("m", torch.nn.Parameter(torch.zeros(channels, 1)))
+ self.register_parameter("logs", torch.nn.Parameter(torch.zeros(channels, 1)))
+
+ def forward(
+ self, x: torch.Tensor, x_mask: torch.Tensor, inverse: bool = False, **kwargs
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, channels, T).
+ x_lengths (Tensor): Length tensor (B,).
+ inverse (bool): Whether to inverse the flow.
+
+ Returns:
+ Tensor: Output tensor (B, channels, T).
+ Tensor: Log-determinant tensor for NLL (B,) if not inverse.
+
+ """
+ if not inverse:
+ y = self.m + torch.exp(self.logs) * x
+ y = y * x_mask
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
+ return y, logdet
+ else:
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
+ return x
+
+
+class Transpose(torch.nn.Module):
+ """Transpose module for torch.nn.Sequential()."""
+
+ def __init__(self, dim1: int, dim2: int):
+ """Initialize Transpose module."""
+ super().__init__()
+ self.dim1 = dim1
+ self.dim2 = dim2
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Transpose."""
+ return x.transpose(self.dim1, self.dim2)
+
+
+class DilatedDepthSeparableConv(torch.nn.Module):
+ """Dilated depth-separable conv module."""
+
+ def __init__(
+ self,
+ channels: int,
+ kernel_size: int,
+ layers: int,
+ dropout_rate: float = 0.0,
+ eps: float = 1e-5,
+ ):
+ """Initialize DilatedDepthSeparableConv module.
+
+ Args:
+ channels (int): Number of channels.
+ kernel_size (int): Kernel size.
+ layers (int): Number of layers.
+ dropout_rate (float): Dropout rate.
+ eps (float): Epsilon for layer norm.
+
+ """
+ super().__init__()
+
+ self.convs = torch.nn.ModuleList()
+ for i in range(layers):
+ dilation = kernel_size**i
+ padding = (kernel_size * dilation - dilation) // 2
+ self.convs += [
+ torch.nn.Sequential(
+ torch.nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ groups=channels,
+ dilation=dilation,
+ padding=padding,
+ ),
+ Transpose(1, 2),
+ torch.nn.LayerNorm(
+ channels,
+ eps=eps,
+ elementwise_affine=True,
+ ),
+ Transpose(1, 2),
+ torch.nn.GELU(),
+ torch.nn.Conv1d(
+ channels,
+ channels,
+ 1,
+ ),
+ Transpose(1, 2),
+ torch.nn.LayerNorm(
+ channels,
+ eps=eps,
+ elementwise_affine=True,
+ ),
+ Transpose(1, 2),
+ torch.nn.GELU(),
+ torch.nn.Dropout(dropout_rate),
+ )
+ ]
+
+ def forward(
+ self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, in_channels, T).
+ x_mask (Tensor): Mask tensor (B, 1, T).
+ g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
+
+ Returns:
+ Tensor: Output tensor (B, channels, T).
+
+ """
+ if g is not None:
+ x = x + g
+ for f in self.convs:
+ y = f(x * x_mask)
+ x = x + y
+ return x * x_mask
+
+
+class ConvFlow(torch.nn.Module):
+ """Convolutional flow module."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ hidden_channels: int,
+ kernel_size: int,
+ layers: int,
+ bins: int = 10,
+ tail_bound: float = 5.0,
+ ):
+ """Initialize ConvFlow module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ hidden_channels (int): Number of hidden channels.
+ kernel_size (int): Kernel size.
+ layers (int): Number of layers.
+ bins (int): Number of bins.
+ tail_bound (float): Tail bound value.
+
+ """
+ super().__init__()
+ self.half_channels = in_channels // 2
+ self.hidden_channels = hidden_channels
+ self.bins = bins
+ self.tail_bound = tail_bound
+
+ self.input_conv = torch.nn.Conv1d(
+ self.half_channels,
+ hidden_channels,
+ 1,
+ )
+ self.dds_conv = DilatedDepthSeparableConv(
+ hidden_channels,
+ kernel_size,
+ layers,
+ dropout_rate=0.0,
+ )
+ self.proj = torch.nn.Conv1d(
+ hidden_channels,
+ self.half_channels * (bins * 3 - 1),
+ 1,
+ )
+ self.proj.weight.data.zero_()
+ self.proj.bias.data.zero_()
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ g: Optional[torch.Tensor] = None,
+ inverse: bool = False,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, channels, T).
+ x_mask (Tensor): Mask tensor (B,).
+ g (Optional[Tensor]): Global conditioning tensor (B, channels, 1).
+ inverse (bool): Whether to inverse the flow.
+
+ Returns:
+ Tensor: Output tensor (B, channels, T).
+ Tensor: Log-determinant tensor for NLL (B,) if not inverse.
+
+ """
+ xa, xb = x.split(x.size(1) // 2, 1)
+ h = self.input_conv(xa)
+ h = self.dds_conv(h, x_mask, g=g)
+ h = self.proj(h) * x_mask # (B, half_channels * (bins * 3 - 1), T)
+
+ b, c, t = xa.shape
+ # (B, half_channels, bins * 3 - 1, T) -> (B, half_channels, T, bins * 3 - 1)
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2)
+
+ # TODO(kan-bayashi): Understand this calculation
+ denom = math.sqrt(self.hidden_channels)
+ unnorm_widths = h[..., : self.bins] / denom
+ unnorm_heights = h[..., self.bins : 2 * self.bins] / denom
+ unnorm_derivatives = h[..., 2 * self.bins :]
+ xb, logdet_abs = piecewise_rational_quadratic_transform(
+ xb,
+ unnorm_widths,
+ unnorm_heights,
+ unnorm_derivatives,
+ inverse=inverse,
+ tails="linear",
+ tail_bound=self.tail_bound,
+ )
+ x = torch.cat([xa, xb], 1) * x_mask
+ logdet = torch.sum(logdet_abs * x_mask, [1, 2])
+ if not inverse:
+ return x, logdet
+ else:
+ return x
diff --git a/egs/ljspeech/TTS/vits/generator.py b/egs/ljspeech/TTS/vits/generator.py
new file mode 100644
index 0000000000..efb0e254cf
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/generator.py
@@ -0,0 +1,531 @@
+# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/generator.py
+
+# Copyright 2021 Tomoki Hayashi
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Generator module in VITS.
+
+This code is based on https://github.com/jaywalnut310/vits.
+
+"""
+
+
+import math
+from typing import List, Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from icefall.utils import make_pad_mask
+
+from duration_predictor import StochasticDurationPredictor
+from hifigan import HiFiGANGenerator
+from posterior_encoder import PosteriorEncoder
+from residual_coupling import ResidualAffineCouplingBlock
+from text_encoder import TextEncoder
+from utils import get_random_segments
+
+
+class VITSGenerator(torch.nn.Module):
+ """Generator module in VITS, `Conditional Variational Autoencoder
+ with Adversarial Learning for End-to-End Text-to-Speech`.
+ """
+
+ def __init__(
+ self,
+ vocabs: int,
+ aux_channels: int = 513,
+ hidden_channels: int = 192,
+ spks: Optional[int] = None,
+ langs: Optional[int] = None,
+ spk_embed_dim: Optional[int] = None,
+ global_channels: int = -1,
+ segment_size: int = 32,
+ text_encoder_attention_heads: int = 2,
+ text_encoder_ffn_expand: int = 4,
+ text_encoder_cnn_module_kernel: int = 5,
+ text_encoder_blocks: int = 6,
+ text_encoder_dropout_rate: float = 0.1,
+ decoder_kernel_size: int = 7,
+ decoder_channels: int = 512,
+ decoder_upsample_scales: List[int] = [8, 8, 2, 2],
+ decoder_upsample_kernel_sizes: List[int] = [16, 16, 4, 4],
+ decoder_resblock_kernel_sizes: List[int] = [3, 7, 11],
+ decoder_resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ use_weight_norm_in_decoder: bool = True,
+ posterior_encoder_kernel_size: int = 5,
+ posterior_encoder_layers: int = 16,
+ posterior_encoder_stacks: int = 1,
+ posterior_encoder_base_dilation: int = 1,
+ posterior_encoder_dropout_rate: float = 0.0,
+ use_weight_norm_in_posterior_encoder: bool = True,
+ flow_flows: int = 4,
+ flow_kernel_size: int = 5,
+ flow_base_dilation: int = 1,
+ flow_layers: int = 4,
+ flow_dropout_rate: float = 0.0,
+ use_weight_norm_in_flow: bool = True,
+ use_only_mean_in_flow: bool = True,
+ stochastic_duration_predictor_kernel_size: int = 3,
+ stochastic_duration_predictor_dropout_rate: float = 0.5,
+ stochastic_duration_predictor_flows: int = 4,
+ stochastic_duration_predictor_dds_conv_layers: int = 3,
+ ):
+ """Initialize VITS generator module.
+
+ Args:
+ vocabs (int): Input vocabulary size.
+ aux_channels (int): Number of acoustic feature channels.
+ hidden_channels (int): Number of hidden channels.
+ spks (Optional[int]): Number of speakers. If set to > 1, assume that the
+ sids will be provided as the input and use sid embedding layer.
+ langs (Optional[int]): Number of languages. If set to > 1, assume that the
+ lids will be provided as the input and use sid embedding layer.
+ spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0,
+ assume that spembs will be provided as the input.
+ global_channels (int): Number of global conditioning channels.
+ segment_size (int): Segment size for decoder.
+ text_encoder_attention_heads (int): Number of heads in conformer block
+ of text encoder.
+ text_encoder_ffn_expand (int): Expansion ratio of FFN in conformer block
+ of text encoder.
+ text_encoder_cnn_module_kernel (int): Convolution kernel size in text encoder.
+ text_encoder_blocks (int): Number of conformer blocks in text encoder.
+ text_encoder_dropout_rate (float): Dropout rate in conformer block of
+ text encoder.
+ decoder_kernel_size (int): Decoder kernel size.
+ decoder_channels (int): Number of decoder initial channels.
+ decoder_upsample_scales (List[int]): List of upsampling scales in decoder.
+ decoder_upsample_kernel_sizes (List[int]): List of kernel size for
+ upsampling layers in decoder.
+ decoder_resblock_kernel_sizes (List[int]): List of kernel size for resblocks
+ in decoder.
+ decoder_resblock_dilations (List[List[int]]): List of list of dilations for
+ resblocks in decoder.
+ use_weight_norm_in_decoder (bool): Whether to apply weight normalization in
+ decoder.
+ posterior_encoder_kernel_size (int): Posterior encoder kernel size.
+ posterior_encoder_layers (int): Number of layers of posterior encoder.
+ posterior_encoder_stacks (int): Number of stacks of posterior encoder.
+ posterior_encoder_base_dilation (int): Base dilation of posterior encoder.
+ posterior_encoder_dropout_rate (float): Dropout rate for posterior encoder.
+ use_weight_norm_in_posterior_encoder (bool): Whether to apply weight
+ normalization in posterior encoder.
+ flow_flows (int): Number of flows in flow.
+ flow_kernel_size (int): Kernel size in flow.
+ flow_base_dilation (int): Base dilation in flow.
+ flow_layers (int): Number of layers in flow.
+ flow_dropout_rate (float): Dropout rate in flow
+ use_weight_norm_in_flow (bool): Whether to apply weight normalization in
+ flow.
+ use_only_mean_in_flow (bool): Whether to use only mean in flow.
+ stochastic_duration_predictor_kernel_size (int): Kernel size in stochastic
+ duration predictor.
+ stochastic_duration_predictor_dropout_rate (float): Dropout rate in
+ stochastic duration predictor.
+ stochastic_duration_predictor_flows (int): Number of flows in stochastic
+ duration predictor.
+ stochastic_duration_predictor_dds_conv_layers (int): Number of DDS conv
+ layers in stochastic duration predictor.
+
+ """
+ super().__init__()
+ self.segment_size = segment_size
+ self.text_encoder = TextEncoder(
+ vocabs=vocabs,
+ d_model=hidden_channels,
+ num_heads=text_encoder_attention_heads,
+ dim_feedforward=hidden_channels * text_encoder_ffn_expand,
+ cnn_module_kernel=text_encoder_cnn_module_kernel,
+ num_layers=text_encoder_blocks,
+ dropout=text_encoder_dropout_rate,
+ )
+ self.decoder = HiFiGANGenerator(
+ in_channels=hidden_channels,
+ out_channels=1,
+ channels=decoder_channels,
+ global_channels=global_channels,
+ kernel_size=decoder_kernel_size,
+ upsample_scales=decoder_upsample_scales,
+ upsample_kernel_sizes=decoder_upsample_kernel_sizes,
+ resblock_kernel_sizes=decoder_resblock_kernel_sizes,
+ resblock_dilations=decoder_resblock_dilations,
+ use_weight_norm=use_weight_norm_in_decoder,
+ )
+ self.posterior_encoder = PosteriorEncoder(
+ in_channels=aux_channels,
+ out_channels=hidden_channels,
+ hidden_channels=hidden_channels,
+ kernel_size=posterior_encoder_kernel_size,
+ layers=posterior_encoder_layers,
+ stacks=posterior_encoder_stacks,
+ base_dilation=posterior_encoder_base_dilation,
+ global_channels=global_channels,
+ dropout_rate=posterior_encoder_dropout_rate,
+ use_weight_norm=use_weight_norm_in_posterior_encoder,
+ )
+ self.flow = ResidualAffineCouplingBlock(
+ in_channels=hidden_channels,
+ hidden_channels=hidden_channels,
+ flows=flow_flows,
+ kernel_size=flow_kernel_size,
+ base_dilation=flow_base_dilation,
+ layers=flow_layers,
+ global_channels=global_channels,
+ dropout_rate=flow_dropout_rate,
+ use_weight_norm=use_weight_norm_in_flow,
+ use_only_mean=use_only_mean_in_flow,
+ )
+ # TODO(kan-bayashi): Add deterministic version as an option
+ self.duration_predictor = StochasticDurationPredictor(
+ channels=hidden_channels,
+ kernel_size=stochastic_duration_predictor_kernel_size,
+ dropout_rate=stochastic_duration_predictor_dropout_rate,
+ flows=stochastic_duration_predictor_flows,
+ dds_conv_layers=stochastic_duration_predictor_dds_conv_layers,
+ global_channels=global_channels,
+ )
+
+ self.upsample_factor = int(np.prod(decoder_upsample_scales))
+ self.spks = None
+ if spks is not None and spks > 1:
+ assert global_channels > 0
+ self.spks = spks
+ self.global_emb = torch.nn.Embedding(spks, global_channels)
+ self.spk_embed_dim = None
+ if spk_embed_dim is not None and spk_embed_dim > 0:
+ assert global_channels > 0
+ self.spk_embed_dim = spk_embed_dim
+ self.spemb_proj = torch.nn.Linear(spk_embed_dim, global_channels)
+ self.langs = None
+ if langs is not None and langs > 1:
+ assert global_channels > 0
+ self.langs = langs
+ self.lang_emb = torch.nn.Embedding(langs, global_channels)
+
+ # delayed import
+ from monotonic_align import maximum_path
+
+ self.maximum_path = maximum_path
+
+ def forward(
+ self,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ feats: torch.Tensor,
+ feats_lengths: torch.Tensor,
+ sids: Optional[torch.Tensor] = None,
+ spembs: Optional[torch.Tensor] = None,
+ lids: Optional[torch.Tensor] = None,
+ ) -> Tuple[
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ Tuple[
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ torch.Tensor,
+ ],
+ ]:
+ """Calculate forward propagation.
+
+ Args:
+ text (Tensor): Text index tensor (B, T_text).
+ text_lengths (Tensor): Text length tensor (B,).
+ feats (Tensor): Feature tensor (B, aux_channels, T_feats).
+ feats_lengths (Tensor): Feature length tensor (B,).
+ sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
+ spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
+ lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
+
+ Returns:
+ Tensor: Waveform tensor (B, 1, segment_size * upsample_factor).
+ Tensor: Duration negative log-likelihood (NLL) tensor (B,).
+ Tensor: Monotonic attention weight tensor (B, 1, T_feats, T_text).
+ Tensor: Segments start index tensor (B,).
+ Tensor: Text mask tensor (B, 1, T_text).
+ Tensor: Feature mask tensor (B, 1, T_feats).
+ tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
+ - Tensor: Posterior encoder hidden representation (B, H, T_feats).
+ - Tensor: Flow hidden representation (B, H, T_feats).
+ - Tensor: Expanded text encoder projected mean (B, H, T_feats).
+ - Tensor: Expanded text encoder projected scale (B, H, T_feats).
+ - Tensor: Posterior encoder projected mean (B, H, T_feats).
+ - Tensor: Posterior encoder projected scale (B, H, T_feats).
+
+ """
+ # forward text encoder
+ x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
+
+ # calculate global conditioning
+ g = None
+ if self.spks is not None:
+ # speaker one-hot vector embedding: (B, global_channels, 1)
+ g = self.global_emb(sids.view(-1)).unsqueeze(-1)
+ if self.spk_embed_dim is not None:
+ # pretreined speaker embedding, e.g., X-vector (B, global_channels, 1)
+ g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1)
+ if g is None:
+ g = g_
+ else:
+ g = g + g_
+ if self.langs is not None:
+ # language one-hot vector embedding: (B, global_channels, 1)
+ g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1)
+ if g is None:
+ g = g_
+ else:
+ g = g + g_
+
+ # forward posterior encoder
+ z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g)
+
+ # forward flow
+ z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats)
+
+ # monotonic alignment search
+ with torch.no_grad():
+ # negative cross-entropy
+ s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text)
+ # (B, 1, T_text)
+ neg_x_ent_1 = torch.sum(
+ -0.5 * math.log(2 * math.pi) - logs_p,
+ [1],
+ keepdim=True,
+ )
+ # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
+ neg_x_ent_2 = torch.matmul(
+ -0.5 * (z_p**2).transpose(1, 2),
+ s_p_sq_r,
+ )
+ # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
+ neg_x_ent_3 = torch.matmul(
+ z_p.transpose(1, 2),
+ (m_p * s_p_sq_r),
+ )
+ # (B, 1, T_text)
+ neg_x_ent_4 = torch.sum(
+ -0.5 * (m_p**2) * s_p_sq_r,
+ [1],
+ keepdim=True,
+ )
+ # (B, T_feats, T_text)
+ neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4
+ # (B, 1, T_feats, T_text)
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
+ # monotonic attention weight: (B, 1, T_feats, T_text)
+ attn = (
+ self.maximum_path(
+ neg_x_ent,
+ attn_mask.squeeze(1),
+ )
+ .unsqueeze(1)
+ .detach()
+ )
+
+ # forward duration predictor
+ w = attn.sum(2) # (B, 1, T_text)
+ dur_nll = self.duration_predictor(x, x_mask, w=w, g=g)
+ dur_nll = dur_nll / torch.sum(x_mask)
+
+ # expand the length to match with the feature sequence
+ # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
+ # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
+
+ # get random segments
+ z_segments, z_start_idxs = get_random_segments(
+ z,
+ feats_lengths,
+ self.segment_size,
+ )
+
+ # forward decoder with random segments
+ wav = self.decoder(z_segments, g=g)
+
+ return (
+ wav,
+ dur_nll,
+ attn,
+ z_start_idxs,
+ x_mask,
+ y_mask,
+ (z, z_p, m_p, logs_p, m_q, logs_q),
+ )
+
+ def inference(
+ self,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ feats: Optional[torch.Tensor] = None,
+ feats_lengths: Optional[torch.Tensor] = None,
+ sids: Optional[torch.Tensor] = None,
+ spembs: Optional[torch.Tensor] = None,
+ lids: Optional[torch.Tensor] = None,
+ dur: Optional[torch.Tensor] = None,
+ noise_scale: float = 0.667,
+ noise_scale_dur: float = 0.8,
+ alpha: float = 1.0,
+ max_len: Optional[int] = None,
+ use_teacher_forcing: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Run inference.
+
+ Args:
+ text (Tensor): Input text index tensor (B, T_text,).
+ text_lengths (Tensor): Text length tensor (B,).
+ feats (Tensor): Feature tensor (B, aux_channels, T_feats,).
+ feats_lengths (Tensor): Feature length tensor (B,).
+ sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
+ spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
+ lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
+ dur (Optional[Tensor]): Ground-truth duration (B, T_text,). If provided,
+ skip the prediction of durations (i.e., teacher forcing).
+ noise_scale (float): Noise scale parameter for flow.
+ noise_scale_dur (float): Noise scale parameter for duration predictor.
+ alpha (float): Alpha parameter to control the speed of generated speech.
+ max_len (Optional[int]): Maximum length of acoustic feature sequence.
+ use_teacher_forcing (bool): Whether to use teacher forcing.
+
+ Returns:
+ Tensor: Generated waveform tensor (B, T_wav).
+ Tensor: Monotonic attention weight tensor (B, T_feats, T_text).
+ Tensor: Duration tensor (B, T_text).
+
+ """
+ # encoder
+ x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
+ x_mask = x_mask.to(x.dtype)
+ g = None
+ if self.spks is not None:
+ # (B, global_channels, 1)
+ g = self.global_emb(sids.view(-1)).unsqueeze(-1)
+ if self.spk_embed_dim is not None:
+ # (B, global_channels, 1)
+ g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1)
+ if g is None:
+ g = g_
+ else:
+ g = g + g_
+ if self.langs is not None:
+ # (B, global_channels, 1)
+ g_ = self.lang_emb(lids.view(-1)).unsqueeze(-1)
+ if g is None:
+ g = g_
+ else:
+ g = g + g_
+
+ if use_teacher_forcing:
+ # forward posterior encoder
+ z, m_q, logs_q, y_mask = self.posterior_encoder(feats, feats_lengths, g=g)
+
+ # forward flow
+ z_p = self.flow(z, y_mask, g=g) # (B, H, T_feats)
+
+ # monotonic alignment search
+ s_p_sq_r = torch.exp(-2 * logs_p) # (B, H, T_text)
+ # (B, 1, T_text)
+ neg_x_ent_1 = torch.sum(
+ -0.5 * math.log(2 * math.pi) - logs_p,
+ [1],
+ keepdim=True,
+ )
+ # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
+ neg_x_ent_2 = torch.matmul(
+ -0.5 * (z_p**2).transpose(1, 2),
+ s_p_sq_r,
+ )
+ # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
+ neg_x_ent_3 = torch.matmul(
+ z_p.transpose(1, 2),
+ (m_p * s_p_sq_r),
+ )
+ # (B, 1, T_text)
+ neg_x_ent_4 = torch.sum(
+ -0.5 * (m_p**2) * s_p_sq_r,
+ [1],
+ keepdim=True,
+ )
+ # (B, T_feats, T_text)
+ neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4
+ # (B, 1, T_feats, T_text)
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
+ # monotonic attention weight: (B, 1, T_feats, T_text)
+ attn = self.maximum_path(
+ neg_x_ent,
+ attn_mask.squeeze(1),
+ ).unsqueeze(1)
+ dur = attn.sum(2) # (B, 1, T_text)
+
+ # forward decoder with random segments
+ wav = self.decoder(z * y_mask, g=g)
+ else:
+ # duration
+ if dur is None:
+ logw = self.duration_predictor(
+ x,
+ x_mask,
+ g=g,
+ inverse=True,
+ noise_scale=noise_scale_dur,
+ )
+ w = torch.exp(logw) * x_mask * alpha
+ dur = torch.ceil(w)
+ y_lengths = torch.clamp_min(torch.sum(dur, [1, 2]), 1).long()
+ y_mask = (~make_pad_mask(y_lengths)).unsqueeze(1).to(text.device)
+ y_mask = y_mask.to(x.dtype)
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
+ attn = self._generate_path(dur, attn_mask)
+
+ # expand the length to match with the feature sequence
+ # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
+ m_p = torch.matmul(
+ attn.squeeze(1),
+ m_p.transpose(1, 2),
+ ).transpose(1, 2)
+ # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
+ logs_p = torch.matmul(
+ attn.squeeze(1),
+ logs_p.transpose(1, 2),
+ ).transpose(1, 2)
+
+ # decoder
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
+ z = self.flow(z_p, y_mask, g=g, inverse=True)
+ wav = self.decoder((z * y_mask)[:, :, :max_len], g=g)
+
+ return wav.squeeze(1), attn.squeeze(1), dur.squeeze(1)
+
+ def _generate_path(self, dur: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
+ """Generate path a.k.a. monotonic attention.
+
+ Args:
+ dur (Tensor): Duration tensor (B, 1, T_text).
+ mask (Tensor): Attention mask tensor (B, 1, T_feats, T_text).
+
+ Returns:
+ Tensor: Path tensor (B, 1, T_feats, T_text).
+
+ """
+ b, _, t_y, t_x = mask.shape
+ cum_dur = torch.cumsum(dur, -1)
+ cum_dur_flat = cum_dur.view(b * t_x)
+ path = torch.arange(t_y, dtype=dur.dtype, device=dur.device)
+ path = path.unsqueeze(0) < cum_dur_flat.unsqueeze(1)
+ # path = path.view(b, t_x, t_y).to(dtype=mask.dtype)
+ path = path.view(b, t_x, t_y).to(dtype=torch.float)
+ # path will be like (t_x = 3, t_y = 5):
+ # [[[1., 1., 0., 0., 0.], [[[1., 1., 0., 0., 0.],
+ # [1., 1., 1., 1., 0.], --> [0., 0., 1., 1., 0.],
+ # [1., 1., 1., 1., 1.]]] [0., 0., 0., 0., 1.]]]
+ path = path - F.pad(path, [0, 0, 1, 0, 0, 0])[:, :-1]
+ # path = path.to(dtype=mask.dtype)
+ return path.unsqueeze(1).transpose(2, 3) * mask
diff --git a/egs/ljspeech/TTS/vits/hifigan.py b/egs/ljspeech/TTS/vits/hifigan.py
new file mode 100644
index 0000000000..589ac30f60
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/hifigan.py
@@ -0,0 +1,933 @@
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/hifigan.py
+
+# Copyright 2021 Tomoki Hayashi
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""HiFi-GAN Modules.
+
+This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
+
+"""
+
+import copy
+import logging
+from typing import Any, Dict, List, Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+
+class HiFiGANGenerator(torch.nn.Module):
+ """HiFiGAN generator module."""
+
+ def __init__(
+ self,
+ in_channels: int = 80,
+ out_channels: int = 1,
+ channels: int = 512,
+ global_channels: int = -1,
+ kernel_size: int = 7,
+ upsample_scales: List[int] = [8, 8, 2, 2],
+ upsample_kernel_sizes: List[int] = [16, 16, 4, 4],
+ resblock_kernel_sizes: List[int] = [3, 7, 11],
+ resblock_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ use_additional_convs: bool = True,
+ bias: bool = True,
+ nonlinear_activation: str = "LeakyReLU",
+ nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
+ use_weight_norm: bool = True,
+ ):
+ """Initialize HiFiGANGenerator module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ channels (int): Number of hidden representation channels.
+ global_channels (int): Number of global conditioning channels.
+ kernel_size (int): Kernel size of initial and final conv layer.
+ upsample_scales (List[int]): List of upsampling scales.
+ upsample_kernel_sizes (List[int]): List of kernel sizes for upsample layers.
+ resblock_kernel_sizes (List[int]): List of kernel sizes for residual blocks.
+ resblock_dilations (List[List[int]]): List of list of dilations for residual
+ blocks.
+ use_additional_convs (bool): Whether to use additional conv layers in
+ residual blocks.
+ bias (bool): Whether to add bias parameter in convolution layers.
+ nonlinear_activation (str): Activation function module name.
+ nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
+ function.
+ use_weight_norm (bool): Whether to use weight norm. If set to true, it will
+ be applied to all of the conv layers.
+
+ """
+ super().__init__()
+
+ # check hyperparameters are valid
+ assert kernel_size % 2 == 1, "Kernel size must be odd number."
+ assert len(upsample_scales) == len(upsample_kernel_sizes)
+ assert len(resblock_dilations) == len(resblock_kernel_sizes)
+
+ # define modules
+ self.upsample_factor = int(np.prod(upsample_scales) * out_channels)
+ self.num_upsamples = len(upsample_kernel_sizes)
+ self.num_blocks = len(resblock_kernel_sizes)
+ self.input_conv = torch.nn.Conv1d(
+ in_channels,
+ channels,
+ kernel_size,
+ 1,
+ padding=(kernel_size - 1) // 2,
+ )
+ self.upsamples = torch.nn.ModuleList()
+ self.blocks = torch.nn.ModuleList()
+ for i in range(len(upsample_kernel_sizes)):
+ assert upsample_kernel_sizes[i] == 2 * upsample_scales[i]
+ self.upsamples += [
+ torch.nn.Sequential(
+ getattr(torch.nn, nonlinear_activation)(
+ **nonlinear_activation_params
+ ),
+ torch.nn.ConvTranspose1d(
+ channels // (2**i),
+ channels // (2 ** (i + 1)),
+ upsample_kernel_sizes[i],
+ upsample_scales[i],
+ padding=upsample_scales[i] // 2 + upsample_scales[i] % 2,
+ output_padding=upsample_scales[i] % 2,
+ ),
+ )
+ ]
+ for j in range(len(resblock_kernel_sizes)):
+ self.blocks += [
+ ResidualBlock(
+ kernel_size=resblock_kernel_sizes[j],
+ channels=channels // (2 ** (i + 1)),
+ dilations=resblock_dilations[j],
+ bias=bias,
+ use_additional_convs=use_additional_convs,
+ nonlinear_activation=nonlinear_activation,
+ nonlinear_activation_params=nonlinear_activation_params,
+ )
+ ]
+ self.output_conv = torch.nn.Sequential(
+ # NOTE(kan-bayashi): follow official implementation but why
+ # using different slope parameter here? (0.1 vs. 0.01)
+ torch.nn.LeakyReLU(),
+ torch.nn.Conv1d(
+ channels // (2 ** (i + 1)),
+ out_channels,
+ kernel_size,
+ 1,
+ padding=(kernel_size - 1) // 2,
+ ),
+ torch.nn.Tanh(),
+ )
+ if global_channels > 0:
+ self.global_conv = torch.nn.Conv1d(global_channels, channels, 1)
+
+ # apply weight norm
+ if use_weight_norm:
+ self.apply_weight_norm()
+
+ # reset parameters
+ self.reset_parameters()
+
+ def forward(
+ self, c: torch.Tensor, g: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """Calculate forward propagation.
+
+ Args:
+ c (Tensor): Input tensor (B, in_channels, T).
+ g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
+
+ Returns:
+ Tensor: Output tensor (B, out_channels, T).
+
+ """
+ c = self.input_conv(c)
+ if g is not None:
+ c = c + self.global_conv(g)
+ for i in range(self.num_upsamples):
+ c = self.upsamples[i](c)
+ cs = 0.0 # initialize
+ for j in range(self.num_blocks):
+ cs += self.blocks[i * self.num_blocks + j](c)
+ c = cs / self.num_blocks
+ c = self.output_conv(c)
+
+ return c
+
+ def reset_parameters(self):
+ """Reset parameters.
+
+ This initialization follows the official implementation manner.
+ https://github.com/jik876/hifi-gan/blob/master/models.py
+
+ """
+
+ def _reset_parameters(m: torch.nn.Module):
+ if isinstance(m, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)):
+ m.weight.data.normal_(0.0, 0.01)
+ logging.debug(f"Reset parameters in {m}.")
+
+ self.apply(_reset_parameters)
+
+ def remove_weight_norm(self):
+ """Remove weight normalization module from all of the layers."""
+
+ def _remove_weight_norm(m: torch.nn.Module):
+ try:
+ logging.debug(f"Weight norm is removed from {m}.")
+ torch.nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(_remove_weight_norm)
+
+ def apply_weight_norm(self):
+ """Apply weight normalization module from all of the layers."""
+
+ def _apply_weight_norm(m: torch.nn.Module):
+ if isinstance(m, torch.nn.Conv1d) or isinstance(
+ m, torch.nn.ConvTranspose1d
+ ):
+ torch.nn.utils.weight_norm(m)
+ logging.debug(f"Weight norm is applied to {m}.")
+
+ self.apply(_apply_weight_norm)
+
+ def inference(
+ self, c: torch.Tensor, g: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """Perform inference.
+
+ Args:
+ c (torch.Tensor): Input tensor (T, in_channels).
+ g (Optional[Tensor]): Global conditioning tensor (global_channels, 1).
+
+ Returns:
+ Tensor: Output tensor (T ** upsample_factor, out_channels).
+
+ """
+ if g is not None:
+ g = g.unsqueeze(0)
+ c = self.forward(c.transpose(1, 0).unsqueeze(0), g=g)
+ return c.squeeze(0).transpose(1, 0)
+
+
+class ResidualBlock(torch.nn.Module):
+ """Residual block module in HiFiGAN."""
+
+ def __init__(
+ self,
+ kernel_size: int = 3,
+ channels: int = 512,
+ dilations: List[int] = [1, 3, 5],
+ bias: bool = True,
+ use_additional_convs: bool = True,
+ nonlinear_activation: str = "LeakyReLU",
+ nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
+ ):
+ """Initialize ResidualBlock module.
+
+ Args:
+ kernel_size (int): Kernel size of dilation convolution layer.
+ channels (int): Number of channels for convolution layer.
+ dilations (List[int]): List of dilation factors.
+ use_additional_convs (bool): Whether to use additional convolution layers.
+ bias (bool): Whether to add bias parameter in convolution layers.
+ nonlinear_activation (str): Activation function module name.
+ nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
+ function.
+
+ """
+ super().__init__()
+ self.use_additional_convs = use_additional_convs
+ self.convs1 = torch.nn.ModuleList()
+ if use_additional_convs:
+ self.convs2 = torch.nn.ModuleList()
+ assert kernel_size % 2 == 1, "Kernel size must be odd number."
+ for dilation in dilations:
+ self.convs1 += [
+ torch.nn.Sequential(
+ getattr(torch.nn, nonlinear_activation)(
+ **nonlinear_activation_params
+ ),
+ torch.nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation,
+ bias=bias,
+ padding=(kernel_size - 1) // 2 * dilation,
+ ),
+ )
+ ]
+ if use_additional_convs:
+ self.convs2 += [
+ torch.nn.Sequential(
+ getattr(torch.nn, nonlinear_activation)(
+ **nonlinear_activation_params
+ ),
+ torch.nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ bias=bias,
+ padding=(kernel_size - 1) // 2,
+ ),
+ )
+ ]
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, channels, T).
+
+ Returns:
+ Tensor: Output tensor (B, channels, T).
+
+ """
+ for idx in range(len(self.convs1)):
+ xt = self.convs1[idx](x)
+ if self.use_additional_convs:
+ xt = self.convs2[idx](xt)
+ x = xt + x
+ return x
+
+
+class HiFiGANPeriodDiscriminator(torch.nn.Module):
+ """HiFiGAN period discriminator module."""
+
+ def __init__(
+ self,
+ in_channels: int = 1,
+ out_channels: int = 1,
+ period: int = 3,
+ kernel_sizes: List[int] = [5, 3],
+ channels: int = 32,
+ downsample_scales: List[int] = [3, 3, 3, 3, 1],
+ max_downsample_channels: int = 1024,
+ bias: bool = True,
+ nonlinear_activation: str = "LeakyReLU",
+ nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
+ use_weight_norm: bool = True,
+ use_spectral_norm: bool = False,
+ ):
+ """Initialize HiFiGANPeriodDiscriminator module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ period (int): Period.
+ kernel_sizes (list): Kernel sizes of initial conv layers and the final conv
+ layer.
+ channels (int): Number of initial channels.
+ downsample_scales (List[int]): List of downsampling scales.
+ max_downsample_channels (int): Number of maximum downsampling channels.
+ use_additional_convs (bool): Whether to use additional conv layers in
+ residual blocks.
+ bias (bool): Whether to add bias parameter in convolution layers.
+ nonlinear_activation (str): Activation function module name.
+ nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
+ function.
+ use_weight_norm (bool): Whether to use weight norm.
+ If set to true, it will be applied to all of the conv layers.
+ use_spectral_norm (bool): Whether to use spectral norm.
+ If set to true, it will be applied to all of the conv layers.
+
+ """
+ super().__init__()
+ assert len(kernel_sizes) == 2
+ assert kernel_sizes[0] % 2 == 1, "Kernel size must be odd number."
+ assert kernel_sizes[1] % 2 == 1, "Kernel size must be odd number."
+
+ self.period = period
+ self.convs = torch.nn.ModuleList()
+ in_chs = in_channels
+ out_chs = channels
+ for downsample_scale in downsample_scales:
+ self.convs += [
+ torch.nn.Sequential(
+ torch.nn.Conv2d(
+ in_chs,
+ out_chs,
+ (kernel_sizes[0], 1),
+ (downsample_scale, 1),
+ padding=((kernel_sizes[0] - 1) // 2, 0),
+ ),
+ getattr(torch.nn, nonlinear_activation)(
+ **nonlinear_activation_params
+ ),
+ )
+ ]
+ in_chs = out_chs
+ # NOTE(kan-bayashi): Use downsample_scale + 1?
+ out_chs = min(out_chs * 4, max_downsample_channels)
+ self.output_conv = torch.nn.Conv2d(
+ out_chs,
+ out_channels,
+ (kernel_sizes[1] - 1, 1),
+ 1,
+ padding=((kernel_sizes[1] - 1) // 2, 0),
+ )
+
+ if use_weight_norm and use_spectral_norm:
+ raise ValueError("Either use use_weight_norm or use_spectral_norm.")
+
+ # apply weight norm
+ if use_weight_norm:
+ self.apply_weight_norm()
+
+ # apply spectral norm
+ if use_spectral_norm:
+ self.apply_spectral_norm()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Calculate forward propagation.
+
+ Args:
+ c (Tensor): Input tensor (B, in_channels, T).
+
+ Returns:
+ list: List of each layer's tensors.
+
+ """
+ # transform 1d to 2d -> (B, C, T/P, P)
+ b, c, t = x.shape
+ if t % self.period != 0:
+ n_pad = self.period - (t % self.period)
+ x = F.pad(x, (0, n_pad), "reflect")
+ t += n_pad
+ x = x.view(b, c, t // self.period, self.period)
+
+ # forward conv
+ outs = []
+ for layer in self.convs:
+ x = layer(x)
+ outs += [x]
+ x = self.output_conv(x)
+ x = torch.flatten(x, 1, -1)
+ outs += [x]
+
+ return outs
+
+ def apply_weight_norm(self):
+ """Apply weight normalization module from all of the layers."""
+
+ def _apply_weight_norm(m: torch.nn.Module):
+ if isinstance(m, torch.nn.Conv2d):
+ torch.nn.utils.weight_norm(m)
+ logging.debug(f"Weight norm is applied to {m}.")
+
+ self.apply(_apply_weight_norm)
+
+ def apply_spectral_norm(self):
+ """Apply spectral normalization module from all of the layers."""
+
+ def _apply_spectral_norm(m: torch.nn.Module):
+ if isinstance(m, torch.nn.Conv2d):
+ torch.nn.utils.spectral_norm(m)
+ logging.debug(f"Spectral norm is applied to {m}.")
+
+ self.apply(_apply_spectral_norm)
+
+
+class HiFiGANMultiPeriodDiscriminator(torch.nn.Module):
+ """HiFiGAN multi-period discriminator module."""
+
+ def __init__(
+ self,
+ periods: List[int] = [2, 3, 5, 7, 11],
+ discriminator_params: Dict[str, Any] = {
+ "in_channels": 1,
+ "out_channels": 1,
+ "kernel_sizes": [5, 3],
+ "channels": 32,
+ "downsample_scales": [3, 3, 3, 3, 1],
+ "max_downsample_channels": 1024,
+ "bias": True,
+ "nonlinear_activation": "LeakyReLU",
+ "nonlinear_activation_params": {"negative_slope": 0.1},
+ "use_weight_norm": True,
+ "use_spectral_norm": False,
+ },
+ ):
+ """Initialize HiFiGANMultiPeriodDiscriminator module.
+
+ Args:
+ periods (List[int]): List of periods.
+ discriminator_params (Dict[str, Any]): Parameters for hifi-gan period
+ discriminator module. The period parameter will be overwritten.
+
+ """
+ super().__init__()
+ self.discriminators = torch.nn.ModuleList()
+ for period in periods:
+ params = copy.deepcopy(discriminator_params)
+ params["period"] = period
+ self.discriminators += [HiFiGANPeriodDiscriminator(**params)]
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input noise signal (B, 1, T).
+
+ Returns:
+ List: List of list of each discriminator outputs, which consists of each
+ layer output tensors.
+
+ """
+ outs = []
+ for f in self.discriminators:
+ outs += [f(x)]
+
+ return outs
+
+
+class HiFiGANScaleDiscriminator(torch.nn.Module):
+ """HiFi-GAN scale discriminator module."""
+
+ def __init__(
+ self,
+ in_channels: int = 1,
+ out_channels: int = 1,
+ kernel_sizes: List[int] = [15, 41, 5, 3],
+ channels: int = 128,
+ max_downsample_channels: int = 1024,
+ max_groups: int = 16,
+ bias: int = True,
+ downsample_scales: List[int] = [2, 2, 4, 4, 1],
+ nonlinear_activation: str = "LeakyReLU",
+ nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
+ use_weight_norm: bool = True,
+ use_spectral_norm: bool = False,
+ ):
+ """Initilize HiFiGAN scale discriminator module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ kernel_sizes (List[int]): List of four kernel sizes. The first will be used
+ for the first conv layer, and the second is for downsampling part, and
+ the remaining two are for the last two output layers.
+ channels (int): Initial number of channels for conv layer.
+ max_downsample_channels (int): Maximum number of channels for downsampling
+ layers.
+ bias (bool): Whether to add bias parameter in convolution layers.
+ downsample_scales (List[int]): List of downsampling scales.
+ nonlinear_activation (str): Activation function module name.
+ nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
+ function.
+ use_weight_norm (bool): Whether to use weight norm. If set to true, it will
+ be applied to all of the conv layers.
+ use_spectral_norm (bool): Whether to use spectral norm. If set to true, it
+ will be applied to all of the conv layers.
+
+ """
+ super().__init__()
+ self.layers = torch.nn.ModuleList()
+
+ # check kernel size is valid
+ assert len(kernel_sizes) == 4
+ for ks in kernel_sizes:
+ assert ks % 2 == 1
+
+ # add first layer
+ self.layers += [
+ torch.nn.Sequential(
+ torch.nn.Conv1d(
+ in_channels,
+ channels,
+ # NOTE(kan-bayashi): Use always the same kernel size
+ kernel_sizes[0],
+ bias=bias,
+ padding=(kernel_sizes[0] - 1) // 2,
+ ),
+ getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
+ )
+ ]
+
+ # add downsample layers
+ in_chs = channels
+ out_chs = channels
+ # NOTE(kan-bayashi): Remove hard coding?
+ groups = 4
+ for downsample_scale in downsample_scales:
+ self.layers += [
+ torch.nn.Sequential(
+ torch.nn.Conv1d(
+ in_chs,
+ out_chs,
+ kernel_size=kernel_sizes[1],
+ stride=downsample_scale,
+ padding=(kernel_sizes[1] - 1) // 2,
+ groups=groups,
+ bias=bias,
+ ),
+ getattr(torch.nn, nonlinear_activation)(
+ **nonlinear_activation_params
+ ),
+ )
+ ]
+ in_chs = out_chs
+ # NOTE(kan-bayashi): Remove hard coding?
+ out_chs = min(in_chs * 2, max_downsample_channels)
+ # NOTE(kan-bayashi): Remove hard coding?
+ groups = min(groups * 4, max_groups)
+
+ # add final layers
+ out_chs = min(in_chs * 2, max_downsample_channels)
+ self.layers += [
+ torch.nn.Sequential(
+ torch.nn.Conv1d(
+ in_chs,
+ out_chs,
+ kernel_size=kernel_sizes[2],
+ stride=1,
+ padding=(kernel_sizes[2] - 1) // 2,
+ bias=bias,
+ ),
+ getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
+ )
+ ]
+ self.layers += [
+ torch.nn.Conv1d(
+ out_chs,
+ out_channels,
+ kernel_size=kernel_sizes[3],
+ stride=1,
+ padding=(kernel_sizes[3] - 1) // 2,
+ bias=bias,
+ ),
+ ]
+
+ if use_weight_norm and use_spectral_norm:
+ raise ValueError("Either use use_weight_norm or use_spectral_norm.")
+
+ # apply weight norm
+ self.use_weight_norm = use_weight_norm
+ if use_weight_norm:
+ self.apply_weight_norm()
+
+ # apply spectral norm
+ self.use_spectral_norm = use_spectral_norm
+ if use_spectral_norm:
+ self.apply_spectral_norm()
+
+ # backward compatibility
+ self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
+
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input noise signal (B, 1, T).
+
+ Returns:
+ List[Tensor]: List of output tensors of each layer.
+
+ """
+ outs = []
+ for f in self.layers:
+ x = f(x)
+ outs += [x]
+
+ return outs
+
+ def apply_weight_norm(self):
+ """Apply weight normalization module from all of the layers."""
+
+ def _apply_weight_norm(m: torch.nn.Module):
+ if isinstance(m, torch.nn.Conv1d):
+ torch.nn.utils.weight_norm(m)
+ logging.debug(f"Weight norm is applied to {m}.")
+
+ self.apply(_apply_weight_norm)
+
+ def apply_spectral_norm(self):
+ """Apply spectral normalization module from all of the layers."""
+
+ def _apply_spectral_norm(m: torch.nn.Module):
+ if isinstance(m, torch.nn.Conv1d):
+ torch.nn.utils.spectral_norm(m)
+ logging.debug(f"Spectral norm is applied to {m}.")
+
+ self.apply(_apply_spectral_norm)
+
+ def remove_weight_norm(self):
+ """Remove weight normalization module from all of the layers."""
+
+ def _remove_weight_norm(m):
+ try:
+ logging.debug(f"Weight norm is removed from {m}.")
+ torch.nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(_remove_weight_norm)
+
+ def remove_spectral_norm(self):
+ """Remove spectral normalization module from all of the layers."""
+
+ def _remove_spectral_norm(m):
+ try:
+ logging.debug(f"Spectral norm is removed from {m}.")
+ torch.nn.utils.remove_spectral_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(_remove_spectral_norm)
+
+ def _load_state_dict_pre_hook(
+ self,
+ state_dict,
+ prefix,
+ local_metadata,
+ strict,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ ):
+ """Fix the compatibility of weight / spectral normalization issue.
+
+ Some pretrained models are trained with configs that use weight / spectral
+ normalization, but actually, the norm is not applied. This causes the mismatch
+ of the parameters with configs. To solve this issue, when parameter mismatch
+ happens in loading pretrained model, we remove the norm from the current model.
+
+ See also:
+ - https://github.com/espnet/espnet/pull/5240
+ - https://github.com/espnet/espnet/pull/5249
+ - https://github.com/kan-bayashi/ParallelWaveGAN/pull/409
+
+ """
+ current_module_keys = [x for x in state_dict.keys() if x.startswith(prefix)]
+ if self.use_weight_norm and any(
+ [k.endswith("weight") for k in current_module_keys]
+ ):
+ logging.warning(
+ "It seems weight norm is not applied in the pretrained model but the"
+ " current model uses it. To keep the compatibility, we remove the norm"
+ " from the current model. This may cause unexpected behavior due to the"
+ " parameter mismatch in finetuning. To avoid this issue, please change"
+ " the following parameters in config to false:\n"
+ " - discriminator_params.follow_official_norm\n"
+ " - discriminator_params.scale_discriminator_params.use_weight_norm\n"
+ " - discriminator_params.scale_discriminator_params.use_spectral_norm\n"
+ "\n"
+ "See also:\n"
+ " - https://github.com/espnet/espnet/pull/5240\n"
+ " - https://github.com/espnet/espnet/pull/5249"
+ )
+ self.remove_weight_norm()
+ self.use_weight_norm = False
+ for k in current_module_keys:
+ if k.endswith("weight_g") or k.endswith("weight_v"):
+ del state_dict[k]
+
+ if self.use_spectral_norm and any(
+ [k.endswith("weight") for k in current_module_keys]
+ ):
+ logging.warning(
+ "It seems spectral norm is not applied in the pretrained model but the"
+ " current model uses it. To keep the compatibility, we remove the norm"
+ " from the current model. This may cause unexpected behavior due to the"
+ " parameter mismatch in finetuning. To avoid this issue, please change"
+ " the following parameters in config to false:\n"
+ " - discriminator_params.follow_official_norm\n"
+ " - discriminator_params.scale_discriminator_params.use_weight_norm\n"
+ " - discriminator_params.scale_discriminator_params.use_spectral_norm\n"
+ "\n"
+ "See also:\n"
+ " - https://github.com/espnet/espnet/pull/5240\n"
+ " - https://github.com/espnet/espnet/pull/5249"
+ )
+ self.remove_spectral_norm()
+ self.use_spectral_norm = False
+ for k in current_module_keys:
+ if (
+ k.endswith("weight_u")
+ or k.endswith("weight_v")
+ or k.endswith("weight_orig")
+ ):
+ del state_dict[k]
+
+
+class HiFiGANMultiScaleDiscriminator(torch.nn.Module):
+ """HiFi-GAN multi-scale discriminator module."""
+
+ def __init__(
+ self,
+ scales: int = 3,
+ downsample_pooling: str = "AvgPool1d",
+ # follow the official implementation setting
+ downsample_pooling_params: Dict[str, Any] = {
+ "kernel_size": 4,
+ "stride": 2,
+ "padding": 2,
+ },
+ discriminator_params: Dict[str, Any] = {
+ "in_channels": 1,
+ "out_channels": 1,
+ "kernel_sizes": [15, 41, 5, 3],
+ "channels": 128,
+ "max_downsample_channels": 1024,
+ "max_groups": 16,
+ "bias": True,
+ "downsample_scales": [2, 2, 4, 4, 1],
+ "nonlinear_activation": "LeakyReLU",
+ "nonlinear_activation_params": {"negative_slope": 0.1},
+ },
+ follow_official_norm: bool = False,
+ ):
+ """Initilize HiFiGAN multi-scale discriminator module.
+
+ Args:
+ scales (int): Number of multi-scales.
+ downsample_pooling (str): Pooling module name for downsampling of the
+ inputs.
+ downsample_pooling_params (Dict[str, Any]): Parameters for the above pooling
+ module.
+ discriminator_params (Dict[str, Any]): Parameters for hifi-gan scale
+ discriminator module.
+ follow_official_norm (bool): Whether to follow the norm setting of the
+ official implementaion. The first discriminator uses spectral norm
+ and the other discriminators use weight norm.
+
+ """
+ super().__init__()
+ self.discriminators = torch.nn.ModuleList()
+
+ # add discriminators
+ for i in range(scales):
+ params = copy.deepcopy(discriminator_params)
+ if follow_official_norm:
+ if i == 0:
+ params["use_weight_norm"] = False
+ params["use_spectral_norm"] = True
+ else:
+ params["use_weight_norm"] = True
+ params["use_spectral_norm"] = False
+ self.discriminators += [HiFiGANScaleDiscriminator(**params)]
+ self.pooling = None
+ if scales > 1:
+ self.pooling = getattr(torch.nn, downsample_pooling)(
+ **downsample_pooling_params
+ )
+
+ def forward(self, x: torch.Tensor) -> List[List[torch.Tensor]]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input noise signal (B, 1, T).
+
+ Returns:
+ List[List[torch.Tensor]]: List of list of each discriminator outputs,
+ which consists of eachlayer output tensors.
+
+ """
+ outs = []
+ for f in self.discriminators:
+ outs += [f(x)]
+ if self.pooling is not None:
+ x = self.pooling(x)
+
+ return outs
+
+
+class HiFiGANMultiScaleMultiPeriodDiscriminator(torch.nn.Module):
+ """HiFi-GAN multi-scale + multi-period discriminator module."""
+
+ def __init__(
+ self,
+ # Multi-scale discriminator related
+ scales: int = 3,
+ scale_downsample_pooling: str = "AvgPool1d",
+ scale_downsample_pooling_params: Dict[str, Any] = {
+ "kernel_size": 4,
+ "stride": 2,
+ "padding": 2,
+ },
+ scale_discriminator_params: Dict[str, Any] = {
+ "in_channels": 1,
+ "out_channels": 1,
+ "kernel_sizes": [15, 41, 5, 3],
+ "channels": 128,
+ "max_downsample_channels": 1024,
+ "max_groups": 16,
+ "bias": True,
+ "downsample_scales": [2, 2, 4, 4, 1],
+ "nonlinear_activation": "LeakyReLU",
+ "nonlinear_activation_params": {"negative_slope": 0.1},
+ },
+ follow_official_norm: bool = True,
+ # Multi-period discriminator related
+ periods: List[int] = [2, 3, 5, 7, 11],
+ period_discriminator_params: Dict[str, Any] = {
+ "in_channels": 1,
+ "out_channels": 1,
+ "kernel_sizes": [5, 3],
+ "channels": 32,
+ "downsample_scales": [3, 3, 3, 3, 1],
+ "max_downsample_channels": 1024,
+ "bias": True,
+ "nonlinear_activation": "LeakyReLU",
+ "nonlinear_activation_params": {"negative_slope": 0.1},
+ "use_weight_norm": True,
+ "use_spectral_norm": False,
+ },
+ ):
+ """Initilize HiFiGAN multi-scale + multi-period discriminator module.
+
+ Args:
+ scales (int): Number of multi-scales.
+ scale_downsample_pooling (str): Pooling module name for downsampling of the
+ inputs.
+ scale_downsample_pooling_params (dict): Parameters for the above pooling
+ module.
+ scale_discriminator_params (dict): Parameters for hifi-gan scale
+ discriminator module.
+ follow_official_norm (bool): Whether to follow the norm setting of the
+ official implementaion. The first discriminator uses spectral norm and
+ the other discriminators use weight norm.
+ periods (list): List of periods.
+ period_discriminator_params (dict): Parameters for hifi-gan period
+ discriminator module. The period parameter will be overwritten.
+
+ """
+ super().__init__()
+ self.msd = HiFiGANMultiScaleDiscriminator(
+ scales=scales,
+ downsample_pooling=scale_downsample_pooling,
+ downsample_pooling_params=scale_downsample_pooling_params,
+ discriminator_params=scale_discriminator_params,
+ follow_official_norm=follow_official_norm,
+ )
+ self.mpd = HiFiGANMultiPeriodDiscriminator(
+ periods=periods,
+ discriminator_params=period_discriminator_params,
+ )
+
+ def forward(self, x: torch.Tensor) -> List[List[torch.Tensor]]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input noise signal (B, 1, T).
+
+ Returns:
+ List[List[Tensor]]: List of list of each discriminator outputs,
+ which consists of each layer output tensors. Multi scale and
+ multi period ones are concatenated.
+
+ """
+ msd_outs = self.msd(x)
+ mpd_outs = self.mpd(x)
+ return msd_outs + mpd_outs
diff --git a/egs/ljspeech/TTS/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py
new file mode 100755
index 0000000000..91a35e3602
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/infer.py
@@ -0,0 +1,233 @@
+#!/usr/bin/env python3
+#
+# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script performs model inference on test set.
+
+Usage:
+./vits/infer.py \
+ --epoch 1000 \
+ --exp-dir ./vits/exp \
+ --max-duration 500
+"""
+
+
+import argparse
+import logging
+from concurrent.futures import ThreadPoolExecutor
+from pathlib import Path
+from typing import List
+
+import k2
+import torch
+import torch.nn as nn
+import torchaudio
+
+from train import get_model, get_params
+from tokenizer import Tokenizer
+
+from icefall.checkpoint import load_checkpoint
+from icefall.utils import AttributeDict, setup_logger
+from tts_datamodule import LJSpeechTtsDataModule
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=1000,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="vits/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/tokens.txt",
+ help="""Path to vocabulary.""",
+ )
+
+ return parser
+
+
+def infer_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ tokenizer: Tokenizer,
+) -> None:
+ """Decode dataset.
+ The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ tokenizer:
+ Used to convert text to phonemes.
+ """
+ # Background worker save audios to disk.
+ def _save_worker(
+ batch_size: int,
+ cut_ids: List[str],
+ audio: torch.Tensor,
+ audio_pred: torch.Tensor,
+ audio_lens: List[int],
+ audio_lens_pred: List[int],
+ ):
+ for i in range(batch_size):
+ torchaudio.save(
+ str(params.save_wav_dir / f"{cut_ids[i]}_gt.wav"),
+ audio[i:i + 1, :audio_lens[i]],
+ sample_rate=params.sampling_rate,
+ )
+ torchaudio.save(
+ str(params.save_wav_dir / f"{cut_ids[i]}_pred.wav"),
+ audio_pred[i:i + 1, :audio_lens_pred[i]],
+ sample_rate=params.sampling_rate,
+ )
+
+ device = next(model.parameters()).device
+ num_cuts = 0
+ log_interval = 5
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ futures = []
+ with ThreadPoolExecutor(max_workers=1) as executor:
+ for batch_idx, batch in enumerate(dl):
+ batch_size = len(batch["tokens"])
+
+ tokens = batch["tokens"]
+ tokens = tokenizer.tokens_to_token_ids(tokens)
+ tokens = k2.RaggedTensor(tokens)
+ row_splits = tokens.shape.row_splits(1)
+ tokens_lens = row_splits[1:] - row_splits[:-1]
+ tokens = tokens.to(device)
+ tokens_lens = tokens_lens.to(device)
+ # tensor of shape (B, T)
+ tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
+
+ audio = batch["audio"]
+ audio_lens = batch["audio_lens"].tolist()
+ cut_ids = [cut.id for cut in batch["cut"]]
+
+ audio_pred, _, durations = model.inference_batch(text=tokens, text_lengths=tokens_lens)
+ audio_pred = audio_pred.detach().cpu()
+ # convert to samples
+ audio_lens_pred = (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist()
+
+ futures.append(
+ executor.submit(
+ _save_worker, batch_size, cut_ids, audio, audio_pred, audio_lens, audio_lens_pred
+ )
+ )
+
+ num_cuts += batch_size
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ # return results
+ for f in futures:
+ f.result()
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ LJSpeechTtsDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ params.suffix = f"epoch-{params.epoch}"
+
+ params.res_dir = params.exp_dir / "infer" / params.suffix
+ params.save_wav_dir = params.res_dir / "wav"
+ params.save_wav_dir.mkdir(parents=True, exist_ok=True)
+
+ setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
+ logging.info("Infer started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ tokenizer = Tokenizer(params.tokens)
+ params.blank_id = tokenizer.blank_id
+ params.oov_id = tokenizer.oov_id
+ params.vocab_size = tokenizer.vocab_size
+
+ logging.info(f"Device: {device}")
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+
+ model.to(device)
+ model.eval()
+
+ num_param_g = sum([p.numel() for p in model.generator.parameters()])
+ logging.info(f"Number of parameters in generator: {num_param_g}")
+ num_param_d = sum([p.numel() for p in model.discriminator.parameters()])
+ logging.info(f"Number of parameters in discriminator: {num_param_d}")
+ logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ ljspeech = LJSpeechTtsDataModule(args)
+
+ test_cuts = ljspeech.test_cuts()
+ test_dl = ljspeech.test_dataloaders(test_cuts)
+
+ infer_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ tokenizer=tokenizer,
+ )
+
+ logging.info(f"Wav files are saved to {params.save_wav_dir}")
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/ljspeech/TTS/vits/loss.py b/egs/ljspeech/TTS/vits/loss.py
new file mode 100644
index 0000000000..21aaad6e75
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/loss.py
@@ -0,0 +1,336 @@
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py
+
+# Copyright 2021 Tomoki Hayashi
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""HiFiGAN-related loss modules.
+
+This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
+
+"""
+
+from typing import List, Tuple, Union
+
+import torch
+import torch.distributions as D
+import torch.nn.functional as F
+
+from lhotse.features.kaldi import Wav2LogFilterBank
+
+
+class GeneratorAdversarialLoss(torch.nn.Module):
+ """Generator adversarial loss module."""
+
+ def __init__(
+ self,
+ average_by_discriminators: bool = True,
+ loss_type: str = "mse",
+ ):
+ """Initialize GeneratorAversarialLoss module.
+
+ Args:
+ average_by_discriminators (bool): Whether to average the loss by
+ the number of discriminators.
+ loss_type (str): Loss type, "mse" or "hinge".
+
+ """
+ super().__init__()
+ self.average_by_discriminators = average_by_discriminators
+ assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
+ if loss_type == "mse":
+ self.criterion = self._mse_loss
+ else:
+ self.criterion = self._hinge_loss
+
+ def forward(
+ self,
+ outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
+ ) -> torch.Tensor:
+ """Calcualate generator adversarial loss.
+
+ Args:
+ outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
+ outputs, list of discriminator outputs, or list of list of discriminator
+ outputs..
+
+ Returns:
+ Tensor: Generator adversarial loss value.
+
+ """
+ if isinstance(outputs, (tuple, list)):
+ adv_loss = 0.0
+ for i, outputs_ in enumerate(outputs):
+ if isinstance(outputs_, (tuple, list)):
+ # NOTE(kan-bayashi): case including feature maps
+ outputs_ = outputs_[-1]
+ adv_loss += self.criterion(outputs_)
+ if self.average_by_discriminators:
+ adv_loss /= i + 1
+ else:
+ adv_loss = self.criterion(outputs)
+
+ return adv_loss
+
+ def _mse_loss(self, x):
+ return F.mse_loss(x, x.new_ones(x.size()))
+
+ def _hinge_loss(self, x):
+ return -x.mean()
+
+
+class DiscriminatorAdversarialLoss(torch.nn.Module):
+ """Discriminator adversarial loss module."""
+
+ def __init__(
+ self,
+ average_by_discriminators: bool = True,
+ loss_type: str = "mse",
+ ):
+ """Initialize DiscriminatorAversarialLoss module.
+
+ Args:
+ average_by_discriminators (bool): Whether to average the loss by
+ the number of discriminators.
+ loss_type (str): Loss type, "mse" or "hinge".
+
+ """
+ super().__init__()
+ self.average_by_discriminators = average_by_discriminators
+ assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
+ if loss_type == "mse":
+ self.fake_criterion = self._mse_fake_loss
+ self.real_criterion = self._mse_real_loss
+ else:
+ self.fake_criterion = self._hinge_fake_loss
+ self.real_criterion = self._hinge_real_loss
+
+ def forward(
+ self,
+ outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
+ outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Calcualate discriminator adversarial loss.
+
+ Args:
+ outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
+ outputs, list of discriminator outputs, or list of list of discriminator
+ outputs calculated from generator.
+ outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator
+ outputs, list of discriminator outputs, or list of list of discriminator
+ outputs calculated from groundtruth.
+
+ Returns:
+ Tensor: Discriminator real loss value.
+ Tensor: Discriminator fake loss value.
+
+ """
+ if isinstance(outputs, (tuple, list)):
+ real_loss = 0.0
+ fake_loss = 0.0
+ for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)):
+ if isinstance(outputs_hat_, (tuple, list)):
+ # NOTE(kan-bayashi): case including feature maps
+ outputs_hat_ = outputs_hat_[-1]
+ outputs_ = outputs_[-1]
+ real_loss += self.real_criterion(outputs_)
+ fake_loss += self.fake_criterion(outputs_hat_)
+ if self.average_by_discriminators:
+ fake_loss /= i + 1
+ real_loss /= i + 1
+ else:
+ real_loss = self.real_criterion(outputs)
+ fake_loss = self.fake_criterion(outputs_hat)
+
+ return real_loss, fake_loss
+
+ def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor:
+ return F.mse_loss(x, x.new_ones(x.size()))
+
+ def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
+ return F.mse_loss(x, x.new_zeros(x.size()))
+
+ def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor:
+ return -torch.mean(torch.min(x - 1, x.new_zeros(x.size())))
+
+ def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor:
+ return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size())))
+
+
+class FeatureMatchLoss(torch.nn.Module):
+ """Feature matching loss module."""
+
+ def __init__(
+ self,
+ average_by_layers: bool = True,
+ average_by_discriminators: bool = True,
+ include_final_outputs: bool = False,
+ ):
+ """Initialize FeatureMatchLoss module.
+
+ Args:
+ average_by_layers (bool): Whether to average the loss by the number
+ of layers.
+ average_by_discriminators (bool): Whether to average the loss by
+ the number of discriminators.
+ include_final_outputs (bool): Whether to include the final output of
+ each discriminator for loss calculation.
+
+ """
+ super().__init__()
+ self.average_by_layers = average_by_layers
+ self.average_by_discriminators = average_by_discriminators
+ self.include_final_outputs = include_final_outputs
+
+ def forward(
+ self,
+ feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]],
+ feats: Union[List[List[torch.Tensor]], List[torch.Tensor]],
+ ) -> torch.Tensor:
+ """Calculate feature matching loss.
+
+ Args:
+ feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of
+ discriminator outputs or list of discriminator outputs calcuated
+ from generator's outputs.
+ feats (Union[List[List[Tensor]], List[Tensor]]): List of list of
+ discriminator outputs or list of discriminator outputs calcuated
+ from groundtruth..
+
+ Returns:
+ Tensor: Feature matching loss value.
+
+ """
+ feat_match_loss = 0.0
+ for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
+ feat_match_loss_ = 0.0
+ if not self.include_final_outputs:
+ feats_hat_ = feats_hat_[:-1]
+ feats_ = feats_[:-1]
+ for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
+ feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
+ if self.average_by_layers:
+ feat_match_loss_ /= j + 1
+ feat_match_loss += feat_match_loss_
+ if self.average_by_discriminators:
+ feat_match_loss /= i + 1
+
+ return feat_match_loss
+
+
+class MelSpectrogramLoss(torch.nn.Module):
+ """Mel-spectrogram loss."""
+
+ def __init__(
+ self,
+ sampling_rate: int = 22050,
+ frame_length: int = 1024, # in samples
+ frame_shift: int = 256, # in samples
+ n_mels: int = 80,
+ use_fft_mag: bool = True,
+ ):
+ super().__init__()
+ self.wav_to_mel = Wav2LogFilterBank(
+ sampling_rate=sampling_rate,
+ frame_length=frame_length / sampling_rate, # in second
+ frame_shift=frame_shift / sampling_rate, # in second
+ use_fft_mag=use_fft_mag,
+ num_filters=n_mels,
+ )
+
+ def forward(
+ self,
+ y_hat: torch.Tensor,
+ y: torch.Tensor,
+ return_mel: bool = False,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
+ """Calculate Mel-spectrogram loss.
+
+ Args:
+ y_hat (Tensor): Generated waveform tensor (B, 1, T).
+ y (Tensor): Groundtruth waveform tensor (B, 1, T).
+ spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor
+ (B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth
+ waveform.
+
+ Returns:
+ Tensor: Mel-spectrogram loss value.
+
+ """
+ mel_hat = self.wav_to_mel(y_hat.squeeze(1))
+ mel = self.wav_to_mel(y.squeeze(1))
+ mel_loss = F.l1_loss(mel_hat, mel)
+
+ if return_mel:
+ return mel_loss, (mel_hat, mel)
+
+ return mel_loss
+
+
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/loss.py
+
+"""VITS-related loss modules.
+
+This code is based on https://github.com/jaywalnut310/vits.
+
+"""
+
+
+class KLDivergenceLoss(torch.nn.Module):
+ """KL divergence loss."""
+
+ def forward(
+ self,
+ z_p: torch.Tensor,
+ logs_q: torch.Tensor,
+ m_p: torch.Tensor,
+ logs_p: torch.Tensor,
+ z_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ """Calculate KL divergence loss.
+
+ Args:
+ z_p (Tensor): Flow hidden representation (B, H, T_feats).
+ logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats).
+ m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats).
+ logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats).
+ z_mask (Tensor): Mask tensor (B, 1, T_feats).
+
+ Returns:
+ Tensor: KL divergence loss.
+
+ """
+ z_p = z_p.float()
+ logs_q = logs_q.float()
+ m_p = m_p.float()
+ logs_p = logs_p.float()
+ z_mask = z_mask.float()
+ kl = logs_p - logs_q - 0.5
+ kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
+ kl = torch.sum(kl * z_mask)
+ loss = kl / torch.sum(z_mask)
+
+ return loss
+
+
+class KLDivergenceLossWithoutFlow(torch.nn.Module):
+ """KL divergence loss without flow."""
+
+ def forward(
+ self,
+ m_q: torch.Tensor,
+ logs_q: torch.Tensor,
+ m_p: torch.Tensor,
+ logs_p: torch.Tensor,
+ ) -> torch.Tensor:
+ """Calculate KL divergence loss without flow.
+
+ Args:
+ m_q (Tensor): Posterior encoder projected mean (B, H, T_feats).
+ logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats).
+ m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats).
+ logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats).
+ """
+ posterior_norm = D.Normal(m_q, torch.exp(logs_q))
+ prior_norm = D.Normal(m_p, torch.exp(logs_p))
+ loss = D.kl_divergence(posterior_norm, prior_norm).mean()
+ return loss
diff --git a/egs/ljspeech/TTS/vits/monotonic_align/__init__.py b/egs/ljspeech/TTS/vits/monotonic_align/__init__.py
new file mode 100644
index 0000000000..2b35654f51
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/monotonic_align/__init__.py
@@ -0,0 +1,81 @@
+# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/__init__.py
+
+"""Maximum path calculation module.
+
+This code is based on https://github.com/jaywalnut310/vits.
+
+"""
+
+import warnings
+
+import numpy as np
+import torch
+from numba import njit, prange
+
+try:
+ from .core import maximum_path_c
+
+ is_cython_avalable = True
+except ImportError:
+ is_cython_avalable = False
+ warnings.warn(
+ "Cython version is not available. Fallback to 'EXPERIMETAL' numba version. "
+ "If you want to use the cython version, please build it as follows: "
+ "`cd espnet2/gan_tts/vits/monotonic_align; python setup.py build_ext --inplace`"
+ )
+
+
+def maximum_path(neg_x_ent: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor:
+ """Calculate maximum path.
+
+ Args:
+ neg_x_ent (Tensor): Negative X entropy tensor (B, T_feats, T_text).
+ attn_mask (Tensor): Attention mask (B, T_feats, T_text).
+
+ Returns:
+ Tensor: Maximum path tensor (B, T_feats, T_text).
+
+ """
+ device, dtype = neg_x_ent.device, neg_x_ent.dtype
+ neg_x_ent = neg_x_ent.cpu().numpy().astype(np.float32)
+ path = np.zeros(neg_x_ent.shape, dtype=np.int32)
+ t_t_max = attn_mask.sum(1)[:, 0].cpu().numpy().astype(np.int32)
+ t_s_max = attn_mask.sum(2)[:, 0].cpu().numpy().astype(np.int32)
+ if is_cython_avalable:
+ maximum_path_c(path, neg_x_ent, t_t_max, t_s_max)
+ else:
+ maximum_path_numba(path, neg_x_ent, t_t_max, t_s_max)
+
+ return torch.from_numpy(path).to(device=device, dtype=dtype)
+
+
+@njit
+def maximum_path_each_numba(path, value, t_y, t_x, max_neg_val=-np.inf):
+ """Calculate a single maximum path with numba."""
+ index = t_x - 1
+ for y in range(t_y):
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
+ if x == y:
+ v_cur = max_neg_val
+ else:
+ v_cur = value[y - 1, x]
+ if x == 0:
+ if y == 0:
+ v_prev = 0.0
+ else:
+ v_prev = max_neg_val
+ else:
+ v_prev = value[y - 1, x - 1]
+ value[y, x] += max(v_prev, v_cur)
+
+ for y in range(t_y - 1, -1, -1):
+ path[y, index] = 1
+ if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
+ index = index - 1
+
+
+@njit(parallel=True)
+def maximum_path_numba(paths, values, t_ys, t_xs):
+ """Calculate batch maximum path with numba."""
+ for i in prange(paths.shape[0]):
+ maximum_path_each_numba(paths[i], values[i], t_ys[i], t_xs[i])
diff --git a/egs/ljspeech/TTS/vits/monotonic_align/core.pyx b/egs/ljspeech/TTS/vits/monotonic_align/core.pyx
new file mode 100644
index 0000000000..c02c2d02e7
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/monotonic_align/core.pyx
@@ -0,0 +1,51 @@
+# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/core.pyx
+
+"""Maximum path calculation module with cython optimization.
+
+This code is copied from https://github.com/jaywalnut310/vits and modifed code format.
+
+"""
+
+cimport cython
+
+from cython.parallel import prange
+
+
+@cython.boundscheck(False)
+@cython.wraparound(False)
+cdef void maximum_path_each(int[:, ::1] path, float[:, ::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil:
+ cdef int x
+ cdef int y
+ cdef float v_prev
+ cdef float v_cur
+ cdef float tmp
+ cdef int index = t_x - 1
+
+ for y in range(t_y):
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
+ if x == y:
+ v_cur = max_neg_val
+ else:
+ v_cur = value[y - 1, x]
+ if x == 0:
+ if y == 0:
+ v_prev = 0.0
+ else:
+ v_prev = max_neg_val
+ else:
+ v_prev = value[y - 1, x - 1]
+ value[y, x] += max(v_prev, v_cur)
+
+ for y in range(t_y - 1, -1, -1):
+ path[y, index] = 1
+ if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
+ index = index - 1
+
+
+@cython.boundscheck(False)
+@cython.wraparound(False)
+cpdef void maximum_path_c(int[:, :, ::1] paths, float[:, :, ::1] values, int[::1] t_ys, int[::1] t_xs) nogil:
+ cdef int b = paths.shape[0]
+ cdef int i
+ for i in prange(b, nogil=True):
+ maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])
diff --git a/egs/ljspeech/TTS/vits/monotonic_align/setup.py b/egs/ljspeech/TTS/vits/monotonic_align/setup.py
new file mode 100644
index 0000000000..33d75e1765
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/monotonic_align/setup.py
@@ -0,0 +1,31 @@
+# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/monotonic_align/setup.py
+"""Setup cython code."""
+
+from Cython.Build import cythonize
+from setuptools import Extension, setup
+from setuptools.command.build_ext import build_ext as _build_ext
+
+
+class build_ext(_build_ext):
+ """Overwrite build_ext."""
+
+ def finalize_options(self):
+ """Prevent numpy from thinking it is still in its setup process."""
+ _build_ext.finalize_options(self)
+ __builtins__.__NUMPY_SETUP__ = False
+ import numpy
+
+ self.include_dirs.append(numpy.get_include())
+
+
+exts = [
+ Extension(
+ name="core",
+ sources=["core.pyx"],
+ )
+]
+setup(
+ name="monotonic_align",
+ ext_modules=cythonize(exts, language_level=3),
+ cmdclass={"build_ext": build_ext},
+)
diff --git a/egs/ljspeech/TTS/vits/posterior_encoder.py b/egs/ljspeech/TTS/vits/posterior_encoder.py
new file mode 100644
index 0000000000..6b8a5be52f
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/posterior_encoder.py
@@ -0,0 +1,117 @@
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/posterior_encoder.py
+
+# Copyright 2021 Tomoki Hayashi
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Posterior encoder module in VITS.
+
+This code is based on https://github.com/jaywalnut310/vits.
+
+"""
+
+from typing import Optional, Tuple
+
+import torch
+
+from icefall.utils import make_pad_mask
+from wavenet import WaveNet, Conv1d
+
+
+class PosteriorEncoder(torch.nn.Module):
+ """Posterior encoder module in VITS.
+
+ This is a module of posterior encoder described in `Conditional Variational
+ Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_.
+
+ .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
+ Text-to-Speech`: https://arxiv.org/abs/2006.04558
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 513,
+ out_channels: int = 192,
+ hidden_channels: int = 192,
+ kernel_size: int = 5,
+ layers: int = 16,
+ stacks: int = 1,
+ base_dilation: int = 1,
+ global_channels: int = -1,
+ dropout_rate: float = 0.0,
+ bias: bool = True,
+ use_weight_norm: bool = True,
+ ):
+ """Initilialize PosteriorEncoder module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ hidden_channels (int): Number of hidden channels.
+ kernel_size (int): Kernel size in WaveNet.
+ layers (int): Number of layers of WaveNet.
+ stacks (int): Number of repeat stacking of WaveNet.
+ base_dilation (int): Base dilation factor.
+ global_channels (int): Number of global conditioning channels.
+ dropout_rate (float): Dropout rate.
+ bias (bool): Whether to use bias parameters in conv.
+ use_weight_norm (bool): Whether to apply weight norm.
+
+ """
+ super().__init__()
+
+ # define modules
+ self.input_conv = Conv1d(in_channels, hidden_channels, 1)
+ self.encoder = WaveNet(
+ in_channels=-1,
+ out_channels=-1,
+ kernel_size=kernel_size,
+ layers=layers,
+ stacks=stacks,
+ base_dilation=base_dilation,
+ residual_channels=hidden_channels,
+ aux_channels=-1,
+ gate_channels=hidden_channels * 2,
+ skip_channels=hidden_channels,
+ global_channels=global_channels,
+ dropout_rate=dropout_rate,
+ bias=bias,
+ use_weight_norm=use_weight_norm,
+ use_first_conv=False,
+ use_last_conv=False,
+ scale_residual=False,
+ scale_skip_connect=True,
+ )
+ self.proj = Conv1d(hidden_channels, out_channels * 2, 1)
+
+ def forward(
+ self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, in_channels, T_feats).
+ x_lengths (Tensor): Length tensor (B,).
+ g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
+
+ Returns:
+ Tensor: Encoded hidden representation tensor (B, out_channels, T_feats).
+ Tensor: Projected mean tensor (B, out_channels, T_feats).
+ Tensor: Projected scale tensor (B, out_channels, T_feats).
+ Tensor: Mask tensor for input tensor (B, 1, T_feats).
+
+ """
+ x_mask = (
+ (~make_pad_mask(x_lengths))
+ .unsqueeze(1)
+ .to(
+ dtype=x.dtype,
+ device=x.device,
+ )
+ )
+ x = self.input_conv(x) * x_mask
+ x = self.encoder(x, x_mask, g=g)
+ stats = self.proj(x) * x_mask
+ m, logs = stats.split(stats.size(1) // 2, dim=1)
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
+
+ return z, m, logs, x_mask
diff --git a/egs/ljspeech/TTS/vits/residual_coupling.py b/egs/ljspeech/TTS/vits/residual_coupling.py
new file mode 100644
index 0000000000..2d6807cb7c
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/residual_coupling.py
@@ -0,0 +1,229 @@
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/residual_coupling.py
+
+# Copyright 2021 Tomoki Hayashi
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Residual affine coupling modules in VITS.
+
+This code is based on https://github.com/jaywalnut310/vits.
+
+"""
+
+from typing import Optional, Tuple, Union
+
+import torch
+
+from flow import FlipFlow
+from wavenet import WaveNet
+
+
+class ResidualAffineCouplingBlock(torch.nn.Module):
+ """Residual affine coupling block module.
+
+ This is a module of residual affine coupling block, which used as "Flow" in
+ `Conditional Variational Autoencoder with Adversarial Learning for End-to-End
+ Text-to-Speech`_.
+
+ .. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
+ Text-to-Speech`: https://arxiv.org/abs/2006.04558
+
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 192,
+ hidden_channels: int = 192,
+ flows: int = 4,
+ kernel_size: int = 5,
+ base_dilation: int = 1,
+ layers: int = 4,
+ global_channels: int = -1,
+ dropout_rate: float = 0.0,
+ use_weight_norm: bool = True,
+ bias: bool = True,
+ use_only_mean: bool = True,
+ ):
+ """Initilize ResidualAffineCouplingBlock module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ hidden_channels (int): Number of hidden channels.
+ flows (int): Number of flows.
+ kernel_size (int): Kernel size for WaveNet.
+ base_dilation (int): Base dilation factor for WaveNet.
+ layers (int): Number of layers of WaveNet.
+ stacks (int): Number of stacks of WaveNet.
+ global_channels (int): Number of global channels.
+ dropout_rate (float): Dropout rate.
+ use_weight_norm (bool): Whether to use weight normalization in WaveNet.
+ bias (bool): Whether to use bias paramters in WaveNet.
+ use_only_mean (bool): Whether to estimate only mean.
+
+ """
+ super().__init__()
+
+ self.flows = torch.nn.ModuleList()
+ for i in range(flows):
+ self.flows += [
+ ResidualAffineCouplingLayer(
+ in_channels=in_channels,
+ hidden_channels=hidden_channels,
+ kernel_size=kernel_size,
+ base_dilation=base_dilation,
+ layers=layers,
+ stacks=1,
+ global_channels=global_channels,
+ dropout_rate=dropout_rate,
+ use_weight_norm=use_weight_norm,
+ bias=bias,
+ use_only_mean=use_only_mean,
+ )
+ ]
+ self.flows += [FlipFlow()]
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ g: Optional[torch.Tensor] = None,
+ inverse: bool = False,
+ ) -> torch.Tensor:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, in_channels, T).
+ x_lengths (Tensor): Length tensor (B,).
+ g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
+ inverse (bool): Whether to inverse the flow.
+
+ Returns:
+ Tensor: Output tensor (B, in_channels, T).
+
+ """
+ if not inverse:
+ for flow in self.flows:
+ x, _ = flow(x, x_mask, g=g, inverse=inverse)
+ else:
+ for flow in reversed(self.flows):
+ x = flow(x, x_mask, g=g, inverse=inverse)
+ return x
+
+
+class ResidualAffineCouplingLayer(torch.nn.Module):
+ """Residual affine coupling layer."""
+
+ def __init__(
+ self,
+ in_channels: int = 192,
+ hidden_channels: int = 192,
+ kernel_size: int = 5,
+ base_dilation: int = 1,
+ layers: int = 5,
+ stacks: int = 1,
+ global_channels: int = -1,
+ dropout_rate: float = 0.0,
+ use_weight_norm: bool = True,
+ bias: bool = True,
+ use_only_mean: bool = True,
+ ):
+ """Initialzie ResidualAffineCouplingLayer module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ hidden_channels (int): Number of hidden channels.
+ kernel_size (int): Kernel size for WaveNet.
+ base_dilation (int): Base dilation factor for WaveNet.
+ layers (int): Number of layers of WaveNet.
+ stacks (int): Number of stacks of WaveNet.
+ global_channels (int): Number of global channels.
+ dropout_rate (float): Dropout rate.
+ use_weight_norm (bool): Whether to use weight normalization in WaveNet.
+ bias (bool): Whether to use bias paramters in WaveNet.
+ use_only_mean (bool): Whether to estimate only mean.
+
+ """
+ assert in_channels % 2 == 0, "in_channels should be divisible by 2"
+ super().__init__()
+ self.half_channels = in_channels // 2
+ self.use_only_mean = use_only_mean
+
+ # define modules
+ self.input_conv = torch.nn.Conv1d(
+ self.half_channels,
+ hidden_channels,
+ 1,
+ )
+ self.encoder = WaveNet(
+ in_channels=-1,
+ out_channels=-1,
+ kernel_size=kernel_size,
+ layers=layers,
+ stacks=stacks,
+ base_dilation=base_dilation,
+ residual_channels=hidden_channels,
+ aux_channels=-1,
+ gate_channels=hidden_channels * 2,
+ skip_channels=hidden_channels,
+ global_channels=global_channels,
+ dropout_rate=dropout_rate,
+ bias=bias,
+ use_weight_norm=use_weight_norm,
+ use_first_conv=False,
+ use_last_conv=False,
+ scale_residual=False,
+ scale_skip_connect=True,
+ )
+ if use_only_mean:
+ self.proj = torch.nn.Conv1d(
+ hidden_channels,
+ self.half_channels,
+ 1,
+ )
+ else:
+ self.proj = torch.nn.Conv1d(
+ hidden_channels,
+ self.half_channels * 2,
+ 1,
+ )
+ self.proj.weight.data.zero_()
+ self.proj.bias.data.zero_()
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ g: Optional[torch.Tensor] = None,
+ inverse: bool = False,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, in_channels, T).
+ x_lengths (Tensor): Length tensor (B,).
+ g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
+ inverse (bool): Whether to inverse the flow.
+
+ Returns:
+ Tensor: Output tensor (B, in_channels, T).
+ Tensor: Log-determinant tensor for NLL (B,) if not inverse.
+
+ """
+ xa, xb = x.split(x.size(1) // 2, dim=1)
+ h = self.input_conv(xa) * x_mask
+ h = self.encoder(h, x_mask, g=g)
+ stats = self.proj(h) * x_mask
+ if not self.use_only_mean:
+ m, logs = stats.split(stats.size(1) // 2, dim=1)
+ else:
+ m = stats
+ logs = torch.zeros_like(m)
+
+ if not inverse:
+ xb = m + xb * torch.exp(logs) * x_mask
+ x = torch.cat([xa, xb], 1)
+ logdet = torch.sum(logs, [1, 2])
+ return x, logdet
+ else:
+ xb = (xb - m) * torch.exp(-logs) * x_mask
+ x = torch.cat([xa, xb], 1)
+ return x
diff --git a/egs/ljspeech/TTS/vits/test_onnx.py b/egs/ljspeech/TTS/vits/test_onnx.py
new file mode 100755
index 0000000000..8acca7c026
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/test_onnx.py
@@ -0,0 +1,123 @@
+#!/usr/bin/env python3
+#
+# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This script is used to test the exported onnx model by vits/export-onnx.py
+
+Use the onnx model to generate a wav:
+./vits/test_onnx.py \
+ --model-filename vits/exp/vits-epoch-1000.onnx \
+ --tokens data/tokens.txt
+"""
+
+
+import argparse
+import logging
+import onnxruntime as ort
+import torch
+import torchaudio
+
+from tokenizer import Tokenizer
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--model-filename",
+ type=str,
+ required=True,
+ help="Path to the onnx model.",
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/tokens.txt",
+ help="""Path to vocabulary.""",
+ )
+
+ return parser
+
+
+class OnnxModel:
+ def __init__(self, model_filename: str):
+ session_opts = ort.SessionOptions()
+ session_opts.inter_op_num_threads = 1
+ session_opts.intra_op_num_threads = 4
+
+ self.session_opts = session_opts
+
+ self.model = ort.InferenceSession(
+ model_filename,
+ sess_options=self.session_opts,
+ providers=["CPUExecutionProvider"],
+ )
+ logging.info(f"{self.model.get_modelmeta().custom_metadata_map}")
+
+ def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ tokens:
+ A 1-D tensor of shape (1, T)
+ Returns:
+ A tensor of shape (1, T')
+ """
+ noise_scale = torch.tensor([0.667], dtype=torch.float32)
+ noise_scale_dur = torch.tensor([0.8], dtype=torch.float32)
+ alpha = torch.tensor([1.0], dtype=torch.float32)
+
+ out = self.model.run(
+ [
+ self.model.get_outputs()[0].name,
+ ],
+ {
+ self.model.get_inputs()[0].name: tokens.numpy(),
+ self.model.get_inputs()[1].name: tokens_lens.numpy(),
+ self.model.get_inputs()[2].name: noise_scale.numpy(),
+ self.model.get_inputs()[3].name: noise_scale_dur.numpy(),
+ self.model.get_inputs()[4].name: alpha.numpy(),
+ },
+ )[0]
+ return torch.from_numpy(out)
+
+
+def main():
+ args = get_parser().parse_args()
+
+ tokenizer = Tokenizer(args.tokens)
+
+ logging.info("About to create onnx model")
+ model = OnnxModel(args.model_filename)
+
+ text = "I went there to see the land, the people and how their system works, end quote."
+ tokens = tokenizer.texts_to_token_ids([text])
+ tokens = torch.tensor(tokens) # (1, T)
+ tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T)
+ audio = model(tokens, tokens_lens) # (1, T')
+
+ torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050)
+ logging.info("Saved to test_onnx.wav")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/ljspeech/TTS/vits/text_encoder.py b/egs/ljspeech/TTS/vits/text_encoder.py
new file mode 100644
index 0000000000..9f337e45bb
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/text_encoder.py
@@ -0,0 +1,662 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Text encoder module in VITS.
+
+This code is based on
+ - https://github.com/jaywalnut310/vits
+ - https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/text_encoder.py
+ - https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/transducer_stateless/conformer.py
+"""
+
+import copy
+import math
+from typing import Optional, Tuple
+
+import torch
+from torch import Tensor, nn
+
+from icefall.utils import is_jit_tracing, make_pad_mask
+
+
+class TextEncoder(torch.nn.Module):
+ """Text encoder module in VITS.
+
+ This is a module of text encoder described in `Conditional Variational Autoencoder
+ with Adversarial Learning for End-to-End Text-to-Speech`.
+ """
+
+ def __init__(
+ self,
+ vocabs: int,
+ d_model: int = 192,
+ num_heads: int = 2,
+ dim_feedforward: int = 768,
+ cnn_module_kernel: int = 5,
+ num_layers: int = 6,
+ dropout: float = 0.1,
+ ):
+ """Initialize TextEncoder module.
+
+ Args:
+ vocabs (int): Vocabulary size.
+ d_model (int): attention dimension
+ num_heads (int): number of attention heads
+ dim_feedforward (int): feedforward dimention
+ cnn_module_kernel (int): convolution kernel size
+ num_layers (int): number of encoder layers
+ dropout (float): dropout rate
+ """
+ super().__init__()
+ self.d_model = d_model
+
+ # define modules
+ self.emb = torch.nn.Embedding(vocabs, d_model)
+ torch.nn.init.normal_(self.emb.weight, 0.0, d_model**-0.5)
+
+ # We use conformer as text encoder
+ self.encoder = Transformer(
+ d_model=d_model,
+ num_heads=num_heads,
+ dim_feedforward=dim_feedforward,
+ cnn_module_kernel=cnn_module_kernel,
+ num_layers=num_layers,
+ dropout=dropout,
+ )
+
+ self.proj = torch.nn.Conv1d(d_model, d_model * 2, 1)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_lengths: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input index tensor (B, T_text).
+ x_lengths (Tensor): Length tensor (B,).
+
+ Returns:
+ Tensor: Encoded hidden representation (B, attention_dim, T_text).
+ Tensor: Projected mean tensor (B, attention_dim, T_text).
+ Tensor: Projected scale tensor (B, attention_dim, T_text).
+ Tensor: Mask tensor for input tensor (B, 1, T_text).
+
+ """
+ # (B, T_text, embed_dim)
+ x = self.emb(x) * math.sqrt(self.d_model)
+
+ assert x.size(1) == x_lengths.max().item()
+
+ # (B, T_text)
+ pad_mask = make_pad_mask(x_lengths)
+
+ # encoder assume the channel last (B, T_text, embed_dim)
+ x = self.encoder(x, key_padding_mask=pad_mask)
+
+ # convert the channel first (B, embed_dim, T_text)
+ x = x.transpose(1, 2)
+ non_pad_mask = (~pad_mask).unsqueeze(1)
+ stats = self.proj(x) * non_pad_mask
+ m, logs = stats.split(stats.size(1) // 2, dim=1)
+
+ return x, m, logs, non_pad_mask
+
+
+class Transformer(nn.Module):
+ """
+ Args:
+ d_model (int): attention dimension
+ num_heads (int): number of attention heads
+ dim_feedforward (int): feedforward dimention
+ cnn_module_kernel (int): convolution kernel size
+ num_layers (int): number of encoder layers
+ dropout (float): dropout rate
+ """
+
+ def __init__(
+ self,
+ d_model: int = 192,
+ num_heads: int = 2,
+ dim_feedforward: int = 768,
+ cnn_module_kernel: int = 5,
+ num_layers: int = 6,
+ dropout: float = 0.1,
+ ) -> None:
+ super().__init__()
+
+ self.num_layers = num_layers
+ self.d_model = d_model
+
+ self.encoder_pos = RelPositionalEncoding(d_model, dropout)
+
+ encoder_layer = TransformerEncoderLayer(
+ d_model=d_model,
+ num_heads=num_heads,
+ dim_feedforward=dim_feedforward,
+ cnn_module_kernel=cnn_module_kernel,
+ dropout=dropout,
+ )
+ self.encoder = TransformerEncoder(encoder_layer, num_layers)
+ self.after_norm = nn.LayerNorm(d_model)
+
+ def forward(
+ self, x: Tensor, key_padding_mask: Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ x:
+ The input tensor. Its shape is (batch_size, seq_len, feature_dim).
+ lengths:
+ A tensor of shape (batch_size,) containing the number of frames in
+ `x` before padding.
+ """
+ x, pos_emb = self.encoder_pos(x)
+ x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
+
+ x = self.encoder(
+ x, pos_emb, key_padding_mask=key_padding_mask
+ ) # (T, N, C)
+
+ x = self.after_norm(x)
+
+ x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
+ return x
+
+
+class TransformerEncoderLayer(nn.Module):
+ """
+ TransformerEncoderLayer is made up of self-attn and feedforward.
+
+ Args:
+ d_model: the number of expected features in the input.
+ num_heads: the number of heads in the multi-head attention models.
+ dim_feedforward: the dimension of the feed-forward network model.
+ dropout: the dropout value (default=0.1).
+ """
+
+ def __init__(
+ self,
+ d_model: int,
+ num_heads: int,
+ dim_feedforward: int,
+ cnn_module_kernel: int,
+ dropout: float = 0.1,
+ ) -> None:
+ super(TransformerEncoderLayer, self).__init__()
+
+ self.feed_forward_macaron = nn.Sequential(
+ nn.Linear(d_model, dim_feedforward),
+ Swish(),
+ nn.Dropout(dropout),
+ nn.Linear(dim_feedforward, d_model),
+ )
+
+ self.self_attn = RelPositionMultiheadAttention(d_model, num_heads, dropout=dropout)
+
+ self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
+
+ self.feed_forward = nn.Sequential(
+ nn.Linear(d_model, dim_feedforward),
+ Swish(),
+ nn.Dropout(dropout),
+ nn.Linear(dim_feedforward, d_model),
+ )
+
+ self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module
+ self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
+ self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
+ self.norm_final = nn.LayerNorm(d_model) # for the final output of the block
+ self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
+
+ self.ff_scale = 0.5
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(
+ self,
+ src: Tensor,
+ pos_emb: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """
+ Pass the input through the transformer encoder layer.
+
+ Args:
+ src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim).
+ pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim).
+ key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
+ """
+ # macaron style feed-forward module
+ src = src + self.ff_scale * self.dropout(self.feed_forward_macaron(self.norm_ff_macaron(src)))
+
+ # multi-head self-attention module
+ src_attn = self.self_attn(
+ self.norm_mha(src),
+ pos_emb=pos_emb,
+ key_padding_mask=key_padding_mask,
+ )
+ src = src + self.dropout(src_attn)
+
+ # convolution module
+ src = src + self.dropout(self.conv_module(self.norm_conv(src)))
+
+ # feed-forward module
+ src = src + self.dropout(self.feed_forward(self.norm_ff(src)))
+
+ src = self.norm_final(src)
+
+ return src
+
+
+class TransformerEncoder(nn.Module):
+ r"""TransformerEncoder is a stack of N encoder layers
+
+ Args:
+ encoder_layer: an instance of the TransformerEncoderLayer class.
+ num_layers: the number of sub-encoder-layers in the encoder.
+ """
+
+ def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
+ super().__init__()
+
+ self.layers = nn.ModuleList(
+ [copy.deepcopy(encoder_layer) for i in range(num_layers)]
+ )
+ self.num_layers = num_layers
+
+ def forward(
+ self,
+ src: Tensor,
+ pos_emb: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ r"""Pass the input through the encoder layers in turn.
+
+ Args:
+ src: the sequence to the encoder layer, of shape (seq_len, batch_size, embed_dim).
+ pos_emb: Positional embedding tensor, of shape (1, seq_len*2-1, pos_dim).
+ key_padding_mask: the mask for the src keys per batch, of shape (batch_size, seq_len)
+ """
+ output = src
+
+ for layer_index, mod in enumerate(self.layers):
+ output = mod(
+ output,
+ pos_emb,
+ key_padding_mask=key_padding_mask,
+ )
+
+ return output
+
+
+class RelPositionalEncoding(torch.nn.Module):
+ """Relative positional encoding module.
+
+ See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
+ Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
+
+ Args:
+ d_model: Embedding dimension.
+ dropout_rate: Dropout rate.
+ max_len: Maximum input length.
+
+ """
+
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
+ """Construct an PositionalEncoding object."""
+ super(RelPositionalEncoding, self).__init__()
+
+ self.d_model = d_model
+ self.xscale = math.sqrt(self.d_model)
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+ self.pe = None
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
+
+ def extend_pe(self, x: Tensor) -> None:
+ """Reset the positional encodings."""
+ x_size = x.size(1)
+ if self.pe is not None:
+ # self.pe contains both positive and negative parts
+ # the length of self.pe is 2 * input_len - 1
+ if self.pe.size(1) >= x_size * 2 - 1:
+ # Note: TorchScript doesn't implement operator== for torch.Device
+ if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+ # Suppose `i` means to the position of query vector and `j` means the
+ # position of key vector. We use position relative positions when keys
+ # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]:
+ """Add positional encoding.
+
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+ torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
+ """
+ self.extend_pe(x)
+ x = x * self.xscale
+ pos_emb = self.pe[
+ :,
+ self.pe.size(1) // 2
+ - x.size(1)
+ + 1 : self.pe.size(1) // 2 # noqa E203
+ + x.size(1),
+ ]
+ return self.dropout(x), self.dropout(pos_emb)
+
+
+class RelPositionMultiheadAttention(nn.Module):
+ r"""Multi-Head Attention layer with relative position encoding
+
+ See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
+
+ Args:
+ embed_dim: total dimension of the model.
+ num_heads: parallel attention heads.
+ dropout: a Dropout layer on attn_output_weights. Default: 0.0.
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ ) -> None:
+ super(RelPositionMultiheadAttention, self).__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ assert (
+ self.head_dim * num_heads == self.embed_dim
+ ), "embed_dim must be divisible by num_heads"
+
+ self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
+
+ # linear transformation for positional encoding.
+ self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
+ # these two learnable bias are used in matrix c and matrix d
+ # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
+ self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
+ self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
+
+ self._reset_parameters()
+
+ def _reset_parameters(self) -> None:
+ nn.init.xavier_uniform_(self.in_proj.weight)
+ nn.init.constant_(self.in_proj.bias, 0.0)
+ nn.init.constant_(self.out_proj.bias, 0.0)
+
+ nn.init.xavier_uniform_(self.pos_bias_u)
+ nn.init.xavier_uniform_(self.pos_bias_v)
+
+ def rel_shift(self, x: Tensor) -> Tensor:
+ """Compute relative positional encoding.
+
+ Args:
+ x: Input tensor (batch, head, seq_len, 2*seq_len-1).
+
+ Returns:
+ Tensor: tensor of shape (batch, head, seq_len, seq_len)
+ """
+ (batch_size, num_heads, seq_len, n) = x.shape
+
+ if not is_jit_tracing():
+ assert n == 2 * seq_len - 1, f"{n} == 2 * {seq_len} - 1"
+
+ if is_jit_tracing():
+ rows = torch.arange(start=seq_len - 1, end=-1, step=-1)
+ cols = torch.arange(seq_len)
+ rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
+ indexes = rows + cols
+
+ x = x.reshape(-1, n)
+ x = torch.gather(x, dim=1, index=indexes)
+ x = x.reshape(batch_size, num_heads, seq_len, seq_len)
+ return x
+ else:
+ # Note: TorchScript requires explicit arg for stride()
+ batch_stride = x.stride(0)
+ head_stride = x.stride(1)
+ time_stride = x.stride(2)
+ n_stride = x.stride(3)
+ return x.as_strided(
+ (batch_size, num_heads, seq_len, seq_len),
+ (batch_stride, head_stride, time_stride - n_stride, n_stride),
+ storage_offset=n_stride * (seq_len - 1),
+ )
+
+ def forward(
+ self,
+ x: Tensor,
+ pos_emb: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ """
+ Args:
+ x: Input tensor of shape (seq_len, batch_size, embed_dim)
+ pos_emb: Positional embedding tensor, (1, 2*seq_len-1, pos_dim)
+ key_padding_mask: if provided, specified padding elements in the key will
+ be ignored by the attention. This is an binary mask. When the value is True,
+ the corresponding value on the attention layer will be filled with -inf.
+ Its shape is (batch_size, seq_len).
+
+ Outputs:
+ A tensor of shape (seq_len, batch_size, embed_dim).
+ """
+ seq_len, batch_size, _ = x.shape
+ scaling = float(self.head_dim) ** -0.5
+
+ q, k, v = self.in_proj(x).chunk(3, dim=-1)
+
+ q = q.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
+ k = k.contiguous().view(seq_len, batch_size, self.num_heads, self.head_dim)
+ v = v.contiguous().view(seq_len, batch_size * self.num_heads, self.head_dim).transpose(0, 1)
+
+ q = q.transpose(0, 1) # (batch_size, seq_len, num_head, head_dim)
+
+ p = self.linear_pos(pos_emb).view(pos_emb.size(0), -1, self.num_heads, self.head_dim)
+ # (1, 2*seq_len, num_head, head_dim) -> (1, num_head, head_dim, 2*seq_len-1)
+ p = p.permute(0, 2, 3, 1)
+
+ # (batch_size, num_head, seq_len, head_dim)
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+
+ # compute attention score
+ # first compute matrix a and matrix c
+ # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
+ k = k.permute(1, 2, 3, 0) # (batch_size, num_head, head_dim, seq_len)
+ matrix_ac = torch.matmul(q_with_bias_u, k) # (batch_size, num_head, seq_len, seq_len)
+
+ # compute matrix b and matrix d
+ matrix_bd = torch.matmul(q_with_bias_v, p) # (batch_size, num_head, seq_len, 2*seq_len-1)
+ matrix_bd = self.rel_shift(matrix_bd) # (batch_size, num_head, seq_len, seq_len)
+
+ # (batch_size, num_head, seq_len, seq_len)
+ attn_output_weights = (matrix_ac + matrix_bd) * scaling
+ attn_output_weights = attn_output_weights.view(batch_size * self.num_heads, seq_len, seq_len)
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.shape == (batch_size, seq_len)
+ attn_output_weights = attn_output_weights.view(
+ batch_size, self.num_heads, seq_len, seq_len
+ )
+ attn_output_weights = attn_output_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
+ float("-inf"),
+ )
+ attn_output_weights = attn_output_weights.view(
+ batch_size * self.num_heads, seq_len, seq_len
+ )
+
+ attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
+ attn_output_weights = nn.functional.dropout(
+ attn_output_weights, p=self.dropout, training=self.training
+ )
+
+ # (batch_size * num_head, seq_len, head_dim)
+ attn_output = torch.bmm(attn_output_weights, v)
+ assert attn_output.shape == (batch_size * self.num_heads, seq_len, self.head_dim)
+
+ attn_output = (
+ attn_output.transpose(0, 1).contiguous().view(seq_len, batch_size, self.embed_dim)
+ )
+ # (seq_len, batch_size, embed_dim)
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output
+
+
+class ConvolutionModule(nn.Module):
+ """ConvolutionModule in Conformer model.
+ Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
+
+ Args:
+ channels (int): The number of channels of conv layers.
+ kernel_size (int): Kernerl size of conv layers.
+ bias (bool): Whether to use bias in conv layers (default=True).
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ kernel_size: int,
+ bias: bool = True,
+ ) -> None:
+ """Construct an ConvolutionModule object."""
+ super(ConvolutionModule, self).__init__()
+ # kernerl_size should be a odd number for 'SAME' padding
+ assert (kernel_size - 1) % 2 == 0
+
+ self.pointwise_conv1 = nn.Conv1d(
+ channels,
+ 2 * channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+
+ padding = (kernel_size - 1) // 2
+ self.depthwise_conv = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=padding,
+ groups=channels,
+ bias=bias,
+ )
+ self.norm = nn.LayerNorm(channels)
+ self.pointwise_conv2 = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ self.activation = Swish()
+
+ def forward(
+ self,
+ x: Tensor,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tuple[Tensor, Tensor]:
+ """Compute convolution module.
+
+ Args:
+ x: Input tensor (#time, batch, channels).
+ src_key_padding_mask: the mask for the src keys per batch (optional).
+
+ Returns:
+ Tensor: Output tensor (#time, batch, channels).
+
+ """
+ # exchange the temporal dimension and the feature dimension
+ x = x.permute(1, 2, 0) # (#batch, channels, time).
+
+ # GLU mechanism
+ x = self.pointwise_conv1(x) # (batch, 2*channels, time)
+ x = nn.functional.glu(x, dim=1) # (batch, channels, time)
+
+ # 1D Depthwise Conv
+ if src_key_padding_mask is not None:
+ x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
+ x = self.depthwise_conv(x)
+ # x is (batch, channels, time)
+ x = x.permute(0, 2, 1)
+ x = self.norm(x)
+ x = x.permute(0, 2, 1)
+
+ x = self.activation(x)
+
+ x = self.pointwise_conv2(x) # (batch, channel, time)
+
+ return x.permute(2, 0, 1)
+
+
+class Swish(nn.Module):
+ """Construct an Swish object."""
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Return Swich activation function."""
+ return x * torch.sigmoid(x)
+
+
+def _test_text_encoder():
+ vocabs = 500
+ d_model = 192
+ batch_size = 5
+ seq_len = 100
+
+ m = TextEncoder(vocabs=vocabs, d_model=d_model)
+ x, m, logs, mask = m(
+ x=torch.randint(low=0, high=vocabs, size=(batch_size, seq_len)),
+ x_lengths=torch.full((batch_size,), seq_len),
+ )
+ print(x.shape, m.shape, logs.shape, mask.shape)
+
+
+if __name__ == "__main__":
+ _test_text_encoder()
diff --git a/egs/ljspeech/TTS/vits/tokenizer.py b/egs/ljspeech/TTS/vits/tokenizer.py
new file mode 100644
index 0000000000..0678b26fe0
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/tokenizer.py
@@ -0,0 +1,106 @@
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
+#
+# See ../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, List
+
+import g2p_en
+import tacotron_cleaner.cleaners
+from utils import intersperse
+
+
+class Tokenizer(object):
+ def __init__(self, tokens: str):
+ """
+ Args:
+ tokens: the file that maps tokens to ids
+ """
+ # Parse token file
+ self.token2id: Dict[str, int] = {}
+ with open(tokens, "r", encoding="utf-8") as f:
+ for line in f.readlines():
+ info = line.rstrip().split()
+ if len(info) == 1:
+ # case of space
+ token = " "
+ id = int(info[0])
+ else:
+ token, id = info[0], int(info[1])
+ self.token2id[token] = id
+
+ self.blank_id = self.token2id[""]
+ self.oov_id = self.token2id[""]
+ self.vocab_size = len(self.token2id)
+
+ self.g2p = g2p_en.G2p()
+
+ def texts_to_token_ids(self, texts: List[str], intersperse_blank: bool = True):
+ """
+ Args:
+ texts:
+ A list of transcripts.
+ intersperse_blank:
+ Whether to intersperse blanks in the token sequence.
+
+ Returns:
+ Return a list of token id list [utterance][token_id]
+ """
+ token_ids_list = []
+
+ for text in texts:
+ # Text normalization
+ text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
+ # Convert to phonemes
+ tokens = self.g2p(text)
+ token_ids = []
+ for t in tokens:
+ if t in self.token2id:
+ token_ids.append(self.token2id[t])
+ else:
+ token_ids.append(self.oov_id)
+
+ if intersperse_blank:
+ token_ids = intersperse(token_ids, self.blank_id)
+
+ token_ids_list.append(token_ids)
+
+ return token_ids_list
+
+ def tokens_to_token_ids(self, tokens_list: List[str], intersperse_blank: bool = True):
+ """
+ Args:
+ tokens_list:
+ A list of token list, each corresponding to one utterance.
+ intersperse_blank:
+ Whether to intersperse blanks in the token sequence.
+
+ Returns:
+ Return a list of token id list [utterance][token_id]
+ """
+ token_ids_list = []
+
+ for tokens in tokens_list:
+ token_ids = []
+ for t in tokens:
+ if t in self.token2id:
+ token_ids.append(self.token2id[t])
+ else:
+ token_ids.append(self.oov_id)
+
+ if intersperse_blank:
+ token_ids = intersperse(token_ids, self.blank_id)
+ token_ids_list.append(token_ids)
+
+ return token_ids_list
diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py
new file mode 100755
index 0000000000..eb43a4cc93
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/train.py
@@ -0,0 +1,893 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+import numpy as np
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from lhotse.cut import Cut
+from lhotse.utils import fix_random_seed
+from torch.optim import Optimizer
+from torch.cuda.amp import GradScaler, autocast
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import AttributeDict, setup_logger, str2bool
+
+from tokenizer import Tokenizer
+from tts_datamodule import LJSpeechTtsDataModule
+from utils import MetricsTracker, plot_feature, save_checkpoint
+from vits import VITS
+
+LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=1000,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="vits/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--tokens",
+ type=str,
+ default="data/tokens.txt",
+ help="""Path to vocabulary.""",
+ )
+
+ parser.add_argument(
+ "--lr", type=float, default=2.0e-4, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=20,
+ help="""Save checkpoint after processing this number of epochs"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.cur_epoch % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'.
+ Since it will take around 1000 epochs, we suggest using a large
+ save_every_n to save disk space.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ # training params
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": -1, # 0
+ "log_interval": 50,
+ "valid_interval": 200,
+ "env_info": get_env_info(),
+ "sampling_rate": 22050,
+ "frame_shift": 256,
+ "frame_length": 1024,
+ "feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length
+ "n_mels": 80,
+ "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss
+ "lambda_mel": 45.0, # loss scaling coefficient for Mel loss
+ "lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss
+ "lambda_dur": 1.0, # loss scaling coefficient for duration loss
+ "lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss
+ }
+ )
+
+ return params
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict, model: nn.Module
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(filename, model=model)
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ return saved_params
+
+
+def get_model(params: AttributeDict) -> nn.Module:
+ mel_loss_params = {
+ "n_mels": params.n_mels,
+ "frame_length": params.frame_length,
+ "frame_shift": params.frame_shift,
+ }
+ model = VITS(
+ vocab_size=params.vocab_size,
+ feature_dim=params.feature_dim,
+ sampling_rate=params.sampling_rate,
+ mel_loss_params=mel_loss_params,
+ lambda_adv=params.lambda_adv,
+ lambda_mel=params.lambda_mel,
+ lambda_feat_match=params.lambda_feat_match,
+ lambda_dur=params.lambda_dur,
+ lambda_kl=params.lambda_kl,
+ )
+ return model
+
+
+def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
+ """Parse batch data"""
+ audio = batch["audio"].to(device)
+ features = batch["features"].to(device)
+ audio_lens = batch["audio_lens"].to(device)
+ features_lens = batch["features_lens"].to(device)
+ tokens = batch["tokens"]
+
+ tokens = tokenizer.tokens_to_token_ids(tokens)
+ tokens = k2.RaggedTensor(tokens)
+ row_splits = tokens.shape.row_splits(1)
+ tokens_lens = row_splits[1:] - row_splits[:-1]
+ tokens = tokens.to(device)
+ tokens_lens = tokens_lens.to(device)
+ # a tensor of shape (B, T)
+ tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
+
+ return audio, audio_lens, features, features_lens, tokens, tokens_lens
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ tokenizer: Tokenizer,
+ optimizer_g: Optimizer,
+ optimizer_d: Optimizer,
+ scheduler_g: LRSchedulerType,
+ scheduler_d: LRSchedulerType,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ tokenizer:
+ Used to convert text to phonemes.
+ optimizer_g:
+ The optimizer for generator.
+ optimizer_d:
+ The optimizer for discriminator.
+ scheduler_g:
+ The learning rate scheduler for generator, we call step() every epoch.
+ scheduler_d:
+ The learning rate scheduler for discriminator, we call step() every epoch.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ # used to summary the stats over iterations in one epoch
+ tot_loss = MetricsTracker()
+
+ saved_bad_model = False
+
+ def save_bad_model(suffix: str = ""):
+ save_checkpoint(
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
+ model=model,
+ params=params,
+ optimizer_g=optimizer_g,
+ optimizer_d=optimizer_d,
+ scheduler_g=scheduler_g,
+ scheduler_d=scheduler_d,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=0,
+ )
+
+ for batch_idx, batch in enumerate(train_dl):
+ params.batch_idx_train += 1
+
+ batch_size = len(batch["tokens"])
+ audio, audio_lens, features, features_lens, tokens, tokens_lens = \
+ prepare_input(batch, tokenizer, device)
+
+ loss_info = MetricsTracker()
+ loss_info['samples'] = batch_size
+
+ try:
+ with autocast(enabled=params.use_fp16):
+ # forward discriminator
+ loss_d, stats_d = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ forward_generator=False,
+ )
+ for k, v in stats_d.items():
+ loss_info[k] = v * batch_size
+ # update discriminator
+ optimizer_d.zero_grad()
+ scaler.scale(loss_d).backward()
+ scaler.step(optimizer_d)
+
+ with autocast(enabled=params.use_fp16):
+ # forward generator
+ loss_g, stats_g = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ forward_generator=True,
+ return_sample=params.batch_idx_train % params.log_interval == 0,
+ )
+ for k, v in stats_g.items():
+ if "returned_sample" not in k:
+ loss_info[k] = v * batch_size
+ # update generator
+ optimizer_g.zero_grad()
+ scaler.scale(loss_g).backward()
+ scaler.step(optimizer_g)
+ scaler.update()
+
+ # summary stats
+ tot_loss = tot_loss + loss_info
+ except: # noqa
+ save_bad_model()
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if params.batch_idx_train % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+
+ if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ if not saved_bad_model:
+ save_bad_model(suffix="-first-warning")
+ saved_bad_model = True
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ save_bad_model()
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if params.batch_idx_train % params.log_interval == 0:
+ cur_lr_g = max(scheduler_g.get_last_lr())
+ cur_lr_d = max(scheduler_d.get_last_lr())
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, batch {batch_idx}, "
+ f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, "
+ f"loss[{loss_info}], tot_loss[{tot_loss}], "
+ f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate_g", cur_lr_g, params.batch_idx_train
+ )
+ tb_writer.add_scalar(
+ "train/learning_rate_d", cur_lr_d, params.batch_idx_train
+ )
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(
+ tb_writer, "train/tot_", params.batch_idx_train
+ )
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+ if "returned_sample" in stats_g:
+ speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"]
+ tb_writer.add_audio(
+ "train/speech_hat_", speech_hat_, params.batch_idx_train, params.sampling_rate
+ )
+ tb_writer.add_audio(
+ "train/speech_", speech_, params.batch_idx_train, params.sampling_rate
+ )
+ tb_writer.add_image(
+ "train/mel_hat_", plot_feature(mel_hat_), params.batch_idx_train, dataformats='HWC'
+ )
+ tb_writer.add_image(
+ "train/mel_", plot_feature(mel_), params.batch_idx_train, dataformats='HWC'
+ )
+
+ if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info, (speech_hat, speech) = compute_validation_loss(
+ params=params,
+ model=model,
+ tokenizer=tokenizer,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+ tb_writer.add_audio(
+ "train/valdi_speech_hat", speech_hat, params.batch_idx_train, params.sampling_rate
+ )
+ tb_writer.add_audio(
+ "train/valdi_speech", speech, params.batch_idx_train, params.sampling_rate
+ )
+
+ loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ tokenizer: Tokenizer,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+ rank: int = 0,
+) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]:
+ """Run the validation process."""
+ model.eval()
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+
+ # used to summary the stats over iterations
+ tot_loss = MetricsTracker()
+ returned_sample = None
+
+ with torch.no_grad():
+ for batch_idx, batch in enumerate(valid_dl):
+ batch_size = len(batch["tokens"])
+ audio, audio_lens, features, features_lens, tokens, tokens_lens = \
+ prepare_input(batch, tokenizer, device)
+
+ loss_info = MetricsTracker()
+ loss_info['samples'] = batch_size
+
+ # forward discriminator
+ loss_d, stats_d = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ forward_generator=False,
+ )
+ assert loss_d.requires_grad is False
+ for k, v in stats_d.items():
+ loss_info[k] = v * batch_size
+
+ # forward generator
+ loss_g, stats_g = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ forward_generator=True,
+ )
+ assert loss_g.requires_grad is False
+ for k, v in stats_g.items():
+ loss_info[k] = v * batch_size
+
+ # summary stats
+ tot_loss = tot_loss + loss_info
+
+ # infer for first batch:
+ if batch_idx == 0 and rank == 0:
+ inner_model = model.module if isinstance(model, DDP) else model
+ audio_pred, _, duration = inner_model.inference(
+ text=tokens[0, :tokens_lens[0].item()]
+ )
+ audio_pred = audio_pred.data.cpu().numpy()
+ audio_len_pred = (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item()
+ assert audio_len_pred == len(audio_pred), (audio_len_pred, len(audio_pred))
+ audio_gt = audio[0, :audio_lens[0].item()].data.cpu().numpy()
+ returned_sample = (audio_pred, audio_gt)
+
+ if world_size > 1:
+ tot_loss.reduce(device)
+
+ loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss, returned_sample
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ tokenizer: Tokenizer,
+ optimizer_g: torch.optim.Optimizer,
+ optimizer_d: torch.optim.Optimizer,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ audio, audio_lens, features, features_lens, tokens, tokens_lens = \
+ prepare_input(batch, tokenizer, device)
+ try:
+ # for discriminator
+ with autocast(enabled=params.use_fp16):
+ loss_d, stats_d = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ forward_generator=False,
+ )
+ optimizer_d.zero_grad()
+ loss_d.backward()
+ # for generator
+ with autocast(enabled=params.use_fp16):
+ loss_g, stats_g = model(
+ text=tokens,
+ text_lengths=tokens_lens,
+ feats=features,
+ feats_lengths=features_lens,
+ speech=audio,
+ speech_lengths=audio_lens,
+ forward_generator=True,
+ )
+ optimizer_g.zero_grad()
+ loss_g.backward()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ tokenizer = Tokenizer(params.tokens)
+ params.blank_id = tokenizer.blank_id
+ params.oov_id = tokenizer.oov_id
+ params.vocab_size = tokenizer.vocab_size
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_model(params)
+ generator = model.generator
+ discriminator = model.discriminator
+
+ num_param_g = sum([p.numel() for p in generator.parameters()])
+ logging.info(f"Number of parameters in generator: {num_param_g}")
+ num_param_d = sum([p.numel() for p in discriminator.parameters()])
+ logging.info(f"Number of parameters in discriminator: {num_param_d}")
+ logging.info(f"Total number of parameters: {num_param_g + num_param_d}")
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(params=params, model=model)
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ optimizer_g = torch.optim.AdamW(
+ generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
+ )
+ optimizer_d = torch.optim.AdamW(
+ discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9
+ )
+
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875)
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875)
+
+ if checkpoints is not None:
+ # load state_dict for optimizers
+ if "optimizer_g" in checkpoints:
+ logging.info("Loading optimizer_g state dict")
+ optimizer_g.load_state_dict(checkpoints["optimizer_g"])
+ if "optimizer_d" in checkpoints:
+ logging.info("Loading optimizer_d state dict")
+ optimizer_d.load_state_dict(checkpoints["optimizer_d"])
+
+ # load state_dict for schedulers
+ if "scheduler_g" in checkpoints:
+ logging.info("Loading scheduler_g state dict")
+ scheduler_g.load_state_dict(checkpoints["scheduler_g"])
+ if "scheduler_d" in checkpoints:
+ logging.info("Loading scheduler_d state dict")
+ scheduler_d.load_state_dict(checkpoints["scheduler_d"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ ljspeech = LJSpeechTtsDataModule(args)
+
+ train_cuts = ljspeech.train_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 20.0:
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ # )
+ return False
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+ train_dl = ljspeech.train_dataloaders(train_cuts)
+
+ valid_cuts = ljspeech.valid_cuts()
+ valid_dl = ljspeech.valid_dataloaders(valid_cuts)
+
+ if not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ tokenizer=tokenizer,
+ optimizer_g=optimizer_g,
+ optimizer_d=optimizer_d,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ logging.info(f"Start epoch {epoch}")
+
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ params.cur_epoch = epoch
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ tokenizer=tokenizer,
+ optimizer_g=optimizer_g,
+ optimizer_d=optimizer_d,
+ scheduler_g=scheduler_g,
+ scheduler_d=scheduler_d,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ if epoch % params.save_every_n == 0 or epoch == params.num_epochs:
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint(
+ filename=filename,
+ params=params,
+ model=model,
+ optimizer_g=optimizer_g,
+ optimizer_d=optimizer_d,
+ scheduler_g=scheduler_g,
+ scheduler_d=scheduler_d,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ if rank == 0:
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+ # step per epoch
+ scheduler_g.step()
+ scheduler_d.step()
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def main():
+ parser = get_parser()
+ LJSpeechTtsDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/ljspeech/TTS/vits/transform.py b/egs/ljspeech/TTS/vits/transform.py
new file mode 100644
index 0000000000..c20d13130a
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/transform.py
@@ -0,0 +1,218 @@
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/transform.py
+
+"""Flow-related transformation.
+
+This code is derived from https://github.com/bayesiains/nflows.
+
+"""
+
+import numpy as np
+import torch
+from torch.nn import functional as F
+
+DEFAULT_MIN_BIN_WIDTH = 1e-3
+DEFAULT_MIN_BIN_HEIGHT = 1e-3
+DEFAULT_MIN_DERIVATIVE = 1e-3
+
+
+# TODO(kan-bayashi): Documentation and type hint
+def piecewise_rational_quadratic_transform(
+ inputs,
+ unnormalized_widths,
+ unnormalized_heights,
+ unnormalized_derivatives,
+ inverse=False,
+ tails=None,
+ tail_bound=1.0,
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
+):
+ if tails is None:
+ spline_fn = rational_quadratic_spline
+ spline_kwargs = {}
+ else:
+ spline_fn = unconstrained_rational_quadratic_spline
+ spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
+
+ outputs, logabsdet = spline_fn(
+ inputs=inputs,
+ unnormalized_widths=unnormalized_widths,
+ unnormalized_heights=unnormalized_heights,
+ unnormalized_derivatives=unnormalized_derivatives,
+ inverse=inverse,
+ min_bin_width=min_bin_width,
+ min_bin_height=min_bin_height,
+ min_derivative=min_derivative,
+ **spline_kwargs
+ )
+ return outputs, logabsdet
+
+
+# TODO(kan-bayashi): Documentation and type hint
+def unconstrained_rational_quadratic_spline(
+ inputs,
+ unnormalized_widths,
+ unnormalized_heights,
+ unnormalized_derivatives,
+ inverse=False,
+ tails="linear",
+ tail_bound=1.0,
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
+):
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
+ outside_interval_mask = ~inside_interval_mask
+
+ outputs = torch.zeros_like(inputs)
+ logabsdet = torch.zeros_like(inputs)
+
+ if tails == "linear":
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
+ constant = np.log(np.exp(1 - min_derivative) - 1)
+ unnormalized_derivatives[..., 0] = constant
+ unnormalized_derivatives[..., -1] = constant
+
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
+ logabsdet[outside_interval_mask] = 0
+ else:
+ raise RuntimeError("{} tails are not implemented.".format(tails))
+
+ (
+ outputs[inside_interval_mask],
+ logabsdet[inside_interval_mask],
+ ) = rational_quadratic_spline(
+ inputs=inputs[inside_interval_mask],
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
+ inverse=inverse,
+ left=-tail_bound,
+ right=tail_bound,
+ bottom=-tail_bound,
+ top=tail_bound,
+ min_bin_width=min_bin_width,
+ min_bin_height=min_bin_height,
+ min_derivative=min_derivative,
+ )
+
+ return outputs, logabsdet
+
+
+# TODO(kan-bayashi): Documentation and type hint
+def rational_quadratic_spline(
+ inputs,
+ unnormalized_widths,
+ unnormalized_heights,
+ unnormalized_derivatives,
+ inverse=False,
+ left=0.0,
+ right=1.0,
+ bottom=0.0,
+ top=1.0,
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
+):
+ if torch.min(inputs) < left or torch.max(inputs) > right:
+ raise ValueError("Input to a transform is not within its domain")
+
+ num_bins = unnormalized_widths.shape[-1]
+
+ if min_bin_width * num_bins > 1.0:
+ raise ValueError("Minimal bin width too large for the number of bins")
+ if min_bin_height * num_bins > 1.0:
+ raise ValueError("Minimal bin height too large for the number of bins")
+
+ widths = F.softmax(unnormalized_widths, dim=-1)
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
+ cumwidths = torch.cumsum(widths, dim=-1)
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
+ cumwidths = (right - left) * cumwidths + left
+ cumwidths[..., 0] = left
+ cumwidths[..., -1] = right
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
+
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
+
+ heights = F.softmax(unnormalized_heights, dim=-1)
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
+ cumheights = torch.cumsum(heights, dim=-1)
+ cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
+ cumheights = (top - bottom) * cumheights + bottom
+ cumheights[..., 0] = bottom
+ cumheights[..., -1] = top
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
+
+ if inverse:
+ bin_idx = _searchsorted(cumheights, inputs)[..., None]
+ else:
+ bin_idx = _searchsorted(cumwidths, inputs)[..., None]
+
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
+
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
+ delta = heights / widths
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
+
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
+
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
+
+ if inverse:
+ a = (inputs - input_cumheights) * (
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
+ ) + input_heights * (input_delta - input_derivatives)
+ b = input_heights * input_derivatives - (inputs - input_cumheights) * (
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
+ )
+ c = -input_delta * (inputs - input_cumheights)
+
+ discriminant = b.pow(2) - 4 * a * c
+ assert (discriminant >= 0).all()
+
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
+ outputs = root * input_bin_widths + input_cumwidths
+
+ theta_one_minus_theta = root * (1 - root)
+ denominator = input_delta + (
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
+ * theta_one_minus_theta
+ )
+ derivative_numerator = input_delta.pow(2) * (
+ input_derivatives_plus_one * root.pow(2)
+ + 2 * input_delta * theta_one_minus_theta
+ + input_derivatives * (1 - root).pow(2)
+ )
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
+
+ return outputs, -logabsdet
+ else:
+ theta = (inputs - input_cumwidths) / input_bin_widths
+ theta_one_minus_theta = theta * (1 - theta)
+
+ numerator = input_heights * (
+ input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
+ )
+ denominator = input_delta + (
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
+ * theta_one_minus_theta
+ )
+ outputs = input_cumheights + numerator / denominator
+
+ derivative_numerator = input_delta.pow(2) * (
+ input_derivatives_plus_one * theta.pow(2)
+ + 2 * input_delta * theta_one_minus_theta
+ + input_derivatives * (1 - theta).pow(2)
+ )
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
+
+ return outputs, logabsdet
+
+
+def _searchsorted(bin_locations, inputs, eps=1e-6):
+ bin_locations[..., -1] += eps
+ return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
diff --git a/egs/ljspeech/TTS/vits/tts_datamodule.py b/egs/ljspeech/TTS/vits/tts_datamodule.py
new file mode 100644
index 0000000000..0fcbb92c16
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/tts_datamodule.py
@@ -0,0 +1,325 @@
+# Copyright 2021 Piotr Żelasko
+# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+import torch
+from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy
+from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
+ CutConcatenate,
+ CutMix,
+ DynamicBucketingSampler,
+ SpeechSynthesisDataset,
+ PrecomputedFeatures,
+ SimpleCutSampler,
+ SpecAugment,
+)
+from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
+ AudioSamples,
+ OnTheFlyFeatures,
+)
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+class LJSpeechTtsDataModule:
+ """
+ DataModule for tts experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - cut concatenation,
+ - on-the-fly feature extraction
+
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="TTS data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/spectrogram"),
+ help="Path to directory with train/valid/test cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=200.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=30,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help="When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available.",
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--drop-last",
+ type=str2bool,
+ default=True,
+ help="Whether to drop last batch. Used by sampler.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=False,
+ help="When enabled, each batch will have the "
+ "field: batch['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ group.add_argument(
+ "--input-strategy",
+ type=str,
+ default="PrecomputedFeatures",
+ help="AudioSamples or PrecomputedFeatures",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+ logging.info("About to create train dataset")
+ train = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ feature_input_strategy=eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.on_the_fly_feats:
+ sampling_rate = 22050
+ config = SpectrogramConfig(
+ sampling_rate=sampling_rate,
+ frame_length=1024 / sampling_rate, # (in second),
+ frame_shift=256 / sampling_rate, # (in second)
+ use_fft_mag=True,
+ )
+ train = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
+ return_cuts=self.args.return_cuts,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ drop_last=self.args.drop_last,
+ )
+ else:
+ logging.info("Using SimpleCutSampler.")
+ train_sampler = SimpleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_sampler.load_state_dict(sampler_state_dict)
+
+ # 'seed' is derived from the current random state, which will have
+ # previously been set in the main process.
+ seed = torch.randint(0, 100000, ()).item()
+ worker_init_fn = _SeedWorkers(seed)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ worker_init_fn=worker_init_fn,
+ )
+
+ return train_dl
+
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+ logging.info("About to create dev dataset")
+ if self.args.on_the_fly_feats:
+ sampling_rate = 22050
+ config = SpectrogramConfig(
+ sampling_rate=sampling_rate,
+ frame_length=1024 / sampling_rate, # (in second),
+ frame_shift=256 / sampling_rate, # (in second)
+ use_fft_mag=True,
+ )
+ validate = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
+ return_cuts=self.args.return_cuts,
+ )
+ else:
+ validate = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ feature_input_strategy=eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create valid dataloader")
+ valid_dl = DataLoader(
+ validate,
+ sampler=valid_sampler,
+ batch_size=None,
+ num_workers=2,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.info("About to create test dataset")
+ if self.args.on_the_fly_feats:
+ sampling_rate = 22050
+ config = SpectrogramConfig(
+ sampling_rate=sampling_rate,
+ frame_length=1024 / sampling_rate, # (in second),
+ frame_shift=256 / sampling_rate, # (in second)
+ use_fft_mag=True,
+ )
+ test = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)),
+ return_cuts=self.args.return_cuts,
+ )
+ else:
+ test = SpeechSynthesisDataset(
+ return_text=False,
+ return_tokens=True,
+ feature_input_strategy=eval(self.args.input_strategy)(),
+ return_cuts=self.args.return_cuts,
+ )
+ test_sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration,
+ shuffle=False,
+ )
+ logging.info("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=test_sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ @lru_cache()
+ def train_cuts(self) -> CutSet:
+ logging.info("About to get train cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz"
+ )
+
+ @lru_cache()
+ def valid_cuts(self) -> CutSet:
+ logging.info("About to get validation cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz"
+ )
+
+ @lru_cache()
+ def test_cuts(self) -> CutSet:
+ logging.info("About to get test cuts")
+ return load_manifest_lazy(
+ self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz"
+ )
diff --git a/egs/ljspeech/TTS/vits/utils.py b/egs/ljspeech/TTS/vits/utils.py
new file mode 100644
index 0000000000..2a3dae9007
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/utils.py
@@ -0,0 +1,265 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Dict, List, Optional, Tuple, Union
+import collections
+import logging
+
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+from lhotse.dataset.sampling.base import CutSampler
+from pathlib import Path
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.optim import Optimizer
+from torch.utils.tensorboard import SummaryWriter
+
+
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py
+def get_random_segments(
+ x: torch.Tensor,
+ x_lengths: torch.Tensor,
+ segment_size: int,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Get random segments.
+
+ Args:
+ x (Tensor): Input tensor (B, C, T).
+ x_lengths (Tensor): Length tensor (B,).
+ segment_size (int): Segment size.
+
+ Returns:
+ Tensor: Segmented tensor (B, C, segment_size).
+ Tensor: Start index tensor (B,).
+
+ """
+ b, c, t = x.size()
+ max_start_idx = x_lengths - segment_size
+ max_start_idx[max_start_idx < 0] = 0
+ start_idxs = (torch.rand([b]).to(x.device) * max_start_idx).to(
+ dtype=torch.long,
+ )
+ segments = get_segments(x, start_idxs, segment_size)
+
+ return segments, start_idxs
+
+
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py
+def get_segments(
+ x: torch.Tensor,
+ start_idxs: torch.Tensor,
+ segment_size: int,
+) -> torch.Tensor:
+ """Get segments.
+
+ Args:
+ x (Tensor): Input tensor (B, C, T).
+ start_idxs (Tensor): Start index tensor (B,).
+ segment_size (int): Segment size.
+
+ Returns:
+ Tensor: Segmented tensor (B, C, segment_size).
+
+ """
+ b, c, t = x.size()
+ segments = x.new_zeros(b, c, segment_size)
+ for i, start_idx in enumerate(start_idxs):
+ segments[i] = x[i, :, start_idx : start_idx + segment_size]
+ return segments
+
+
+# from https://github.com/jaywalnut310/vit://github.com/jaywalnut310/vits/blob/main/commons.py
+def intersperse(sequence, item=0):
+ result = [item] * (len(sequence) * 2 + 1)
+ result[1::2] = sequence
+ return result
+
+
+# from https://github.com/jaywalnut310/vits/blob/main/utils.py
+MATPLOTLIB_FLAG = False
+
+
+def plot_feature(spectrogram):
+ global MATPLOTLIB_FLAG
+ if not MATPLOTLIB_FLAG:
+ import matplotlib
+ matplotlib.use("Agg")
+ MATPLOTLIB_FLAG = True
+ mpl_logger = logging.getLogger('matplotlib')
+ mpl_logger.setLevel(logging.WARNING)
+ import matplotlib.pylab as plt
+ import numpy as np
+
+ fig, ax = plt.subplots(figsize=(10, 2))
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
+ interpolation='none')
+ plt.colorbar(im, ax=ax)
+ plt.xlabel("Frames")
+ plt.ylabel("Channels")
+ plt.tight_layout()
+
+ fig.canvas.draw()
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
+ plt.close()
+ return data
+
+
+class MetricsTracker(collections.defaultdict):
+ def __init__(self):
+ # Passing the type 'int' to the base-class constructor
+ # makes undefined items default to int() which is zero.
+ # This class will play a role as metrics tracker.
+ # It can record many metrics, including but not limited to loss.
+ super(MetricsTracker, self).__init__(int)
+
+ def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
+ ans = MetricsTracker()
+ for k, v in self.items():
+ ans[k] = v
+ for k, v in other.items():
+ ans[k] = ans[k] + v
+ return ans
+
+ def __mul__(self, alpha: float) -> "MetricsTracker":
+ ans = MetricsTracker()
+ for k, v in self.items():
+ ans[k] = v * alpha
+ return ans
+
+ def __str__(self) -> str:
+ ans = ""
+ for k, v in self.norm_items():
+ norm_value = "%.4g" % v
+ ans += str(k) + "=" + str(norm_value) + ", "
+ samples = "%.2f" % self["samples"]
+ ans += "over " + str(samples) + " samples."
+ return ans
+
+ def norm_items(self) -> List[Tuple[str, float]]:
+ """
+ Returns a list of pairs, like:
+ [('loss_1', 0.1), ('loss_2', 0.07)]
+ """
+ samples = self["samples"] if "samples" in self else 1
+ ans = []
+ for k, v in self.items():
+ if k == "samples":
+ continue
+ norm_value = float(v) / samples
+ ans.append((k, norm_value))
+ return ans
+
+ def reduce(self, device):
+ """
+ Reduce using torch.distributed, which I believe ensures that
+ all processes get the total.
+ """
+ keys = sorted(self.keys())
+ s = torch.tensor([float(self[k]) for k in keys], device=device)
+ dist.all_reduce(s, op=dist.ReduceOp.SUM)
+ for k, v in zip(keys, s.cpu().tolist()):
+ self[k] = v
+
+ def write_summary(
+ self,
+ tb_writer: SummaryWriter,
+ prefix: str,
+ batch_idx: int,
+ ) -> None:
+ """Add logging information to a TensorBoard writer.
+
+ Args:
+ tb_writer: a TensorBoard writer
+ prefix: a prefix for the name of the loss, e.g. "train/valid_",
+ or "train/current_"
+ batch_idx: The current batch index, used as the x-axis of the plot.
+ """
+ for k, v in self.norm_items():
+ tb_writer.add_scalar(prefix + k, v, batch_idx)
+
+
+# checkpoint saving and loading
+LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
+
+
+def save_checkpoint(
+ filename: Path,
+ model: Union[nn.Module, DDP],
+ params: Optional[Dict[str, Any]] = None,
+ optimizer_g: Optional[Optimizer] = None,
+ optimizer_d: Optional[Optimizer] = None,
+ scheduler_g: Optional[LRSchedulerType] = None,
+ scheduler_d: Optional[LRSchedulerType] = None,
+ scaler: Optional[GradScaler] = None,
+ sampler: Optional[CutSampler] = None,
+ rank: int = 0,
+) -> None:
+ """Save training information to a file.
+
+ Args:
+ filename:
+ The checkpoint filename.
+ model:
+ The model to be saved. We only save its `state_dict()`.
+ model_avg:
+ The stored model averaged from the start of training.
+ params:
+ User defined parameters, e.g., epoch, loss.
+ optimizer_g:
+ The optimizer for generator used in the training.
+ Its `state_dict` will be saved.
+ optimizer_d:
+ The optimizer for discriminator used in the training.
+ Its `state_dict` will be saved.
+ scheduler_g:
+ The learning rate scheduler for generator used in the training.
+ Its `state_dict` will be saved.
+ scheduler_d:
+ The learning rate scheduler for discriminator used in the training.
+ Its `state_dict` will be saved.
+ scalar:
+ The GradScaler to be saved. We only save its `state_dict()`.
+ rank:
+ Used in DDP. We save checkpoint only for the node whose rank is 0.
+ Returns:
+ Return None.
+ """
+ if rank != 0:
+ return
+
+ logging.info(f"Saving checkpoint to {filename}")
+
+ if isinstance(model, DDP):
+ model = model.module
+
+ checkpoint = {
+ "model": model.state_dict(),
+ "optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None,
+ "optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None,
+ "scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None,
+ "scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None,
+ "grad_scaler": scaler.state_dict() if scaler is not None else None,
+ "sampler": sampler.state_dict() if sampler is not None else None,
+ }
+
+ if params:
+ for k, v in params.items():
+ assert k not in checkpoint
+ checkpoint[k] = v
+
+ torch.save(checkpoint, filename)
diff --git a/egs/ljspeech/TTS/vits/vits.py b/egs/ljspeech/TTS/vits/vits.py
new file mode 100644
index 0000000000..d5e20a5787
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/vits.py
@@ -0,0 +1,610 @@
+# based on https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/vits.py
+
+# Copyright 2021 Tomoki Hayashi
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""VITS module for GAN-TTS task."""
+
+from typing import Any, Dict, Optional, Tuple
+
+import torch
+import torch.nn as nn
+from torch.cuda.amp import autocast
+
+from hifigan import (
+ HiFiGANMultiPeriodDiscriminator,
+ HiFiGANMultiScaleDiscriminator,
+ HiFiGANMultiScaleMultiPeriodDiscriminator,
+ HiFiGANPeriodDiscriminator,
+ HiFiGANScaleDiscriminator,
+)
+from loss import (
+ DiscriminatorAdversarialLoss,
+ FeatureMatchLoss,
+ GeneratorAdversarialLoss,
+ KLDivergenceLoss,
+ MelSpectrogramLoss,
+)
+from utils import get_segments
+from generator import VITSGenerator
+
+
+AVAILABLE_GENERATERS = {
+ "vits_generator": VITSGenerator,
+}
+AVAILABLE_DISCRIMINATORS = {
+ "hifigan_period_discriminator": HiFiGANPeriodDiscriminator,
+ "hifigan_scale_discriminator": HiFiGANScaleDiscriminator,
+ "hifigan_multi_period_discriminator": HiFiGANMultiPeriodDiscriminator,
+ "hifigan_multi_scale_discriminator": HiFiGANMultiScaleDiscriminator,
+ "hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, # NOQA
+}
+
+
+class VITS(nn.Module):
+ """Implement VITS, `Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`
+ """
+
+ def __init__(
+ self,
+ # generator related
+ vocab_size: int,
+ feature_dim: int = 513,
+ sampling_rate: int = 22050,
+ generator_type: str = "vits_generator",
+ generator_params: Dict[str, Any] = {
+ "hidden_channels": 192,
+ "spks": None,
+ "langs": None,
+ "spk_embed_dim": None,
+ "global_channels": -1,
+ "segment_size": 32,
+ "text_encoder_attention_heads": 2,
+ "text_encoder_ffn_expand": 4,
+ "text_encoder_cnn_module_kernel": 5,
+ "text_encoder_blocks": 6,
+ "text_encoder_dropout_rate": 0.1,
+ "decoder_kernel_size": 7,
+ "decoder_channels": 512,
+ "decoder_upsample_scales": [8, 8, 2, 2],
+ "decoder_upsample_kernel_sizes": [16, 16, 4, 4],
+ "decoder_resblock_kernel_sizes": [3, 7, 11],
+ "decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ "use_weight_norm_in_decoder": True,
+ "posterior_encoder_kernel_size": 5,
+ "posterior_encoder_layers": 16,
+ "posterior_encoder_stacks": 1,
+ "posterior_encoder_base_dilation": 1,
+ "posterior_encoder_dropout_rate": 0.0,
+ "use_weight_norm_in_posterior_encoder": True,
+ "flow_flows": 4,
+ "flow_kernel_size": 5,
+ "flow_base_dilation": 1,
+ "flow_layers": 4,
+ "flow_dropout_rate": 0.0,
+ "use_weight_norm_in_flow": True,
+ "use_only_mean_in_flow": True,
+ "stochastic_duration_predictor_kernel_size": 3,
+ "stochastic_duration_predictor_dropout_rate": 0.5,
+ "stochastic_duration_predictor_flows": 4,
+ "stochastic_duration_predictor_dds_conv_layers": 3,
+ },
+ # discriminator related
+ discriminator_type: str = "hifigan_multi_scale_multi_period_discriminator",
+ discriminator_params: Dict[str, Any] = {
+ "scales": 1,
+ "scale_downsample_pooling": "AvgPool1d",
+ "scale_downsample_pooling_params": {
+ "kernel_size": 4,
+ "stride": 2,
+ "padding": 2,
+ },
+ "scale_discriminator_params": {
+ "in_channels": 1,
+ "out_channels": 1,
+ "kernel_sizes": [15, 41, 5, 3],
+ "channels": 128,
+ "max_downsample_channels": 1024,
+ "max_groups": 16,
+ "bias": True,
+ "downsample_scales": [2, 2, 4, 4, 1],
+ "nonlinear_activation": "LeakyReLU",
+ "nonlinear_activation_params": {"negative_slope": 0.1},
+ "use_weight_norm": True,
+ "use_spectral_norm": False,
+ },
+ "follow_official_norm": False,
+ "periods": [2, 3, 5, 7, 11],
+ "period_discriminator_params": {
+ "in_channels": 1,
+ "out_channels": 1,
+ "kernel_sizes": [5, 3],
+ "channels": 32,
+ "downsample_scales": [3, 3, 3, 3, 1],
+ "max_downsample_channels": 1024,
+ "bias": True,
+ "nonlinear_activation": "LeakyReLU",
+ "nonlinear_activation_params": {"negative_slope": 0.1},
+ "use_weight_norm": True,
+ "use_spectral_norm": False,
+ },
+ },
+ # loss related
+ generator_adv_loss_params: Dict[str, Any] = {
+ "average_by_discriminators": False,
+ "loss_type": "mse",
+ },
+ discriminator_adv_loss_params: Dict[str, Any] = {
+ "average_by_discriminators": False,
+ "loss_type": "mse",
+ },
+ feat_match_loss_params: Dict[str, Any] = {
+ "average_by_discriminators": False,
+ "average_by_layers": False,
+ "include_final_outputs": True,
+ },
+ mel_loss_params: Dict[str, Any] = {
+ "frame_shift": 256,
+ "frame_length": 1024,
+ "n_mels": 80,
+ },
+ lambda_adv: float = 1.0,
+ lambda_mel: float = 45.0,
+ lambda_feat_match: float = 2.0,
+ lambda_dur: float = 1.0,
+ lambda_kl: float = 1.0,
+ cache_generator_outputs: bool = True,
+ ):
+ """Initialize VITS module.
+
+ Args:
+ idim (int): Input vocabrary size.
+ odim (int): Acoustic feature dimension. The actual output channels will
+ be 1 since VITS is the end-to-end text-to-wave model but for the
+ compatibility odim is used to indicate the acoustic feature dimension.
+ sampling_rate (int): Sampling rate, not used for the training but it will
+ be referred in saving waveform during the inference.
+ generator_type (str): Generator type.
+ generator_params (Dict[str, Any]): Parameter dict for generator.
+ discriminator_type (str): Discriminator type.
+ discriminator_params (Dict[str, Any]): Parameter dict for discriminator.
+ generator_adv_loss_params (Dict[str, Any]): Parameter dict for generator
+ adversarial loss.
+ discriminator_adv_loss_params (Dict[str, Any]): Parameter dict for
+ discriminator adversarial loss.
+ feat_match_loss_params (Dict[str, Any]): Parameter dict for feat match loss.
+ mel_loss_params (Dict[str, Any]): Parameter dict for mel loss.
+ lambda_adv (float): Loss scaling coefficient for adversarial loss.
+ lambda_mel (float): Loss scaling coefficient for mel spectrogram loss.
+ lambda_feat_match (float): Loss scaling coefficient for feat match loss.
+ lambda_dur (float): Loss scaling coefficient for duration loss.
+ lambda_kl (float): Loss scaling coefficient for KL divergence loss.
+ cache_generator_outputs (bool): Whether to cache generator outputs.
+
+ """
+ super().__init__()
+
+ # define modules
+ generator_class = AVAILABLE_GENERATERS[generator_type]
+ if generator_type == "vits_generator":
+ # NOTE(kan-bayashi): Update parameters for the compatibility.
+ # The idim and odim is automatically decided from input data,
+ # where idim represents #vocabularies and odim represents
+ # the input acoustic feature dimension.
+ generator_params.update(vocabs=vocab_size, aux_channels=feature_dim)
+ self.generator = generator_class(
+ **generator_params,
+ )
+ discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type]
+ self.discriminator = discriminator_class(
+ **discriminator_params,
+ )
+ self.generator_adv_loss = GeneratorAdversarialLoss(
+ **generator_adv_loss_params,
+ )
+ self.discriminator_adv_loss = DiscriminatorAdversarialLoss(
+ **discriminator_adv_loss_params,
+ )
+ self.feat_match_loss = FeatureMatchLoss(
+ **feat_match_loss_params,
+ )
+ mel_loss_params.update(sampling_rate=sampling_rate)
+ self.mel_loss = MelSpectrogramLoss(
+ **mel_loss_params,
+ )
+ self.kl_loss = KLDivergenceLoss()
+
+ # coefficients
+ self.lambda_adv = lambda_adv
+ self.lambda_mel = lambda_mel
+ self.lambda_kl = lambda_kl
+ self.lambda_feat_match = lambda_feat_match
+ self.lambda_dur = lambda_dur
+
+ # cache
+ self.cache_generator_outputs = cache_generator_outputs
+ self._cache = None
+
+ # store sampling rate for saving wav file
+ # (not used for the training)
+ self.sampling_rate = sampling_rate
+
+ # store parameters for test compatibility
+ self.spks = self.generator.spks
+ self.langs = self.generator.langs
+ self.spk_embed_dim = self.generator.spk_embed_dim
+
+ def forward(
+ self,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ feats: torch.Tensor,
+ feats_lengths: torch.Tensor,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ return_sample: bool = False,
+ sids: Optional[torch.Tensor] = None,
+ spembs: Optional[torch.Tensor] = None,
+ lids: Optional[torch.Tensor] = None,
+ forward_generator: bool = True,
+ ) -> Tuple[torch.Tensor, Dict[str, Any]]:
+ """Perform generator forward.
+
+ Args:
+ text (Tensor): Text index tensor (B, T_text).
+ text_lengths (Tensor): Text length tensor (B,).
+ feats (Tensor): Feature tensor (B, T_feats, aux_channels).
+ feats_lengths (Tensor): Feature length tensor (B,).
+ speech (Tensor): Speech waveform tensor (B, T_wav).
+ speech_lengths (Tensor): Speech length tensor (B,).
+ sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
+ spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
+ lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
+ forward_generator (bool): Whether to forward generator.
+
+ Returns:
+ - loss (Tensor): Loss scalar tensor.
+ - stats (Dict[str, float]): Statistics to be monitored.
+ """
+ if forward_generator:
+ return self._forward_generator(
+ text=text,
+ text_lengths=text_lengths,
+ feats=feats,
+ feats_lengths=feats_lengths,
+ speech=speech,
+ speech_lengths=speech_lengths,
+ return_sample=return_sample,
+ sids=sids,
+ spembs=spembs,
+ lids=lids,
+ )
+ else:
+ return self._forward_discrminator(
+ text=text,
+ text_lengths=text_lengths,
+ feats=feats,
+ feats_lengths=feats_lengths,
+ speech=speech,
+ speech_lengths=speech_lengths,
+ sids=sids,
+ spembs=spembs,
+ lids=lids,
+ )
+
+ def _forward_generator(
+ self,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ feats: torch.Tensor,
+ feats_lengths: torch.Tensor,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ return_sample: bool = False,
+ sids: Optional[torch.Tensor] = None,
+ spembs: Optional[torch.Tensor] = None,
+ lids: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Dict[str, Any]]:
+ """Perform generator forward.
+
+ Args:
+ text (Tensor): Text index tensor (B, T_text).
+ text_lengths (Tensor): Text length tensor (B,).
+ feats (Tensor): Feature tensor (B, T_feats, aux_channels).
+ feats_lengths (Tensor): Feature length tensor (B,).
+ speech (Tensor): Speech waveform tensor (B, T_wav).
+ speech_lengths (Tensor): Speech length tensor (B,).
+ sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
+ spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
+ lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
+
+ Returns:
+ * loss (Tensor): Loss scalar tensor.
+ * stats (Dict[str, float]): Statistics to be monitored.
+ """
+ # setup
+ feats = feats.transpose(1, 2)
+ speech = speech.unsqueeze(1)
+
+ # calculate generator outputs
+ reuse_cache = True
+ if not self.cache_generator_outputs or self._cache is None:
+ reuse_cache = False
+ outs = self.generator(
+ text=text,
+ text_lengths=text_lengths,
+ feats=feats,
+ feats_lengths=feats_lengths,
+ sids=sids,
+ spembs=spembs,
+ lids=lids,
+ )
+ else:
+ outs = self._cache
+
+ # store cache
+ if self.training and self.cache_generator_outputs and not reuse_cache:
+ self._cache = outs
+
+ # parse outputs
+ speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs
+ _, z_p, m_p, logs_p, _, logs_q = outs_
+ speech_ = get_segments(
+ x=speech,
+ start_idxs=start_idxs * self.generator.upsample_factor,
+ segment_size=self.generator.segment_size * self.generator.upsample_factor,
+ )
+
+ # calculate discriminator outputs
+ p_hat = self.discriminator(speech_hat_)
+ with torch.no_grad():
+ # do not store discriminator gradient in generator turn
+ p = self.discriminator(speech_)
+
+ # calculate losses
+ with autocast(enabled=False):
+ if not return_sample:
+ mel_loss = self.mel_loss(speech_hat_, speech_)
+ else:
+ mel_loss, (mel_hat_, mel_) = self.mel_loss(
+ speech_hat_, speech_, return_mel=True
+ )
+ kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask)
+ dur_loss = torch.sum(dur_nll.float())
+ adv_loss = self.generator_adv_loss(p_hat)
+ feat_match_loss = self.feat_match_loss(p_hat, p)
+
+ mel_loss = mel_loss * self.lambda_mel
+ kl_loss = kl_loss * self.lambda_kl
+ dur_loss = dur_loss * self.lambda_dur
+ adv_loss = adv_loss * self.lambda_adv
+ feat_match_loss = feat_match_loss * self.lambda_feat_match
+ loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss
+
+ stats = dict(
+ generator_loss=loss.item(),
+ generator_mel_loss=mel_loss.item(),
+ generator_kl_loss=kl_loss.item(),
+ generator_dur_loss=dur_loss.item(),
+ generator_adv_loss=adv_loss.item(),
+ generator_feat_match_loss=feat_match_loss.item(),
+ )
+
+ if return_sample:
+ stats["returned_sample"] = (
+ speech_hat_[0].data.cpu().numpy(),
+ speech_[0].data.cpu().numpy(),
+ mel_hat_[0].data.cpu().numpy(),
+ mel_[0].data.cpu().numpy(),
+ )
+
+ # reset cache
+ if reuse_cache or not self.training:
+ self._cache = None
+
+ return loss, stats
+
+ def _forward_discrminator(
+ self,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ feats: torch.Tensor,
+ feats_lengths: torch.Tensor,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ sids: Optional[torch.Tensor] = None,
+ spembs: Optional[torch.Tensor] = None,
+ lids: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Dict[str, Any]]:
+ """Perform discriminator forward.
+
+ Args:
+ text (Tensor): Text index tensor (B, T_text).
+ text_lengths (Tensor): Text length tensor (B,).
+ feats (Tensor): Feature tensor (B, T_feats, aux_channels).
+ feats_lengths (Tensor): Feature length tensor (B,).
+ speech (Tensor): Speech waveform tensor (B, T_wav).
+ speech_lengths (Tensor): Speech length tensor (B,).
+ sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
+ spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
+ lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
+
+ Returns:
+ * loss (Tensor): Loss scalar tensor.
+ * stats (Dict[str, float]): Statistics to be monitored.
+ """
+ # setup
+ feats = feats.transpose(1, 2)
+ speech = speech.unsqueeze(1)
+
+ # calculate generator outputs
+ reuse_cache = True
+ if not self.cache_generator_outputs or self._cache is None:
+ reuse_cache = False
+ outs = self.generator(
+ text=text,
+ text_lengths=text_lengths,
+ feats=feats,
+ feats_lengths=feats_lengths,
+ sids=sids,
+ spembs=spembs,
+ lids=lids,
+ )
+ else:
+ outs = self._cache
+
+ # store cache
+ if self.cache_generator_outputs and not reuse_cache:
+ self._cache = outs
+
+ # parse outputs
+ speech_hat_, _, _, start_idxs, *_ = outs
+ speech_ = get_segments(
+ x=speech,
+ start_idxs=start_idxs * self.generator.upsample_factor,
+ segment_size=self.generator.segment_size * self.generator.upsample_factor,
+ )
+
+ # calculate discriminator outputs
+ p_hat = self.discriminator(speech_hat_.detach())
+ p = self.discriminator(speech_)
+
+ # calculate losses
+ with autocast(enabled=False):
+ real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
+ loss = real_loss + fake_loss
+
+ stats = dict(
+ discriminator_loss=loss.item(),
+ discriminator_real_loss=real_loss.item(),
+ discriminator_fake_loss=fake_loss.item(),
+ )
+
+ # reset cache
+ if reuse_cache or not self.training:
+ self._cache = None
+
+ return loss, stats
+
+ def inference(
+ self,
+ text: torch.Tensor,
+ feats: Optional[torch.Tensor] = None,
+ sids: Optional[torch.Tensor] = None,
+ spembs: Optional[torch.Tensor] = None,
+ lids: Optional[torch.Tensor] = None,
+ durations: Optional[torch.Tensor] = None,
+ noise_scale: float = 0.667,
+ noise_scale_dur: float = 0.8,
+ alpha: float = 1.0,
+ max_len: Optional[int] = None,
+ use_teacher_forcing: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Run inference for single sample.
+
+ Args:
+ text (Tensor): Input text index tensor (T_text,).
+ feats (Tensor): Feature tensor (T_feats, aux_channels).
+ sids (Tensor): Speaker index tensor (1,).
+ spembs (Optional[Tensor]): Speaker embedding tensor (spk_embed_dim,).
+ lids (Tensor): Language index tensor (1,).
+ durations (Tensor): Ground-truth duration tensor (T_text,).
+ noise_scale (float): Noise scale value for flow.
+ noise_scale_dur (float): Noise scale value for duration predictor.
+ alpha (float): Alpha parameter to control the speed of generated speech.
+ max_len (Optional[int]): Maximum length.
+ use_teacher_forcing (bool): Whether to use teacher forcing.
+
+ Returns:
+ * wav (Tensor): Generated waveform tensor (T_wav,).
+ * att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text).
+ * duration (Tensor): Predicted duration tensor (T_text,).
+ """
+ # setup
+ text = text[None]
+ text_lengths = torch.tensor(
+ [text.size(1)],
+ dtype=torch.long,
+ device=text.device,
+ )
+ if sids is not None:
+ sids = sids.view(1)
+ if lids is not None:
+ lids = lids.view(1)
+ if durations is not None:
+ durations = durations.view(1, 1, -1)
+
+ # inference
+ if use_teacher_forcing:
+ assert feats is not None
+ feats = feats[None].transpose(1, 2)
+ feats_lengths = torch.tensor(
+ [feats.size(2)],
+ dtype=torch.long,
+ device=feats.device,
+ )
+ wav, att_w, dur = self.generator.inference(
+ text=text,
+ text_lengths=text_lengths,
+ feats=feats,
+ feats_lengths=feats_lengths,
+ sids=sids,
+ spembs=spembs,
+ lids=lids,
+ max_len=max_len,
+ use_teacher_forcing=use_teacher_forcing,
+ )
+ else:
+ wav, att_w, dur = self.generator.inference(
+ text=text,
+ text_lengths=text_lengths,
+ sids=sids,
+ spembs=spembs,
+ lids=lids,
+ dur=durations,
+ noise_scale=noise_scale,
+ noise_scale_dur=noise_scale_dur,
+ alpha=alpha,
+ max_len=max_len,
+ )
+ return wav.view(-1), att_w[0], dur[0]
+
+ def inference_batch(
+ self,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ sids: Optional[torch.Tensor] = None,
+ durations: Optional[torch.Tensor] = None,
+ noise_scale: float = 0.667,
+ noise_scale_dur: float = 0.8,
+ alpha: float = 1.0,
+ max_len: Optional[int] = None,
+ use_teacher_forcing: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Run inference for one batch.
+
+ Args:
+ text (Tensor): Input text index tensor (B, T_text).
+ text_lengths (Tensor): Input text index tensor (B,).
+ sids (Tensor): Speaker index tensor (B,).
+ noise_scale (float): Noise scale value for flow.
+ noise_scale_dur (float): Noise scale value for duration predictor.
+ alpha (float): Alpha parameter to control the speed of generated speech.
+ max_len (Optional[int]): Maximum length.
+
+ Returns:
+ * wav (Tensor): Generated waveform tensor (B, T_wav).
+ * att_w (Tensor): Monotonic attention weight tensor (B, T_feats, T_text).
+ * duration (Tensor): Predicted duration tensor (B, T_text).
+ """
+ # inference
+ wav, att_w, dur = self.generator.inference(
+ text=text,
+ text_lengths=text_lengths,
+ sids=sids,
+ noise_scale=noise_scale,
+ noise_scale_dur=noise_scale_dur,
+ alpha=alpha,
+ max_len=max_len,
+ )
+ return wav, att_w, dur
diff --git a/egs/ljspeech/TTS/vits/wavenet.py b/egs/ljspeech/TTS/vits/wavenet.py
new file mode 100644
index 0000000000..fbe1be52b0
--- /dev/null
+++ b/egs/ljspeech/TTS/vits/wavenet.py
@@ -0,0 +1,349 @@
+# from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/wavenet/wavenet.py
+
+# Copyright 2021 Tomoki Hayashi
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""WaveNet modules.
+
+This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
+
+"""
+
+import math
+import logging
+
+from typing import Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+
+
+class WaveNet(torch.nn.Module):
+ """WaveNet with global conditioning."""
+
+ def __init__(
+ self,
+ in_channels: int = 1,
+ out_channels: int = 1,
+ kernel_size: int = 3,
+ layers: int = 30,
+ stacks: int = 3,
+ base_dilation: int = 2,
+ residual_channels: int = 64,
+ aux_channels: int = -1,
+ gate_channels: int = 128,
+ skip_channels: int = 64,
+ global_channels: int = -1,
+ dropout_rate: float = 0.0,
+ bias: bool = True,
+ use_weight_norm: bool = True,
+ use_first_conv: bool = False,
+ use_last_conv: bool = False,
+ scale_residual: bool = False,
+ scale_skip_connect: bool = False,
+ ):
+ """Initialize WaveNet module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ kernel_size (int): Kernel size of dilated convolution.
+ layers (int): Number of residual block layers.
+ stacks (int): Number of stacks i.e., dilation cycles.
+ base_dilation (int): Base dilation factor.
+ residual_channels (int): Number of channels in residual conv.
+ gate_channels (int): Number of channels in gated conv.
+ skip_channels (int): Number of channels in skip conv.
+ aux_channels (int): Number of channels for local conditioning feature.
+ global_channels (int): Number of channels for global conditioning feature.
+ dropout_rate (float): Dropout rate. 0.0 means no dropout applied.
+ bias (bool): Whether to use bias parameter in conv layer.
+ use_weight_norm (bool): Whether to use weight norm. If set to true, it will
+ be applied to all of the conv layers.
+ use_first_conv (bool): Whether to use the first conv layers.
+ use_last_conv (bool): Whether to use the last conv layers.
+ scale_residual (bool): Whether to scale the residual outputs.
+ scale_skip_connect (bool): Whether to scale the skip connection outputs.
+
+ """
+ super().__init__()
+ self.layers = layers
+ self.stacks = stacks
+ self.kernel_size = kernel_size
+ self.base_dilation = base_dilation
+ self.use_first_conv = use_first_conv
+ self.use_last_conv = use_last_conv
+ self.scale_skip_connect = scale_skip_connect
+
+ # check the number of layers and stacks
+ assert layers % stacks == 0
+ layers_per_stack = layers // stacks
+
+ # define first convolution
+ if self.use_first_conv:
+ self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True)
+
+ # define residual blocks
+ self.conv_layers = torch.nn.ModuleList()
+ for layer in range(layers):
+ dilation = base_dilation ** (layer % layers_per_stack)
+ conv = ResidualBlock(
+ kernel_size=kernel_size,
+ residual_channels=residual_channels,
+ gate_channels=gate_channels,
+ skip_channels=skip_channels,
+ aux_channels=aux_channels,
+ global_channels=global_channels,
+ dilation=dilation,
+ dropout_rate=dropout_rate,
+ bias=bias,
+ scale_residual=scale_residual,
+ )
+ self.conv_layers += [conv]
+
+ # define output layers
+ if self.use_last_conv:
+ self.last_conv = torch.nn.Sequential(
+ torch.nn.ReLU(inplace=True),
+ Conv1d1x1(skip_channels, skip_channels, bias=True),
+ torch.nn.ReLU(inplace=True),
+ Conv1d1x1(skip_channels, out_channels, bias=True),
+ )
+
+ # apply weight norm
+ if use_weight_norm:
+ self.apply_weight_norm()
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: Optional[torch.Tensor] = None,
+ c: Optional[torch.Tensor] = None,
+ g: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input noise signal (B, 1, T) if use_first_conv else
+ (B, residual_channels, T).
+ x_mask (Optional[Tensor]): Mask tensor (B, 1, T).
+ c (Optional[Tensor]): Local conditioning features (B, aux_channels, T).
+ g (Optional[Tensor]): Global conditioning features (B, global_channels, 1).
+
+ Returns:
+ Tensor: Output tensor (B, out_channels, T) if use_last_conv else
+ (B, residual_channels, T).
+
+ """
+ # encode to hidden representation
+ if self.use_first_conv:
+ x = self.first_conv(x)
+
+ # residual block
+ skips = 0.0
+ for f in self.conv_layers:
+ x, h = f(x, x_mask=x_mask, c=c, g=g)
+ skips = skips + h
+ x = skips
+ if self.scale_skip_connect:
+ x = x * math.sqrt(1.0 / len(self.conv_layers))
+
+ # apply final layers
+ if self.use_last_conv:
+ x = self.last_conv(x)
+
+ return x
+
+ def remove_weight_norm(self):
+ """Remove weight normalization module from all of the layers."""
+
+ def _remove_weight_norm(m: torch.nn.Module):
+ try:
+ logging.debug(f"Weight norm is removed from {m}.")
+ torch.nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(_remove_weight_norm)
+
+ def apply_weight_norm(self):
+ """Apply weight normalization module from all of the layers."""
+
+ def _apply_weight_norm(m: torch.nn.Module):
+ if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
+ torch.nn.utils.weight_norm(m)
+ logging.debug(f"Weight norm is applied to {m}.")
+
+ self.apply(_apply_weight_norm)
+
+ @staticmethod
+ def _get_receptive_field_size(
+ layers: int,
+ stacks: int,
+ kernel_size: int,
+ base_dilation: int,
+ ) -> int:
+ assert layers % stacks == 0
+ layers_per_cycle = layers // stacks
+ dilations = [base_dilation ** (i % layers_per_cycle) for i in range(layers)]
+ return (kernel_size - 1) * sum(dilations) + 1
+
+ @property
+ def receptive_field_size(self) -> int:
+ """Return receptive field size."""
+ return self._get_receptive_field_size(
+ self.layers, self.stacks, self.kernel_size, self.base_dilation
+ )
+
+
+class Conv1d(torch.nn.Conv1d):
+ """Conv1d module with customized initialization."""
+
+ def __init__(self, *args, **kwargs):
+ """Initialize Conv1d module."""
+ super().__init__(*args, **kwargs)
+
+ def reset_parameters(self):
+ """Reset parameters."""
+ torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu")
+ if self.bias is not None:
+ torch.nn.init.constant_(self.bias, 0.0)
+
+
+class Conv1d1x1(Conv1d):
+ """1x1 Conv1d with customized initialization."""
+
+ def __init__(self, in_channels: int, out_channels: int, bias: bool):
+ """Initialize 1x1 Conv1d module."""
+ super().__init__(
+ in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias
+ )
+
+
+class ResidualBlock(torch.nn.Module):
+ """Residual block module in WaveNet."""
+
+ def __init__(
+ self,
+ kernel_size: int = 3,
+ residual_channels: int = 64,
+ gate_channels: int = 128,
+ skip_channels: int = 64,
+ aux_channels: int = 80,
+ global_channels: int = -1,
+ dropout_rate: float = 0.0,
+ dilation: int = 1,
+ bias: bool = True,
+ scale_residual: bool = False,
+ ):
+ """Initialize ResidualBlock module.
+
+ Args:
+ kernel_size (int): Kernel size of dilation convolution layer.
+ residual_channels (int): Number of channels for residual connection.
+ skip_channels (int): Number of channels for skip connection.
+ aux_channels (int): Number of local conditioning channels.
+ dropout (float): Dropout probability.
+ dilation (int): Dilation factor.
+ bias (bool): Whether to add bias parameter in convolution layers.
+ scale_residual (bool): Whether to scale the residual outputs.
+
+ """
+ super().__init__()
+ self.dropout_rate = dropout_rate
+ self.residual_channels = residual_channels
+ self.skip_channels = skip_channels
+ self.scale_residual = scale_residual
+
+ # check
+ assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
+ assert gate_channels % 2 == 0
+
+ # dilation conv
+ padding = (kernel_size - 1) // 2 * dilation
+ self.conv = Conv1d(
+ residual_channels,
+ gate_channels,
+ kernel_size,
+ padding=padding,
+ dilation=dilation,
+ bias=bias,
+ )
+
+ # local conditioning
+ if aux_channels > 0:
+ self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False)
+ else:
+ self.conv1x1_aux = None
+
+ # global conditioning
+ if global_channels > 0:
+ self.conv1x1_glo = Conv1d1x1(global_channels, gate_channels, bias=False)
+ else:
+ self.conv1x1_glo = None
+
+ # conv output is split into two groups
+ gate_out_channels = gate_channels // 2
+
+ # NOTE(kan-bayashi): concat two convs into a single conv for the efficiency
+ # (integrate res 1x1 + skip 1x1 convs)
+ self.conv1x1_out = Conv1d1x1(
+ gate_out_channels, residual_channels + skip_channels, bias=bias
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: Optional[torch.Tensor] = None,
+ c: Optional[torch.Tensor] = None,
+ g: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, residual_channels, T).
+ x_mask Optional[torch.Tensor]: Mask tensor (B, 1, T).
+ c (Optional[Tensor]): Local conditioning tensor (B, aux_channels, T).
+ g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
+
+ Returns:
+ Tensor: Output tensor for residual connection (B, residual_channels, T).
+ Tensor: Output tensor for skip connection (B, skip_channels, T).
+
+ """
+ residual = x
+ x = F.dropout(x, p=self.dropout_rate, training=self.training)
+ x = self.conv(x)
+
+ # split into two part for gated activation
+ splitdim = 1
+ xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)
+
+ # local conditioning
+ if c is not None:
+ c = self.conv1x1_aux(c)
+ ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
+ xa, xb = xa + ca, xb + cb
+
+ # global conditioning
+ if g is not None:
+ g = self.conv1x1_glo(g)
+ ga, gb = g.split(g.size(splitdim) // 2, dim=splitdim)
+ xa, xb = xa + ga, xb + gb
+
+ x = torch.tanh(xa) * torch.sigmoid(xb)
+
+ # residual + skip 1x1 conv
+ x = self.conv1x1_out(x)
+ if x_mask is not None:
+ x = x * x_mask
+
+ # split integrated conv results
+ x, s = x.split([self.residual_channels, self.skip_channels], dim=1)
+
+ # for residual connection
+ x = x + residual
+ if self.scale_residual:
+ x = x * math.sqrt(0.5)
+
+ return x, s
diff --git a/pyproject.toml b/pyproject.toml
index c40143fb93..435256416c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -14,4 +14,5 @@ exclude = '''
| icefall\/diagnostics\.py
| icefall\/profiler\.py
| egs\/librispeech\/ASR\/zipformer
+ | egs\/ljspeech\/TTS\/vits
'''