Skip to content

Commit

Permalink
generate lexicon and export onnx models
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Dec 27, 2024
1 parent e4f08c7 commit 6478902
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 0 deletions.
200 changes: 200 additions & 0 deletions egs/baker_zh/TTS/matcha/export_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)

"""
This script exports a Matcha-TTS model to ONNX.
Note that the model outputs fbank. You need to use a vocoder to convert
it to audio. See also ./export_onnx_hifigan.py
"""

import argparse
import json
import logging
from pathlib import Path
from typing import Any, Dict

import onnx
import torch
from tokenizer import Tokenizer
from train import get_model, get_params

from icefall.checkpoint import load_checkpoint


def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)

parser.add_argument(
"--epoch",
type=int,
default=2000,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
""",
)

parser.add_argument(
"--exp-dir",
type=Path,
default="matcha/exp-new-3",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)

parser.add_argument(
"--tokens",
type=Path,
default="data/tokens.txt",
)

parser.add_argument(
"--cmvn",
type=str,
default="data/fbank/cmvn.json",
help="""Path to vocabulary.""",
)

return parser


def add_meta_data(filename: str, meta_data: Dict[str, Any]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)

while len(model.metadata_props):
model.metadata_props.pop()

for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)

onnx.save(model, filename)


class ModelWrapper(torch.nn.Module):
def __init__(self, model, num_steps: int = 5):
super().__init__()
self.model = model
self.num_steps = num_steps

def forward(
self,
x: torch.Tensor,
x_lengths: torch.Tensor,
temperature: torch.Tensor,
length_scale: torch.Tensor,
) -> torch.Tensor:
"""
Args: :
x: (batch_size, num_tokens), torch.int64
x_lengths: (batch_size,), torch.int64
temperature: (1,), torch.float32
length_scale (1,), torch.float32
Returns:
audio: (batch_size, num_samples)
"""
mel = self.model.synthesise(
x=x,
x_lengths=x_lengths,
n_timesteps=self.num_steps,
temperature=temperature,
length_scale=length_scale,
)["mel"]
# mel: (batch_size, feat_dim, num_frames)

return mel


@torch.inference_mode()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()

params.update(vars(args))

tokenizer = Tokenizer(params.tokens)
params.pad_id = tokenizer.pad_id
params.vocab_size = tokenizer.vocab_size
params.model_args.n_vocab = params.vocab_size

with open(params.cmvn) as f:
stats = json.load(f)
params.data_args.data_statistics.mel_mean = stats["fbank_mean"]
params.data_args.data_statistics.mel_std = stats["fbank_std"]

params.model_args.data_statistics.mel_mean = stats["fbank_mean"]
params.model_args.data_statistics.mel_std = stats["fbank_std"]
logging.info(params)

logging.info("About to create model")
model = get_model(params)
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)

for num_steps in [2, 3, 4, 5, 6]:
logging.info(f"num_steps: {num_steps}")
wrapper = ModelWrapper(model, num_steps=num_steps)
wrapper.eval()

# Use a large value so the rotary position embedding in the text
# encoder has a large initial length
x = torch.ones(1, 1000, dtype=torch.int64)
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64)
temperature = torch.tensor([1.0])
length_scale = torch.tensor([1.0])

opset_version = 14
filename = f"model-steps-{num_steps}.onnx"
torch.onnx.export(
wrapper,
(x, x_lengths, temperature, length_scale),
filename,
opset_version=opset_version,
input_names=["x", "x_length", "noise_scale", "length_scale"],
output_names=["mel"],
dynamic_axes={
"x": {0: "N", 1: "L"},
"x_length": {0: "N"},
"mel": {0: "N", 2: "L"},
},
)

meta_data = {
"model_type": "matcha-tts",
"language": "Chinese",
"has_espeak": 0,
"n_speakers": 1,
"jieba": 1,
"sample_rate": 22050,
"version": 1,
"pad_id": params.pad_id,
"model_author": "icefall",
"maintainer": "k2-fsa",
"dataset": "baker-zh",
"use_eos_bos": 1,
"dataset_url": "https://www.data-baker.com/open_source.html",
"dataset_comment": "The dataset is for non-commercial use only.",
"num_ode_steps": num_steps,
}
add_meta_data(filename=filename, meta_data=meta_data)
print(meta_data)


if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"

logging.basicConfig(format=formatter, level=logging.INFO)
main()
1 change: 1 addition & 0 deletions egs/baker_zh/TTS/matcha/export_onnx_hifigan.py
42 changes: 42 additions & 0 deletions egs/baker_zh/TTS/matcha/generate_lexicon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/usr/bin/env python3

import jieba
from pypinyin import Style, lazy_pinyin, load_phrases_dict, phrases_dict, pinyin_dict
from tokenizer import Tokenizer

load_phrases_dict(
{
"行长": [["hang2"], ["zhang3"]],
"银行行长": [["yin2"], ["hang2"], ["hang2"], ["zhang3"]],
}
)


def main():
filename = "lexicon.txt"
tokens = "./data/tokens.txt"
tokenizer = Tokenizer(tokens)

word_dict = pinyin_dict.pinyin_dict
phrases = phrases_dict.phrases_dict

i = 0
with open(filename, "w", encoding="utf-8") as f:
for key in word_dict:
if not (0x4E00 <= key <= 0x9FFF):
continue

w = chr(key)
tokens = lazy_pinyin(w, style=Style.TONE3, tone_sandhi=True)[0]

f.write(f"{w} {tokens}\n")

for key in phrases:
tokens = lazy_pinyin(key, style=Style.TONE3, tone_sandhi=True)
tokens = " ".join(tokens)

f.write(f"{key} {tokens}\n")


if __name__ == "__main__":
main()

0 comments on commit 6478902

Please sign in to comment.