-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathgen_gta.py
45 lines (35 loc) · 1.54 KB
/
gen_gta.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
import os
import time
import hparams
import numpy as np
from tqdm import tqdm
from model import Tacotron2
from utils import process_text
from text import text_to_sequence
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def get_model(checkpoint_path="tacotron2_statedict.pt"):
model = Tacotron2(hparams).to(device)
model.load_state_dict(torch.load(checkpoint_path)['state_dict'])
model.eval()
return model
def generator(model):
os.makedirs("gta", exist_ok=True)
with torch.no_grad():
text = process_text(os.path.join("data", "train.txt"))
start = time.perf_counter()
for i in tqdm(range(len(text))):
mel_gt_name = os.path.join(hparams.mel_ground_truth, "ljspeech-mel-%05d.npy" % (i + 1))
mel_gt_target = np.load(mel_gt_name)
character = text[i][0:len(text[i])-1]
character = np.array(text_to_sequence(character, hparams.text_cleaners))
character = torch.stack([torch.from_numpy(character)]).long().to(device)
length = torch.Tensor([character.size(1)]).long().to(device)
mel_gt_target = torch.stack([torch.from_numpy(mel_gt_target.T)]).float().to(device)
mel_gta = model.gta(character, mel_gt_target, length)
np.save(os.path.join("gta", "ljspeech-mel-%05d.npy" % (i + 1)), mel_gta.cpu()[0].numpy())
end = time.perf_counter()
print("cost {:.2f}s to generate gta data.".format(end - start))
if __name__ == "__main__":
model = get_model()
generator(model)