Skip to content

Commit

Permalink
Update utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr authored Jan 27, 2025
1 parent 0f75112 commit 46c0770
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions egs/wenetspeech4tts/TTS/f5-tts/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def default(v, d):


def lens_to_mask(
t: int["b"], length: int | None = None # noqa: F722 F821
t: int["b"], length: int | None = None # noqa: F722 F821
) -> bool["b n"]: # noqa: F722 F821
if not exists(length):
length = t.amax()
Expand All @@ -48,7 +48,7 @@ def lens_to_mask(


def mask_from_start_end_indices(
seq_len: int["b"], start: int["b"], end: int["b"] # noqa: F722 F821
seq_len: int["b"], start: int["b"], end: int["b"] # noqa: F722 F821
):
max_seq_len = seq_len.max().item()
seq = torch.arange(max_seq_len, device=start.device).long()
Expand All @@ -58,7 +58,7 @@ def mask_from_start_end_indices(


def mask_from_frac_lengths(
seq_len: int["b"], frac_lengths: float["b"] # noqa: F722 F821
seq_len: int["b"], frac_lengths: float["b"] # noqa: F722 F821
):
lengths = (frac_lengths * seq_len).long()
max_start = seq_len - lengths
Expand All @@ -71,7 +71,7 @@ def mask_from_frac_lengths(


def maybe_masked_mean(
t: float["b n d"], mask: bool["b n"] = None # noqa: F722 F821
t: float["b n d"], mask: bool["b n"] = None # noqa: F722 F821
) -> float["b d"]: # noqa: F722 F821
if not exists(mask):
return t.mean(dim=1)
Expand Down

0 comments on commit 46c0770

Please sign in to comment.