Skip to content

vita-epfl/stable-clearaudio

 
 

Repository files navigation

Stable ClearAudio

This repository is a fork of Stability AI's stable-audio-tools, adapted for the purpose of audio restoration for music. Our focus is on developing and training models for tasks like denoising, de-reverberation, and enhancing the quality of musical recordings.

While the core functionalities from the original repository are preserved, this fork introduces specific configurations and models tailored for audio restoration.

Install

The library can be installed from PyPI with:

$ pip install stable-audio-tools

To run the training scripts or inference code, you'll want to clone this repository, navigate to the root, and run:

$ pip install .
$ pip install frechet-audio-distance==0.3.1 --no-deps
$ pip install resampy==0.4.3

Requirements

Requires PyTorch 2.0 or later for Flash Attention support

Development for the repo is done in Python 3.8.10

Interface

A basic Gradio interface is provided to test out trained models.

For example, to create an interface for an audio restoration model, you can run:

$ python3 ./run_gradio.py --model-config /path/to/your/restoration/model/config.json --ckpt-path /path/to/your/model.ckpt

The run_gradio.py script accepts the following command line arguments:

  • --pretrained-name
    • Hugging Face repository name for a Stable Audio Tools model
    • Will prioritize model.safetensors over model.ckpt in the repo
    • Optional, used in place of model-config and ckpt-path when using pre-trained model checkpoints on Hugging Face
  • --model-config
    • Path to the model config file for a local model
  • --ckpt-path
    • Path to unwrapped model checkpoint file for a local model
  • --pretransform-ckpt-path
    • Path to an unwrapped pretransform checkpoint, replaces the pretransform in the model, useful for testing out fine-tuned decoders
    • Optional
  • --share
    • If true, a publicly shareable link will be created for the Gradio demo
    • Optional
  • --username and --password
    • Used together to set a login for the Gradio demo
    • Optional
  • --model-half
    • If true, the model weights to half-precision
    • Optional

Training

Prerequisites

Before starting your training run, you'll need a model config file, as well as a dataset config file. For more information about those, refer to the Configurations section below

The training code also requires a Weights & Biases account to log the training outputs and demos. Create an account and log in with:

$ wandb login

Start training

To start a training run, run the train.py script in the repo root with:

$ python3 ./train.py --dataset-config /path/to/dataset/config --model-config /path/to/model/config --name clearaudio_train

The --name parameter will set the project name for your Weights and Biases run. The example above uses clearaudio_train.

Training wrappers and model unwrapping

stable-audio-tools uses PyTorch Lightning to facilitate multi-GPU and multi-node training.

When a model is being trained, it is wrapped in a "training wrapper", which is a pl.LightningModule that contains all of the relevant objects needed only for training. That includes things like discriminators for autoencoders, EMA copies of models, and all of the optimizer states.

The checkpoint files created during training include this training wrapper, which greatly increases the size of the checkpoint file.

unwrap_model.py in the repo root will take in a wrapped model checkpoint and save a new checkpoint file including only the model itself.

That can be run with from the repo root with:

$ python3 ./unwrap_model.py --model-config /path/to/model/config --ckpt-path /path/to/wrapped/ckpt --name unwrapped_clearaudio_model

Unwrapped model checkpoints are required for:

  • Inference scripts
  • Using a model as a pretransform for another model (e.g. using an autoencoder model for latent diffusion)
  • Fine-tuning a pre-trained model with a modified configuration (i.e. partial initialization)

Fine-tuning

Fine-tuning a model involves continuning a training run from a pre-trained checkpoint.

To continue a training run from a wrapped model checkpoint, you can pass in the checkpoint path to train.py with the --ckpt-path flag.

To start a fresh training run using a pre-trained unwrapped model, you can pass in the unwrapped checkpoint to train.py with the --pretrained-ckpt-path flag.

Additional training flags

Additional optional flags for train.py include:

  • --config-file
    • The path to the defaults.ini file in the repo root, required if running train.py from a directory other than the repo root
  • --pretransform-ckpt-path
    • Used in various model types such as latent diffusion models to load a pre-trained autoencoder. Requires an unwrapped model checkpoint.
  • --save-dir
    • The directory in which to save the model checkpoints
  • --checkpoint-every
    • The number of steps between saved checkpoints.
    • Default: 10000
  • --batch-size
    • Number of samples per-GPU during training. Should be set as large as your GPU VRAM will allow.
    • Default: 8
  • --val-batch-size
    • Number of samples per-GPU during validation. Should be set as large as (val_dataset_size // num_gpus) // 2 or val_dataset_size // num_gpus. Reduce this number if you get the error "Total length of DataLoader across ranks is zero".
    • Default: 8
  • --num-gpus
    • Number of GPUs per-node to use for training
    • Default: 1
  • --num-nodes
    • Number of GPU nodes being used for training
    • Default: 1
  • --accum-batches
    • Enables and sets the number of batches for gradient batch accumulation. Useful for increasing effective batch size when training on smaller GPUs.
  • --strategy
    • Multi-GPU strategy for distributed training. Setting to deepspeed will enable DeepSpeed ZeRO Stage 2.
    • Default: ddp if --num_gpus > 1, else None
  • --precision
    • floating-point precision to use during training
    • Default: 16
  • --num-workers
    • Number of CPU workers used by the data loader
  • --seed
    • RNG seed for PyTorch, helps with deterministic training

Test Autoencoder

To test the autoencoder, run:

python stable_audio_tools/tests/test_autoencoder.py --config <path_to_model_config.json> --input <path_to_your_clean_audio.wav> --output <path_to_output_audio.wav> --ckpt_path <path_to_your_ckpt>

Example

python train.py \
    --pretrained-ckpt-path stable_audio_tools/checkpoints/model.ckpt \
    --model-config stable_audio_tools/configs/model_configs/audio_restoration/stable_clearaudio_rcp.json \
    --dataset-config stable_audio_tools/configs/dataset_configs/maestro_RCP_intense_eq.json \
    --val-dataset-config stable_audio_tools/configs/dataset_configs/maestro_RCP_intense_eq_valid.json \
    --val-every-n-epoch 1 \
    --checkpoint-every-n-epoch 1 \
    --early-stopping true \
    --early-stopping-patience 5 \
    --save-top-k 3 \
    --name stable-clearaudio-early-stopping-intense-equalizer \
    --save-dir stable_audio_tools/output/checkpoints \
    --strategy ddp_find_unused_parameters_true \
    --batch-size 64

Training Command Explained

This command initiates the training process for the Stable-ClearAudio model, which is designed for audio restoration. The command uses various parameters to configure the training process, including model checkpoints, configurations, datasets, and training strategies.

Here's what each parameter does:

  • python train.py: Executes the main training script.

  • --pretrained-ckpt-path: Specifies the path to a pre-trained model checkpoint to start training from. This enables transfer learning by building upon existing weights.

  • --model-config: Defines the model architecture configuration file. This JSON file contains the specific parameters for the Stable-ClearAudio model with RCP (Restoration, Compression, Processing) capabilities.

  • --dataset-config: Specifies the configuration file for the training dataset. This file configures the Maestro dataset with intense equalizer effects for RCP tasks.

  • --val-dataset-config: Provides the configuration file for the validation dataset, similar to the training dataset but specifically for validation purposes.

  • --val-every-n-epoch: Sets the validation frequency to occur after every epoch, allowing continuous monitoring of model performance.

  • --checkpoint-every-n-epoch: Saves model checkpoints after every epoch, ensuring no training progress is lost.

  • --early-stopping: Enables early stopping, which will halt training if the model performance stops improving.

  • --early-stopping-patience: Sets the early stopping patience to 2 epochs, meaning training will stop if the performance doesn't improve for 2 consecutive epochs.

  • --save-top-k: Keeps the 3 best-performing model checkpoints based on validation metrics.

  • --name: Assigns a descriptive name to this training run for identification and logging purposes.

  • --save-dir: Specifies the directory where model checkpoints will be saved.

  • --strategy: Sets the distributed training strategy to DDP (Distributed Data Parallel) with unused parameter detection, useful for training efficiency and handling complex models.

  • --batch-size: Defines the number of samples processed in each training batch. A batch size of 64 balances between memory usage and training speed.

Configurations

Training and inference code for stable-audio-tools is based around JSON configuration files that define model hyperparameters, training settings, and information about your training dataset.

Model config

The model config file defines all of the information needed to load a model for training or inference. It also contains the training configuration needed to fine-tune a model or train from scratch.

The following properties are defined in the top level of the model configuration:

  • model_type
    • The type of model being defined, currently limited to one of "autoencoder", "diffusion_autoencoder", "lm", "diffusion_cond_restoration", "cold_diffusion_uncond_restoration".
  • sample_size
    • The length of the audio provided to the model during training, in samples. For diffusion models, this is also the raw audio sample length used for inference.
  • sample_rate
    • The sample rate of the audio provided to the model during training, and generated during inference, in Hz.
  • audio_channels
    • The number of channels of audio provided to the model during training, and generated during inference. Defaults to 2. Set to 1 for mono.
  • model
    • The specific configuration for the model being defined, varies based on model_type
  • training
    • The training configuration for the model, varies based on model_type. Provides parameters for training as well as demos.

Dataset config

stable-audio-tools currently supports two kinds of data sources: local directories of audio files, and WebDataset datasets stored in Amazon S3. More information can be found in the dataset config documentation

Todo

  • Add troubleshooting section
  • Add contribution guidelines

About

Generative models for conditional audio restoration

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.9%
  • Dockerfile 0.1%