Skip to content

Commit

Permalink
removed the erroneous ‘’continual'' implementation (#1865)
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr authored Jan 16, 2025
1 parent 8ab0352 commit 79074ef
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 126 deletions.
43 changes: 14 additions & 29 deletions egs/wenetspeech4tts/TTS/valle/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,6 @@ def get_args():
help="The temperature of AR Decoder top_k sampling.",
)

parser.add_argument(
"--continual",
type=str2bool,
default=False,
help="Do continual task.",
)

parser.add_argument(
"--repetition-aware-sampling",
type=str2bool,
Expand Down Expand Up @@ -262,29 +255,21 @@ def main():
)

# synthesis
if args.continual:
assert text == ""
encoded_frames = model.continual(
text_tokens.to(device),
text_tokens_lens.to(device),
audio_prompts,
)
else:
enroll_x_lens = None
if text_prompts:
_, enroll_x_lens = text_collater(
[tokenize_text(text_tokenizer, text=f"{text_prompts}".strip())]
)
encoded_frames = model.inference(
text_tokens.to(device),
text_tokens_lens.to(device),
audio_prompts,
enroll_x_lens=enroll_x_lens,
top_k=args.top_k,
temperature=args.temperature,
top_p=args.top_p,
ras=args.repetition_aware_sampling,
enroll_x_lens = None
if text_prompts:
_, enroll_x_lens = text_collater(
[tokenize_text(text_tokenizer, text=f"{text_prompts}".strip())]
)
encoded_frames = model.inference(
text_tokens.to(device),
text_tokens_lens.to(device),
audio_prompts,
enroll_x_lens=enroll_x_lens,
top_k=args.top_k,
temperature=args.temperature,
top_p=args.top_p,
ras=args.repetition_aware_sampling,
)

if audio_prompts != []:
samples = audio_tokenizer.decode([(encoded_frames.transpose(2, 1), None)])
Expand Down
97 changes: 0 additions & 97 deletions egs/wenetspeech4tts/TTS/valle/valle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1564,103 +1564,6 @@ def inference(
assert len(codes) == self.num_quantizers
return torch.stack(codes, dim=-1)

def continual(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: torch.Tensor,
) -> torch.Tensor:
"""
Args:
x:
A 2-D tensor of shape (1, S).
x_lens:
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
before padding.
y:
A 3-D tensor of shape (1, T, 8).
Returns:
Return the predicted audio code matrix.
"""
assert x.ndim == 2, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.ndim == 3, y.shape
assert y.shape[0] == 1, y.shape

assert torch.all(x_lens > 0)
assert self.num_quantizers == 8

# NOTE: x has been padded in TextTokenCollater
text = x
x = self.ar_text_embedding(text)
x = self.ar_text_prenet(x)
x = self.ar_text_position(x)

text_len = x_lens.max()

prefix_len = min(int(y.shape[1] * 0.5), 3 * 75)

# AR Decoder
prompts = y[:, :prefix_len]

codes = [y[:, prefix_len:, 0]]
# Non-AR Decoders
x = self.nar_text_embedding(text)
x = self.nar_text_prenet(x)
x = self.nar_text_position(x)

y_emb = self.nar_audio_embeddings[0](y[..., 0])

if self.prefix_mode == 0:
for i, (predict_layer, embedding_layer) in enumerate(
zip(
self.nar_predict_layers,
self.nar_audio_embeddings[1:],
)
):
y_pos = self.nar_audio_position(y_emb)
y_pos = self.nar_audio_prenet(y_pos)
xy_pos = torch.concat([x, y_pos], dim=1)

xy_dec, _ = self.nar_decoder(
(xy_pos, self.nar_stage_embeddings[i].weight)
)
logits = predict_layer(xy_dec[:, text_len + prefix_len :])

samples = torch.argmax(logits, dim=-1)
codes.append(samples)

if i < 6:
y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1])
y_emb[:, prefix_len:] += embedding_layer(samples)
else:
for j in range(1, 8):
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j])

for i, (predict_layer, embedding_layer) in enumerate(
zip(
self.nar_predict_layers,
self.nar_audio_embeddings[1:],
)
):
y_pos = self.nar_audio_prenet(y_emb)
y_pos = self.nar_audio_position(y_pos)
xy_pos = torch.concat([x, y_pos], dim=1)

xy_dec, _ = self.nar_decoder(
(xy_pos, self.nar_stage_embeddings[i].weight)
)
logits = predict_layer(xy_dec[:, text_len + prefix_len :])

samples = torch.argmax(logits, dim=-1)
codes.append(samples)

if i < 6:
y_emb[:, prefix_len:] += embedding_layer(samples)

assert len(codes) == 8
return torch.stack(codes, dim=-1)

def visualize(
self,
predicts: Tuple[torch.Tensor],
Expand Down

0 comments on commit 79074ef

Please sign in to comment.