-
Notifications
You must be signed in to change notification settings - Fork 309
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
generate lexicon and export onnx models
- Loading branch information
1 parent
e4f08c7
commit 6478902
Showing
3 changed files
with
243 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) | ||
|
||
""" | ||
This script exports a Matcha-TTS model to ONNX. | ||
Note that the model outputs fbank. You need to use a vocoder to convert | ||
it to audio. See also ./export_onnx_hifigan.py | ||
""" | ||
|
||
import argparse | ||
import json | ||
import logging | ||
from pathlib import Path | ||
from typing import Any, Dict | ||
|
||
import onnx | ||
import torch | ||
from tokenizer import Tokenizer | ||
from train import get_model, get_params | ||
|
||
from icefall.checkpoint import load_checkpoint | ||
|
||
|
||
def get_parser(): | ||
parser = argparse.ArgumentParser( | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter | ||
) | ||
|
||
parser.add_argument( | ||
"--epoch", | ||
type=int, | ||
default=2000, | ||
help="""It specifies the checkpoint to use for decoding. | ||
Note: Epoch counts from 1. | ||
""", | ||
) | ||
|
||
parser.add_argument( | ||
"--exp-dir", | ||
type=Path, | ||
default="matcha/exp-new-3", | ||
help="""The experiment dir. | ||
It specifies the directory where all training related | ||
files, e.g., checkpoints, log, etc, are saved | ||
""", | ||
) | ||
|
||
parser.add_argument( | ||
"--tokens", | ||
type=Path, | ||
default="data/tokens.txt", | ||
) | ||
|
||
parser.add_argument( | ||
"--cmvn", | ||
type=str, | ||
default="data/fbank/cmvn.json", | ||
help="""Path to vocabulary.""", | ||
) | ||
|
||
return parser | ||
|
||
|
||
def add_meta_data(filename: str, meta_data: Dict[str, Any]): | ||
"""Add meta data to an ONNX model. It is changed in-place. | ||
Args: | ||
filename: | ||
Filename of the ONNX model to be changed. | ||
meta_data: | ||
Key-value pairs. | ||
""" | ||
model = onnx.load(filename) | ||
|
||
while len(model.metadata_props): | ||
model.metadata_props.pop() | ||
|
||
for key, value in meta_data.items(): | ||
meta = model.metadata_props.add() | ||
meta.key = key | ||
meta.value = str(value) | ||
|
||
onnx.save(model, filename) | ||
|
||
|
||
class ModelWrapper(torch.nn.Module): | ||
def __init__(self, model, num_steps: int = 5): | ||
super().__init__() | ||
self.model = model | ||
self.num_steps = num_steps | ||
|
||
def forward( | ||
self, | ||
x: torch.Tensor, | ||
x_lengths: torch.Tensor, | ||
temperature: torch.Tensor, | ||
length_scale: torch.Tensor, | ||
) -> torch.Tensor: | ||
""" | ||
Args: : | ||
x: (batch_size, num_tokens), torch.int64 | ||
x_lengths: (batch_size,), torch.int64 | ||
temperature: (1,), torch.float32 | ||
length_scale (1,), torch.float32 | ||
Returns: | ||
audio: (batch_size, num_samples) | ||
""" | ||
mel = self.model.synthesise( | ||
x=x, | ||
x_lengths=x_lengths, | ||
n_timesteps=self.num_steps, | ||
temperature=temperature, | ||
length_scale=length_scale, | ||
)["mel"] | ||
# mel: (batch_size, feat_dim, num_frames) | ||
|
||
return mel | ||
|
||
|
||
@torch.inference_mode() | ||
def main(): | ||
parser = get_parser() | ||
args = parser.parse_args() | ||
params = get_params() | ||
|
||
params.update(vars(args)) | ||
|
||
tokenizer = Tokenizer(params.tokens) | ||
params.pad_id = tokenizer.pad_id | ||
params.vocab_size = tokenizer.vocab_size | ||
params.model_args.n_vocab = params.vocab_size | ||
|
||
with open(params.cmvn) as f: | ||
stats = json.load(f) | ||
params.data_args.data_statistics.mel_mean = stats["fbank_mean"] | ||
params.data_args.data_statistics.mel_std = stats["fbank_std"] | ||
|
||
params.model_args.data_statistics.mel_mean = stats["fbank_mean"] | ||
params.model_args.data_statistics.mel_std = stats["fbank_std"] | ||
logging.info(params) | ||
|
||
logging.info("About to create model") | ||
model = get_model(params) | ||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) | ||
|
||
for num_steps in [2, 3, 4, 5, 6]: | ||
logging.info(f"num_steps: {num_steps}") | ||
wrapper = ModelWrapper(model, num_steps=num_steps) | ||
wrapper.eval() | ||
|
||
# Use a large value so the rotary position embedding in the text | ||
# encoder has a large initial length | ||
x = torch.ones(1, 1000, dtype=torch.int64) | ||
x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64) | ||
temperature = torch.tensor([1.0]) | ||
length_scale = torch.tensor([1.0]) | ||
|
||
opset_version = 14 | ||
filename = f"model-steps-{num_steps}.onnx" | ||
torch.onnx.export( | ||
wrapper, | ||
(x, x_lengths, temperature, length_scale), | ||
filename, | ||
opset_version=opset_version, | ||
input_names=["x", "x_length", "noise_scale", "length_scale"], | ||
output_names=["mel"], | ||
dynamic_axes={ | ||
"x": {0: "N", 1: "L"}, | ||
"x_length": {0: "N"}, | ||
"mel": {0: "N", 2: "L"}, | ||
}, | ||
) | ||
|
||
meta_data = { | ||
"model_type": "matcha-tts", | ||
"language": "Chinese", | ||
"has_espeak": 0, | ||
"n_speakers": 1, | ||
"jieba": 1, | ||
"sample_rate": 22050, | ||
"version": 1, | ||
"pad_id": params.pad_id, | ||
"model_author": "icefall", | ||
"maintainer": "k2-fsa", | ||
"dataset": "baker-zh", | ||
"use_eos_bos": 1, | ||
"dataset_url": "https://www.data-baker.com/open_source.html", | ||
"dataset_comment": "The dataset is for non-commercial use only.", | ||
"num_ode_steps": num_steps, | ||
} | ||
add_meta_data(filename=filename, meta_data=meta_data) | ||
print(meta_data) | ||
|
||
|
||
if __name__ == "__main__": | ||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" | ||
|
||
logging.basicConfig(format=formatter, level=logging.INFO) | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../../ljspeech/TTS/matcha/export_onnx_hifigan.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import jieba | ||
from pypinyin import Style, lazy_pinyin, load_phrases_dict, phrases_dict, pinyin_dict | ||
from tokenizer import Tokenizer | ||
|
||
load_phrases_dict( | ||
{ | ||
"行长": [["hang2"], ["zhang3"]], | ||
"银行行长": [["yin2"], ["hang2"], ["hang2"], ["zhang3"]], | ||
} | ||
) | ||
|
||
|
||
def main(): | ||
filename = "lexicon.txt" | ||
tokens = "./data/tokens.txt" | ||
tokenizer = Tokenizer(tokens) | ||
|
||
word_dict = pinyin_dict.pinyin_dict | ||
phrases = phrases_dict.phrases_dict | ||
|
||
i = 0 | ||
with open(filename, "w", encoding="utf-8") as f: | ||
for key in word_dict: | ||
if not (0x4E00 <= key <= 0x9FFF): | ||
continue | ||
|
||
w = chr(key) | ||
tokens = lazy_pinyin(w, style=Style.TONE3, tone_sandhi=True)[0] | ||
|
||
f.write(f"{w} {tokens}\n") | ||
|
||
for key in phrases: | ||
tokens = lazy_pinyin(key, style=Style.TONE3, tone_sandhi=True) | ||
tokens = " ".join(tokens) | ||
|
||
f.write(f"{key} {tokens}\n") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |