Skip to content

Streamable Text-to-Speech model using a language modeling approach, without vector quantization

Notifications You must be signed in to change notification settings

ictnlp/SLED-TTS

Repository files navigation

SLED-TTS: Efficient Speech Language Modeling via Energy Distance in Continuous Space

HuggingFace WeChat AI ICT/CAS

News

  • Update a guide here to train a continuous autoregressive model in bf16.

Key features

  • Autoregressive Continuous Modeling: SLED models speech in a continuous latent space using a speacial type of maximum mean discrepancy as the objective.
  • Streaming Synthesis: SLED supports streaming synthesis, enabling speech generation to start as soon as the text stream begins.
  • Voice Cloning: Capable of generating speech based on a 3-second prefix or reference utterance as prompt.

Demo

You can check SLED in action by exploring the demo page.

Available Models on Hugging Face

We have made SLED available on Hugging Face, currently offering two distinct English models for different use cases:

  1. SLED-TTS-Libriheavy: This model is trained on the Libriheavy dataset and provides high-quality text-to-speech synthesis.

  2. SLED-TTS-Streaming-Libriheavy: This variant supports streaming decoding, which generates a 0.6-second speech chunk for every 5 text tokens received. It’s ideal for applications requiring low-latency audio generation.

The Mandarin models are on the way! Alternatively, you can train your own SLED-TTS models by following the guidelines below.

Usage

We provide the training and inference code for SLED-TTS.

Installation

git clone https://github.com/ictnlp/SLED-TTS.git
cd SLED-TTS
pip install -e ./

We currently utilize the sum of the first 8 embedding vectors from Encodec_24khz as the continuous latent vector. To proceed, ensure that Encodec_24khz is downloaded and cached in your HuggingFace dir.

Inference

CHECKPOINT=/path/to/checkpoint
CFG=2.0
SEED=0

Offline Inference

python scripts/run_offline.py \
    --model_name_or_path ${CHECKPOINT} \
    --cfg ${CFG} \
    --input "My remark pleases him, but I soon prove to him that it is not the right way to speak. However perfect may have been the language of that ancient writer." \
    --seed ${SEED}

Streaming Inference

python scripts/run_stream.py \
    --model_name_or_path ${CHECKPOINT} \
    --cfg ${CFG} \
    --input "My remark pleases him, but I soon prove to him that it is not the right way to speak. However perfect may have been the language of that ancient writer." \
    --seed ${SEED}
# Please note that we have simulated the generation in a streaming environment in run_stream.py for evaluating its quality.
# However, the existing code does not actually provide a streaming API.

Voice Clone

You can adjust the prompt speech by setting --prompt_text and --prompt_audio.

python scripts/run_voice_clone.py \
    --prompt_text "Were I in the warm room with all the splendor and magnificence!" \
    --prompt_audio "example_prompt.flac" \
    --model_name_or_path ${CHECKPOINT} \
    --cfg ${CFG} \
    --input "Perhaps the other trees from the forest will come to look at me!" \
    --seed ${SEED}

Training

Data Processing #TODO

Training Offline Model

OUTPUT_DIR=./runs/libriheavy
mkdir -p $OUTPUT_DIR
LOG_FILE=${OUTPUT_DIR}/log

BATCH_SIZE=8
UPDATE_FREQ=8
# assume 8 proc per node, then WORLD_SIZE * 8 * BATCH_SIZE * UPDATE_FREQ == 512

torchrun --nnodes ${WORLD_SIZE} --node_rank ${RANK} --nproc_per_node 8 --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} \
    ./scripts/train_libriheavy.py \
    --training_cfg 0.1 \
    --num_hidden_layers 12 --diffloss_d 6 --noise_channels 128 \
    --dataloader_num_workers 8 \
    --dataloader_pin_memory True \
    --remove_unused_columns False \
    --label_names audio_inputs \
    --group_by_speech_length \
    --do_train \
    --do_eval \
    --eval_strategy steps \
    --eval_steps 10000 \
    --prediction_loss_only \
    --per_device_train_batch_size ${BATCH_SIZE} \
    --per_device_eval_batch_size 24 \
    --gradient_accumulation_steps ${UPDATE_FREQ} \
    --bf16 \
    --learning_rate 5e-4 \
    --weight_decay 0.01 \
    --adam_beta1 0.9 \
    --adam_beta2 0.999 \
    --adam_epsilon 1e-8 \
    --max_grad_norm 1.0 \
    --max_steps 300000 \
    --lr_scheduler_type "linear" \
    --warmup_steps 32000 \
    --logging_first_step \
    --logging_steps 100 \
    --save_steps 10000 \
    --save_total_limit 10 \
    --output_dir ${OUTPUT_DIR} \
    --report_to tensorboard \
    --disable_tqdm True \
    --ddp_timeout 3600 --overwrite_output_dir

Training Streaming Model

OUTPUT_DIR=./runs/libriheavy_stream
mkdir -p $OUTPUT_DIR
LOG_FILE=${OUTPUT_DIR}/log

BATCH_SIZE=8
UPDATE_FREQ=8
# assume 8 proc per node, then WORLD_SIZE * 8 * BATCH_SIZE * UPDATE_FREQ == 512

torchrun --nnodes ${WORLD_SIZE} --node_rank ${RANK} --nproc_per_node 8 --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} \
    ./scripts/train_libriheavy_stream.py \
    --finetune_path ./runs/libriheavy/checkpoint-300000/model.safetensors \
    --stream_n 5 --stream_m 45 \
    --training_cfg 0.1 \
    --num_hidden_layers 12 --diffloss_d 6 --noise_channels 128 \
    --dataloader_num_workers 8 \
    --dataloader_pin_memory True \
    --remove_unused_columns False \
    --label_names audio_inputs \
    --group_by_speech_length \
    --do_train \
    --do_eval \
    --eval_strategy steps \
    --eval_steps 10000 \
    --prediction_loss_only \
    --per_device_train_batch_size ${BATCH_SIZE} \
    --per_device_eval_batch_size 24 \
    --gradient_accumulation_steps ${UPDATE_FREQ} \
    --bf16 \
    --learning_rate 3e-4 \
    --weight_decay 0.01 \
    --adam_beta1 0.9 \
    --adam_beta2 0.999 \
    --adam_epsilon 1e-8 \
    --max_grad_norm 1.0 \
    --max_steps 100000 \
    --lr_scheduler_type "linear" \
    --warmup_steps 10000 \
    --logging_first_step \
    --logging_steps 100 \
    --save_steps 10000 \
    --save_total_limit 10 \
    --output_dir ${OUTPUT_DIR} \
    --report_to tensorboard \
    --disable_tqdm True \
    --ddp_timeout 3600 --overwrite_output_dir

BF16 Support

By setting the -bf16 flag, the model will load in bf16 during inference and in fp32 during training (for mixed precision training). To enable pure bf16 training, you can change

torch_dtype = None #torch.bfloat16 if training_args.bf16 else None
to

torch_dtype = torch.bfloat16 if training_args.bf16 else None

However, Encodec should always execute in fp32 to maintain the precision of latents. Therefore, we load Encodec in fp32 and downcast the encoded latent to bf16.

Code Contributors

Ackonwledgement

This work is inspired by following great works:

  • A Proper Loss Is All You Need: Autoregressive Image Generation in Continuous Space via Score Maximization
  • Autoregressive Image Generation without Vector Quantization
  • A Spectral Energy Distance for Parallel Speech Synthesis

Citation

#TODO

About

Streamable Text-to-Speech model using a language modeling approach, without vector quantization

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published