Skip to content

Commit

Permalink
minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr committed Oct 22, 2024
1 parent ca3b495 commit 3c3db1a
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 15 deletions.
4 changes: 2 additions & 2 deletions egs/libritts/TTS/vits/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _save_worker(
audio_lens = batch["audio_lens"].tolist()
cut_ids = [cut.id for cut in batch["cut"]]
sids = ["_".join(cut_id.split("_")[:2]) for cut_id in cut_ids]
speakers = (
spembs = (
torch.Tensor(np.array([speaker_map.read(sid) for sid in sids]))
.squeeze(1)
.to(device)
Expand All @@ -161,7 +161,7 @@ def _save_worker(
audio_pred, _, durations = model.inference_batch(
text=tokens,
text_lengths=tokens_lens,
spembs=speakers,
spembs=spembs,
)
audio_pred = audio_pred.detach().cpu()
# convert to samples
Expand Down
24 changes: 12 additions & 12 deletions egs/libritts/TTS/vits/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def parse_sids(batch: dict) -> List[str]:
audio_lens = batch["audio_lens"].to(device)
features_lens = batch["features_lens"].to(device)
tokens = batch["tokens"]
speakers = (
spembs = (
torch.Tensor(np.array([speaker_map.read(sid) for sid in parse_sids(batch)]))
.squeeze(1)
.to(device)
Expand All @@ -361,7 +361,7 @@ def parse_sids(batch: dict) -> List[str]:
# a tensor of shape (B, T)
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)

return audio, audio_lens, features, features_lens, tokens, tokens_lens, speakers
return audio, audio_lens, features, features_lens, tokens, tokens_lens, spembs


def train_one_epoch(
Expand Down Expand Up @@ -449,7 +449,7 @@ def save_bad_model(suffix: str = ""):
features_lens,
tokens,
tokens_lens,
speakers,
spembs,
) = prepare_input(batch, tokenizer, device, train_speaker_map)

loss_info = MetricsTracker()
Expand All @@ -465,7 +465,7 @@ def save_bad_model(suffix: str = ""):
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
spembs=speakers,
spembs=spembs,
forward_generator=False,
)
for k, v in stats_d.items():
Expand All @@ -484,7 +484,7 @@ def save_bad_model(suffix: str = ""):
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
spembs=speakers,
spembs=spembs,
forward_generator=True,
return_sample=params.batch_idx_train % params.log_interval == 0,
)
Expand Down Expand Up @@ -651,7 +651,7 @@ def compute_validation_loss(
features_lens,
tokens,
tokens_lens,
speakers,
spembs,
) = prepare_input(batch, tokenizer, device, dev_speaker_map)

loss_info = MetricsTracker()
Expand All @@ -665,7 +665,7 @@ def compute_validation_loss(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
spembs=speakers,
spembs=spembs,
forward_generator=False,
)
assert loss_d.requires_grad is False
Expand All @@ -680,7 +680,7 @@ def compute_validation_loss(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
spembs=speakers,
spembs=spembs,
forward_generator=True,
)
assert loss_g.requires_grad is False
Expand All @@ -695,7 +695,7 @@ def compute_validation_loss(
inner_model = model.module if isinstance(model, DDP) else model
audio_pred, _, duration = inner_model.inference(
text=tokens[0, : tokens_lens[0].item()],
spembs=speakers[0],
spembs=spembs[0],
)
audio_pred = audio_pred.data.cpu().numpy()
audio_len_pred = (
Expand Down Expand Up @@ -744,7 +744,7 @@ def scan_pessimistic_batches_for_oom(
features_lens,
tokens,
tokens_lens,
speakers,
spembs,
) = prepare_input(batch, tokenizer, device, train_speaker_map)
try:
# for discriminator
Expand All @@ -756,7 +756,7 @@ def scan_pessimistic_batches_for_oom(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
spembs=speakers,
spembs=spembs,
forward_generator=False,
)
optimizer_d.zero_grad()
Expand All @@ -770,7 +770,7 @@ def scan_pessimistic_batches_for_oom(
feats_lengths=features_lens,
speech=audio,
speech_lengths=audio_lens,
spembs=speakers,
spembs=spembs,
forward_generator=True,
)
optimizer_g.zero_grad()
Expand Down
7 changes: 6 additions & 1 deletion egs/ljspeech/TTS/vits/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,12 @@ def inference(
g = self.global_emb(sids.view(-1)).unsqueeze(-1)
if self.spk_embed_dim is not None:
# (B, global_channels, 1)
g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1)
if spembs.ndim == 2:
g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1)
elif spembs.ndim == 1:
g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1)
else:
raise ValueError("spembs should be 1D or 2D (batch mode) tensor.")
if g is None:
g = g_
else:
Expand Down

0 comments on commit 3c3db1a

Please sign in to comment.