From 6e6b022e413a49cc2cf1c14995db39656e0ad85b Mon Sep 17 00:00:00 2001 From: zr_jin Date: Fri, 6 Dec 2024 16:14:51 +0800 Subject: [PATCH] performed end to end testing to the VALL-E recipe (#1818) * added the missing ``visualize`` function * minor fixes --- ...te_neural_codec_and_prepare_text_tokens.py | 22 +++-- egs/wenetspeech4tts/TTS/valle/infer.py | 2 +- .../TTS/valle/requirements.txt | 2 + egs/wenetspeech4tts/TTS/valle/train.py | 9 +- egs/wenetspeech4tts/TTS/valle/valle.py | 85 +++++++++++++++++++ 5 files changed, 109 insertions(+), 11 deletions(-) create mode 100644 egs/wenetspeech4tts/TTS/valle/requirements.txt diff --git a/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py b/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py index 5494bf3400..7de2c6202e 100755 --- a/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py +++ b/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py @@ -516,9 +516,19 @@ def main(): for idx, part in enumerate(cut_sets): if args.audio_extractor: if args.audio_extractor == "Encodec": - storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx if split > 1 else ''}" + if split > 1: + storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx}" + else: + storage_path = ( + f"{args.output_dir}/{args.prefix}_encodec_{partition}" + ) else: - storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx if split > 1 else ''}" + if split > 1: + storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx}" + else: + storage_path = ( + f"{args.output_dir}/{args.prefix}_fbank_{partition}" + ) if args.prefix.lower() in [ "ljspeech", @@ -587,9 +597,11 @@ def main(): ].normalized_text, "normalized_text is None" # Save each part with an index if split > 1 - cuts_filename = ( - f"{prefix}cuts_{partition}.{idx if split > 1 else ''}.{args.suffix}" - ) + if split > 1: + cuts_filename = f"{prefix}cuts_{partition}.{idx}.{args.suffix}" + else: + cuts_filename = f"{prefix}cuts_{partition}.{args.suffix}" + part.to_file(f"{args.output_dir}/{cuts_filename}") logging.info(f"Saved {cuts_filename}") diff --git a/egs/wenetspeech4tts/TTS/valle/infer.py b/egs/wenetspeech4tts/TTS/valle/infer.py index fd7ba9f216..44a251c561 100644 --- a/egs/wenetspeech4tts/TTS/valle/infer.py +++ b/egs/wenetspeech4tts/TTS/valle/infer.py @@ -86,7 +86,7 @@ def get_args(): parser.add_argument( "--checkpoint", type=str, - default="exp/vallf_nano_full/checkpoint-100000.pt", + default="./valle/exp/checkpoint-100000.pt", help="Path to the saved checkpoint.", ) diff --git a/egs/wenetspeech4tts/TTS/valle/requirements.txt b/egs/wenetspeech4tts/TTS/valle/requirements.txt new file mode 100644 index 0000000000..06958dbeaf --- /dev/null +++ b/egs/wenetspeech4tts/TTS/valle/requirements.txt @@ -0,0 +1,2 @@ +phonemizer==3.2.1 +git+https://github.com/facebookresearch/encodec.git \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/valle/train.py b/egs/wenetspeech4tts/TTS/valle/train.py index fde209511c..e9ec548f33 100755 --- a/egs/wenetspeech4tts/TTS/valle/train.py +++ b/egs/wenetspeech4tts/TTS/valle/train.py @@ -4,6 +4,7 @@ # Mingshuang Luo) # Copyright 2023 (authors: Feiteng Li) # Copyright 2024 (authors: Yuekai Zhang) +# Copyright 2024 Tsinghua University (authors: Zengrui Jin,) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -48,10 +49,8 @@ import argparse import copy import logging -import os import random import warnings -from contextlib import nullcontext from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Tuple, Union @@ -216,7 +215,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="exp/valle_dev", + default="./valle/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -686,9 +685,9 @@ def compute_validation_loss( output_dir = Path(f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}") output_dir.mkdir(parents=True, exist_ok=True) if isinstance(model, DDP): - model.module.visualize(predicts, batch, output_dir=output_dir) + model.module.visualize(predicts, batch, tokenizer, output_dir=output_dir) else: - model.visualize(predicts, batch, output_dir=output_dir) + model.visualize(predicts, batch, tokenizer, output_dir=output_dir) return tot_loss diff --git a/egs/wenetspeech4tts/TTS/valle/valle.py b/egs/wenetspeech4tts/TTS/valle/valle.py index b2eb8ae69d..4bfa2b577b 100644 --- a/egs/wenetspeech4tts/TTS/valle/valle.py +++ b/egs/wenetspeech4tts/TTS/valle/valle.py @@ -19,8 +19,11 @@ from functools import partial from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union +import matplotlib.pyplot as plt +import numpy as np import torch import torch.nn as nn +from tokenizer import TextTokenCollater from torch import Tensor from torch.nn import Linear, Module from torch.nn import functional as F @@ -1658,6 +1661,88 @@ def continual( assert len(codes) == 8 return torch.stack(codes, dim=-1) + def visualize( + self, + predicts: Tuple[torch.Tensor], + batch: Dict[str, Union[List, torch.Tensor]], + tokenizer: TextTokenCollater, + output_dir: str, + limit: int = 4, + ) -> None: + audio_features = batch["features"].to("cpu").detach().numpy() + audio_features_lens = batch["features_lens"].to("cpu").detach().numpy() + + tokens = batch["tokens"] + text_tokens, text_tokens_lens = tokenizer(tokens) + assert text_tokens.ndim == 2 + + texts = batch["text"] + utt_ids = [cut.id for cut in batch["cut"]] + + encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy() + decoder_outputs = predicts[1] + if isinstance(decoder_outputs, list): + decoder_outputs = decoder_outputs[-1] + decoder_outputs = decoder_outputs.to("cpu").type(torch.float32).detach().numpy() + + vmin, vmax = 0, 1024 # Encodec + if decoder_outputs.dtype == np.float32: + vmin, vmax = -6, 0 # Fbank + + num_figures = 3 + for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])): + _ = plt.figure(figsize=(14, 8 * num_figures)) + + S = text_tokens_lens[b] + T = audio_features_lens[b] + + # encoder + plt.subplot(num_figures, 1, 1) + plt.title(f"Text: {text}") + plt.imshow( + X=np.transpose(encoder_outputs[b]), + cmap=plt.get_cmap("jet"), + aspect="auto", + interpolation="nearest", + ) + plt.gca().invert_yaxis() + plt.axvline(x=S - 0.4, linewidth=2, color="r") + plt.xlabel("Encoder Output") + plt.colorbar() + + # decoder + plt.subplot(num_figures, 1, 2) + plt.imshow( + X=np.transpose(decoder_outputs[b]), + cmap=plt.get_cmap("jet"), + aspect="auto", + interpolation="nearest", + vmin=vmin, + vmax=vmax, + ) + plt.gca().invert_yaxis() + plt.axvline(x=T - 0.4, linewidth=2, color="r") + plt.xlabel("Decoder Output") + plt.colorbar() + + # target + plt.subplot(num_figures, 1, 3) + plt.imshow( + X=np.transpose(audio_features[b]), + cmap=plt.get_cmap("jet"), + aspect="auto", + interpolation="nearest", + vmin=vmin, + vmax=vmax, + ) + plt.gca().invert_yaxis() + plt.axvline(x=T - 0.4, linewidth=2, color="r") + plt.xlabel("Decoder Target") + plt.colorbar() + + plt.savefig(f"{output_dir}/{utt_id}.png") + plt.close() + # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py def top_k_top_p_filtering(