diff --git a/egs/wenetspeech4tts/TTS/f5-tts/model/dit.py b/egs/wenetspeech4tts/TTS/f5-tts/model/dit.py index b048034707..966fabfdd4 100644 --- a/egs/wenetspeech4tts/TTS/f5-tts/model/dit.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/model/dit.py @@ -92,9 +92,9 @@ def __init__(self, mel_dim, text_dim, out_dim): def forward( self, - x: float["b n d"], # noqa: F722 - cond: float["b n d"], # noqa: F722 - text_embed: float["b n d"], # noqa: F722 + x: float["b n d"], # noqa: F722 + cond: float["b n d"], # noqa: F722 + text_embed: float["b n d"], # noqa: F722 drop_audio_cond=False, ): if drop_audio_cond: # cfg for cond audio