From 46c077081bcb1614f1512bb9075f81438963d966 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Mon, 27 Jan 2025 16:11:37 +0800 Subject: [PATCH] Update utils.py --- egs/wenetspeech4tts/TTS/f5-tts/model/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/wenetspeech4tts/TTS/f5-tts/model/utils.py b/egs/wenetspeech4tts/TTS/f5-tts/model/utils.py index 09c5fc4680..fae5fadb61 100644 --- a/egs/wenetspeech4tts/TTS/f5-tts/model/utils.py +++ b/egs/wenetspeech4tts/TTS/f5-tts/model/utils.py @@ -38,7 +38,7 @@ def default(v, d): def lens_to_mask( - t: int["b"], length: int | None = None # noqa: F722 F821 + t: int["b"], length: int | None = None # noqa: F722 F821 ) -> bool["b n"]: # noqa: F722 F821 if not exists(length): length = t.amax() @@ -48,7 +48,7 @@ def lens_to_mask( def mask_from_start_end_indices( - seq_len: int["b"], start: int["b"], end: int["b"] # noqa: F722 F821 + seq_len: int["b"], start: int["b"], end: int["b"] # noqa: F722 F821 ): max_seq_len = seq_len.max().item() seq = torch.arange(max_seq_len, device=start.device).long() @@ -58,7 +58,7 @@ def mask_from_start_end_indices( def mask_from_frac_lengths( - seq_len: int["b"], frac_lengths: float["b"] # noqa: F722 F821 + seq_len: int["b"], frac_lengths: float["b"] # noqa: F722 F821 ): lengths = (frac_lengths * seq_len).long() max_start = seq_len - lengths @@ -71,7 +71,7 @@ def mask_from_frac_lengths( def maybe_masked_mean( - t: float["b n d"], mask: bool["b n"] = None # noqa: F722 F821 + t: float["b n d"], mask: bool["b n"] = None # noqa: F722 F821 ) -> float["b d"]: # noqa: F722 F821 if not exists(mask): return t.mean(dim=1)