Skip to content

Commit

Permalink
First working version
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Dec 26, 2024
1 parent e4d40ba commit e4f08c7
Show file tree
Hide file tree
Showing 7 changed files with 357 additions and 3 deletions.
4 changes: 4 additions & 0 deletions egs/baker_zh/TTS/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
path.sh
*.wav
generator_v1
generator_v2
generator_v3
4 changes: 3 additions & 1 deletion egs/baker_zh/TTS/local/convert_text_to_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
("’", "'"),
(":", ":"),
("、", ","),
("B", "逼"),
("P", "批"),
]
]

Expand Down Expand Up @@ -108,7 +110,7 @@ def main():
text_list = split_text(text)
tokens = lazy_pinyin(text_list, style=Style.TONE3, tone_sandhi=True)

c.supervisions[0].tokens = tokens
c.tokens = tokens

cuts.to_file(args.out_file)

Expand Down
16 changes: 16 additions & 0 deletions egs/baker_zh/TTS/local/generate_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,22 @@ def generate_token_list() -> List[str]:
t = lazy_pinyin(w, style=Style.TONE3, tone_sandhi=True)[0]
token_set.add(t)

no_digit = set()
for t in token_set:
if t[-1] not in "1234":
no_digit.add(t)
else:
no_digit.add(t[:-1])

no_digit.add("dei")
no_digit.add("tou")
no_digit.add("dia")

for t in no_digit:
token_set.add(t)
for i in range(1, 5):
token_set.add(f"{t}{i}")

ans = list(token_set)
ans.sort()

Expand Down
Empty file.
332 changes: 332 additions & 0 deletions egs/baker_zh/TTS/matcha/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,332 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)

import argparse
import datetime as dt
import json
import logging
from pathlib import Path

import soundfile as sf
import torch
import torch.nn as nn
from hifigan.config import v1, v2, v3
from hifigan.denoiser import Denoiser
from hifigan.models import Generator as HiFiGAN
from tokenizer import Tokenizer
from train import get_model, get_params
from tts_datamodule import BakerZhTtsDataModule

from icefall.checkpoint import load_checkpoint
from icefall.utils import AttributeDict, setup_logger
from local.convert_text_to_tokens import split_text
from pypinyin import lazy_pinyin, Style


def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)

parser.add_argument(
"--epoch",
type=int,
default=4000,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
)

parser.add_argument(
"--exp-dir",
type=Path,
default="matcha/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)

parser.add_argument(
"--vocoder",
type=Path,
default="./generator_v1",
help="Path to the vocoder",
)

parser.add_argument(
"--tokens",
type=Path,
default="data/tokens.txt",
)

parser.add_argument(
"--cmvn",
type=str,
default="data/fbank/cmvn.json",
help="""Path to vocabulary.""",
)

# The following arguments are used for inference on single text
parser.add_argument(
"--input-text",
type=str,
required=False,
help="The text to generate speech for",
)

parser.add_argument(
"--output-wav",
type=str,
required=False,
help="The filename of the wave to save the generated speech",
)

parser.add_argument(
"--sampling-rate",
type=int,
default=22050,
help="The sampling rate of the generated speech (default: 22050 for baker_zh)",
)

return parser


def load_vocoder(checkpoint_path: Path) -> nn.Module:
checkpoint_path = str(checkpoint_path)
if checkpoint_path.endswith("v1"):
h = AttributeDict(v1)
elif checkpoint_path.endswith("v2"):
h = AttributeDict(v2)
elif checkpoint_path.endswith("v3"):
h = AttributeDict(v3)
else:
raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}")

hifigan = HiFiGAN(h).to("cpu")
hifigan.load_state_dict(
torch.load(checkpoint_path, map_location="cpu")["generator"]
)
_ = hifigan.eval()
hifigan.remove_weight_norm()
return hifigan


def to_waveform(
mel: torch.Tensor, vocoder: nn.Module, denoiser: nn.Module
) -> torch.Tensor:
audio = vocoder(mel).clamp(-1, 1)
audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
return audio.squeeze()


def process_text(text: str, tokenizer: Tokenizer, device: str = "cpu") -> dict:
text = split_text(text)
tokens = lazy_pinyin(text, style=Style.TONE3, tone_sandhi=True)

x = tokenizer.texts_to_token_ids([tokens])
x = torch.tensor(x, dtype=torch.long, device=device)
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device)
return {"x_orig": text, "x": x, "x_lengths": x_lengths}


def synthesize(
model: nn.Module,
tokenizer: Tokenizer,
n_timesteps: int,
text: str,
length_scale: float,
temperature: float,
device: str = "cpu",
spks=None,
) -> dict:
text_processed = process_text(text=text, tokenizer=tokenizer, device=device)
start_t = dt.datetime.now()
output = model.synthesise(
text_processed["x"],
text_processed["x_lengths"],
n_timesteps=n_timesteps,
temperature=temperature,
spks=spks,
length_scale=length_scale,
)
# merge everything to one dict
output.update({"start_t": start_t, **text_processed})
return output


def infer_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
vocoder: nn.Module,
denoiser: nn.Module,
tokenizer: Tokenizer,
) -> None:
"""Decode dataset.
The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
tokenizer:
Used to convert text to phonemes.
"""

device = next(model.parameters()).device
num_cuts = 0
log_interval = 5

try:
num_batches = len(dl)
except TypeError:
num_batches = "?"

for batch_idx, batch in enumerate(dl):
batch_size = len(batch["tokens"])

texts = [c.supervisions[0].normalized_text for c in batch["cut"]]

audio = batch["audio"]
audio_lens = batch["audio_lens"].tolist()
cut_ids = [cut.id for cut in batch["cut"]]

for i in range(batch_size):
output = synthesize(
model=model,
tokenizer=tokenizer,
n_timesteps=params.n_timesteps,
text=texts[i],
length_scale=params.length_scale,
temperature=params.temperature,
device=device,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)

sf.write(
file=params.save_wav_dir / f"{cut_ids[i]}_pred.wav",
data=output["waveform"],
samplerate=params.data_args.sampling_rate,
subtype="PCM_16",
)
sf.write(
file=params.save_wav_dir / f"{cut_ids[i]}_gt.wav",
data=audio[i].numpy(),
samplerate=params.data_args.sampling_rate,
subtype="PCM_16",
)

num_cuts += batch_size

if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"

logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")


@torch.inference_mode()
def main():
parser = get_parser()
BakerZhTtsDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)

params = get_params()
params.update(vars(args))

params.suffix = f"epoch-{params.epoch}"

params.res_dir = params.exp_dir / "infer" / params.suffix
params.save_wav_dir = params.res_dir / "wav"
params.save_wav_dir.mkdir(parents=True, exist_ok=True)

setup_logger(f"{params.res_dir}/log-infer-{params.suffix}")
logging.info("Infer started")

device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")

tokenizer = Tokenizer(params.tokens)
params.vocab_size = tokenizer.vocab_size
params.model_args.n_vocab = params.vocab_size

with open(params.cmvn) as f:
stats = json.load(f)
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
params.data_args.data_statistics.mel_std = stats["fbank_std"]

params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
params.model_args.data_statistics.mel_std = stats["fbank_std"]

# Number of ODE Solver steps
params.n_timesteps = 2

# Changes to the speaking rate
params.length_scale = 1.0

# Sampling temperature
params.temperature = 0.667
logging.info(params)

logging.info("About to create model")
model = get_model(params)

load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device)
model.eval()

# we need cut ids to organize tts results.
args.return_cuts = True
baker_zh = BakerZhTtsDataModule(args)

test_cuts = baker_zh.test_cuts()
test_dl = baker_zh.test_dataloaders(test_cuts)

if not Path(params.vocoder).is_file():
raise ValueError(f"{params.vocoder} does not exist")

vocoder = load_vocoder(params.vocoder)
vocoder.to(device)

denoiser = Denoiser(vocoder, mode="zeros")
denoiser.to(device)

if params.input_text is not None and params.output_wav is not None:
logging.info("Synthesizing a single text")
output = synthesize(
model=model,
tokenizer=tokenizer,
n_timesteps=params.n_timesteps,
text=params.input_text,
length_scale=params.length_scale,
temperature=params.temperature,
device=device,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)

sf.write(
file=params.output_wav,
data=output["waveform"],
samplerate=params.sampling_rate,
subtype="PCM_16",
)
else:
logging.info("Decoding the test set")
infer_dataset(
dl=test_dl,
params=params,
model=model,
vocoder=vocoder,
denoiser=denoiser,
tokenizer=tokenizer,
)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion egs/baker_zh/TTS/matcha/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def texts_to_token_ids(
token_ids = []
for t in tokens_list:
if t not in self.token2id:
logging.warning(f"Skip OOV {t}")
logging.warning(f"Skip OOV {t} {sentence}")
continue

if t == " " and len(token_ids) > 0 and token_ids[-1] == self.space_id:
Expand Down
2 changes: 1 addition & 1 deletion egs/baker_zh/TTS/matcha/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device, param
features_lens = batch["features_lens"].to(device)
tokens = batch["tokens"]

tokens = tokenizer.tokens_to_token_ids(tokens, intersperse_blank=True)
tokens = tokenizer.texts_to_token_ids(tokens, intersperse_blank=True)
tokens = k2.RaggedTensor(tokens)
row_splits = tokens.shape.row_splits(1)
tokens_lens = row_splits[1:] - row_splits[:-1]
Expand Down

0 comments on commit e4f08c7

Please sign in to comment.