Skip to content

Commit

Permalink
code clean
Browse files Browse the repository at this point in the history
  • Loading branch information
yuekaizhang committed Jan 20, 2025
1 parent 29de94e commit 4553664
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 195 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:
rev: 5.0.4
hooks:
- id: flake8
args: ["--max-line-length=88", "--extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503, F722, F821"]
args: ["--max-line-length=88", "--extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503"]
#exclude:

# What are we ignoring here?
Expand Down
21 changes: 18 additions & 3 deletions egs/wenetspeech4tts/TTS/f5-tts/infer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
#!/usr/bin/env python3
# Modified from https://github.com/SWivid/F5-TTS/blob/main/src/f5_tts/eval/eval_infer_batch.py
"""
Usage:
# docker: ghcr.io/swivid/f5-tts:main
# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece sherpa-onnx
# huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x
manifest=/path/seed_tts_eval/seedtts_testset/zh/meta.lst
python3 f5-tts/generate_averaged_model.py \
--epoch 56 \
--avg 14 --decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
--exp-dir exp/f5_small
accelerate launch f5-tts/infer.py --nfe 16 --model-path $model_path --manifest-file $manifest --output-dir $output_dir --decoder-dim 768 --nhead 12 --num-decoder-layers 18
bash local/compute_wer.sh $output_dir $manifest
"""
import argparse
import logging
import math
Expand Down Expand Up @@ -62,7 +78,7 @@ def get_parser():
parser.add_argument(
"--manifest-file",
type=str,
default="/home/yuekaiz/seed_tts_eval/seedtts_testset/zh/meta_head.lst",
default="/path/seed_tts_eval/seedtts_testset/zh/meta.lst",
help="The manifest file in seed_tts_eval format",
)

Expand Down Expand Up @@ -180,7 +196,6 @@ def get_inference_prompt(
batch_accum[bucket_i] += total_mel_len

if batch_accum[bucket_i] >= infer_batch_size:
# print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
prompts_all.append(
(
utts[bucket_i],
Expand Down Expand Up @@ -282,7 +297,7 @@ def main():

model = get_model(args).eval().to(device)
checkpoint = torch.load(args.model_path, map_location="cpu")
if "ema_model_state_dict" in checkpoint or 'model_state_dict' in checkpoint:
if "ema_model_state_dict" in checkpoint or "model_state_dict" in checkpoint:
model = load_F5_TTS_pretrained_checkpoint(model, args.model_path)
else:
_ = load_checkpoint(
Expand Down
1 change: 0 additions & 1 deletion egs/wenetspeech4tts/TTS/f5-tts/optim.py

This file was deleted.

111 changes: 46 additions & 65 deletions egs/wenetspeech4tts/TTS/f5-tts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,17 @@
# limitations under the License.
"""
Usage:
# docker: ghcr.io/swivid/f5-tts:main
# pip install k2==1.24.4.dev20241030+cuda12.4.torch2.4.0 -f https://k2-fsa.github.io/k2/cuda.html
# pip install kaldialign lhotse tensorboard bigvganinference sentencepiece
# huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir bigvgan_v2_24khz_100band_256x
world_size=8
exp_dir=exp/ft-tts
exp_dir=exp/f5-tts-small
python3 f5-tts/train.py --max-duration 700 --filter-min-duration 0.5 --filter-max-duration 20 \
--num-buckets 6 --dtype "bfloat16" --save-every-n 5000 --valid-interval 10000 \
--base-lr 7.5e-5 --warmup-steps 20000 --num-epochs 60 \
--num-decoder-layers 18 --nhead 12 --decoder-dim 768 \
--exp-dir ${exp_dir} --world-size ${world_size}
"""

import argparse
Expand All @@ -45,13 +54,10 @@
from model.cfm import CFM
from model.dit import DiT
from model.utils import convert_char_to_pinyin
from optim import Eden, ScaledAdam
from torch.optim.lr_scheduler import LinearLR, SequentialLR
from torch import Tensor

# from torch.cuda.amp import GradScaler
from torch.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import LinearLR, SequentialLR
from torch.utils.tensorboard import SummaryWriter
from tts_datamodule import TtsDataModule
from utils import MetricsTracker
Expand Down Expand Up @@ -87,12 +93,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
default=1024,
help="Embedding dimension in the decoder model.",
)

parser.add_argument(
"--nhead",
type=int,
default=16,
help="Number of attention heads in the Decoder layers.",
)

parser.add_argument(
"--num-decoder-layers",
type=int,
Expand Down Expand Up @@ -156,7 +164,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=Path,
default="exp/valle_dev",
default="exp/f5",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
Expand All @@ -169,7 +177,7 @@ def get_parser():
default="f5-tts/vocab.txt",
help="Path to the unique text tokens file",
)
# /home/yuekaiz//HF/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt

parser.add_argument(
"--pretrained-model-path",
type=str,
Expand All @@ -180,15 +188,9 @@ def get_parser():
parser.add_argument(
"--optimizer-name",
type=str,
default="ScaledAdam",
default="AdamW",
help="The optimizer.",
)
parser.add_argument(
"--scheduler-name",
type=str,
default="Eden",
help="The scheduler.",
)
parser.add_argument(
"--base-lr", type=float, default=0.05, help="The base learning rate."
)
Expand All @@ -203,7 +205,7 @@ def get_parser():
parser.add_argument(
"--decay-steps",
type=int,
default=None,
default=1000000,
help="""Number of steps that affects how rapidly the learning rate
decreases. We suggest not to change this.""",
)
Expand Down Expand Up @@ -286,20 +288,14 @@ def get_parser():
default=0.0,
help="Keep only utterances with duration > this.",
)

parser.add_argument(
"--filter-max-duration",
type=float,
default=20.0,
help="Keep only utterances with duration < this.",
)

parser.add_argument(
"--visualize",
type=str2bool,
default=False,
help="visualize model results in eval step.",
)

parser.add_argument(
"--oom-check",
type=str2bool,
Expand Down Expand Up @@ -383,6 +379,7 @@ def get_tokenizer(vocab_file_path: str):

def get_model(params):
vocab_char_map, vocab_size = get_tokenizer(params.tokens)
# bigvgan 100 dim features
n_mel_channels = 100
n_fft = 1024
sampling_rate = 24_000
Expand Down Expand Up @@ -421,7 +418,6 @@ def get_model(params):
def load_F5_TTS_pretrained_checkpoint(
model, ckpt_path, device: str = "cpu", dtype=torch.float32
):
# model = model.to(dtype)
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
if "ema_model_state_dict" in checkpoint:
checkpoint["model_state_dict"] = {
Expand Down Expand Up @@ -641,14 +637,6 @@ def compute_validation_loss(
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value

# if params.visualize:
# output_dir = Path(f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}")
# output_dir.mkdir(parents=True, exist_ok=True)
# if isinstance(model, DDP):
# model.module.visualize(predicts, batch, output_dir=output_dir)
# else:
# model.visualize(predicts, batch, output_dir=output_dir)

return tot_loss


Expand Down Expand Up @@ -744,11 +732,11 @@ def train_one_epoch(
scaler.scale(loss).backward()
if params.batch_idx_train >= params.accumulate_grad_steps:
if params.batch_idx_train % params.accumulate_grad_steps == 0:
if params.optimizer_name not in ["ScaledAdam", "Eve"]:
# Unscales the gradients of optimizer's assigned params in-place
scaler.unscale_(optimizer)
# Since the gradients of optimizer's assigned params are unscaled, clips as usual:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

# Unscales the gradients of optimizer's assigned params in-place
scaler.unscale_(optimizer)
# Since the gradients of optimizer's assigned params are unscaled, clips as usual:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

scaler.step(optimizer)
scaler.update()
Expand All @@ -757,10 +745,7 @@ def train_one_epoch(
# optimizer.step()

for k in range(params.accumulate_grad_steps):
if isinstance(scheduler, Eden):
scheduler.step_batch(params.batch_idx_train)
else:
scheduler.step()
scheduler.step()

set_batch_count(model, params.batch_idx_train)
except: # noqa
Expand Down Expand Up @@ -940,16 +925,18 @@ def run(rank, world_size, args):

logging.info(f"Device: {device}")
tokenizer = get_tokenizer(params.tokens)
print("the class type of tokenizer is: ", type(tokenizer))
logging.info(params)

logging.info("About to create model")

model = get_model(params)

if params.pretrained_model_path:
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
if "ema_model_state_dict" in checkpoint or 'model_state_dict' in checkpoint:
model = load_F5_TTS_pretrained_checkpoint(model, params.pretrained_model_path)
if "ema_model_state_dict" in checkpoint or "model_state_dict" in checkpoint:
model = load_F5_TTS_pretrained_checkpoint(
model, params.pretrained_model_path
)
else:
_ = load_checkpoint(
params.pretrained_model_path,
Expand Down Expand Up @@ -984,27 +971,24 @@ def run(rank, world_size, args):

model_parameters = model.parameters()

if params.optimizer_name == "ScaledAdam":
optimizer = ScaledAdam(
model_parameters,
lr=params.base_lr,
clipping_scale=2.0,
)
elif params.optimizer_name == "AdamW":
optimizer = torch.optim.AdamW(
model_parameters,
lr=params.base_lr,
betas=(0.9, 0.95),
weight_decay=1e-2,
eps=1e-8,
)
else:
raise NotImplementedError()
optimizer = torch.optim.AdamW(
model_parameters,
lr=params.base_lr,
betas=(0.9, 0.95),
weight_decay=1e-2,
eps=1e-8,
)

warmup_scheduler = LinearLR(optimizer, start_factor=1e-8, end_factor=1.0, total_iters=params.warmup_steps)
decay_scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=1e-8, total_iters=params.decay_steps)
warmup_scheduler = LinearLR(
optimizer, start_factor=1e-8, end_factor=1.0, total_iters=params.warmup_steps
)
decay_scheduler = LinearLR(
optimizer, start_factor=1.0, end_factor=1e-8, total_iters=params.decay_steps
)
scheduler = SequentialLR(
optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[params.warmup_steps]
optimizer,
schedulers=[warmup_scheduler, decay_scheduler],
milestones=[params.warmup_steps],
)

optimizer.zero_grad()
Expand Down Expand Up @@ -1062,8 +1046,6 @@ def run(rank, world_size, args):
scaler.load_state_dict(checkpoints["grad_scaler"])

for epoch in range(params.start_epoch, params.num_epochs + 1):
if isinstance(scheduler, Eden):
scheduler.step_epoch(epoch - 1)

fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
Expand Down Expand Up @@ -1140,7 +1122,6 @@ def scan_pessimistic_batches_for_oom(
"Sanity check -- see if any of the batches in epoch 1 would cause OOM."
)
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
print(23333)
dtype = torch.float32
if params.dtype in ["bfloat16", "bf16"]:
dtype = torch.bfloat16
Expand Down
50 changes: 0 additions & 50 deletions egs/wenetspeech4tts/TTS/f5-tts/tts_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
from typing import Any, Dict, Optional

import torch

# from fbank import MatchaFbank, MatchaFbankConfig
from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures; SpeechSynthesisDataset,
CutConcatenate,
Expand Down Expand Up @@ -185,22 +183,6 @@ def train_dataloaders(
raise NotImplementedError(
"On-the-fly feature extraction is not implemented yet."
)
# sampling_rate = 22050
# config = MatchaFbankConfig(
# n_fft=1024,
# n_mels=80,
# sampling_rate=sampling_rate,
# hop_length=256,
# win_length=1024,
# f_min=0,
# f_max=8000,
# )
# train = SpeechSynthesisDataset(
# return_text=True,
# return_tokens=False,
# feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
# return_cuts=self.args.return_cuts,
# )

if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
Expand Down Expand Up @@ -249,22 +231,6 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
raise NotImplementedError(
"On-the-fly feature extraction is not implemented yet."
)
# sampling_rate = 22050
# config = MatchaFbankConfig(
# n_fft=1024,
# n_mels=80,
# sampling_rate=sampling_rate,
# hop_length=256,
# win_length=1024,
# f_min=0,
# f_max=8000,
# )
# validate = SpeechSynthesisDataset(
# return_text=True,
# return_tokens=False,
# feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
# return_cuts=self.args.return_cuts,
# )
else:
validate = SpeechSynthesisDataset(
return_text=True,
Expand Down Expand Up @@ -296,22 +262,6 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
raise NotImplementedError(
"On-the-fly feature extraction is not implemented yet."
)
# sampling_rate = 22050
# config = MatchaFbankConfig(
# n_fft=1024,
# n_mels=80,
# sampling_rate=sampling_rate,
# hop_length=256,
# win_length=1024,
# f_min=0,
# f_max=8000,
# )
# test = SpeechSynthesisDataset(
# return_text=True,
# return_tokens=False,
# feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)),
# return_cuts=self.args.return_cuts,
# )
else:
test = SpeechSynthesisDataset(
return_text=True,
Expand Down
Loading

0 comments on commit 4553664

Please sign in to comment.