diff --git a/benchmarks/DASB/FUSS/README.md b/benchmarks/DASB/FUSS/README.md new file mode 100644 index 000000000..a434f3a8c --- /dev/null +++ b/benchmarks/DASB/FUSS/README.md @@ -0,0 +1,65 @@ +# FUSS Separation Task + +This folder defines the **FUSS source separation benchmark** within DASB (Discrete Audio Separation Benchmark). It enables evaluating discrete audio representations on **general-purpose source separation**, using the [FUSS dataset](https://www.tensorflow.org/datasets/catalog/fuss) (Free Universal Sound Separation). + +## Overview + +The goal of this task is to perform **source separation** from complex acoustic mixtures of general sounds and go beyond speech and music, + +This benchmark supports: +- Preparing the FUSS dataset for **supervised training and evaluation** +- Running separation experiments using various discrete audio codecs and backbones (namely Conformer and CRDNN) +- Computing standard evaluation metrics (e.g., SDR) using the Fast-BSSEval library. + +--- + +## Directory Structure + +``` +FUSS +├── create_fuss.py # Generates chunked training data from raw FUSS mixtures +├── README.md +├── separation +│   ├── fuss_prepare.py # Prepare FUSS dataset for supervised SS training +│   ├── train.py # Unified training script for all FUSS experiments +│   ├── utils.py # Audio I/O and utility functions +│   ├── hparams +│   │   ├── conformer +│   │   │   ├── train_dac.yaml # config recipe for Conformer + ... +│   │   └── crdnn +│   │   ├── train_dac.yaml # config recipe for CRDNN + ... +│   └── metrics +│   └── bsseval.py # BSSEval implementation (SDR, SIR, SAR) +└── experiments +``` + + +--- + +## Setup + +**Install dependencies:** + +You may need additional packages for separation and evaluation: +```bash +pip install -r ../extra_requirements.txt +``` + +--- + +## Data Preparation + +- Download the raw FUSS dataset: Please follow the instructions from the [official repo](https://github.com/google-research/sound-separation/tree/master/datasets/fuss) to download the data locally +- Unpack it into a directory `` +- To validate the data and create the FUSS mixtures, you'll need to run `create_fuss.py`. That will create mixtures for all three sets, namely `['eval', 'train', 'validation']` +- Lastly, to create the `.csv` manifests, run `separation/fuss_prepare.py` + +## Running Separation Experiment + +```python +python FUSS/separation/train.py FUSS/separation/hparams/conformer/train_encodec.yaml \ + --data_folder= \ + --output_folder=FUSS/experiments +``` \ No newline at end of file diff --git a/benchmarks/DASB/FUSS/create_fuss.py b/benchmarks/DASB/FUSS/create_fuss.py new file mode 100644 index 000000000..814092107 --- /dev/null +++ b/benchmarks/DASB/FUSS/create_fuss.py @@ -0,0 +1,123 @@ +import os +import argparse +import numpy as np +import soundfile as sf +from tqdm import tqdm + + +def create_silent_audio(reference_path, target_path): + """ + Create a silent audio file with the same length and sampling rate as the reference audio. + + Args: + reference_path (str): Path to the reference audio file. + target_path (str): Path where the silent audio will be saved. + """ + # Read the reference audio to get sampling rate and length + data, samplerate = sf.read(reference_path) + silent_audio = np.zeros_like(data) + + # Save the silent audio + sf.write(target_path, silent_audio, samplerate) + + +def create_mixture_audio(directory, required_files, output_path): + """ + Create a mixture audio file that is a linear mix of all existing audio files in the directory. + + Args: + directory (str): Path to the directory containing the audio files. + required_files (list): List of required audio file names. + output_path (str): Path where the mixture audio will be saved. + """ + mixture = None + samplerate = None + + for file in required_files: + file_path = os.path.join(directory, file) + if os.path.exists(file_path): + data, sr = sf.read(file_path) + if mixture is None: + mixture = np.zeros_like(data, dtype=np.float32) + samplerate = sr + mixture += data + + if mixture is not None and samplerate is not None: + # Normalize the mixture to prevent clipping + # mixture = mixture / len(required_files) + sf.write(output_path, mixture, samplerate) + + +def ensure_audio_files(directory): + """ + Ensure all required audio files exist in a directory. If not, create silent versions of them. + + Args: + directory (str): Path to the directory containing the audio files. + """ + required_files = [ + "background0_sound.wav", + "foreground0_sound.wav", + "foreground1_sound.wav", + "foreground2_sound.wav", + ] + + # Full paths to the required files + required_paths = { + file: os.path.join(directory, file) for file in required_files + } + + # Check if 'background0_sound.wav' exists + background_path = required_paths["background0_sound.wav"] + if not os.path.exists(background_path): + print(f"Error: {background_path} is missing. Cannot proceed.") + return + + # Ensure other files exist, creating silent versions if necessary + for file, path in required_paths.items(): + if not os.path.exists(path): + # print(f"{file} is missing. Creating a silent version.") + create_silent_audio(background_path, path) + + # Create the mixture audio file + mixture_path = os.path.join(directory, "mixture.wav") + create_mixture_audio(directory, required_files, mixture_path) + + +def process_directories(root_directory): + """ + Walk through each subdirectory and ensure required audio files exist and create mixture files. + + Args: + root_directory (str): Path to the root directory of the FUSS eval set. + """ + for subdir, _, _ in tqdm(os.walk(root_directory)): + ensure_audio_files(subdir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Ensure audio files and create mixture files in each subdirectory." + ) + parser.add_argument( + "root_dir", type=str, help="Path to the root directory of the FUSS." + ) + + args = parser.parse_args() + root_dir = args.root_dir + + required_subdirs = ["eval", "train", "validation"] + missing = [ + d + for d in required_subdirs + if not os.path.isdir(os.path.join(root_dir, d)) + ] + + if missing: + raise FileNotFoundError( + f"Missing required subdirectories in '{root_dir}': {', '.join(missing)}" + ) + + for subdir in required_subdirs: + subdir_path = os.path.join(root_dir, subdir) + process_directories(subdir_path) diff --git a/benchmarks/DASB/FUSS/separation/fuss_prepare.py b/benchmarks/DASB/FUSS/separation/fuss_prepare.py new file mode 100644 index 000000000..bc27decc7 --- /dev/null +++ b/benchmarks/DASB/FUSS/separation/fuss_prepare.py @@ -0,0 +1,167 @@ +import csv +import logging +import os +from typing import Optional, Sequence + +from tqdm import tqdm + +import speechbrain as sb + + +__all__ = ["prepare_fuss"] + +SOURCE_NAMES = [ + "background0_sound.wav", + "foreground0_sound.wav", + "foreground1_sound.wav", + "foreground2_sound.wav", +] + +# Workaround to use fastest backend (SoundFile) +try: + import torchaudio + + torchaudio._backend.utils.get_available_backends().pop("ffmpeg", None) +except Exception: + pass + +# Logging configuration +logging.basicConfig( + level=logging.INFO, # format="%(asctime)s [%(levelname)s] %(funcName)s - %(message)s", +) + +_LOGGER = logging.getLogger(__name__) + + +def prepare_fuss( + data_folder: "str", + save_folder: "Optional[str]" = None, + splits: "Sequence[str]" = ("train", "eval", "validation"), +) -> "None": + """Prepare data manifest CSV files for the MUSDB dataset + + Arguments + --------- + data_folder: + The path to the dataset folder. + save_folder: + The path to the folder where the data manifest CSV files will be stored. + Default to `data_folder`. + splits: + The dataset splits to prepare. + num_sources: + The number of speakers (1, 2 or 3). + + Raises + ------ + ValueError + If an invalid argument value is given. + RuntimeError + If one of the expected split folders is missing. + + Examples + -------- + >>> # Expected folder structure: MUSDB/{train, test}//{mixture.wav, bass.wav, others.wav, drums.wav, vocals.wa} + >>> prepare_musdb("MUSDB", num_sources=4) + + """ + if not save_folder: + save_folder = data_folder + + train_data = [] + test_data = [] + valid_data = [] + + # Iterate over train and test splits + for split in splits: + split_dir = os.path.join(data_folder, split) + + # Check if the split directory exists + if not os.path.exists(split_dir): + print(f"Warning: {split_dir} does not exist. Skipping.") + continue + + # Walk through the subdirectories of the split (tracks) + for track_id in tqdm(os.listdir(split_dir), desc=split): + track_dir = os.path.join(split_dir, track_id) + # Ensure the track directory exists and contains the required files + required_files = [ + "mixture.wav", + "background0_sound.wav", + "foreground0_sound.wav", + "foreground1_sound.wav", + "foreground2_sound.wav", + ] + file_paths = {} + + for file_name in required_files: + file_path = os.path.join(track_dir, file_name) + if os.path.exists(file_path): + file_paths[file_name] = file_path + else: + import pdb + + pdb.set_trace() + print( + f"Warning: {file_name} missing in {track_dir}. Skipping track." + ) + file_paths = None + break # If any file is missing, skip the current track + + # If all required files are found, process the track + if file_paths: + # Get the duration of the 'mixture.wav' file + mixture_wav_path = file_paths["mixture.wav"] + info = sb.dataio.dataio.read_audio_info(mixture_wav_path) + duration = info.num_frames / info.sample_rate + + # Prepare the row for the CSV + row = [ + split, + track_id, # ID + duration, # duration + file_paths["mixture.wav"], # mixture_wav + file_paths["background0_sound.wav"], + file_paths["foreground0_sound.wav"], + file_paths["foreground1_sound.wav"], + file_paths["foreground2_sound.wav"], + ] + + # Add the row to the appropriate data list + if split == "train": + train_data.append(row) + elif split == "eval": + test_data.append(row) + elif split == "validation": + valid_data.append(row) + + # Define the CSV file headers + headers = [ + "split", + "ID", + "duration", + "mixture_wav", + "background0_sound_wav", + "foreground0_sound_wav", + "foreground1_sound_wav", + "foreground2_sound_wav", + ] + + # Write the CSV files for each split + for data, split in [ + (train_data, "train"), + (test_data, "eval"), + (valid_data, "validation"), + ]: + output_csv = os.path.join(save_folder, f"{split}.csv") + + with open(output_csv, mode="w", newline="") as file: + writer = csv.writer(file) + writer.writerow(headers) + writer.writerows(data) + print(f"CSV file created for {split}: {output_csv}") + + _LOGGER.info( + "----------------------------------------------------------------------", + ) + _LOGGER.info("Done!") diff --git a/benchmarks/DASB/FUSS/separation/hparams/conformer/train_dac.yaml b/benchmarks/DASB/FUSS/separation/hparams/conformer/train_dac.yaml new file mode 100644 index 000000000..30fcd3d58 --- /dev/null +++ b/benchmarks/DASB/FUSS/separation/hparams/conformer/train_dac.yaml @@ -0,0 +1,219 @@ +# ########################################################################################### +# Model: Conformer with DAC audio representations +# Authors: Luca Della Libera 2024 +# ########################################################################################### + +experiment_name: dac + +# Seed needs to be set at top of YAML +seed: 0 +__set_seed: !apply:torch.manual_seed [!ref ] +testing: False # If set to True, the test evlaution is done, otherwise skipped. + +# Data preparation +data_folder: !PLACEHOLDER +train_csv: !ref /train.csv +valid_csv: !ref /validation.csv +test_csv: !ref /eval.csv +splits: [train, validation, eval] +num_speakers: 4 +add_noise: False +version: wav16k/min + +# Output folders +output_folder: !ref results// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Save options +compute_metrics: True +save_audios: True + +# Preprocessing parameters +train_remove_if_longer: 1000.0 # Seconds +valid_remove_if_longer: 1000.0 # Seconds +test_remove_if_longer: 1000.0 # Seconds +sorting: random +use_cache: True + +# Training parameters +num_epochs: 40 +grad_accumulation_factor: 16 +train_batch_size: 1 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 8 +nonfinite_patience: 10 +max_grad_norm: 5.0 +precision: fp32 +ckpt_interval_minutes: 6000 +keep_checkpoints: 1 +augment: False +augment_prob: 0.75 +use_pit: True + +# Optimizer parameters +lr: 0.0003578 # @orion_step1: --lr~"loguniform(0.00005,0.001)" +weight_decay: 0.01 +improvement_threshold: 0.0025 +annealing_factor: 0.9 +patient: 1 + +# DAC parameters +# sample_rate: [16000, 24000, 44000, 44000] +# vocab_size: [1024, 1024, 1024, 1024] +# max_num_codebooks: [12, 32, 9, 18] +# model_type: [16khz, 24khz, 44khz, 44khz] +# model_bitrate: [8kbps, 8kbps, 8kbps, 16kbps] +sample_rate: 24000 # NOTE: must match DAC's model type +vocab_size: 1024 +num_codebooks: 2 # NOTE: must be smaller or equal to the maximum number of codebooks for the given model type +model_type: 24khz +model_bitrate: 8kbps + +# Embedding parameters +embedding_dim: 1024 +pretrain_embedding: False # If True, must match the codec's embedding size (1024) +freeze_embedding: False + +# Encoder parameters +dropout: 0.1 +activation: !name:torch.nn.GELU +d_model: 256 +nhead: 4 +num_layers: 6 +d_ffn: 2048 +max_length: 2000 +causal: False + +# Augmentation +drop_freq: !new:speechbrain.augment.time_domain.DropFreq + drop_freq_low: 0 # Min frequency band dropout probability + drop_freq_high: 1 # Max frequency band dropout probability + drop_freq_count_low: 1 # Min number of frequency bands to drop + drop_freq_count_high: 3 # Max number of frequency bands to drop + drop_freq_width: 0.05 # Width of frequency bands to drop + +drop_chunk: !new:speechbrain.augment.time_domain.DropChunk + drop_length_low: 1 # Min number of audio chunks to drop + drop_length_high: 5 # Max number of audio chunks to drop + drop_count_low: 1000 # Min length of audio chunks to drop + drop_count_high: 2000 # Max length of audio chunks to drop + +augmentation: !new:speechbrain.augment.augmenter.Augmenter + parallel_augment: False + concat_original: False + repeat_augment: 1 + shuffle_augmentations: False + min_augmentations: 2 + max_augmentations: 2 + augment_prob: !ref + augmentations: [!ref , !ref ] + +# Modules +codec: !new:speechbrain.lobes.models.discrete.dac.DAC + model_type: !ref + model_bitrate: !ref + load_pretrained: True + tag: latest + +embedding: !new:model.custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + +attention_mlp: !new:model.custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR + input_size: !ref + tgt_vocab: -1 + d_model: !ref + nhead: !ref + num_encoder_layers: !ref + num_decoder_layers: 0 + d_ffn: !ref + dropout: !ref + activation: !ref + max_length: !ref + encoder_module: conformer + normalize_before: True + causal: !ref + +head: !new:torch.nn.Linear + in_features: !ref + out_features: !ref * * + +modules: + embedding: !ref + attention_mlp: !ref + encoder: !ref + head: !ref + +model: !new:torch.nn.ModuleList + [[!ref , + !ref , + !ref , + !ref ]] + +# Loss functions +ce_loss: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.0 + allowed_len_diff: 0 + reduction: mean + +# Optimizers +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +# Schedulers +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: !ref + annealing_factor: !ref + patient: !ref + +# Dataloaders +train_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + shuffle: !apply:str.__eq__ [!ref , random] + +valid_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +test_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +# Performance metrics +ter_computer: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +bsseval_computer: !name:metrics.bsseval.BSSEval + n_sources: !ref + permutation_invariant: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/FUSS/separation/hparams/conformer/train_encodec.yaml b/benchmarks/DASB/FUSS/separation/hparams/conformer/train_encodec.yaml new file mode 100644 index 000000000..439ceeb7e --- /dev/null +++ b/benchmarks/DASB/FUSS/separation/hparams/conformer/train_encodec.yaml @@ -0,0 +1,220 @@ +# ########################################################################################### +# Model: Conformer with EnCodec audio representations +# Authors: Luca Della Libera 2024 +# ########################################################################################### + +experiment_name: encodec + +# Seed needs to be set at top of YAML +seed: 0 +__set_seed: !apply:torch.manual_seed [!ref ] +testing: False # If set to True, the test evlaution is done, otherwise skipped. + +# Data preparation +data_folder: !PLACEHOLDER +train_csv: !ref /train.csv +valid_csv: !ref /validation.csv +test_csv: !ref /eval.csv +splits: [train, validation, eval] +num_speakers: 4 +add_noise: False +version: wav16k/min + +# Output folders +output_folder: !ref results// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Save options +compute_metrics: True +save_audios: True + +# Preprocessing parameters +train_remove_if_longer: 1000.0 # Seconds +valid_remove_if_longer: 1000.0 # Seconds +test_remove_if_longer: 1000.0 # Seconds +sorting: random +use_cache: True + +# Training parameters +num_epochs: 40 +grad_accumulation_factor: 16 +train_batch_size: 1 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 8 +nonfinite_patience: 10 +max_grad_norm: 5.0 +precision: fp32 +ckpt_interval_minutes: 6000 +keep_checkpoints: 1 +augment: False +augment_prob: 0.75 +use_pit: True + +# Optimizer parameters +lr: 0.0003578 # @orion_step1: --lr~"loguniform(0.00005,0.001)" +weight_decay: 0.01 +improvement_threshold: 0.0025 +annealing_factor: 0.9 +patient: 1 + +# EnCodec parameters +# sample_rate: [24000, 24000, 24000, 24000] +# vocab_size: [1024, 1024, 1024, 1024] +# num_codebooks: [2, 4, 8, 16, 32] +# bandwidth: [1.5, 3.0, 6.0, 12.0, 24.0] +sample_rate: 24000 +vocab_size: 1024 +num_codebooks: 2 +bandwidth: !ref * 75 / 100 + +# Embedding parameters +embedding_dim: 1024 +pretrain_embedding: False # If True, must match the codec's embedding size (128) +freeze_embedding: False + +# Encoder parameters +dropout: 0.1 +activation: !name:torch.nn.GELU +d_model: 256 +nhead: 4 +num_layers: 6 +d_ffn: 2048 +max_length: 2000 +causal: False + +# Augmentation +drop_freq: !new:speechbrain.augment.time_domain.DropFreq + drop_freq_low: 0 # Min frequency band dropout probability + drop_freq_high: 1 # Max frequency band dropout probability + drop_freq_count_low: 1 # Min number of frequency bands to drop + drop_freq_count_high: 3 # Max number of frequency bands to drop + drop_freq_width: 0.05 # Width of frequency bands to drop + +drop_chunk: !new:speechbrain.augment.time_domain.DropChunk + drop_length_low: 1 # Min number of audio chunks to drop + drop_length_high: 5 # Max number of audio chunks to drop + drop_count_low: 1000 # Min length of audio chunks to drop + drop_count_high: 2000 # Max length of audio chunks to drop + +augmentation: !new:speechbrain.augment.augmenter.Augmenter + parallel_augment: False + concat_original: False + repeat_augment: 1 + shuffle_augmentations: False + min_augmentations: 2 + max_augmentations: 2 + augment_prob: !ref + augmentations: [!ref , !ref ] + +# Modules +codec: !new:speechbrain.lobes.models.huggingface_transformers.encodec.Encodec + source: facebook/encodec_24khz # Only the 24kHz version supports mono audio + save_path: !ref + sample_rate: !ref + bandwidth: !ref + flat_embeddings: False + freeze: True + renorm_embeddings: False + +embedding: !new:model.custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + +attention_mlp: !new:model.custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR + input_size: !ref + tgt_vocab: -1 + d_model: !ref + nhead: !ref + num_encoder_layers: !ref + num_decoder_layers: 0 + d_ffn: !ref + dropout: !ref + activation: !ref + max_length: !ref + encoder_module: conformer + normalize_before: True + causal: !ref + +head: !new:torch.nn.Linear + in_features: !ref + out_features: !ref * * + +modules: + embedding: !ref + attention_mlp: !ref + encoder: !ref + head: !ref + +model: !new:torch.nn.ModuleList + [[!ref , + !ref , + !ref , + !ref ]] + +# Loss functions +ce_loss: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.0 + allowed_len_diff: 0 + reduction: mean + +# Optimizers +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +# Schedulers +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: !ref + annealing_factor: !ref + patient: !ref + +# Dataloaders +train_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + shuffle: !apply:str.__eq__ [!ref , random] + +valid_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +test_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +# Performance metrics +ter_computer: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +bsseval_computer: !name:metrics.bsseval.BSSEval + n_sources: !ref + permutation_invariant: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/FUSS/separation/hparams/conformer/train_sqcodec.yaml b/benchmarks/DASB/FUSS/separation/hparams/conformer/train_sqcodec.yaml new file mode 100644 index 000000000..040937b02 --- /dev/null +++ b/benchmarks/DASB/FUSS/separation/hparams/conformer/train_sqcodec.yaml @@ -0,0 +1,224 @@ +# ########################################################################################### +# Model: Conformer with EnCodec audio representations +# Authors: Luca Della Libera 2024 +# ########################################################################################### + +experiment_name: encodec + +# Seed needs to be set at top of YAML +seed: 0 +__set_seed: !apply:torch.manual_seed [!ref ] +testing: False # If set to True, the test evlaution is done, otherwise skipped. + +# Data preparation +data_folder: !PLACEHOLDER +train_csv: !ref /train.csv +valid_csv: !ref /validation.csv +test_csv: !ref /eval.csv +splits: [train, validation, eval] +num_speakers: 4 +add_noise: False +version: wav16k/min + +# Output folders +output_folder: !ref results// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Save options +compute_metrics: True +save_audios: True + +# Preprocessing parameters +train_remove_if_longer: 1000.0 # Seconds +valid_remove_if_longer: 1000.0 # Seconds +test_remove_if_longer: 1000.0 # Seconds +sorting: random +use_cache: True + +# Training parameters +num_epochs: 40 +grad_accumulation_factor: 16 +train_batch_size: 1 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 8 +nonfinite_patience: 10 +max_grad_norm: 5.0 +precision: fp32 +ckpt_interval_minutes: 6000 +keep_checkpoints: 1 +augment: False +augment_prob: 0.75 +use_pit: True + +# Optimizer parameters +lr: 0.0003578 # @orion_step1: --lr~"loguniform(0.00005,0.001)" +weight_decay: 0.01 +improvement_threshold: 0.0025 +annealing_factor: 0.9 +patient: 1 + +# EnCodec parameters +# sample_rate: [24000, 24000, 24000, 24000] +# vocab_size: [1024, 1024, 1024, 1024] +# num_codebooks: [2, 4, 8, 16, 32] +# bandwidth: [1.5, 3.0, 6.0, 12.0, 24.0] +sample_rate: 16000 +vocab_size: 19683 +num_codebooks: 4 +bandwidth: 2 + +# Embedding parameters +pretrain_embedding: False # If True, must match the codec's embedding size (128) +freeze_embedding: False +encoder_dim: 1024 +embedding_dim: 9 +hidden_dim: 256 +# if set to concat, you need to set embedding_dim to match the encoder_dim after concatenation. Eg, if you have 4 codebook, embedding_dim shoudl set to encoder_dim/4 +embedding_strg: concat # option are concat and att_pool +scalar_embedding: True + +# Encoder parameters +dropout: 0.1 +activation: !name:torch.nn.GELU +d_model: 256 +nhead: 4 +num_layers: 6 +d_ffn: 2048 +max_length: 2000 +causal: False + +# Augmentation +drop_freq: !new:speechbrain.augment.time_domain.DropFreq + drop_freq_low: 0 # Min frequency band dropout probability + drop_freq_high: 1 # Max frequency band dropout probability + drop_freq_count_low: 1 # Min number of frequency bands to drop + drop_freq_count_high: 3 # Max number of frequency bands to drop + drop_freq_width: 0.05 # Width of frequency bands to drop + +drop_chunk: !new:speechbrain.augment.time_domain.DropChunk + drop_length_low: 1 # Min number of audio chunks to drop + drop_length_high: 5 # Max number of audio chunks to drop + drop_count_low: 1000 # Min length of audio chunks to drop + drop_count_high: 2000 # Max length of audio chunks to drop + +augmentation: !new:speechbrain.augment.augmenter.Augmenter + parallel_augment: False + concat_original: False + repeat_augment: 1 + shuffle_augmentations: False + min_augmentations: 2 + max_augmentations: 2 + augment_prob: !ref + augmentations: [!ref , !ref ] + +# Modules +codec: !new:sq_codec.SQCodec + save_path: !ref + config: config.yaml + checkpoint: ckpt_00190000.pth + +embedding: !new:model.custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + init: !ref + scalar: !ref + hidden_dim: !ref + +attention_mlp: !new:model.custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR + input_size: !ref + tgt_vocab: -1 + d_model: !ref + nhead: !ref + num_encoder_layers: !ref + num_decoder_layers: 0 + d_ffn: !ref + dropout: !ref + activation: !ref + max_length: !ref + encoder_module: conformer + normalize_before: True + causal: !ref + +head: !new:torch.nn.Linear + in_features: !ref + out_features: !ref * * + +modules: + embedding: !ref + attention_mlp: !ref + encoder: !ref + head: !ref + +model: !new:torch.nn.ModuleList + [[!ref , + !ref , + !ref , + !ref ]] + +# Loss functions +ce_loss: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.0 + allowed_len_diff: 0 + reduction: mean + +# Optimizers +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +# Schedulers +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: !ref + annealing_factor: !ref + patient: !ref + +# Dataloaders +train_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + shuffle: !apply:str.__eq__ [!ref , random] + +valid_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +test_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +# Performance metrics +ter_computer: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +bsseval_computer: !name:metrics.bsseval.BSSEval + n_sources: !ref + permutation_invariant: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/FUSS/separation/hparams/conformer/train_wavtokenizer.yaml b/benchmarks/DASB/FUSS/separation/hparams/conformer/train_wavtokenizer.yaml new file mode 100644 index 000000000..e5ebb0313 --- /dev/null +++ b/benchmarks/DASB/FUSS/separation/hparams/conformer/train_wavtokenizer.yaml @@ -0,0 +1,223 @@ +# ########################################################################################### +# Model: Conformer with Wavtokenizer audio representations +# Authors: Luca Della Libera 2024 +# ########################################################################################### + +experiment_name: wavtokenizer + +# Seed needs to be set at top of YAML +seed: 0 +__set_seed: !apply:torch.manual_seed [!ref ] +testing: False # If set to True, the test evlaution is done, otherwise skipped. + +# Data preparation +data_folder: !PLACEHOLDER +train_csv: !ref /train.csv +valid_csv: !ref /validation.csv +test_csv: !ref /eval.csv +splits: [train, validation, eval] +num_speakers: 4 +add_noise: False +version: wav16k/min + +# Output folders +output_folder: !ref results// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Save options +compute_metrics: True +save_audios: True + +# Preprocessing parameters +train_remove_if_longer: 1000.0 # Seconds +valid_remove_if_longer: 1000.0 # Seconds +test_remove_if_longer: 1000.0 # Seconds +sorting: random +use_cache: True + +# Training parameters +num_epochs: 40 +grad_accumulation_factor: 16 +train_batch_size: 1 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 8 +nonfinite_patience: 10 +max_grad_norm: 5.0 +precision: fp32 +ckpt_interval_minutes: 6000 +keep_checkpoints: 1 +augment: False +augment_prob: 0.75 +use_pit: True + +# Optimizer parameters +lr: 0.0003578 # @orion_step1: --lr~"loguniform(0.00005,0.001)" +weight_decay: 0.01 +improvement_threshold: 0.0025 +annealing_factor: 0.9 +patient: 1 + +# EnCodec parameters +# sample_rate: [24000, 24000, 24000, 24000] +# vocab_size: [1024, 1024, 1024, 1024] +# num_codebooks: [2, 4, 8, 16, 32] +# bandwidth: [1.5, 3.0, 6.0, 12.0, 24.0] +sample_rate: 24000 +vocab_size: 4096 +num_codebooks: 1 +bandwidth: 2 + +# Embedding parameters +embedding_dim: 1024 +pretrain_embedding: False # If True, must match the codec's embedding size (128) +freeze_embedding: False + +# Encoder parameters +dropout: 0.1 +activation: !name:torch.nn.GELU +d_model: 256 +nhead: 4 +num_layers: 6 +d_ffn: 2048 +max_length: 2000 +causal: False + +# Augmentation +drop_freq: !new:speechbrain.augment.time_domain.DropFreq + drop_freq_low: 0 # Min frequency band dropout probability + drop_freq_high: 1 # Max frequency band dropout probability + drop_freq_count_low: 1 # Min number of frequency bands to drop + drop_freq_count_high: 3 # Max number of frequency bands to drop + drop_freq_width: 0.05 # Width of frequency bands to drop + +drop_chunk: !new:speechbrain.augment.time_domain.DropChunk + drop_length_low: 1 # Min number of audio chunks to drop + drop_length_high: 5 # Max number of audio chunks to drop + drop_count_low: 1000 # Min length of audio chunks to drop + drop_count_high: 2000 # Max length of audio chunks to drop + +augmentation: !new:speechbrain.augment.augmenter.Augmenter + parallel_augment: False + concat_original: False + repeat_augment: 1 + shuffle_augmentations: False + min_augmentations: 2 + max_augmentations: 2 + augment_prob: !ref + augmentations: [!ref , !ref ] + +model_hub: novateur/WavTokenizer-medium-music-audio-75token +config: wavtokenizer_mediumdata_music_audio_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml +checkpoint: wavtokenizer_medium_music_audio_320_24k_v2.ckpt + +# Modules +codec: !new:speechbrain.lobes.models.discrete.wavtokenizer.WavTokenizer + source: !ref + save_path: !ref + checkpoint: !ref + config: !ref + sample_rate: !ref + freeze: True + +embedding: !new:model.custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + +attention_mlp: !new:model.custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR + input_size: !ref + tgt_vocab: -1 + d_model: !ref + nhead: !ref + num_encoder_layers: !ref + num_decoder_layers: 0 + d_ffn: !ref + dropout: !ref + activation: !ref + max_length: !ref + encoder_module: conformer + normalize_before: True + causal: !ref + +head: !new:torch.nn.Linear + in_features: !ref + out_features: !ref * * + +modules: + embedding: !ref + attention_mlp: !ref + encoder: !ref + head: !ref + +model: !new:torch.nn.ModuleList + [[!ref , + !ref , + !ref , + !ref ]] + +# Loss functions +ce_loss: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.0 + allowed_len_diff: 0 + reduction: mean + +# Optimizers +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +# Schedulers +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: !ref + annealing_factor: !ref + patient: !ref + +# Dataloaders +train_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + shuffle: !apply:str.__eq__ [!ref , random] + +valid_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +test_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +# Performance metrics +ter_computer: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +bsseval_computer: !name:metrics.bsseval.BSSEval + n_sources: !ref + permutation_invariant: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/FUSS/separation/hparams/crdnn/train_dac.yaml b/benchmarks/DASB/FUSS/separation/hparams/crdnn/train_dac.yaml new file mode 100644 index 000000000..846d0fcbf --- /dev/null +++ b/benchmarks/DASB/FUSS/separation/hparams/crdnn/train_dac.yaml @@ -0,0 +1,229 @@ +# ########################################################################################### +# Model: CRDNN with DAC audio representations +# Authors: Luca Della Libera 2024 +# ########################################################################################### + +experiment_name: dac + +# Seed needs to be set at top of YAML +seed: 0 +__set_seed: !apply:torch.manual_seed [!ref ] +testing: False # If set to True, the test evlaution is done, otherwise skipped. + +# Data preparation +data_folder: !PLACEHOLDER +train_csv: !ref /train.csv +valid_csv: !ref /validation.csv +test_csv: !ref /eval.csv +splits: [train, validation, eval] +num_speakers: 4 +add_noise: False +version: wav16k/min + +# Output folders +output_folder: !ref results// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Save options +compute_metrics: True +save_audios: True + +# Preprocessing parameters +train_remove_if_longer: 1000.0 # Seconds +valid_remove_if_longer: 1000.0 # Seconds +test_remove_if_longer: 1000.0 # Seconds +sorting: random +use_cache: True + +# Training parameters +num_epochs: 40 +grad_accumulation_factor: 16 +train_batch_size: 1 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 8 +nonfinite_patience: 10 +max_grad_norm: 5.0 +precision: fp32 +ckpt_interval_minutes: 6000 +keep_checkpoints: 1 +augment: False +augment_prob: 0.75 +use_pit: True + +# Optimizer parameters +lr: 0.0003578 # @orion_step1: --lr~"loguniform(0.00005,0.001)" +weight_decay: 0.01 +improvement_threshold: 0.0025 +annealing_factor: 0.9 +patient: 1 + +# DAC parameters +# sample_rate: [16000, 24000, 44000, 44000] +# vocab_size: [1024, 1024, 1024, 1024] +# max_num_codebooks: [12, 32, 9, 18] +# model_type: [16khz, 24khz, 44khz, 44khz] +# model_bitrate: [8kbps, 8kbps, 8kbps, 16kbps] +sample_rate: 24000 # NOTE: must match DAC's model type +vocab_size: 1024 +num_codebooks: 2 # NOTE: must be smaller or equal to the maximum number of codebooks for the given model type +model_type: 24khz +model_bitrate: 8kbps + +# Embedding parameters +embedding_dim: 1024 +pretrain_embedding: False # If True, must match the codec's embedding size (1024) +freeze_embedding: False + +# Encoder parameters +dropout: 0.1 +activation: !name:torch.nn.LeakyReLU +rnn_class: !name:speechbrain.nnet.RNN.LSTM +rnn_layers: 4 +time_pooling_size: 1 +rnn_bidirectional: True +rnn_neurons: 256 +dnn_blocks: 2 +dnn_neurons: 256 +cnn_blocks: 2 +cnn_channels: (16, 16) +inter_layer_pooling_size: (2, 2) +cnn_kernelsize: (3, 3) + +# Augmentation +drop_freq: !new:speechbrain.augment.time_domain.DropFreq + drop_freq_low: 0 # Min frequency band dropout probability + drop_freq_high: 1 # Max frequency band dropout probability + drop_freq_count_low: 1 # Min number of frequency bands to drop + drop_freq_count_high: 3 # Max number of frequency bands to drop + drop_freq_width: 0.05 # Width of frequency bands to drop + +drop_chunk: !new:speechbrain.augment.time_domain.DropChunk + drop_length_low: 1 # Min number of audio chunks to drop + drop_length_high: 5 # Max number of audio chunks to drop + drop_count_low: 1000 # Min length of audio chunks to drop + drop_count_high: 2000 # Max length of audio chunks to drop + +augmentation: !new:speechbrain.augment.augmenter.Augmenter + parallel_augment: False + concat_original: False + repeat_augment: 1 + shuffle_augmentations: False + min_augmentations: 2 + max_augmentations: 2 + augment_prob: !ref + augmentations: [!ref , !ref ] + +# Modules +codec: !new:speechbrain.lobes.models.discrete.dac.DAC + model_type: !ref + model_bitrate: !ref + load_pretrained: True + tag: latest + +embedding: !new:model.custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + +attention_mlp: !new:model.custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.CRDNN.CRDNN + input_shape: [null, null, !ref ] + activation: !ref + dropout: !ref + cnn_blocks: !ref + cnn_channels: !ref + cnn_kernelsize: !ref + inter_layer_pooling_size: !ref + time_pooling: True + using_2d_pooling: False + time_pooling_size: !ref + rnn_class: !ref + rnn_layers: !ref + rnn_neurons: !ref + rnn_bidirectional: !ref + dnn_blocks: !ref + dnn_neurons: !ref + rnn_re_init: True + use_rnnp: False + +head: !new:torch.nn.Linear + in_features: !ref + out_features: !ref * * + +modules: + embedding: !ref + attention_mlp: !ref + encoder: !ref + head: !ref + +model: !new:torch.nn.ModuleList + [[!ref , + !ref , + !ref , + !ref ]] + +# Loss functions +ce_loss: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.0 + allowed_len_diff: 0 + reduction: mean + +# Optimizers +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +# Schedulers +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: !ref + annealing_factor: !ref + patient: !ref + +# Dataloaders +train_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + shuffle: !apply:str.__eq__ [!ref , random] + +valid_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +test_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +# Performance metrics +ter_computer: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +bsseval_computer: !name:metrics.bsseval.BSSEval + n_sources: !ref + permutation_invariant: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/FUSS/separation/hparams/crdnn/train_encodec.yaml b/benchmarks/DASB/FUSS/separation/hparams/crdnn/train_encodec.yaml new file mode 100644 index 000000000..4cc82e8a5 --- /dev/null +++ b/benchmarks/DASB/FUSS/separation/hparams/crdnn/train_encodec.yaml @@ -0,0 +1,230 @@ +# ########################################################################################### +# Model: CRDNN with EnCodec audio representations +# Authors: Luca Della Libera 2024 +# ########################################################################################### + +experiment_name: encodec + +# Seed needs to be set at top of YAML +seed: 0 +__set_seed: !apply:torch.manual_seed [!ref ] +testing: False # If set to True, the test evlaution is done, otherwise skipped. + +# Data preparation +data_folder: !PLACEHOLDER +train_csv: !ref /train.csv +valid_csv: !ref /validation.csv +test_csv: !ref /eval.csv +splits: [train, validation, eval] +num_speakers: 4 +add_noise: False +version: wav16k/min + +# Output folders +output_folder: !ref results// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Save options +compute_metrics: True +save_audios: True + +# Preprocessing parameters +train_remove_if_longer: 1000.0 # Seconds +valid_remove_if_longer: 1000.0 # Seconds +test_remove_if_longer: 1000.0 # Seconds +sorting: random +use_cache: True + +# Training parameters +num_epochs: 40 +grad_accumulation_factor: 16 +train_batch_size: 1 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 8 +nonfinite_patience: 10 +max_grad_norm: 5.0 +precision: fp32 +ckpt_interval_minutes: 6000 +keep_checkpoints: 1 +augment: False +augment_prob: 0.75 +use_pit: True + +# Optimizer parameters +lr: 0.0003578 # @orion_step1: --lr~"loguniform(0.00005,0.001)" +weight_decay: 0.01 +improvement_threshold: 0.0025 +annealing_factor: 0.9 +patient: 1 + +# EnCodec parameters +# sample_rate: [24000, 24000, 24000, 24000] +# vocab_size: [1024, 1024, 1024, 1024] +# num_codebooks: [2, 4, 8, 16, 32] +# bandwidth: [1.5, 3.0, 6.0, 12.0, 24.0] +sample_rate: 24000 +vocab_size: 1024 +num_codebooks: 2 +bandwidth: !ref * 75 / 100 + +# Embedding parameters +embedding_dim: 1024 +pretrain_embedding: False # If True, must match the codec's embedding size (128) +freeze_embedding: False + +# Encoder parameters +dropout: 0.1 +activation: !name:torch.nn.LeakyReLU +rnn_class: !name:speechbrain.nnet.RNN.LSTM +rnn_layers: 4 +time_pooling_size: 1 +rnn_bidirectional: True +rnn_neurons: 256 +dnn_blocks: 2 +dnn_neurons: 256 +cnn_blocks: 2 +cnn_channels: (16, 16) +inter_layer_pooling_size: (2, 2) +cnn_kernelsize: (3, 3) + +# Augmentation +drop_freq: !new:speechbrain.augment.time_domain.DropFreq + drop_freq_low: 0 # Min frequency band dropout probability + drop_freq_high: 1 # Max frequency band dropout probability + drop_freq_count_low: 1 # Min number of frequency bands to drop + drop_freq_count_high: 3 # Max number of frequency bands to drop + drop_freq_width: 0.05 # Width of frequency bands to drop + +drop_chunk: !new:speechbrain.augment.time_domain.DropChunk + drop_length_low: 1 # Min number of audio chunks to drop + drop_length_high: 5 # Max number of audio chunks to drop + drop_count_low: 1000 # Min length of audio chunks to drop + drop_count_high: 2000 # Max length of audio chunks to drop + +augmentation: !new:speechbrain.augment.augmenter.Augmenter + parallel_augment: False + concat_original: False + repeat_augment: 1 + shuffle_augmentations: False + min_augmentations: 2 + max_augmentations: 2 + augment_prob: !ref + augmentations: [!ref , !ref ] + +# Modules +codec: !new:speechbrain.lobes.models.huggingface_transformers.encodec.Encodec + source: facebook/encodec_24khz # Only the 24kHz version supports mono audio + save_path: !ref + sample_rate: !ref + bandwidth: !ref + flat_embeddings: False + freeze: True + renorm_embeddings: False + +embedding: !new:model.custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + +attention_mlp: !new:model.custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.CRDNN.CRDNN + input_shape: [null, null, !ref ] + activation: !ref + dropout: !ref + cnn_blocks: !ref + cnn_channels: !ref + cnn_kernelsize: !ref + inter_layer_pooling_size: !ref + time_pooling: True + using_2d_pooling: False + time_pooling_size: !ref + rnn_class: !ref + rnn_layers: !ref + rnn_neurons: !ref + rnn_bidirectional: !ref + dnn_blocks: !ref + dnn_neurons: !ref + rnn_re_init: True + use_rnnp: False + +head: !new:torch.nn.Linear + in_features: !ref + out_features: !ref * * + +modules: + embedding: !ref + attention_mlp: !ref + encoder: !ref + head: !ref + +model: !new:torch.nn.ModuleList + [[!ref , + !ref , + !ref , + !ref ]] + +# Loss functions +ce_loss: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.0 + allowed_len_diff: 0 + reduction: mean + +# Optimizers +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +# Schedulers +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: !ref + annealing_factor: !ref + patient: !ref + +# Dataloaders +train_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + shuffle: !apply:str.__eq__ [!ref , random] + +valid_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +test_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +# Performance metrics +ter_computer: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +bsseval_computer: !name:metrics.bsseval.BSSEval + n_sources: !ref + permutation_invariant: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/FUSS/separation/hparams/crdnn/train_sqcodec.yaml b/benchmarks/DASB/FUSS/separation/hparams/crdnn/train_sqcodec.yaml new file mode 100644 index 000000000..2d1983a8d --- /dev/null +++ b/benchmarks/DASB/FUSS/separation/hparams/crdnn/train_sqcodec.yaml @@ -0,0 +1,234 @@ +# ########################################################################################### +# Model: CRDNN with SQCodec audio representations +# Authors: Luca Della Libera 2024 +# ########################################################################################### + +experiment_name: sqcodec + +# Seed needs to be set at top of YAML +seed: 0 +__set_seed: !apply:torch.manual_seed [!ref ] +testing: False # If set to True, the test evlaution is done, otherwise skipped. + +# Data preparation +data_folder: !PLACEHOLDER +train_csv: !ref /train.csv +valid_csv: !ref /validation.csv +test_csv: !ref /eval.csv +splits: [train, validation, eval] +num_speakers: 4 +add_noise: False +version: wav16k/min + +# Output folders +output_folder: !ref results// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Save options +compute_metrics: True +save_audios: True + +# Preprocessing parameters +train_remove_if_longer: 1000.0 # Seconds +valid_remove_if_longer: 1000.0 # Seconds +test_remove_if_longer: 1000.0 # Seconds +sorting: random +use_cache: True + +# Training parameters +num_epochs: 40 +grad_accumulation_factor: 16 +train_batch_size: 1 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 8 +nonfinite_patience: 10 +max_grad_norm: 5.0 +precision: fp32 +ckpt_interval_minutes: 6000 +keep_checkpoints: 1 +augment: False +augment_prob: 0.75 +use_pit: True + +# Optimizer parameters +lr: 0.0003578 # @orion_step1: --lr~"loguniform(0.00005,0.001)" +weight_decay: 0.01 +improvement_threshold: 0.0025 +annealing_factor: 0.9 +patient: 1 + +# EnCodec parameters +# sample_rate: [24000, 24000, 24000, 24000] +# vocab_size: [1024, 1024, 1024, 1024] +# num_codebooks: [2, 4, 8, 16, 32] +# bandwidth: [1.5, 3.0, 6.0, 12.0, 24.0] +sample_rate: 16000 +vocab_size: 19683 +num_codebooks: 4 +bandwidth: 2 + +# Embedding parameters +pretrain_embedding: False # If True, must match the codec's embedding size (128) +freeze_embedding: False +encoder_dim: 1024 +embedding_dim: 9 +hidden_dim: 256 +# if set to concat, you need to set embedding_dim to match the encoder_dim after concatenation. Eg, if you have 4 codebook, embedding_dim shoudl set to encoder_dim/4 +embedding_strg: concat # option are concat and att_pool +scalar_embedding: True + +# Encoder parameters +dropout: 0.1 +activation: !name:torch.nn.LeakyReLU +rnn_class: !name:speechbrain.nnet.RNN.LSTM +rnn_layers: 4 +time_pooling_size: 1 +rnn_bidirectional: True +rnn_neurons: 256 +dnn_blocks: 2 +dnn_neurons: 256 +cnn_blocks: 2 +cnn_channels: (16, 16) +inter_layer_pooling_size: (2, 2) +cnn_kernelsize: (3, 3) + +# Augmentation +drop_freq: !new:speechbrain.augment.time_domain.DropFreq + drop_freq_low: 0 # Min frequency band dropout probability + drop_freq_high: 1 # Max frequency band dropout probability + drop_freq_count_low: 1 # Min number of frequency bands to drop + drop_freq_count_high: 3 # Max number of frequency bands to drop + drop_freq_width: 0.05 # Width of frequency bands to drop + +drop_chunk: !new:speechbrain.augment.time_domain.DropChunk + drop_length_low: 1 # Min number of audio chunks to drop + drop_length_high: 5 # Max number of audio chunks to drop + drop_count_low: 1000 # Min length of audio chunks to drop + drop_count_high: 2000 # Max length of audio chunks to drop + +augmentation: !new:speechbrain.augment.augmenter.Augmenter + parallel_augment: False + concat_original: False + repeat_augment: 1 + shuffle_augmentations: False + min_augmentations: 2 + max_augmentations: 2 + augment_prob: !ref + augmentations: [!ref , !ref ] + +# Modules +codec: !new:sq_codec.SQCodec + save_path: !ref + config: config.yaml + checkpoint: ckpt_00190000.pth + +embedding: !new:model.custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + init: !ref + scalar: !ref + hidden_dim: !ref + +attention_mlp: !new:model.custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.CRDNN.CRDNN + input_shape: [null, null, !ref ] + activation: !ref + dropout: !ref + cnn_blocks: !ref + cnn_channels: !ref + cnn_kernelsize: !ref + inter_layer_pooling_size: !ref + time_pooling: True + using_2d_pooling: False + time_pooling_size: !ref + rnn_class: !ref + rnn_layers: !ref + rnn_neurons: !ref + rnn_bidirectional: !ref + dnn_blocks: !ref + dnn_neurons: !ref + rnn_re_init: True + use_rnnp: False + +head: !new:torch.nn.Linear + in_features: !ref + out_features: !ref * * + +modules: + embedding: !ref + attention_mlp: !ref + encoder: !ref + head: !ref + +model: !new:torch.nn.ModuleList + [[!ref , + !ref , + !ref , + !ref ]] + +# Loss functions +ce_loss: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.0 + allowed_len_diff: 0 + reduction: mean + +# Optimizers +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +# Schedulers +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: !ref + annealing_factor: !ref + patient: !ref + +# Dataloaders +train_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + shuffle: !apply:str.__eq__ [!ref , random] + +valid_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +test_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +# Performance metrics +ter_computer: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +bsseval_computer: !name:metrics.bsseval.BSSEval + n_sources: !ref + permutation_invariant: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/FUSS/separation/hparams/crdnn/train_wavtokenizer.yaml b/benchmarks/DASB/FUSS/separation/hparams/crdnn/train_wavtokenizer.yaml new file mode 100644 index 000000000..a3b15bd8d --- /dev/null +++ b/benchmarks/DASB/FUSS/separation/hparams/crdnn/train_wavtokenizer.yaml @@ -0,0 +1,233 @@ +# ########################################################################################### +# Model: CRDNN with WavTokenizer audio representations +# Authors: Luca Della Libera 2024 +# ########################################################################################### + +experiment_name: wavtokenizer + +# Seed needs to be set at top of YAML +seed: 0 +__set_seed: !apply:torch.manual_seed [!ref ] +testing: False # If set to True, the test evlaution is done, otherwise skipped. + +# Data preparation +data_folder: !PLACEHOLDER +train_csv: !ref /train.csv +valid_csv: !ref /validation.csv +test_csv: !ref /eval.csv +splits: [train, validation, eval] +num_speakers: 4 +add_noise: False +version: wav16k/min + +# Output folders +output_folder: !ref results// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Save options +compute_metrics: True +save_audios: True + +# Preprocessing parameters +train_remove_if_longer: 1000.0 # Seconds +valid_remove_if_longer: 1000.0 # Seconds +test_remove_if_longer: 1000.0 # Seconds +sorting: random +use_cache: True + +# Training parameters +num_epochs: 40 +grad_accumulation_factor: 16 +train_batch_size: 1 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 8 +nonfinite_patience: 10 +max_grad_norm: 5.0 +precision: fp32 +ckpt_interval_minutes: 6000 +keep_checkpoints: 1 +augment: False +augment_prob: 0.75 +use_pit: True + +# Optimizer parameters +lr: 0.0003578 # @orion_step1: --lr~"loguniform(0.00005,0.001)" +weight_decay: 0.01 +improvement_threshold: 0.0025 +annealing_factor: 0.9 +patient: 1 + +# EnCodec parameters +# sample_rate: [24000, 24000, 24000, 24000] +# vocab_size: [1024, 1024, 1024, 1024] +# num_codebooks: [2, 4, 8, 16, 32] +# bandwidth: [1.5, 3.0, 6.0, 12.0, 24.0] +sample_rate: 24000 +vocab_size: 4096 +num_codebooks: 1 +bandwidth: 2 + +# Embedding parameters +embedding_dim: 1024 +pretrain_embedding: False # If True, must match the codec's embedding size (128) +freeze_embedding: False + +# Encoder parameters +dropout: 0.1 +activation: !name:torch.nn.LeakyReLU +rnn_class: !name:speechbrain.nnet.RNN.LSTM +rnn_layers: 4 +time_pooling_size: 1 +rnn_bidirectional: True +rnn_neurons: 256 +dnn_blocks: 2 +dnn_neurons: 256 +cnn_blocks: 2 +cnn_channels: (16, 16) +inter_layer_pooling_size: (2, 2) +cnn_kernelsize: (3, 3) + +# Augmentation +drop_freq: !new:speechbrain.augment.time_domain.DropFreq + drop_freq_low: 0 # Min frequency band dropout probability + drop_freq_high: 1 # Max frequency band dropout probability + drop_freq_count_low: 1 # Min number of frequency bands to drop + drop_freq_count_high: 3 # Max number of frequency bands to drop + drop_freq_width: 0.05 # Width of frequency bands to drop + +drop_chunk: !new:speechbrain.augment.time_domain.DropChunk + drop_length_low: 1 # Min number of audio chunks to drop + drop_length_high: 5 # Max number of audio chunks to drop + drop_count_low: 1000 # Min length of audio chunks to drop + drop_count_high: 2000 # Max length of audio chunks to drop + +augmentation: !new:speechbrain.augment.augmenter.Augmenter + parallel_augment: False + concat_original: False + repeat_augment: 1 + shuffle_augmentations: False + min_augmentations: 2 + max_augmentations: 2 + augment_prob: !ref + augmentations: [!ref , !ref ] + +model_hub: novateur/WavTokenizer-medium-music-audio-75token +config: wavtokenizer_mediumdata_music_audio_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml +checkpoint: wavtokenizer_medium_music_audio_320_24k_v2.ckpt + +# Modules +codec: !new:speechbrain.lobes.models.discrete.wavtokenizer.WavTokenizer + source: !ref + save_path: !ref + checkpoint: !ref + config: !ref + sample_rate: !ref + freeze: True + +embedding: !new:model.custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + +attention_mlp: !new:model.custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.CRDNN.CRDNN + input_shape: [null, null, !ref ] + activation: !ref + dropout: !ref + cnn_blocks: !ref + cnn_channels: !ref + cnn_kernelsize: !ref + inter_layer_pooling_size: !ref + time_pooling: True + using_2d_pooling: False + time_pooling_size: !ref + rnn_class: !ref + rnn_layers: !ref + rnn_neurons: !ref + rnn_bidirectional: !ref + dnn_blocks: !ref + dnn_neurons: !ref + rnn_re_init: True + use_rnnp: False + +head: !new:torch.nn.Linear + in_features: !ref + out_features: !ref * * + +modules: + embedding: !ref + attention_mlp: !ref + encoder: !ref + head: !ref + +model: !new:torch.nn.ModuleList + [[!ref , + !ref , + !ref , + !ref ]] + +# Loss functions +ce_loss: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.0 + allowed_len_diff: 0 + reduction: mean + +# Optimizers +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +# Schedulers +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: !ref + annealing_factor: !ref + patient: !ref + +# Dataloaders +train_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + shuffle: !apply:str.__eq__ [!ref , random] + +valid_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +test_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +# Performance metrics +ter_computer: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +bsseval_computer: !name:metrics.bsseval.BSSEval + n_sources: !ref + permutation_invariant: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/FUSS/separation/metrics/bsseval.py b/benchmarks/DASB/FUSS/separation/metrics/bsseval.py new file mode 100644 index 000000000..3b9222313 --- /dev/null +++ b/benchmarks/DASB/FUSS/separation/metrics/bsseval.py @@ -0,0 +1,153 @@ +import numpy as np +import torch +import json +from fast_bss_eval import bss_eval_sources +from speechbrain.utils.metric_stats import MetricStats + + +__all__ = ["BSSEval"] + + +class BSSEval(MetricStats): + def __init__( + self, n_sources, source_names=None, permutation_invariant=True + ): + """ + A subclass of MetricStats for evaluating source separation algorithms. + + Args: + n_sources (int): Number of sources to evaluate. + source_names (list, optional): Names of the sources. Defaults to None. + permutation_invariant (bool): Whether to apply permutation invariance when matching sources. + """ + self.n_sources = n_sources + self.source_names = source_names or [ + f"Source {i + 1}" for i in range(n_sources) + ] + self.permutation_invariant = permutation_invariant + + # Initialize storage for metrics + self.metrics = dict() + + def compute_metrics(self, reference_sources, estimated_sources): + """ + Computes SDR, SIR, and SAR for the given reference and estimated sources. + + Args: + reference_sources (ndarray): Array of ground truth sources (shape: [n_sources, n_samples]). + estimated_sources (ndarray): Array of estimated sources (shape: [n_sources, n_samples]). + + Returns: + dict: A dictionary containing SDR, SIR, and SAR values for each source. + """ + # Define epsilon + epsilon = 1e-10 + # Identify rows that are all zeros + is_all_zeros = torch.all(reference_sources == 0, axis=1) + + # Create a mask to add epsilon only to all-zero rows + reference_sources[is_all_zeros] += epsilon + try: + sdr, sir, sar, perm = bss_eval_sources( + reference_sources, + estimated_sources, + compute_permutation=self.permutation_invariant, + load_diag=1e-5, + ) + is_all_zeros = is_all_zeros[ + perm + ] # Apply permutation to silent mask + sdr_mean = sdr[~is_all_zeros].mean().detach().cpu().numpy().item() + sir_mean = sir[~is_all_zeros].mean().detach().cpu().numpy().item() + sar_mean = sar[~is_all_zeros].mean().detach().cpu().numpy().item() + except Exception as e: + print(f"Exception occured when computing BBSEval: {e}", flush=True) + sdr_mean, sir_mean, sar_mean = np.nan, np.nan, np.nan + return {"SDR": sdr_mean, "SIR": sir_mean, "SAR": sar_mean} + + def add( + self, + reference_sources: torch.Tensor, + estimated_sources: torch.Tensor, + tag: str = None, + ): + """ + Adds the metrics for a single evaluation instance. + + Args: + reference_sources (tensor): Array of ground truth sources (shape: [n_sources, n_samples]). + estimated_sources (tensor): Array of estimated sources (shape: [n_sources, n_samples]). + """ + # Ensure inputs are numpy arrays + reference_sources = reference_sources.squeeze() + estimated_sources = estimated_sources.squeeze() + + # Validate input shapes + assert ( + reference_sources.shape[0] == self.n_sources + ), "Mismatch in number of reference sources." + assert ( + estimated_sources.shape[0] == self.n_sources + ), "Mismatch in number of estimated sources." + + # Compute metrics + metrics = self.compute_metrics(reference_sources, estimated_sources) + + # Store metrics + for key, values in metrics.items(): + if tag is not None: + key = f"{key}/{tag}" + self.metrics.setdefault(key, []).append(values) + + def summarize(self): + """ + Summarizes the collected metrics. + + Returns: + dict: A dictionary containing mean and standard deviation for each metric. + """ + summary = {} + for metric, values in self.metrics.items(): + values = np.array(values) + values = values[~np.isinf(values)] + summary[metric] = { + "mean": np.nanmean(values, axis=0).tolist(), + "std": np.nanstd(values, axis=0).tolist(), + } + + return summary + + def pretty_print(self): + """ + Prints the summarized metrics in a human-readable format. + """ + summary = self.summarize() + print("Source Separation Evaluation Results:") + for metric, stats in summary.items(): + print(f"\n{metric}:") + for i, source_name in enumerate(self.source_names): + print( + f" {source_name}: Mean = {stats['mean'][i]:.2f}, Std = {stats['std'][i]:.2f}" + ) + + def write_stats(self, path): + results = self.summarize() + with open(path, "w") as outfile: + json.dump(results, outfile, indent=4) + + +if __name__ == "__main__": + n_sources = 2 + source_names = ["Vocals", "Accompaniment"] + stats = BSSEval( + n_sources=n_sources, + source_names=source_names, + permutation_invariant=True, + ) + + # Example ground truth and estimated sources + ref_sources = np.random.randn(n_sources, 10000) + est_sources = np.random.randn(n_sources, 10000) + + stats.add(ref_sources, est_sources) + stats.pretty_print() diff --git a/benchmarks/DASB/FUSS/separation/train.py b/benchmarks/DASB/FUSS/separation/train.py new file mode 100644 index 000000000..817a68c3f --- /dev/null +++ b/benchmarks/DASB/FUSS/separation/train.py @@ -0,0 +1,453 @@ +#!/usr/bin/env/python + +"""Recipe for training a transformer-based speech separation system using EnCodec audio representations. + +To run this recipe: +> python train_encodec.py hparams/.yaml + +Authors + * Luca Della Libera 2024 +""" + +import os +import sys +import warnings +import logging + +import speechbrain as sb +import torch +from hyperpyyaml import load_hyperpyyaml +from speechbrain.dataio.dataio import write_audio +from speechbrain.utils.distributed import if_main_process, run_on_main + +from utils import ( + EncodecHelper, + DacHelper, + SQCodecHelper, + WavTokenizerHelper, +) + + +base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) +sys.path.append(base_dir) +base_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../../model") +) +sys.path.append(base_dir) + + +logger = logging.getLogger(__name__) + + +_CACHE = {} + + +class Separation(sb.Brain): + def __init__( + self, + modules=None, + opt_class=None, + hparams=None, + run_opts=None, + checkpointer=None, + ): + super().__init__(modules, opt_class, hparams, run_opts, checkpointer) + + # Read tokenizer type from hparams + tokenizer_type = self.hparams.codec.__class__.__name__ + self.encdec = self._get_encdec_helper(tokenizer_type) + + def _get_encdec_helper(self, tokenizer_type): + if tokenizer_type == "Encodec": + return EncodecHelper(self.hparams.codec, self.device) + elif tokenizer_type == "DAC": + return DacHelper( + self.hparams.codec, self.device, self.hparams.num_codebooks + ) + elif tokenizer_type == "SQCodec": + return SQCodecHelper(self.hparams.codec, self.device) + elif tokenizer_type == "WavTokenizer": + return WavTokenizerHelper(self.hparams.codec, self.device) + else: + raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}") + + @torch.no_grad() + def sig_to_toks(self, sig, lens): + return self.encdec.sig_to_toks(sig, lens) + + @torch.no_grad() + def toks_to_sig(self, toks): + return self.encdec.toks_to_sig(toks) + + def compute_forward(self, batch, stage): + """Forward pass.""" + batch = batch.to(self.device) + in_sig, in_lens = batch.in_sig # [B, T] + out_sig, out_lens = batch.out_sig # [B, ST] + + # Unflatten + out_sig = out_sig.reshape( + len(out_sig), self.hparams.num_speakers, -1 + ).flatten( + end_dim=-2 + ) # [BS, T] + batch.out_sig = out_sig, out_lens + + # Augment if specified + if stage == sb.Stage.TRAIN and self.hparams.augment: + in_sig, in_lens = self.hparams.augmentation(in_sig, in_lens) + + # Extract tokens (cache them at first epoch if augmentation is disabled) + key = tuple(sorted(batch.id)) + try: + in_toks, out_toks = _CACHE[key] + in_toks = in_toks.to(self.device) + out_toks = out_toks.to(self.device) + except KeyError: + assert (in_lens == out_lens).all() + sig = torch.cat([in_sig, out_sig]) # [B(1 + S), T] + lens = torch.cat( + [ + in_lens, + out_lens.repeat_interleave(self.hparams.num_speakers), + ] + ) # [B(1 + S), T] + toks = self.sig_to_toks(sig, lens) # [B(1 + S), N, K] + in_toks, out_toks = toks.split( + [len(in_sig), len(out_sig)] + ) # [B, N, K], [BS, N, K] + out_toks = out_toks.reshape( + len(in_sig), + self.hparams.num_speakers, + -1, + self.hparams.num_codebooks, + ).movedim( + -2, -3 + ) # [B, N, S, K] + if self.hparams.use_cache and (not self.hparams.augment): + _CACHE[key] = in_toks.cpu(), out_toks.cpu() + + # Avoid in-place modification from embedding layer + in_toks = in_toks.clone() + + # Forward embedding + attention + in_embs = self.modules.embedding(in_toks) # [B, N, K, H] + # Get merged embedding based on strategy set, deafualt ATT_Pooling + if ( + hasattr(self.hparams, "embedding_strg") + and self.hparams.embedding_strg == "concat" + ): + B, T, N_Q, D = in_embs.shape + in_embs = in_embs.view(B, T, N_Q * D) + + else: + att_w = self.modules.attention_mlp(in_embs) # [B, N, K, 1] + in_embs = torch.matmul(att_w.transpose(2, -1), in_embs).squeeze( + -2 + ) # [B, N, H] + + # Forward encoder + if hasattr(self.modules.encoder, "encode"): + hyp_embs = self.modules.encoder.encode( + in_embs, in_lens + ) # [B, N, H] + else: + hyp_embs = self.modules.encoder(in_embs) # [B, N, H] + + # Forward head + log_probs = ( + self.modules.head(hyp_embs) + .reshape( + len(hyp_embs), + -1, + self.hparams.num_speakers, + self.hparams.num_codebooks, + self.hparams.vocab_size, + ) + .log_softmax(dim=-1) + ) # [B, N, S, K, C] + return log_probs, out_toks + + def compute_objectives(self, predictions, batch, stage): + """Computes the objectives.""" + log_probs, out_toks = predictions # [B, N, S, K, C], [B, N, S, K] + + IDs = batch.id + in_sig, _ = batch.in_sig + out_sig, out_lens = batch.out_sig + + if not self.hparams.use_pit: + # Cross-entropy loss + loss = self.hparams.ce_loss( + log_probs.flatten(start_dim=1, end_dim=3), # [B, NSK, C] + out_toks.flatten(start_dim=1), # [B, NSK] + length=out_lens, + ) + else: + # Permutation invariant training + from speechbrain.nnet.losses import PitWrapper + + def base_loss(preds, targets): + # preds: [N, K, C, S, S] + # targets: [N, K, S, S] + preds = preds.permute(3, 4, 0, 1, 2) # [S, S, N, K, C] + targets = targets.permute(2, 3, 0, 1) # [S, S, N, K] + loss = self.hparams.ce_loss( + preds.flatten(end_dim=-2), + targets.flatten(), + reduction="none", + ) # [SSNK] + loss = loss.reshape_as(targets) + loss = loss.permute(2, 3, 0, 1) # [N, K, S, S] + return loss + + log_probs = log_probs.movedim(2, -1) # [B, N, K, C, S] + out_toks = out_toks.movedim(2, -1) # [B, N, K, S] + pit_loss = PitWrapper(base_loss) + log_probs_list = [ + x[: int(l * log_probs.shape[1])] + for x, l in zip(log_probs, out_lens) + ] + out_toks_list = [ + x[: int(l * out_toks.shape[1])] + for x, l in zip(out_toks, out_lens) + ] + loss, perm = pit_loss(log_probs_list, out_toks_list) + loss = loss.mean() + log_probs = pit_loss.reorder_tensor(log_probs, perm) + log_probs = log_probs.movedim(-1, 2) # [B, N, S, K, C] + out_toks = out_toks.movedim(-1, 2) # [B, N, S, K] + + # Compute TER + if stage != sb.Stage.TRAIN: + self.ter_metric.append( + IDs, + log_probs.flatten(start_dim=1, end_dim=3), + out_toks.flatten(start_dim=1), + out_lens, + ) + + # Vocode + if stage in [sb.Stage.TEST] and self.hparams.compute_metrics: + hyp_toks = log_probs.argmax(dim=-1) # [B, N, S, K] + hyp_sig, rec_sig, out_sig = self.vocode( + IDs, in_sig, out_sig, hyp_toks, out_toks, out_lens + ) + self.bsseval_metric.add(out_sig, hyp_sig, tag="clean-hyp") + self.bsseval_metric.add(out_sig, rec_sig, tag="clean-rec") + self.bsseval_metric.add(rec_sig, hyp_sig, tag="rec-hyp") + self.bsseval_metric.add( + out_sig, + in_sig.unsqueeze(1).repeat(1, self.hparams.num_speakers, 1), + tag="clean-mix", + ) + + return loss + + @torch.no_grad() + def vocode(self, IDs, in_sig, out_sig, hyp_toks, out_toks, lens): + hyp_toks = hyp_toks.movedim(-2, -3).contiguous() # [B, S, N, K] + out_toks = out_toks.movedim(-2, -3).contiguous() # [B, S, N, K] + + hyp_sig = self.toks_to_sig( + hyp_toks.flatten(end_dim=1) # [BS, N, K] + ) # [BS, T] + rec_sig = self.toks_to_sig( + out_toks.flatten(end_dim=1) # [BS, N, K] + ) # [BS, T] + # Adjust length + if out_sig.shape[-1] > hyp_sig.shape[-1]: + pad = [0, out_sig.shape[-1] - hyp_sig.shape[-1]] + hyp_sig = torch.nn.functional.pad( + hyp_sig, pad, mode="replicate" + ) # [BS, T_out] + rec_sig = torch.nn.functional.pad( + rec_sig, pad, mode="replicate" + ) # [BS, T_out] + elif out_sig.shape[-1] < hyp_sig.shape[-1]: + hyp_sig = hyp_sig.narrow(-1, 0, out_sig.shape[-1]) # [BS, T_out] + rec_sig = rec_sig.narrow(-1, 0, out_sig.shape[-1]) # [BS, T_out] + + hyp_sig = hyp_sig.reshape(len(hyp_toks), -1) # [B, ST_out] + rec_sig = rec_sig.reshape(len(hyp_toks), -1) # [B, ST_out] + out_sig = out_sig.reshape(len(hyp_toks), -1) # [B, ST_out] + + if self.hparams.save_audios: + save_folder = os.path.join(self.hparams.output_folder, "audios") + os.makedirs(save_folder, exist_ok=True) + for i in range(len(IDs)): + write_audio( + os.path.join(save_folder, f"{IDs[i]}_hyp.wav"), + hyp_sig[i].cpu(), + self.hparams.sample_rate, + ) + write_audio( + os.path.join(save_folder, f"{IDs[i]}_rec.wav"), + rec_sig[i].cpu(), + self.hparams.sample_rate, + ) + write_audio( + os.path.join(save_folder, f"{IDs[i]}_ref.wav"), + out_sig[i].cpu(), + self.hparams.sample_rate, + ) + write_audio( + os.path.join(save_folder, f"{IDs[i]}_in.wav"), + in_sig[i].cpu(), + self.hparams.sample_rate, + ) + hyp_sig = hyp_sig.reshape( + len(IDs), self.hparams.num_speakers, -1 + ) # [B, S, T_out] + rec_sig = rec_sig.reshape( + len(IDs), self.hparams.num_speakers, -1 + ) # [B, S, T_out] + out_sig = out_sig.reshape( + len(IDs), self.hparams.num_speakers, -1 + ) # [B, S, T_out] + return hyp_sig, rec_sig, out_sig + + def on_stage_start(self, stage, epoch=None): + """Gets called at the beginning of each epoch.""" + super().on_stage_start(stage, epoch) + if ( + stage in [sb.Stage.TEST, sb.Stage.VALID] + and self.hparams.compute_metrics + ): + self.bsseval_metric = self.hparams.bsseval_computer() + self.ter_metric = self.hparams.ter_computer() + + def on_stage_end(self, stage, stage_loss, epoch=None): + """Gets called at the end of each epoch.""" + # Compute/store important stats + stage_stats = {"loss": stage_loss} + + if stage == sb.Stage.TRAIN: + self.train_stats = stage_stats + # self.checkpointer.save_and_keep_only() + else: + stage_stats["TER"] = self.ter_metric.summarize("average") * 100 + + # Perform end-of-iteration operations, like annealing, logging, etc. + if stage == sb.Stage.VALID: + _, lr = self.hparams.scheduler(stage_stats["TER"]) + sb.nnet.schedulers.update_learning_rate(self.optimizer, lr) + steps = self.optimizer_step + self.hparams.train_logger.log_stats( + stats_meta={"epoch": epoch, "lr": lr, "steps": steps}, + train_stats=self.train_stats, + valid_stats=stage_stats, + ) + self.checkpointer.save_and_keep_only( + meta={"TER": stage_stats["TER"], "epoch": epoch}, + min_keys=["TER"], + num_to_keep=self.hparams.keep_checkpoints, + keep_recent=False, + ) + + elif stage == sb.Stage.TEST: + if self.hparams.compute_metrics: + stage_stats["BSSEval"] = self.bsseval_metric.summarize() + self.hparams.train_logger.log_stats( + stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, + test_stats=stage_stats, + ) + if if_main_process(): + # Save dWER + if self.hparams.compute_metrics: + self.bsseval_metric.write_stats(self.hparams.bsseval_file) + + +if __name__ == "__main__": + # Command-line interface + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # Filter warnings + warnings.filterwarnings("once") + warnings.filterwarnings("ignore", module="torch") + + # If --distributed_launch then create ddp_init_group with the right communication protocol + sb.utils.distributed.ddp_init_group(run_opts) + + # Create experiment directory + sb.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + # Dataset preparation + from fuss_prepare import prepare_fuss as prepare_data + + prepare_data_kwargs = { + "data_folder": hparams["data_folder"], + "save_folder": hparams["save_folder"], + "splits": hparams["splits"], + } + + run_on_main(prepare_data, kwargs=prepare_data_kwargs) + + # Create the datasets objects + from utils import dataio_prepare + + train_data, valid_data, test_data = dataio_prepare( + debug=run_opts.get("debug", False), **hparams + ) + + # Pretrain the specified modules + if "pretrainer" in hparams: + run_on_main(hparams["pretrainer"].collect_files) + run_on_main(hparams["pretrainer"].load_collected) + + # Use pretrained embeddings + if hparams["pretrain_embedding"]: + embs = hparams["codec"].vocabulary.reshape(-1, hparams["embedding_dim"]) + hparams["embedding"].embedding.weight.data.copy_(embs) + + # Log number of parameters/buffers + codec_params = sum( + [x.numel() for x in hparams["codec"].state_dict().values()] + ) + model_params = sum( + [ + x.numel() + for module in hparams["modules"].values() + for x in module.state_dict().values() + ] + ) + hparams["train_logger"].log_stats( + stats_meta={ + f"Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", + "Model parameters/buffers (M)": f"{model_params / 1e6:.2f}", + }, + ) + + # Trainer initialization + brain = Separation( + modules=hparams["modules"], + opt_class=hparams["opt_class"], + hparams=hparams, + run_opts=run_opts, + checkpointer=hparams["checkpointer"], + ) + + # Train + brain.fit( + brain.hparams.epoch_counter, + train_data, + valid_data, + train_loader_kwargs=hparams["train_dataloader_kwargs"], + valid_loader_kwargs=hparams["valid_dataloader_kwargs"], + ) + + # Test + if hparams["testing"]: + # Testing + brain.hparams.bsseval_file = os.path.join( + hparams["output_folder"], "bsseval.txt" + ) + brain.evaluate( + test_data, test_loader_kwargs=hparams["test_dataloader_kwargs"], + ) diff --git a/benchmarks/DASB/FUSS/separation/utils.py b/benchmarks/DASB/FUSS/separation/utils.py new file mode 100644 index 000000000..bb843dabb --- /dev/null +++ b/benchmarks/DASB/FUSS/separation/utils.py @@ -0,0 +1,359 @@ +"""Common utilities. + +Authors + * Luca Della Libera 2024 +""" + +import os + +import speechbrain as sb +import torch +import torchaudio +from speechbrain.dataio.dataio import merge_csvs +from transformers.models.hubert.modeling_hubert import ( + HubertEncoderStableLayerNorm, +) +from transformers.models.wav2vec2.modeling_wav2vec2 import ( + Wav2Vec2EncoderStableLayerNorm, +) +from transformers.models.wavlm.modeling_wavlm import WavLMEncoderStableLayerNorm + + +__all__ = ["SBWav2Vec2ForwardWrapper", "dataio_prepare"] + +CHUNK = 10.0 + + +class EncodecHelper: + def __init__(self, codec, device): + self.codec = codec + self.device = device + + @torch.no_grad() + def sig_to_toks(self, sig, lens): + self.codec.to(self.device).eval() + toks, _ = self.codec.encode(sig.unsqueeze(1), lens) # [B, N, K] + return toks + + @torch.no_grad() + def toks_to_sig(self, toks): + self.codec.to(self.device).eval() + sig = self.codec.decode(toks)[:, 0] # [B, T] + return sig + + +class DacHelper: + def __init__(self, codec, device, num_codebooks): + self.codec = codec + self.device = device + self.num_codebooks = num_codebooks + + @torch.no_grad() + def sig_to_toks(self, sig, lens): + self.codec.to(self.device).eval() + toks, _ = self.codec( + sig[:, None], n_quantizers=self.num_codebooks + ) # [B, K, N] + toks = toks.movedim(-1, -2) # [B, N, K] + return toks + + @torch.no_grad() + def toks_to_sig(self, toks): + self.codec.to(self.device).eval() + qfeats, _, _ = self.codec.quantizer.from_codes( + toks.movedim(-1, -2) + ) # [B, K, N] -> [B, K, N] + sig = self.codec.decode(qfeats)[:, 0] # [B, T] + return sig + + +class SQCodecHelper: + def __init__(self, codec, device): + self.codec = codec + self.device = device + + @torch.no_grad() + def sig_to_toks(self, sig, lens): + # sig: [B, T] + self.codec.to(self.device).eval() + toks, _ = self.codec.encode(sig[:, None]) # [B, K * N] + K = self.codec.n_codebook + N = toks.shape[-1] // K + toks = self._unflatten_codebooks(toks, N, K) # [B, N, K] + return toks + + def _flatten_codebooks(self, arr): + assert ( + len(arr.shape) == 3 + ), "Input array must have 3 dimensions [B, N, K]" + N, B, K = arr.shape + arr = arr.clone() + flattened_arr = arr.permute(1, 2, 0).reshape(B, N * K) + return flattened_arr + + def _unflatten_codebooks(self, flat_arr, N, K): + # flat_arr: [B, N * K] + B = flat_arr.shape[0] + return flat_arr.reshape(B, N, K) + + @torch.no_grad() + def toks_to_sig(self, toks): + toks = toks.permute(2, 0, 1) # [B, N, K] -> [K, B, N] + flat_toks = self._flatten_codebooks(toks).to(torch.int32) + sig = self.codec.decode(flat_toks).squeeze(1) # [B, T] + return sig.to(toks.device) + + +class WavTokenizerHelper: + def __init__(self, codec, device): + self.codec = codec + self.device = device + + @torch.no_grad() + def sig_to_toks(self, sig, lens): + self.codec.to(self.device).eval() + toks, _ = self.codec.encode(sig) # [B, K, N] + toks = toks.permute(0, 2, 1) # [B, N, K] + return toks + + @torch.no_grad() + def toks_to_sig(self, toks): + self.codec.to(self.device).eval() + toks = toks.movedim(-1, -2) # [B, N, K] -> [B, K, N] + sig = self.codec.decode(toks) # [B, T] + return sig.clone() + + +class SBWav2Vec2ForwardWrapper(torch.nn.Module): + """SpeechBrain wav2vec 2.0 wrapper that returns the hidden representations from the specified layer IDs. + + Arguments + --------- + wav2vec2: + The SpeechBrain wav2vec 2.0 module. + layer_ids: + The layer IDs from which the hidden representations are extracted. + + Examples + -------- + >>> import torch + >>> from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE + >>> from speechbrain.lobes.models.huggingface_transformers.wavlm import WavLM + >>> + >>> encoder = WavLM(source="microsoft/wavlm-large", save_path=HUGGINGFACE_HUB_CACHE) + >>> encoder = SBWav2Vec2ForwardWrapper(encoder, layer_ids=[6, 7]) + >>> + >>> input = torch.rand([10, 16000]) + >>> length = torch.ones(10) + >>> output = encoder(input, length) + + """ + + def __init__(self, wav2vec2, layer_ids): + super().__init__() + self.wav2vec2 = wav2vec2 + # Workaround to deal with hardcoded class name in discrete SSL + # https://github.com/speechbrain/speechbrain/blob/60062c2536e8122253d6ad0e681208f554528950/speechbrain/lobes/models/huggingface_transformers/discrete_ssl.py#L88 + self.__class__.__name__ = self.wav2vec2.__class__.__name__ + self.layer_ids = sorted(layer_ids) + assert hasattr(self.wav2vec2, "model") + assert hasattr(self.wav2vec2.model, "encoder") + assert hasattr(self.wav2vec2.model.encoder, "layers") + # Workaround for early exiting to avoid the computational overhead of forwarding through the whole model + # NOTE: the model is modified in-place + self.wav2vec2.output_all_hiddens = True + self.wav2vec2.model.encoder.layers = self.wav2vec2.model.encoder.layers[ + : max(self.layer_ids) + ] + # NOTE: workaround to account for layer norm applied to the last hidden states when StableLayerNorm variant is used: + # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/wavlm/modeling_wavlm.py#L816 + if isinstance( + self.wav2vec2.model.encoder, + ( + HubertEncoderStableLayerNorm, + Wav2Vec2EncoderStableLayerNorm, + WavLMEncoderStableLayerNorm, + ), + ): + self.wav2vec2.model.encoder.layer_norm = torch.nn.Identity() + + def extract_features(self, wav, length=None): + feats = self.wav2vec2(wav, length) # (K, B, N, H) + return feats + + def forward(self, wav, length=None): + return self.extract_features(wav, length) + + +def dataio_prepare( + data_folder, + train_csv, + valid_csv, + test_csv, + sample_rate=16000, + train_remove_if_longer=60.0, + valid_remove_if_longer=60.0, + test_remove_if_longer=60.0, + sorting="ascending", + debug=False, + **hparams, +): + """This function prepares the datasets to be used in the brain class. + It also defines the data processing pipeline through user-defined functions. + + """ + if isinstance(train_csv, (list, tuple)): + csvs = [os.path.basename(x) for x in train_csv] + save_folder = os.path.dirname(train_csv[0]) + merge_csvs( + save_folder, csvs, "train.csv", + ) + train_csv = os.path.join(save_folder, "train.csv") + + if isinstance(valid_csv, (list, tuple)): + csvs = [os.path.basename(x) for x in valid_csv] + save_folder = os.path.dirname(valid_csv[0]) + merge_csvs( + save_folder, csvs, "valid.csv", + ) + valid_csv = os.path.join(save_folder, "valid.csv") + + if isinstance(test_csv, (list, tuple)): + csvs = [os.path.basename(x) for x in test_csv] + save_folder = os.path.dirname(test_csv[0]) + merge_csvs( + save_folder, csvs, "test.csv", + ) + test_csv = os.path.join(save_folder, "test.csv") + + train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=train_csv, replacements={"DATA_ROOT": data_folder}, + ) + # Sort training data to speed up training + train_data = train_data.filtered_sorted( + sort_key="duration", + reverse=sorting == "descending", + key_max_value={"duration": train_remove_if_longer}, + ) + + valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=valid_csv, replacements={"DATA_ROOT": data_folder}, + ) + # Sort validation data to speed up validation + valid_data = valid_data.filtered_sorted( + sort_key="duration", + reverse=not debug, + key_max_value={"duration": valid_remove_if_longer}, + ) + + test_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=test_csv, replacements={"DATA_ROOT": data_folder}, + ) + # Sort the test data to speed up testing + test_data = test_data.filtered_sorted( + sort_key="duration", + reverse=not debug, + key_max_value={"duration": test_remove_if_longer}, + ) + + # train_data = train_data.overfit_test(32, 32) + # valid_data = valid_data.overfit_test(8, 8) + # test_data = test_data.overfit_test(3, 3) + + datasets = [train_data, valid_data, test_data] + + # Define audio pipeline + takes = [ + "mixture_wav", + "background0_sound_wav", + "foreground0_sound_wav", + "foreground1_sound_wav", + "foreground2_sound_wav", + ] + provides = ["in_sig", "out_sig"] + + def audio_pipeline(mix_wav, *src_wavs): + # Mixed signal + try: + original_sample_rate = sb.dataio.dataio.read_audio_info( + mix_wav + ).sample_rate + # total_frames = sb.dataio.dataio.read_audio_info( + # mix_wav + # ).num_frames + + # start = randint(0, total_frames - int(CHUNK * original_sample_rate)) + + # Source signals + src_sigs = [] + for src_wav in src_wavs: + assert ( + original_sample_rate + == sb.dataio.dataio.read_audio_info(src_wav).sample_rate + ) + src_sig = sb.dataio.dataio.read_audio( + dict(file=src_wav) + ) # ,start=start, stop=start + int(CHUNK * original_sample_rate))) + src_sigs.append(src_sig) + src_sigs = torch.stack(src_sigs) # [S, T] + max_vals = torch.max(torch.abs(src_sigs), dim=1, keepdim=True)[ + 0 + ] # Find peak per source item + src_sigs = torch.where( + max_vals > 0, src_sigs / max_vals, src_sigs + ) # Normalize only non-silent signals + + out_sig = torchaudio.functional.resample( + src_sigs, original_sample_rate, sample_rate, + ) + in_sig = out_sig.sum(0) # [T] + except Exception as e: + print(e) + yield in_sig + + # Flatten as SpeechBrain's dataloader does not support multichannel audio + out_sig = out_sig.flatten() # [S * T] + yield out_sig + + sb.dataio.dataset.add_dynamic_item( + [train_data, valid_data, test_data], audio_pipeline, takes, provides + ) + + # Set output + sb.dataio.dataset.set_output_keys(datasets, ["id"] + provides) + + return train_data, valid_data, test_data + + +if __name__ == "__main__": + from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE + from speechbrain.lobes.models.huggingface_transformers.wav2vec2 import ( + Wav2Vec2, + ) + + for source in [ + "facebook/wav2vec2-large-960h-lv60-self", + "facebook/hubert-large-ll60k", + "microsoft/wavlm-large", + ]: + layer_ids = [3, 7] + encoder1 = Wav2Vec2( + source=source, save_path=HUGGINGFACE_HUB_CACHE, output_norm=True, + ) + encoder1 = SBWav2Vec2ForwardWrapper( + encoder1, layer_ids=layer_ids + ).eval() + + encoder2 = Wav2Vec2( + source=source, + save_path=HUGGINGFACE_HUB_CACHE, + output_norm=True, + output_all_hiddens=True, + ).eval() + + input = torch.ones([1, 16000]) + with torch.no_grad(): + output1 = encoder1(input)[layer_ids] + output2 = encoder2(input)[layer_ids] + + print((output1 == output2).all()) diff --git a/benchmarks/DASB/MUSDB/README.md b/benchmarks/DASB/MUSDB/README.md new file mode 100644 index 000000000..8f339ef1f --- /dev/null +++ b/benchmarks/DASB/MUSDB/README.md @@ -0,0 +1,65 @@ +# MUSDB Separation Task + +This folder defines the **MUSDB source separation benchmark** within DASB (Discrete Audio Separation Benchmark). It enables evaluating discrete audio representations on **music source separation**, using the [MUSDB-18 dataset](https://sigsep.github.io/datasets/musdb.html#musdb18-hq-uncompressed-wav). + +## Overview + +The goal of this task is to perform **source separation** from musical mixtures containg `[bass, drums, others, vocals]` source types, + +This benchmark supports: +- Preparing the MUSDB dataset for **supervised training and evaluation** +- Running separation experiments using various discrete audio codecs and backbones (namely Conformer and CRDNN) +- Computing standard evaluation metrics (e.g., SDR, SIR, SAR) using the Fast-BSSEval library. + +--- + +## Directory Structure + +``` +MUSDB +├── create_musdb.py # Generates chunked training data from raw MUSDB mixtures +├── README.md +├── separation +│   ├── musdb_prepare.py # Prepare MUSDB dataset for supervised SS training +│   ├── train.py # Unified training script for all MUSDB experiments +│   ├── utils.py # Audio I/O and utility functions +│   ├── hparams +│   │   ├── conformer +│   │   │   ├── train_dac.yaml # config recipe for Conformer + ... +│   │   └── crdnn +│   │   ├── train_dac.yaml # config recipe for CRDNN + ... +│   └── metrics +│   └── bsseval.py # BSSEval implementation (SDR, SIR, SAR) +└── experiments +``` + + +--- + +## Setup + +**Install dependencies:** + +You may need additional packages for separation and evaluation: +```bash +pip install -r ../extra_requirements.txt +``` + +--- + +## Data Preparation + +- Download the raw MUSDB dataset: Please follow the instructions from the [official project page](https://sigsep.github.io/datasets/musdb.html#musdb18-hq-uncompressed-wav) to download the data locally +- Unpack it into a directory `` +- To validate the data and create the MUSDB mixtures, you'll need to run `create_musdb.py`. That will create mixtures for all three sets, namely `['train']`. Note that this script only creates the `train` set and necessitate a `--num_chunks` argument which will define the number of random chunks to take per track. For `eval` and `validation`, please run `create_musdb_eval.py` instead. +- Lastly, to create the `.csv` manifests, run `separation/musdb_prepare.py` + +## Running Separation Experiment + +```python +python MUSDB/separation/train.py MUSBD/separation/hparams/conformer/train_encodec.yaml \ + --data_folder= \ + --output_folder=MUSDB/experiments +``` \ No newline at end of file diff --git a/benchmarks/DASB/MUSDB/create_musdb.py b/benchmarks/DASB/MUSDB/create_musdb.py new file mode 100644 index 000000000..aeebc5641 --- /dev/null +++ b/benchmarks/DASB/MUSDB/create_musdb.py @@ -0,0 +1,188 @@ +import os +import argparse +import numpy as np +import soundfile as sf +from tqdm import tqdm +from copy import copy +from concurrent.futures import ProcessPoolExecutor + + +def apply_random_gain(audio, min_gain=0.25, max_gain=1.25): + """ + Apply a random gain to a numpy array representing an audio signal. + + Args: + audio (numpy.ndarray): Input audio signal. + min_gain (float): Minimum gain value. + max_gain (float): Maximum gain value. + + Returns: + numpy.ndarray: Audio signal with random gain applied. + """ + gain = np.random.uniform(min_gain, max_gain) + return audio * gain + + +def ensure_audio_files(directory, required_files): + """ + Ensure all required audio files exist in a directory. + """ + required_paths = { + file: os.path.join(directory, file) for file in required_files + } + if not all(os.path.exists(p) for p in required_paths.values()): + print(f"Error: Missing files in {directory}. Cannot proceed.") + return False + return True + + +def random_chunk_indices(total_samples, chunk_samples, num_chunks): + """ + Generate random start indices for chunks within the range of the audio length. + Ensures chunks do not exceed the total length. + """ + max_start = total_samples - chunk_samples + if max_start <= 0: + return [ + 0 + ] * num_chunks # Only one possible chunk if audio is shorter than chunk size + return np.random.randint(0, max_start + 1, size=num_chunks) + + +def process_track( + split, track, track_path, target_dir, chunk_size, num_chunks, required_files +): + """ + Process a single track by randomly sampling chunks and saving them. + """ + if not ensure_audio_files(track_path, required_files): + return + + audio_data = {} + sample_rate = None + total_samples = None + + # Load all required files and convert to mono if needed + for file in required_files: + file_path = os.path.join(track_path, file) + audio, sr = sf.read(file_path) + if len(audio.shape) == 2: + audio = np.mean(audio, axis=1) + if sample_rate is None: + sample_rate = sr + if total_samples is None: + total_samples = len(audio) + audio_data[file] = audio + + chunk_samples = int(chunk_size * sample_rate) + start_indices = random_chunk_indices( + total_samples, chunk_samples, num_chunks + ) + + # Save randomly sampled chunks + for i, start in enumerate(start_indices): + end = start + chunk_samples + chunk_sum = np.zeros(chunk_samples) # Initialize for mixture + for file in required_files: + chunk = copy(audio_data[file][start:end]) + if not split == "eval": + chunk = apply_random_gain(chunk) + chunk_sum += chunk # Add to mixture + + new_track_name = f"{track.replace(' ', '_')}_chunk{i:02d}" + save_path = os.path.join(target_dir, split, new_track_name) + os.makedirs(save_path, exist_ok=True) + sf.write(os.path.join(save_path, file), chunk, sample_rate) + + # Save the computed mixture + sf.write(os.path.join(save_path, "mixture.wav"), chunk_sum, sample_rate) + + +def process_tracks(root_dir, target_dir, chunk_size, num_chunks, max_workers=4): + """ + Processes the dataset by randomly sampling chunks from each track using parallel processing. + Args: + root_dir (str): Root directory containing the dataset. + target_dir (str): Target directory to save the processed chunks. + chunk_size (int): Size of each chunk in seconds. + num_chunks (int): Number of random chunks per track. + max_workers (int): Maximum number of parallel workers. + """ + required_files = ["bass.wav", "drums.wav", "other.wav", "vocals.wav"] + tasks = [] + + with ProcessPoolExecutor(max_workers=max_workers) as executor: + for split in ["train"]: # os.listdir(root_dir): + split_path = os.path.join(root_dir, split) + if not os.path.isdir(split_path): + continue + + for track in os.listdir(split_path): + track_path = os.path.join(split_path, track) + if not os.path.isdir(track_path): + continue + tasks.append( + executor.submit( + process_track, + split, + track, + track_path, + target_dir, + chunk_size, + num_chunks, + required_files, + ) + ) + + # Wait for all tasks to complete + for task in tqdm(tasks, desc="Processing tracks"): + task.result() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Extract random audio chunks from tracks in parallel." + ) + + parser.add_argument( + "root_dir", + type=str, + help="Path to the root directory containing source audio tracks.", + ) + + parser.add_argument( + "target_dir", + type=str, + help="Path to the directory where processed chunks will be saved.", + ) + + parser.add_argument( + "--chunk_size", + type=int, + default=5, + help="Size of each audio chunk in seconds (default: 5)", + ) + + parser.add_argument( + "--num_chunks", + type=int, + default=1000, + help="Number of random chunks to extract per track (default: 1000)", + ) + + parser.add_argument( + "--max_workers", + type=int, + default=32, + help="Maximum number of parallel workers (default: 32)", + ) + + args = parser.parse_args() + + process_tracks( + root_dir=args.root_dir, + target_dir=args.target_dir, + chunk_size=args.chunk_size, + num_chunks=args.num_chunks, + max_workers=args.max_workers, + ) diff --git a/benchmarks/DASB/MUSDB/create_musdb_eval.py b/benchmarks/DASB/MUSDB/create_musdb_eval.py new file mode 100644 index 000000000..dad30fd62 --- /dev/null +++ b/benchmarks/DASB/MUSDB/create_musdb_eval.py @@ -0,0 +1,148 @@ +import os +import argparse +import numpy as np +import soundfile as sf +from tqdm import tqdm +from concurrent.futures import ProcessPoolExecutor + + +def ensure_audio_files(directory, required_files): + """ + Ensure all required audio files exist in a directory. + """ + required_paths = { + file: os.path.join(directory, file) for file in required_files + } + if not all(os.path.exists(p) for p in required_paths.values()): + print(f"Error: Missing files in {directory}. Cannot proceed.") + return False + return True + + +def process_track( + split, track, track_path, target_dir, chunk_size, required_files +): + """ + Process a single track by sequentially partitioning it into non-overlapping chunks. + """ + if not ensure_audio_files(track_path, required_files): + return + + audio_data = {} + sample_rate = None + total_samples = None + + # Load all required files and convert to mono if needed + for file in required_files: + file_path = os.path.join(track_path, file) + audio, sr = sf.read(file_path) + if len(audio.shape) == 2: + audio = np.mean(audio, axis=1) + if sample_rate is None: + sample_rate = sr + if total_samples is None: + total_samples = len(audio) + audio_data[file] = audio + + chunk_samples = int(chunk_size * sample_rate) + num_chunks = total_samples // chunk_samples + + # Save sequentially sampled chunks + for i in range(num_chunks): + start = i * chunk_samples + end = start + chunk_samples + chunk_sum = np.zeros(chunk_samples) # Initialize for mixture + + new_track_name = f"{track.replace(' ', '_')}_chunk{i:02d}" + save_path = os.path.join(target_dir, split, new_track_name) + os.makedirs(save_path, exist_ok=True) + + for file in required_files: + chunk = audio_data[file][start:end] + if len(chunk) < chunk_samples: + chunk = np.pad(chunk, (0, chunk_samples - len(chunk))) + chunk_sum += chunk # Add to mixture + sf.write(os.path.join(save_path, file), chunk, sample_rate) + + # Save the computed mixture + sf.write(os.path.join(save_path, "mixture.wav"), chunk_sum, sample_rate) + + +def process_tracks(root_dir, target_dir, chunk_size, max_workers=4): + """ + Processes the dataset by sequentially partitioning tracks using parallel processing. + Args: + root_dir (str): Root directory containing the dataset. + target_dir (str): Target directory to save the processed chunks. + chunk_size (int): Size of each chunk in seconds. + max_workers (int): Maximum number of parallel workers. + """ + required_files = ["bass.wav", "drums.wav", "other.wav", "vocals.wav"] + tasks = [] + + with ProcessPoolExecutor(max_workers=max_workers) as executor: + for split in ["validation", "eval"]: + split_path = os.path.join(root_dir, split) + if not os.path.isdir(split_path): + continue + + for track in os.listdir(split_path): + track_path = os.path.join(split_path, track) + if not os.path.isdir(track_path): + continue + tasks.append( + executor.submit( + process_track, + split, + track, + track_path, + target_dir, + chunk_size, + required_files, + ) + ) + + # Wait for all tasks to complete + for task in tqdm(tasks, desc="Processing tracks"): + task.result() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Process audio tracks in parallel with chunking." + ) + + parser.add_argument( + "root_dir", + type=str, + help="Path to the root directory containing MUSDB source audio tracks.", + ) + + parser.add_argument( + "target_dir", + type=str, + help="Path to the target directory where processed chunks will be saved.", + ) + + parser.add_argument( + "--chunk_size", + type=int, + default=5, + help="Chunk size in seconds (default: 5)", + ) + + parser.add_argument( + "--max_workers", + type=int, + default=8, + help="Maximum number of parallel workers (default: 8)", + ) + + args = parser.parse_args() + + process_tracks( + root_dir=args.root_dir, + target_dir=args.target_dir, + chunk_size=args.chunk_size, + max_workers=args.max_workers, + ) diff --git a/benchmarks/DASB/MUSDB/separation/hparams/conformer/train_dac.yaml b/benchmarks/DASB/MUSDB/separation/hparams/conformer/train_dac.yaml new file mode 100644 index 000000000..30fcd3d58 --- /dev/null +++ b/benchmarks/DASB/MUSDB/separation/hparams/conformer/train_dac.yaml @@ -0,0 +1,219 @@ +# ########################################################################################### +# Model: Conformer with DAC audio representations +# Authors: Luca Della Libera 2024 +# ########################################################################################### + +experiment_name: dac + +# Seed needs to be set at top of YAML +seed: 0 +__set_seed: !apply:torch.manual_seed [!ref ] +testing: False # If set to True, the test evlaution is done, otherwise skipped. + +# Data preparation +data_folder: !PLACEHOLDER +train_csv: !ref /train.csv +valid_csv: !ref /validation.csv +test_csv: !ref /eval.csv +splits: [train, validation, eval] +num_speakers: 4 +add_noise: False +version: wav16k/min + +# Output folders +output_folder: !ref results// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Save options +compute_metrics: True +save_audios: True + +# Preprocessing parameters +train_remove_if_longer: 1000.0 # Seconds +valid_remove_if_longer: 1000.0 # Seconds +test_remove_if_longer: 1000.0 # Seconds +sorting: random +use_cache: True + +# Training parameters +num_epochs: 40 +grad_accumulation_factor: 16 +train_batch_size: 1 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 8 +nonfinite_patience: 10 +max_grad_norm: 5.0 +precision: fp32 +ckpt_interval_minutes: 6000 +keep_checkpoints: 1 +augment: False +augment_prob: 0.75 +use_pit: True + +# Optimizer parameters +lr: 0.0003578 # @orion_step1: --lr~"loguniform(0.00005,0.001)" +weight_decay: 0.01 +improvement_threshold: 0.0025 +annealing_factor: 0.9 +patient: 1 + +# DAC parameters +# sample_rate: [16000, 24000, 44000, 44000] +# vocab_size: [1024, 1024, 1024, 1024] +# max_num_codebooks: [12, 32, 9, 18] +# model_type: [16khz, 24khz, 44khz, 44khz] +# model_bitrate: [8kbps, 8kbps, 8kbps, 16kbps] +sample_rate: 24000 # NOTE: must match DAC's model type +vocab_size: 1024 +num_codebooks: 2 # NOTE: must be smaller or equal to the maximum number of codebooks for the given model type +model_type: 24khz +model_bitrate: 8kbps + +# Embedding parameters +embedding_dim: 1024 +pretrain_embedding: False # If True, must match the codec's embedding size (1024) +freeze_embedding: False + +# Encoder parameters +dropout: 0.1 +activation: !name:torch.nn.GELU +d_model: 256 +nhead: 4 +num_layers: 6 +d_ffn: 2048 +max_length: 2000 +causal: False + +# Augmentation +drop_freq: !new:speechbrain.augment.time_domain.DropFreq + drop_freq_low: 0 # Min frequency band dropout probability + drop_freq_high: 1 # Max frequency band dropout probability + drop_freq_count_low: 1 # Min number of frequency bands to drop + drop_freq_count_high: 3 # Max number of frequency bands to drop + drop_freq_width: 0.05 # Width of frequency bands to drop + +drop_chunk: !new:speechbrain.augment.time_domain.DropChunk + drop_length_low: 1 # Min number of audio chunks to drop + drop_length_high: 5 # Max number of audio chunks to drop + drop_count_low: 1000 # Min length of audio chunks to drop + drop_count_high: 2000 # Max length of audio chunks to drop + +augmentation: !new:speechbrain.augment.augmenter.Augmenter + parallel_augment: False + concat_original: False + repeat_augment: 1 + shuffle_augmentations: False + min_augmentations: 2 + max_augmentations: 2 + augment_prob: !ref + augmentations: [!ref , !ref ] + +# Modules +codec: !new:speechbrain.lobes.models.discrete.dac.DAC + model_type: !ref + model_bitrate: !ref + load_pretrained: True + tag: latest + +embedding: !new:model.custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + +attention_mlp: !new:model.custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR + input_size: !ref + tgt_vocab: -1 + d_model: !ref + nhead: !ref + num_encoder_layers: !ref + num_decoder_layers: 0 + d_ffn: !ref + dropout: !ref + activation: !ref + max_length: !ref + encoder_module: conformer + normalize_before: True + causal: !ref + +head: !new:torch.nn.Linear + in_features: !ref + out_features: !ref * * + +modules: + embedding: !ref + attention_mlp: !ref + encoder: !ref + head: !ref + +model: !new:torch.nn.ModuleList + [[!ref , + !ref , + !ref , + !ref ]] + +# Loss functions +ce_loss: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.0 + allowed_len_diff: 0 + reduction: mean + +# Optimizers +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +# Schedulers +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: !ref + annealing_factor: !ref + patient: !ref + +# Dataloaders +train_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + shuffle: !apply:str.__eq__ [!ref , random] + +valid_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +test_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +# Performance metrics +ter_computer: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +bsseval_computer: !name:metrics.bsseval.BSSEval + n_sources: !ref + permutation_invariant: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/MUSDB/separation/hparams/conformer/train_encodec.yaml b/benchmarks/DASB/MUSDB/separation/hparams/conformer/train_encodec.yaml new file mode 100644 index 000000000..439ceeb7e --- /dev/null +++ b/benchmarks/DASB/MUSDB/separation/hparams/conformer/train_encodec.yaml @@ -0,0 +1,220 @@ +# ########################################################################################### +# Model: Conformer with EnCodec audio representations +# Authors: Luca Della Libera 2024 +# ########################################################################################### + +experiment_name: encodec + +# Seed needs to be set at top of YAML +seed: 0 +__set_seed: !apply:torch.manual_seed [!ref ] +testing: False # If set to True, the test evlaution is done, otherwise skipped. + +# Data preparation +data_folder: !PLACEHOLDER +train_csv: !ref /train.csv +valid_csv: !ref /validation.csv +test_csv: !ref /eval.csv +splits: [train, validation, eval] +num_speakers: 4 +add_noise: False +version: wav16k/min + +# Output folders +output_folder: !ref results// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Save options +compute_metrics: True +save_audios: True + +# Preprocessing parameters +train_remove_if_longer: 1000.0 # Seconds +valid_remove_if_longer: 1000.0 # Seconds +test_remove_if_longer: 1000.0 # Seconds +sorting: random +use_cache: True + +# Training parameters +num_epochs: 40 +grad_accumulation_factor: 16 +train_batch_size: 1 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 8 +nonfinite_patience: 10 +max_grad_norm: 5.0 +precision: fp32 +ckpt_interval_minutes: 6000 +keep_checkpoints: 1 +augment: False +augment_prob: 0.75 +use_pit: True + +# Optimizer parameters +lr: 0.0003578 # @orion_step1: --lr~"loguniform(0.00005,0.001)" +weight_decay: 0.01 +improvement_threshold: 0.0025 +annealing_factor: 0.9 +patient: 1 + +# EnCodec parameters +# sample_rate: [24000, 24000, 24000, 24000] +# vocab_size: [1024, 1024, 1024, 1024] +# num_codebooks: [2, 4, 8, 16, 32] +# bandwidth: [1.5, 3.0, 6.0, 12.0, 24.0] +sample_rate: 24000 +vocab_size: 1024 +num_codebooks: 2 +bandwidth: !ref * 75 / 100 + +# Embedding parameters +embedding_dim: 1024 +pretrain_embedding: False # If True, must match the codec's embedding size (128) +freeze_embedding: False + +# Encoder parameters +dropout: 0.1 +activation: !name:torch.nn.GELU +d_model: 256 +nhead: 4 +num_layers: 6 +d_ffn: 2048 +max_length: 2000 +causal: False + +# Augmentation +drop_freq: !new:speechbrain.augment.time_domain.DropFreq + drop_freq_low: 0 # Min frequency band dropout probability + drop_freq_high: 1 # Max frequency band dropout probability + drop_freq_count_low: 1 # Min number of frequency bands to drop + drop_freq_count_high: 3 # Max number of frequency bands to drop + drop_freq_width: 0.05 # Width of frequency bands to drop + +drop_chunk: !new:speechbrain.augment.time_domain.DropChunk + drop_length_low: 1 # Min number of audio chunks to drop + drop_length_high: 5 # Max number of audio chunks to drop + drop_count_low: 1000 # Min length of audio chunks to drop + drop_count_high: 2000 # Max length of audio chunks to drop + +augmentation: !new:speechbrain.augment.augmenter.Augmenter + parallel_augment: False + concat_original: False + repeat_augment: 1 + shuffle_augmentations: False + min_augmentations: 2 + max_augmentations: 2 + augment_prob: !ref + augmentations: [!ref , !ref ] + +# Modules +codec: !new:speechbrain.lobes.models.huggingface_transformers.encodec.Encodec + source: facebook/encodec_24khz # Only the 24kHz version supports mono audio + save_path: !ref + sample_rate: !ref + bandwidth: !ref + flat_embeddings: False + freeze: True + renorm_embeddings: False + +embedding: !new:model.custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + +attention_mlp: !new:model.custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR + input_size: !ref + tgt_vocab: -1 + d_model: !ref + nhead: !ref + num_encoder_layers: !ref + num_decoder_layers: 0 + d_ffn: !ref + dropout: !ref + activation: !ref + max_length: !ref + encoder_module: conformer + normalize_before: True + causal: !ref + +head: !new:torch.nn.Linear + in_features: !ref + out_features: !ref * * + +modules: + embedding: !ref + attention_mlp: !ref + encoder: !ref + head: !ref + +model: !new:torch.nn.ModuleList + [[!ref , + !ref , + !ref , + !ref ]] + +# Loss functions +ce_loss: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.0 + allowed_len_diff: 0 + reduction: mean + +# Optimizers +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +# Schedulers +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: !ref + annealing_factor: !ref + patient: !ref + +# Dataloaders +train_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + shuffle: !apply:str.__eq__ [!ref , random] + +valid_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +test_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +# Performance metrics +ter_computer: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +bsseval_computer: !name:metrics.bsseval.BSSEval + n_sources: !ref + permutation_invariant: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/MUSDB/separation/hparams/conformer/train_sqcodec.yaml b/benchmarks/DASB/MUSDB/separation/hparams/conformer/train_sqcodec.yaml new file mode 100644 index 000000000..040937b02 --- /dev/null +++ b/benchmarks/DASB/MUSDB/separation/hparams/conformer/train_sqcodec.yaml @@ -0,0 +1,224 @@ +# ########################################################################################### +# Model: Conformer with EnCodec audio representations +# Authors: Luca Della Libera 2024 +# ########################################################################################### + +experiment_name: encodec + +# Seed needs to be set at top of YAML +seed: 0 +__set_seed: !apply:torch.manual_seed [!ref ] +testing: False # If set to True, the test evlaution is done, otherwise skipped. + +# Data preparation +data_folder: !PLACEHOLDER +train_csv: !ref /train.csv +valid_csv: !ref /validation.csv +test_csv: !ref /eval.csv +splits: [train, validation, eval] +num_speakers: 4 +add_noise: False +version: wav16k/min + +# Output folders +output_folder: !ref results// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Save options +compute_metrics: True +save_audios: True + +# Preprocessing parameters +train_remove_if_longer: 1000.0 # Seconds +valid_remove_if_longer: 1000.0 # Seconds +test_remove_if_longer: 1000.0 # Seconds +sorting: random +use_cache: True + +# Training parameters +num_epochs: 40 +grad_accumulation_factor: 16 +train_batch_size: 1 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 8 +nonfinite_patience: 10 +max_grad_norm: 5.0 +precision: fp32 +ckpt_interval_minutes: 6000 +keep_checkpoints: 1 +augment: False +augment_prob: 0.75 +use_pit: True + +# Optimizer parameters +lr: 0.0003578 # @orion_step1: --lr~"loguniform(0.00005,0.001)" +weight_decay: 0.01 +improvement_threshold: 0.0025 +annealing_factor: 0.9 +patient: 1 + +# EnCodec parameters +# sample_rate: [24000, 24000, 24000, 24000] +# vocab_size: [1024, 1024, 1024, 1024] +# num_codebooks: [2, 4, 8, 16, 32] +# bandwidth: [1.5, 3.0, 6.0, 12.0, 24.0] +sample_rate: 16000 +vocab_size: 19683 +num_codebooks: 4 +bandwidth: 2 + +# Embedding parameters +pretrain_embedding: False # If True, must match the codec's embedding size (128) +freeze_embedding: False +encoder_dim: 1024 +embedding_dim: 9 +hidden_dim: 256 +# if set to concat, you need to set embedding_dim to match the encoder_dim after concatenation. Eg, if you have 4 codebook, embedding_dim shoudl set to encoder_dim/4 +embedding_strg: concat # option are concat and att_pool +scalar_embedding: True + +# Encoder parameters +dropout: 0.1 +activation: !name:torch.nn.GELU +d_model: 256 +nhead: 4 +num_layers: 6 +d_ffn: 2048 +max_length: 2000 +causal: False + +# Augmentation +drop_freq: !new:speechbrain.augment.time_domain.DropFreq + drop_freq_low: 0 # Min frequency band dropout probability + drop_freq_high: 1 # Max frequency band dropout probability + drop_freq_count_low: 1 # Min number of frequency bands to drop + drop_freq_count_high: 3 # Max number of frequency bands to drop + drop_freq_width: 0.05 # Width of frequency bands to drop + +drop_chunk: !new:speechbrain.augment.time_domain.DropChunk + drop_length_low: 1 # Min number of audio chunks to drop + drop_length_high: 5 # Max number of audio chunks to drop + drop_count_low: 1000 # Min length of audio chunks to drop + drop_count_high: 2000 # Max length of audio chunks to drop + +augmentation: !new:speechbrain.augment.augmenter.Augmenter + parallel_augment: False + concat_original: False + repeat_augment: 1 + shuffle_augmentations: False + min_augmentations: 2 + max_augmentations: 2 + augment_prob: !ref + augmentations: [!ref , !ref ] + +# Modules +codec: !new:sq_codec.SQCodec + save_path: !ref + config: config.yaml + checkpoint: ckpt_00190000.pth + +embedding: !new:model.custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + init: !ref + scalar: !ref + hidden_dim: !ref + +attention_mlp: !new:model.custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR + input_size: !ref + tgt_vocab: -1 + d_model: !ref + nhead: !ref + num_encoder_layers: !ref + num_decoder_layers: 0 + d_ffn: !ref + dropout: !ref + activation: !ref + max_length: !ref + encoder_module: conformer + normalize_before: True + causal: !ref + +head: !new:torch.nn.Linear + in_features: !ref + out_features: !ref * * + +modules: + embedding: !ref + attention_mlp: !ref + encoder: !ref + head: !ref + +model: !new:torch.nn.ModuleList + [[!ref , + !ref , + !ref , + !ref ]] + +# Loss functions +ce_loss: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.0 + allowed_len_diff: 0 + reduction: mean + +# Optimizers +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +# Schedulers +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: !ref + annealing_factor: !ref + patient: !ref + +# Dataloaders +train_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + shuffle: !apply:str.__eq__ [!ref , random] + +valid_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +test_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +# Performance metrics +ter_computer: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +bsseval_computer: !name:metrics.bsseval.BSSEval + n_sources: !ref + permutation_invariant: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/MUSDB/separation/hparams/conformer/train_wavtokenizer.yaml b/benchmarks/DASB/MUSDB/separation/hparams/conformer/train_wavtokenizer.yaml new file mode 100644 index 000000000..e5ebb0313 --- /dev/null +++ b/benchmarks/DASB/MUSDB/separation/hparams/conformer/train_wavtokenizer.yaml @@ -0,0 +1,223 @@ +# ########################################################################################### +# Model: Conformer with Wavtokenizer audio representations +# Authors: Luca Della Libera 2024 +# ########################################################################################### + +experiment_name: wavtokenizer + +# Seed needs to be set at top of YAML +seed: 0 +__set_seed: !apply:torch.manual_seed [!ref ] +testing: False # If set to True, the test evlaution is done, otherwise skipped. + +# Data preparation +data_folder: !PLACEHOLDER +train_csv: !ref /train.csv +valid_csv: !ref /validation.csv +test_csv: !ref /eval.csv +splits: [train, validation, eval] +num_speakers: 4 +add_noise: False +version: wav16k/min + +# Output folders +output_folder: !ref results// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Save options +compute_metrics: True +save_audios: True + +# Preprocessing parameters +train_remove_if_longer: 1000.0 # Seconds +valid_remove_if_longer: 1000.0 # Seconds +test_remove_if_longer: 1000.0 # Seconds +sorting: random +use_cache: True + +# Training parameters +num_epochs: 40 +grad_accumulation_factor: 16 +train_batch_size: 1 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 8 +nonfinite_patience: 10 +max_grad_norm: 5.0 +precision: fp32 +ckpt_interval_minutes: 6000 +keep_checkpoints: 1 +augment: False +augment_prob: 0.75 +use_pit: True + +# Optimizer parameters +lr: 0.0003578 # @orion_step1: --lr~"loguniform(0.00005,0.001)" +weight_decay: 0.01 +improvement_threshold: 0.0025 +annealing_factor: 0.9 +patient: 1 + +# EnCodec parameters +# sample_rate: [24000, 24000, 24000, 24000] +# vocab_size: [1024, 1024, 1024, 1024] +# num_codebooks: [2, 4, 8, 16, 32] +# bandwidth: [1.5, 3.0, 6.0, 12.0, 24.0] +sample_rate: 24000 +vocab_size: 4096 +num_codebooks: 1 +bandwidth: 2 + +# Embedding parameters +embedding_dim: 1024 +pretrain_embedding: False # If True, must match the codec's embedding size (128) +freeze_embedding: False + +# Encoder parameters +dropout: 0.1 +activation: !name:torch.nn.GELU +d_model: 256 +nhead: 4 +num_layers: 6 +d_ffn: 2048 +max_length: 2000 +causal: False + +# Augmentation +drop_freq: !new:speechbrain.augment.time_domain.DropFreq + drop_freq_low: 0 # Min frequency band dropout probability + drop_freq_high: 1 # Max frequency band dropout probability + drop_freq_count_low: 1 # Min number of frequency bands to drop + drop_freq_count_high: 3 # Max number of frequency bands to drop + drop_freq_width: 0.05 # Width of frequency bands to drop + +drop_chunk: !new:speechbrain.augment.time_domain.DropChunk + drop_length_low: 1 # Min number of audio chunks to drop + drop_length_high: 5 # Max number of audio chunks to drop + drop_count_low: 1000 # Min length of audio chunks to drop + drop_count_high: 2000 # Max length of audio chunks to drop + +augmentation: !new:speechbrain.augment.augmenter.Augmenter + parallel_augment: False + concat_original: False + repeat_augment: 1 + shuffle_augmentations: False + min_augmentations: 2 + max_augmentations: 2 + augment_prob: !ref + augmentations: [!ref , !ref ] + +model_hub: novateur/WavTokenizer-medium-music-audio-75token +config: wavtokenizer_mediumdata_music_audio_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml +checkpoint: wavtokenizer_medium_music_audio_320_24k_v2.ckpt + +# Modules +codec: !new:speechbrain.lobes.models.discrete.wavtokenizer.WavTokenizer + source: !ref + save_path: !ref + checkpoint: !ref + config: !ref + sample_rate: !ref + freeze: True + +embedding: !new:model.custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + +attention_mlp: !new:model.custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR + input_size: !ref + tgt_vocab: -1 + d_model: !ref + nhead: !ref + num_encoder_layers: !ref + num_decoder_layers: 0 + d_ffn: !ref + dropout: !ref + activation: !ref + max_length: !ref + encoder_module: conformer + normalize_before: True + causal: !ref + +head: !new:torch.nn.Linear + in_features: !ref + out_features: !ref * * + +modules: + embedding: !ref + attention_mlp: !ref + encoder: !ref + head: !ref + +model: !new:torch.nn.ModuleList + [[!ref , + !ref , + !ref , + !ref ]] + +# Loss functions +ce_loss: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.0 + allowed_len_diff: 0 + reduction: mean + +# Optimizers +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +# Schedulers +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: !ref + annealing_factor: !ref + patient: !ref + +# Dataloaders +train_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + shuffle: !apply:str.__eq__ [!ref , random] + +valid_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +test_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +# Performance metrics +ter_computer: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +bsseval_computer: !name:metrics.bsseval.BSSEval + n_sources: !ref + permutation_invariant: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/MUSDB/separation/hparams/crdnn/train_dac.yaml b/benchmarks/DASB/MUSDB/separation/hparams/crdnn/train_dac.yaml new file mode 100644 index 000000000..846d0fcbf --- /dev/null +++ b/benchmarks/DASB/MUSDB/separation/hparams/crdnn/train_dac.yaml @@ -0,0 +1,229 @@ +# ########################################################################################### +# Model: CRDNN with DAC audio representations +# Authors: Luca Della Libera 2024 +# ########################################################################################### + +experiment_name: dac + +# Seed needs to be set at top of YAML +seed: 0 +__set_seed: !apply:torch.manual_seed [!ref ] +testing: False # If set to True, the test evlaution is done, otherwise skipped. + +# Data preparation +data_folder: !PLACEHOLDER +train_csv: !ref /train.csv +valid_csv: !ref /validation.csv +test_csv: !ref /eval.csv +splits: [train, validation, eval] +num_speakers: 4 +add_noise: False +version: wav16k/min + +# Output folders +output_folder: !ref results// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Save options +compute_metrics: True +save_audios: True + +# Preprocessing parameters +train_remove_if_longer: 1000.0 # Seconds +valid_remove_if_longer: 1000.0 # Seconds +test_remove_if_longer: 1000.0 # Seconds +sorting: random +use_cache: True + +# Training parameters +num_epochs: 40 +grad_accumulation_factor: 16 +train_batch_size: 1 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 8 +nonfinite_patience: 10 +max_grad_norm: 5.0 +precision: fp32 +ckpt_interval_minutes: 6000 +keep_checkpoints: 1 +augment: False +augment_prob: 0.75 +use_pit: True + +# Optimizer parameters +lr: 0.0003578 # @orion_step1: --lr~"loguniform(0.00005,0.001)" +weight_decay: 0.01 +improvement_threshold: 0.0025 +annealing_factor: 0.9 +patient: 1 + +# DAC parameters +# sample_rate: [16000, 24000, 44000, 44000] +# vocab_size: [1024, 1024, 1024, 1024] +# max_num_codebooks: [12, 32, 9, 18] +# model_type: [16khz, 24khz, 44khz, 44khz] +# model_bitrate: [8kbps, 8kbps, 8kbps, 16kbps] +sample_rate: 24000 # NOTE: must match DAC's model type +vocab_size: 1024 +num_codebooks: 2 # NOTE: must be smaller or equal to the maximum number of codebooks for the given model type +model_type: 24khz +model_bitrate: 8kbps + +# Embedding parameters +embedding_dim: 1024 +pretrain_embedding: False # If True, must match the codec's embedding size (1024) +freeze_embedding: False + +# Encoder parameters +dropout: 0.1 +activation: !name:torch.nn.LeakyReLU +rnn_class: !name:speechbrain.nnet.RNN.LSTM +rnn_layers: 4 +time_pooling_size: 1 +rnn_bidirectional: True +rnn_neurons: 256 +dnn_blocks: 2 +dnn_neurons: 256 +cnn_blocks: 2 +cnn_channels: (16, 16) +inter_layer_pooling_size: (2, 2) +cnn_kernelsize: (3, 3) + +# Augmentation +drop_freq: !new:speechbrain.augment.time_domain.DropFreq + drop_freq_low: 0 # Min frequency band dropout probability + drop_freq_high: 1 # Max frequency band dropout probability + drop_freq_count_low: 1 # Min number of frequency bands to drop + drop_freq_count_high: 3 # Max number of frequency bands to drop + drop_freq_width: 0.05 # Width of frequency bands to drop + +drop_chunk: !new:speechbrain.augment.time_domain.DropChunk + drop_length_low: 1 # Min number of audio chunks to drop + drop_length_high: 5 # Max number of audio chunks to drop + drop_count_low: 1000 # Min length of audio chunks to drop + drop_count_high: 2000 # Max length of audio chunks to drop + +augmentation: !new:speechbrain.augment.augmenter.Augmenter + parallel_augment: False + concat_original: False + repeat_augment: 1 + shuffle_augmentations: False + min_augmentations: 2 + max_augmentations: 2 + augment_prob: !ref + augmentations: [!ref , !ref ] + +# Modules +codec: !new:speechbrain.lobes.models.discrete.dac.DAC + model_type: !ref + model_bitrate: !ref + load_pretrained: True + tag: latest + +embedding: !new:model.custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + +attention_mlp: !new:model.custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.CRDNN.CRDNN + input_shape: [null, null, !ref ] + activation: !ref + dropout: !ref + cnn_blocks: !ref + cnn_channels: !ref + cnn_kernelsize: !ref + inter_layer_pooling_size: !ref + time_pooling: True + using_2d_pooling: False + time_pooling_size: !ref + rnn_class: !ref + rnn_layers: !ref + rnn_neurons: !ref + rnn_bidirectional: !ref + dnn_blocks: !ref + dnn_neurons: !ref + rnn_re_init: True + use_rnnp: False + +head: !new:torch.nn.Linear + in_features: !ref + out_features: !ref * * + +modules: + embedding: !ref + attention_mlp: !ref + encoder: !ref + head: !ref + +model: !new:torch.nn.ModuleList + [[!ref , + !ref , + !ref , + !ref ]] + +# Loss functions +ce_loss: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.0 + allowed_len_diff: 0 + reduction: mean + +# Optimizers +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +# Schedulers +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: !ref + annealing_factor: !ref + patient: !ref + +# Dataloaders +train_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + shuffle: !apply:str.__eq__ [!ref , random] + +valid_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +test_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +# Performance metrics +ter_computer: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +bsseval_computer: !name:metrics.bsseval.BSSEval + n_sources: !ref + permutation_invariant: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/MUSDB/separation/hparams/crdnn/train_encodec.yaml b/benchmarks/DASB/MUSDB/separation/hparams/crdnn/train_encodec.yaml new file mode 100644 index 000000000..4cc82e8a5 --- /dev/null +++ b/benchmarks/DASB/MUSDB/separation/hparams/crdnn/train_encodec.yaml @@ -0,0 +1,230 @@ +# ########################################################################################### +# Model: CRDNN with EnCodec audio representations +# Authors: Luca Della Libera 2024 +# ########################################################################################### + +experiment_name: encodec + +# Seed needs to be set at top of YAML +seed: 0 +__set_seed: !apply:torch.manual_seed [!ref ] +testing: False # If set to True, the test evlaution is done, otherwise skipped. + +# Data preparation +data_folder: !PLACEHOLDER +train_csv: !ref /train.csv +valid_csv: !ref /validation.csv +test_csv: !ref /eval.csv +splits: [train, validation, eval] +num_speakers: 4 +add_noise: False +version: wav16k/min + +# Output folders +output_folder: !ref results// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Save options +compute_metrics: True +save_audios: True + +# Preprocessing parameters +train_remove_if_longer: 1000.0 # Seconds +valid_remove_if_longer: 1000.0 # Seconds +test_remove_if_longer: 1000.0 # Seconds +sorting: random +use_cache: True + +# Training parameters +num_epochs: 40 +grad_accumulation_factor: 16 +train_batch_size: 1 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 8 +nonfinite_patience: 10 +max_grad_norm: 5.0 +precision: fp32 +ckpt_interval_minutes: 6000 +keep_checkpoints: 1 +augment: False +augment_prob: 0.75 +use_pit: True + +# Optimizer parameters +lr: 0.0003578 # @orion_step1: --lr~"loguniform(0.00005,0.001)" +weight_decay: 0.01 +improvement_threshold: 0.0025 +annealing_factor: 0.9 +patient: 1 + +# EnCodec parameters +# sample_rate: [24000, 24000, 24000, 24000] +# vocab_size: [1024, 1024, 1024, 1024] +# num_codebooks: [2, 4, 8, 16, 32] +# bandwidth: [1.5, 3.0, 6.0, 12.0, 24.0] +sample_rate: 24000 +vocab_size: 1024 +num_codebooks: 2 +bandwidth: !ref * 75 / 100 + +# Embedding parameters +embedding_dim: 1024 +pretrain_embedding: False # If True, must match the codec's embedding size (128) +freeze_embedding: False + +# Encoder parameters +dropout: 0.1 +activation: !name:torch.nn.LeakyReLU +rnn_class: !name:speechbrain.nnet.RNN.LSTM +rnn_layers: 4 +time_pooling_size: 1 +rnn_bidirectional: True +rnn_neurons: 256 +dnn_blocks: 2 +dnn_neurons: 256 +cnn_blocks: 2 +cnn_channels: (16, 16) +inter_layer_pooling_size: (2, 2) +cnn_kernelsize: (3, 3) + +# Augmentation +drop_freq: !new:speechbrain.augment.time_domain.DropFreq + drop_freq_low: 0 # Min frequency band dropout probability + drop_freq_high: 1 # Max frequency band dropout probability + drop_freq_count_low: 1 # Min number of frequency bands to drop + drop_freq_count_high: 3 # Max number of frequency bands to drop + drop_freq_width: 0.05 # Width of frequency bands to drop + +drop_chunk: !new:speechbrain.augment.time_domain.DropChunk + drop_length_low: 1 # Min number of audio chunks to drop + drop_length_high: 5 # Max number of audio chunks to drop + drop_count_low: 1000 # Min length of audio chunks to drop + drop_count_high: 2000 # Max length of audio chunks to drop + +augmentation: !new:speechbrain.augment.augmenter.Augmenter + parallel_augment: False + concat_original: False + repeat_augment: 1 + shuffle_augmentations: False + min_augmentations: 2 + max_augmentations: 2 + augment_prob: !ref + augmentations: [!ref , !ref ] + +# Modules +codec: !new:speechbrain.lobes.models.huggingface_transformers.encodec.Encodec + source: facebook/encodec_24khz # Only the 24kHz version supports mono audio + save_path: !ref + sample_rate: !ref + bandwidth: !ref + flat_embeddings: False + freeze: True + renorm_embeddings: False + +embedding: !new:model.custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + +attention_mlp: !new:model.custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.CRDNN.CRDNN + input_shape: [null, null, !ref ] + activation: !ref + dropout: !ref + cnn_blocks: !ref + cnn_channels: !ref + cnn_kernelsize: !ref + inter_layer_pooling_size: !ref + time_pooling: True + using_2d_pooling: False + time_pooling_size: !ref + rnn_class: !ref + rnn_layers: !ref + rnn_neurons: !ref + rnn_bidirectional: !ref + dnn_blocks: !ref + dnn_neurons: !ref + rnn_re_init: True + use_rnnp: False + +head: !new:torch.nn.Linear + in_features: !ref + out_features: !ref * * + +modules: + embedding: !ref + attention_mlp: !ref + encoder: !ref + head: !ref + +model: !new:torch.nn.ModuleList + [[!ref , + !ref , + !ref , + !ref ]] + +# Loss functions +ce_loss: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.0 + allowed_len_diff: 0 + reduction: mean + +# Optimizers +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +# Schedulers +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: !ref + annealing_factor: !ref + patient: !ref + +# Dataloaders +train_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + shuffle: !apply:str.__eq__ [!ref , random] + +valid_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +test_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +# Performance metrics +ter_computer: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +bsseval_computer: !name:metrics.bsseval.BSSEval + n_sources: !ref + permutation_invariant: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/MUSDB/separation/hparams/crdnn/train_sqcodec.yaml b/benchmarks/DASB/MUSDB/separation/hparams/crdnn/train_sqcodec.yaml new file mode 100644 index 000000000..2d1983a8d --- /dev/null +++ b/benchmarks/DASB/MUSDB/separation/hparams/crdnn/train_sqcodec.yaml @@ -0,0 +1,234 @@ +# ########################################################################################### +# Model: CRDNN with SQCodec audio representations +# Authors: Luca Della Libera 2024 +# ########################################################################################### + +experiment_name: sqcodec + +# Seed needs to be set at top of YAML +seed: 0 +__set_seed: !apply:torch.manual_seed [!ref ] +testing: False # If set to True, the test evlaution is done, otherwise skipped. + +# Data preparation +data_folder: !PLACEHOLDER +train_csv: !ref /train.csv +valid_csv: !ref /validation.csv +test_csv: !ref /eval.csv +splits: [train, validation, eval] +num_speakers: 4 +add_noise: False +version: wav16k/min + +# Output folders +output_folder: !ref results// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Save options +compute_metrics: True +save_audios: True + +# Preprocessing parameters +train_remove_if_longer: 1000.0 # Seconds +valid_remove_if_longer: 1000.0 # Seconds +test_remove_if_longer: 1000.0 # Seconds +sorting: random +use_cache: True + +# Training parameters +num_epochs: 40 +grad_accumulation_factor: 16 +train_batch_size: 1 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 8 +nonfinite_patience: 10 +max_grad_norm: 5.0 +precision: fp32 +ckpt_interval_minutes: 6000 +keep_checkpoints: 1 +augment: False +augment_prob: 0.75 +use_pit: True + +# Optimizer parameters +lr: 0.0003578 # @orion_step1: --lr~"loguniform(0.00005,0.001)" +weight_decay: 0.01 +improvement_threshold: 0.0025 +annealing_factor: 0.9 +patient: 1 + +# EnCodec parameters +# sample_rate: [24000, 24000, 24000, 24000] +# vocab_size: [1024, 1024, 1024, 1024] +# num_codebooks: [2, 4, 8, 16, 32] +# bandwidth: [1.5, 3.0, 6.0, 12.0, 24.0] +sample_rate: 16000 +vocab_size: 19683 +num_codebooks: 4 +bandwidth: 2 + +# Embedding parameters +pretrain_embedding: False # If True, must match the codec's embedding size (128) +freeze_embedding: False +encoder_dim: 1024 +embedding_dim: 9 +hidden_dim: 256 +# if set to concat, you need to set embedding_dim to match the encoder_dim after concatenation. Eg, if you have 4 codebook, embedding_dim shoudl set to encoder_dim/4 +embedding_strg: concat # option are concat and att_pool +scalar_embedding: True + +# Encoder parameters +dropout: 0.1 +activation: !name:torch.nn.LeakyReLU +rnn_class: !name:speechbrain.nnet.RNN.LSTM +rnn_layers: 4 +time_pooling_size: 1 +rnn_bidirectional: True +rnn_neurons: 256 +dnn_blocks: 2 +dnn_neurons: 256 +cnn_blocks: 2 +cnn_channels: (16, 16) +inter_layer_pooling_size: (2, 2) +cnn_kernelsize: (3, 3) + +# Augmentation +drop_freq: !new:speechbrain.augment.time_domain.DropFreq + drop_freq_low: 0 # Min frequency band dropout probability + drop_freq_high: 1 # Max frequency band dropout probability + drop_freq_count_low: 1 # Min number of frequency bands to drop + drop_freq_count_high: 3 # Max number of frequency bands to drop + drop_freq_width: 0.05 # Width of frequency bands to drop + +drop_chunk: !new:speechbrain.augment.time_domain.DropChunk + drop_length_low: 1 # Min number of audio chunks to drop + drop_length_high: 5 # Max number of audio chunks to drop + drop_count_low: 1000 # Min length of audio chunks to drop + drop_count_high: 2000 # Max length of audio chunks to drop + +augmentation: !new:speechbrain.augment.augmenter.Augmenter + parallel_augment: False + concat_original: False + repeat_augment: 1 + shuffle_augmentations: False + min_augmentations: 2 + max_augmentations: 2 + augment_prob: !ref + augmentations: [!ref , !ref ] + +# Modules +codec: !new:sq_codec.SQCodec + save_path: !ref + config: config.yaml + checkpoint: ckpt_00190000.pth + +embedding: !new:model.custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + init: !ref + scalar: !ref + hidden_dim: !ref + +attention_mlp: !new:model.custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.CRDNN.CRDNN + input_shape: [null, null, !ref ] + activation: !ref + dropout: !ref + cnn_blocks: !ref + cnn_channels: !ref + cnn_kernelsize: !ref + inter_layer_pooling_size: !ref + time_pooling: True + using_2d_pooling: False + time_pooling_size: !ref + rnn_class: !ref + rnn_layers: !ref + rnn_neurons: !ref + rnn_bidirectional: !ref + dnn_blocks: !ref + dnn_neurons: !ref + rnn_re_init: True + use_rnnp: False + +head: !new:torch.nn.Linear + in_features: !ref + out_features: !ref * * + +modules: + embedding: !ref + attention_mlp: !ref + encoder: !ref + head: !ref + +model: !new:torch.nn.ModuleList + [[!ref , + !ref , + !ref , + !ref ]] + +# Loss functions +ce_loss: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.0 + allowed_len_diff: 0 + reduction: mean + +# Optimizers +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +# Schedulers +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: !ref + annealing_factor: !ref + patient: !ref + +# Dataloaders +train_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + shuffle: !apply:str.__eq__ [!ref , random] + +valid_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +test_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +# Performance metrics +ter_computer: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +bsseval_computer: !name:metrics.bsseval.BSSEval + n_sources: !ref + permutation_invariant: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/MUSDB/separation/hparams/crdnn/train_wavtokenizer.yaml b/benchmarks/DASB/MUSDB/separation/hparams/crdnn/train_wavtokenizer.yaml new file mode 100644 index 000000000..a3b15bd8d --- /dev/null +++ b/benchmarks/DASB/MUSDB/separation/hparams/crdnn/train_wavtokenizer.yaml @@ -0,0 +1,233 @@ +# ########################################################################################### +# Model: CRDNN with WavTokenizer audio representations +# Authors: Luca Della Libera 2024 +# ########################################################################################### + +experiment_name: wavtokenizer + +# Seed needs to be set at top of YAML +seed: 0 +__set_seed: !apply:torch.manual_seed [!ref ] +testing: False # If set to True, the test evlaution is done, otherwise skipped. + +# Data preparation +data_folder: !PLACEHOLDER +train_csv: !ref /train.csv +valid_csv: !ref /validation.csv +test_csv: !ref /eval.csv +splits: [train, validation, eval] +num_speakers: 4 +add_noise: False +version: wav16k/min + +# Output folders +output_folder: !ref results// +save_folder: !ref /save +cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + +# Save options +compute_metrics: True +save_audios: True + +# Preprocessing parameters +train_remove_if_longer: 1000.0 # Seconds +valid_remove_if_longer: 1000.0 # Seconds +test_remove_if_longer: 1000.0 # Seconds +sorting: random +use_cache: True + +# Training parameters +num_epochs: 40 +grad_accumulation_factor: 16 +train_batch_size: 1 +valid_batch_size: 1 +test_batch_size: 1 +dataloader_workers: 8 +nonfinite_patience: 10 +max_grad_norm: 5.0 +precision: fp32 +ckpt_interval_minutes: 6000 +keep_checkpoints: 1 +augment: False +augment_prob: 0.75 +use_pit: True + +# Optimizer parameters +lr: 0.0003578 # @orion_step1: --lr~"loguniform(0.00005,0.001)" +weight_decay: 0.01 +improvement_threshold: 0.0025 +annealing_factor: 0.9 +patient: 1 + +# EnCodec parameters +# sample_rate: [24000, 24000, 24000, 24000] +# vocab_size: [1024, 1024, 1024, 1024] +# num_codebooks: [2, 4, 8, 16, 32] +# bandwidth: [1.5, 3.0, 6.0, 12.0, 24.0] +sample_rate: 24000 +vocab_size: 4096 +num_codebooks: 1 +bandwidth: 2 + +# Embedding parameters +embedding_dim: 1024 +pretrain_embedding: False # If True, must match the codec's embedding size (128) +freeze_embedding: False + +# Encoder parameters +dropout: 0.1 +activation: !name:torch.nn.LeakyReLU +rnn_class: !name:speechbrain.nnet.RNN.LSTM +rnn_layers: 4 +time_pooling_size: 1 +rnn_bidirectional: True +rnn_neurons: 256 +dnn_blocks: 2 +dnn_neurons: 256 +cnn_blocks: 2 +cnn_channels: (16, 16) +inter_layer_pooling_size: (2, 2) +cnn_kernelsize: (3, 3) + +# Augmentation +drop_freq: !new:speechbrain.augment.time_domain.DropFreq + drop_freq_low: 0 # Min frequency band dropout probability + drop_freq_high: 1 # Max frequency band dropout probability + drop_freq_count_low: 1 # Min number of frequency bands to drop + drop_freq_count_high: 3 # Max number of frequency bands to drop + drop_freq_width: 0.05 # Width of frequency bands to drop + +drop_chunk: !new:speechbrain.augment.time_domain.DropChunk + drop_length_low: 1 # Min number of audio chunks to drop + drop_length_high: 5 # Max number of audio chunks to drop + drop_count_low: 1000 # Min length of audio chunks to drop + drop_count_high: 2000 # Max length of audio chunks to drop + +augmentation: !new:speechbrain.augment.augmenter.Augmenter + parallel_augment: False + concat_original: False + repeat_augment: 1 + shuffle_augmentations: False + min_augmentations: 2 + max_augmentations: 2 + augment_prob: !ref + augmentations: [!ref , !ref ] + +model_hub: novateur/WavTokenizer-medium-music-audio-75token +config: wavtokenizer_mediumdata_music_audio_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml +checkpoint: wavtokenizer_medium_music_audio_320_24k_v2.ckpt + +# Modules +codec: !new:speechbrain.lobes.models.discrete.wavtokenizer.WavTokenizer + source: !ref + save_path: !ref + checkpoint: !ref + config: !ref + sample_rate: !ref + freeze: True + +embedding: !new:model.custom_model.Discrete_EmbeddingLayer + num_codebooks: !ref + vocab_size: !ref + emb_dim: !ref + freeze: !ref + +attention_mlp: !new:model.custom_model.AttentionMLP + input_dim: !ref + hidden_dim: !ref + +encoder: !new:speechbrain.lobes.models.CRDNN.CRDNN + input_shape: [null, null, !ref ] + activation: !ref + dropout: !ref + cnn_blocks: !ref + cnn_channels: !ref + cnn_kernelsize: !ref + inter_layer_pooling_size: !ref + time_pooling: True + using_2d_pooling: False + time_pooling_size: !ref + rnn_class: !ref + rnn_layers: !ref + rnn_neurons: !ref + rnn_bidirectional: !ref + dnn_blocks: !ref + dnn_neurons: !ref + rnn_re_init: True + use_rnnp: False + +head: !new:torch.nn.Linear + in_features: !ref + out_features: !ref * * + +modules: + embedding: !ref + attention_mlp: !ref + encoder: !ref + head: !ref + +model: !new:torch.nn.ModuleList + [[!ref , + !ref , + !ref , + !ref ]] + +# Loss functions +ce_loss: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.0 + allowed_len_diff: 0 + reduction: mean + +# Optimizers +opt_class: !name:torch.optim.AdamW + lr: !ref + betas: (0.9, 0.98) + eps: 1.e-8 + weight_decay: !ref + +# Schedulers +scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: !ref + annealing_factor: !ref + patient: !ref + +# Dataloaders +train_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + shuffle: !apply:str.__eq__ [!ref , random] + +valid_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +test_dataloader_kwargs: + batch_size: !ref + num_workers: !ref + pin_memory: True + +# Performance metrics +ter_computer: !name:speechbrain.utils.metric_stats.MetricStats + metric: !name:speechbrain.nnet.losses.classification_error + reduction: batch + +bsseval_computer: !name:metrics.bsseval.BSSEval + n_sources: !ref + permutation_invariant: !ref + +# Counters, checkpointers, loggers, etc. +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + scheduler: !ref + counter: !ref + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref /train_log.txt diff --git a/benchmarks/DASB/MUSDB/separation/metrics/bsseval.py b/benchmarks/DASB/MUSDB/separation/metrics/bsseval.py new file mode 100644 index 000000000..3b9222313 --- /dev/null +++ b/benchmarks/DASB/MUSDB/separation/metrics/bsseval.py @@ -0,0 +1,153 @@ +import numpy as np +import torch +import json +from fast_bss_eval import bss_eval_sources +from speechbrain.utils.metric_stats import MetricStats + + +__all__ = ["BSSEval"] + + +class BSSEval(MetricStats): + def __init__( + self, n_sources, source_names=None, permutation_invariant=True + ): + """ + A subclass of MetricStats for evaluating source separation algorithms. + + Args: + n_sources (int): Number of sources to evaluate. + source_names (list, optional): Names of the sources. Defaults to None. + permutation_invariant (bool): Whether to apply permutation invariance when matching sources. + """ + self.n_sources = n_sources + self.source_names = source_names or [ + f"Source {i + 1}" for i in range(n_sources) + ] + self.permutation_invariant = permutation_invariant + + # Initialize storage for metrics + self.metrics = dict() + + def compute_metrics(self, reference_sources, estimated_sources): + """ + Computes SDR, SIR, and SAR for the given reference and estimated sources. + + Args: + reference_sources (ndarray): Array of ground truth sources (shape: [n_sources, n_samples]). + estimated_sources (ndarray): Array of estimated sources (shape: [n_sources, n_samples]). + + Returns: + dict: A dictionary containing SDR, SIR, and SAR values for each source. + """ + # Define epsilon + epsilon = 1e-10 + # Identify rows that are all zeros + is_all_zeros = torch.all(reference_sources == 0, axis=1) + + # Create a mask to add epsilon only to all-zero rows + reference_sources[is_all_zeros] += epsilon + try: + sdr, sir, sar, perm = bss_eval_sources( + reference_sources, + estimated_sources, + compute_permutation=self.permutation_invariant, + load_diag=1e-5, + ) + is_all_zeros = is_all_zeros[ + perm + ] # Apply permutation to silent mask + sdr_mean = sdr[~is_all_zeros].mean().detach().cpu().numpy().item() + sir_mean = sir[~is_all_zeros].mean().detach().cpu().numpy().item() + sar_mean = sar[~is_all_zeros].mean().detach().cpu().numpy().item() + except Exception as e: + print(f"Exception occured when computing BBSEval: {e}", flush=True) + sdr_mean, sir_mean, sar_mean = np.nan, np.nan, np.nan + return {"SDR": sdr_mean, "SIR": sir_mean, "SAR": sar_mean} + + def add( + self, + reference_sources: torch.Tensor, + estimated_sources: torch.Tensor, + tag: str = None, + ): + """ + Adds the metrics for a single evaluation instance. + + Args: + reference_sources (tensor): Array of ground truth sources (shape: [n_sources, n_samples]). + estimated_sources (tensor): Array of estimated sources (shape: [n_sources, n_samples]). + """ + # Ensure inputs are numpy arrays + reference_sources = reference_sources.squeeze() + estimated_sources = estimated_sources.squeeze() + + # Validate input shapes + assert ( + reference_sources.shape[0] == self.n_sources + ), "Mismatch in number of reference sources." + assert ( + estimated_sources.shape[0] == self.n_sources + ), "Mismatch in number of estimated sources." + + # Compute metrics + metrics = self.compute_metrics(reference_sources, estimated_sources) + + # Store metrics + for key, values in metrics.items(): + if tag is not None: + key = f"{key}/{tag}" + self.metrics.setdefault(key, []).append(values) + + def summarize(self): + """ + Summarizes the collected metrics. + + Returns: + dict: A dictionary containing mean and standard deviation for each metric. + """ + summary = {} + for metric, values in self.metrics.items(): + values = np.array(values) + values = values[~np.isinf(values)] + summary[metric] = { + "mean": np.nanmean(values, axis=0).tolist(), + "std": np.nanstd(values, axis=0).tolist(), + } + + return summary + + def pretty_print(self): + """ + Prints the summarized metrics in a human-readable format. + """ + summary = self.summarize() + print("Source Separation Evaluation Results:") + for metric, stats in summary.items(): + print(f"\n{metric}:") + for i, source_name in enumerate(self.source_names): + print( + f" {source_name}: Mean = {stats['mean'][i]:.2f}, Std = {stats['std'][i]:.2f}" + ) + + def write_stats(self, path): + results = self.summarize() + with open(path, "w") as outfile: + json.dump(results, outfile, indent=4) + + +if __name__ == "__main__": + n_sources = 2 + source_names = ["Vocals", "Accompaniment"] + stats = BSSEval( + n_sources=n_sources, + source_names=source_names, + permutation_invariant=True, + ) + + # Example ground truth and estimated sources + ref_sources = np.random.randn(n_sources, 10000) + est_sources = np.random.randn(n_sources, 10000) + + stats.add(ref_sources, est_sources) + stats.pretty_print() diff --git a/benchmarks/DASB/MUSDB/separation/musdb_prepare.py b/benchmarks/DASB/MUSDB/separation/musdb_prepare.py new file mode 100644 index 000000000..485160839 --- /dev/null +++ b/benchmarks/DASB/MUSDB/separation/musdb_prepare.py @@ -0,0 +1,165 @@ +import csv +import logging +import os +from typing import Optional, Sequence + +from tqdm import tqdm + +import speechbrain as sb + + +__all__ = ["prepare_musdb"] + +SOURCE_NAMES = [ + "bass.wav", + "drums.wav", + "other.wav", + "vocals.wav", +] + +# Workaround to use fastest backend (SoundFile) +try: + import torchaudio + + torchaudio._backend.utils.get_available_backends().pop("ffmpeg", None) +except Exception: + pass + +# Logging configuration +logging.basicConfig( + level=logging.INFO, # format="%(asctime)s [%(levelname)s] %(funcName)s - %(message)s", +) + +_LOGGER = logging.getLogger(__name__) + + +def prepare_musdb( + data_folder: "str", + save_folder: "Optional[str]" = None, + splits: "Sequence[str]" = ("train", "eval", "validation"), +) -> "None": + """Prepare data manifest CSV files for the MUSDB dataset + + Arguments + --------- + data_folder: + The path to the dataset folder. + save_folder: + The path to the folder where the data manifest CSV files will be stored. + Default to `data_folder`. + splits: + The dataset splits to prepare. + num_sources: + The number of speakers (1, 2 or 3). + + Raises + ------ + ValueError + If an invalid argument value is given. + RuntimeError + If one of the expected split folders is missing. + + Examples + -------- + >>> # Expected folder structure: MUSDB/{train, test}//{mixture.wav, bass.wav, other.wav, drums.wav, vocals.wa} + >>> prepare_musdb("MUSDB", num_sources=4) + + """ + if not save_folder: + save_folder = data_folder + + train_data = [] + test_data = [] + valid_data = [] + + # Iterate over train and test splits + for split in splits: + split_dir = os.path.join(data_folder, split) + + # Check if the split directory exists + if not os.path.exists(split_dir): + print(f"Warning: {split_dir} does not exist. Skipping.") + continue + + # Walk through the subdirectories of the split (tracks) + tracks = os.listdir(split_dir) + for i, track_id in enumerate(tqdm(tracks, desc=split)): + track_dir = os.path.join(split_dir, track_id) + # Ensure the track directory exists and contains the required files + required_files = [ + "mixture.wav", + "bass.wav", + "drums.wav", + "other.wav", + "vocals.wav", + ] + file_paths = {} + + for file_name in required_files: + file_path = os.path.join(track_dir, file_name) + if os.path.exists(file_path): + file_paths[file_name] = file_path + else: + print( + f"Warning: {file_name} missing in {track_dir}. Skipping track." + ) + file_paths = None + break # If any file is missing, skip the current track + + # If all required files are found, process the track + if file_paths: + # Get the duration of the 'mixture.wav' file + mixture_wav_path = file_paths["mixture.wav"] + info = sb.dataio.dataio.read_audio_info(mixture_wav_path) + duration = info.num_frames / info.sample_rate + + # Prepare the row for the CSV + row = [ + split, + track_id, # ID + duration, # duration + file_paths["mixture.wav"], # mixture_wav + file_paths["bass.wav"], + file_paths["drums.wav"], + file_paths["other.wav"], + file_paths["vocals.wav"], + ] + + # Add the row to the appropriate data list + if split == "train": + train_data.append(row) + elif split == "eval": + test_data.append(row) + elif split == "validation": + valid_data.append(row) + + # Define the CSV file headers + headers = [ + "split", + "ID", + "duration", + "mixture_wav", + "bass_wav", + "drums_wav", + "other_wav", + "vocals_wav", + ] + + # Write the CSV files for each split + for data, split in [ + (train_data, "train"), + (test_data, "eval"), + (valid_data, "validation"), + ]: + output_csv = os.path.join(save_folder, f"{split}.csv") + + with open(output_csv, mode="w", newline="") as file: + writer = csv.writer(file) + writer.writerow(headers) + writer.writerows(data) + print(f"CSV file created for {split}: {output_csv}") + + _LOGGER.info( + "----------------------------------------------------------------------", + ) + _LOGGER.info("Done!") diff --git a/benchmarks/DASB/MUSDB/separation/train.py b/benchmarks/DASB/MUSDB/separation/train.py new file mode 100644 index 000000000..746491afb --- /dev/null +++ b/benchmarks/DASB/MUSDB/separation/train.py @@ -0,0 +1,453 @@ +#!/usr/bin/env/python + +"""Recipe for training a transformer-based speech separation system using EnCodec audio representations. + +To run this recipe: +> python train_encodec.py hparams/.yaml + +Authors + * Luca Della Libera 2024 +""" + +import os +import sys +import warnings +import logging + +import speechbrain as sb +import torch +from hyperpyyaml import load_hyperpyyaml +from speechbrain.dataio.dataio import write_audio +from speechbrain.utils.distributed import if_main_process, run_on_main + +from utils import ( + EncodecHelper, + DacHelper, + SQCodecHelper, + WavTokenizerHelper, +) + + +base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) +sys.path.append(base_dir) +base_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../../model") +) +sys.path.append(base_dir) + + +logger = logging.getLogger(__name__) + + +_CACHE = {} + + +class Separation(sb.Brain): + def __init__( + self, + modules=None, + opt_class=None, + hparams=None, + run_opts=None, + checkpointer=None, + ): + super().__init__(modules, opt_class, hparams, run_opts, checkpointer) + + # Read tokenizer type from hparams + tokenizer_type = self.hparams.codec.__class__.__name__ + self.encdec = self._get_encdec_helper(tokenizer_type) + + def _get_encdec_helper(self, tokenizer_type): + if tokenizer_type == "Encodec": + return EncodecHelper(self.hparams.codec, self.device) + elif tokenizer_type == "DAC": + return DacHelper( + self.hparams.codec, self.device, self.hparams.num_codebooks + ) + elif tokenizer_type == "SQCodec": + return SQCodecHelper(self.hparams.codec, self.device) + elif tokenizer_type == "WavTokenizer": + return WavTokenizerHelper(self.hparams.codec, self.device) + else: + raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}") + + @torch.no_grad() + def sig_to_toks(self, sig, lens): + return self.encdec.sig_to_toks(sig, lens) + + @torch.no_grad() + def toks_to_sig(self, toks): + return self.encdec.toks_to_sig(toks) + + def compute_forward(self, batch, stage): + """Forward pass.""" + batch = batch.to(self.device) + in_sig, in_lens = batch.in_sig # [B, T] + out_sig, out_lens = batch.out_sig # [B, ST] + + # Unflatten + out_sig = out_sig.reshape( + len(out_sig), self.hparams.num_speakers, -1 + ).flatten( + end_dim=-2 + ) # [BS, T] + batch.out_sig = out_sig, out_lens + + # Augment if specified + if stage == sb.Stage.TRAIN and self.hparams.augment: + in_sig, in_lens = self.hparams.augmentation(in_sig, in_lens) + + # Extract tokens (cache them at first epoch if augmentation is disabled) + key = tuple(sorted(batch.id)) + try: + in_toks, out_toks = _CACHE[key] + in_toks = in_toks.to(self.device) + out_toks = out_toks.to(self.device) + except KeyError: + assert (in_lens == out_lens).all() + sig = torch.cat([in_sig, out_sig]) # [B(1 + S), T] + lens = torch.cat( + [ + in_lens, + out_lens.repeat_interleave(self.hparams.num_speakers), + ] + ) # [B(1 + S), T] + toks = self.sig_to_toks(sig, lens) # [B(1 + S), N, K] + in_toks, out_toks = toks.split( + [len(in_sig), len(out_sig)] + ) # [B, N, K], [BS, N, K] + out_toks = out_toks.reshape( + len(in_sig), + self.hparams.num_speakers, + -1, + self.hparams.num_codebooks, + ).movedim( + -2, -3 + ) # [B, N, S, K] + if self.hparams.use_cache and (not self.hparams.augment): + _CACHE[key] = in_toks.cpu(), out_toks.cpu() + + # Avoid in-place modification from embedding layer + in_toks = in_toks.clone() + + # Forward embedding + attention + in_embs = self.modules.embedding(in_toks) # [B, N, K, H] + # Get merged embedding based on strategy set, deafualt ATT_Pooling + if ( + hasattr(self.hparams, "embedding_strg") + and self.hparams.embedding_strg == "concat" + ): + B, T, N_Q, D = in_embs.shape + in_embs = in_embs.view(B, T, N_Q * D) + + else: + att_w = self.modules.attention_mlp(in_embs) # [B, N, K, 1] + in_embs = torch.matmul(att_w.transpose(2, -1), in_embs).squeeze( + -2 + ) # [B, N, H] + + # Forward encoder + if hasattr(self.modules.encoder, "encode"): + hyp_embs = self.modules.encoder.encode( + in_embs, in_lens + ) # [B, N, H] + else: + hyp_embs = self.modules.encoder(in_embs) # [B, N, H] + + # Forward head + log_probs = ( + self.modules.head(hyp_embs) + .reshape( + len(hyp_embs), + -1, + self.hparams.num_speakers, + self.hparams.num_codebooks, + self.hparams.vocab_size, + ) + .log_softmax(dim=-1) + ) # [B, N, S, K, C] + return log_probs, out_toks + + def compute_objectives(self, predictions, batch, stage): + """Computes the objectives.""" + log_probs, out_toks = predictions # [B, N, S, K, C], [B, N, S, K] + + IDs = batch.id + in_sig, _ = batch.in_sig + out_sig, out_lens = batch.out_sig + + if not self.hparams.use_pit: + # Cross-entropy loss + loss = self.hparams.ce_loss( + log_probs.flatten(start_dim=1, end_dim=3), # [B, NSK, C] + out_toks.flatten(start_dim=1), # [B, NSK] + length=out_lens, + ) + else: + # Permutation invariant training + from speechbrain.nnet.losses import PitWrapper + + def base_loss(preds, targets): + # preds: [N, K, C, S, S] + # targets: [N, K, S, S] + preds = preds.permute(3, 4, 0, 1, 2) # [S, S, N, K, C] + targets = targets.permute(2, 3, 0, 1) # [S, S, N, K] + loss = self.hparams.ce_loss( + preds.flatten(end_dim=-2), + targets.flatten(), + reduction="none", + ) # [SSNK] + loss = loss.reshape_as(targets) + loss = loss.permute(2, 3, 0, 1) # [N, K, S, S] + return loss + + log_probs = log_probs.movedim(2, -1) # [B, N, K, C, S] + out_toks = out_toks.movedim(2, -1) # [B, N, K, S] + pit_loss = PitWrapper(base_loss) + log_probs_list = [ + x[: int(l * log_probs.shape[1])] + for x, l in zip(log_probs, out_lens) + ] + out_toks_list = [ + x[: int(l * out_toks.shape[1])] + for x, l in zip(out_toks, out_lens) + ] + loss, perm = pit_loss(log_probs_list, out_toks_list) + loss = loss.mean() + log_probs = pit_loss.reorder_tensor(log_probs, perm) + log_probs = log_probs.movedim(-1, 2) # [B, N, S, K, C] + out_toks = out_toks.movedim(-1, 2) # [B, N, S, K] + + # Compute TER + if stage != sb.Stage.TRAIN: + self.ter_metric.append( + IDs, + log_probs.flatten(start_dim=1, end_dim=3), + out_toks.flatten(start_dim=1), + out_lens, + ) + + # Vocode + if stage in [sb.Stage.TEST] and self.hparams.compute_metrics: + hyp_toks = log_probs.argmax(dim=-1) # [B, N, S, K] + hyp_sig, rec_sig, out_sig = self.vocode( + IDs, in_sig, out_sig, hyp_toks, out_toks, out_lens + ) + self.bsseval_metric.add(out_sig, hyp_sig, tag="clean-hyp") + self.bsseval_metric.add(out_sig, rec_sig, tag="clean-rec") + self.bsseval_metric.add(rec_sig, hyp_sig, tag="rec-hyp") + self.bsseval_metric.add( + out_sig, + in_sig.unsqueeze(1).repeat(1, self.hparams.num_speakers, 1), + tag="clean-mix", + ) + + return loss + + @torch.no_grad() + def vocode(self, IDs, in_sig, out_sig, hyp_toks, out_toks, lens): + hyp_toks = hyp_toks.movedim(-2, -3).contiguous() # [B, S, N, K] + out_toks = out_toks.movedim(-2, -3).contiguous() # [B, S, N, K] + + hyp_sig = self.toks_to_sig( + hyp_toks.flatten(end_dim=1) # [BS, N, K] + ) # [BS, T] + rec_sig = self.toks_to_sig( + out_toks.flatten(end_dim=1) # [BS, N, K] + ) # [BS, T] + # Adjust length + if out_sig.shape[-1] > hyp_sig.shape[-1]: + pad = [0, out_sig.shape[-1] - hyp_sig.shape[-1]] + hyp_sig = torch.nn.functional.pad( + hyp_sig, pad, mode="replicate" + ) # [BS, T_out] + rec_sig = torch.nn.functional.pad( + rec_sig, pad, mode="replicate" + ) # [BS, T_out] + elif out_sig.shape[-1] < hyp_sig.shape[-1]: + hyp_sig = hyp_sig.narrow(-1, 0, out_sig.shape[-1]) # [BS, T_out] + rec_sig = rec_sig.narrow(-1, 0, out_sig.shape[-1]) # [BS, T_out] + + hyp_sig = hyp_sig.reshape(len(hyp_toks), -1) # [B, ST_out] + rec_sig = rec_sig.reshape(len(hyp_toks), -1) # [B, ST_out] + out_sig = out_sig.reshape(len(hyp_toks), -1) # [B, ST_out] + + if self.hparams.save_audios: + save_folder = os.path.join(self.hparams.output_folder, "audios") + os.makedirs(save_folder, exist_ok=True) + for i in range(len(IDs)): + write_audio( + os.path.join(save_folder, f"{IDs[i]}_hyp.wav"), + hyp_sig[i].cpu(), + self.hparams.sample_rate, + ) + write_audio( + os.path.join(save_folder, f"{IDs[i]}_rec.wav"), + rec_sig[i].cpu(), + self.hparams.sample_rate, + ) + write_audio( + os.path.join(save_folder, f"{IDs[i]}_ref.wav"), + out_sig[i].cpu(), + self.hparams.sample_rate, + ) + write_audio( + os.path.join(save_folder, f"{IDs[i]}_in.wav"), + in_sig[i].cpu(), + self.hparams.sample_rate, + ) + hyp_sig = hyp_sig.reshape( + len(IDs), self.hparams.num_speakers, -1 + ) # [B, S, T_out] + rec_sig = rec_sig.reshape( + len(IDs), self.hparams.num_speakers, -1 + ) # [B, S, T_out] + out_sig = out_sig.reshape( + len(IDs), self.hparams.num_speakers, -1 + ) # [B, S, T_out] + return hyp_sig, rec_sig, out_sig + + def on_stage_start(self, stage, epoch=None): + """Gets called at the beginning of each epoch.""" + super().on_stage_start(stage, epoch) + if ( + stage in [sb.Stage.TEST, sb.Stage.VALID] + and self.hparams.compute_metrics + ): + self.bsseval_metric = self.hparams.bsseval_computer() + self.ter_metric = self.hparams.ter_computer() + + def on_stage_end(self, stage, stage_loss, epoch=None): + """Gets called at the end of each epoch.""" + # Compute/store important stats + stage_stats = {"loss": stage_loss} + + if stage == sb.Stage.TRAIN: + self.train_stats = stage_stats + # self.checkpointer.save_and_keep_only() + else: + stage_stats["TER"] = self.ter_metric.summarize("average") * 100 + + # Perform end-of-iteration operations, like annealing, logging, etc. + if stage == sb.Stage.VALID: + _, lr = self.hparams.scheduler(stage_stats["TER"]) + sb.nnet.schedulers.update_learning_rate(self.optimizer, lr) + steps = self.optimizer_step + self.hparams.train_logger.log_stats( + stats_meta={"epoch": epoch, "lr": lr, "steps": steps}, + train_stats=self.train_stats, + valid_stats=stage_stats, + ) + self.checkpointer.save_and_keep_only( + meta={"TER": stage_stats["TER"], "epoch": epoch}, + min_keys=["TER"], + num_to_keep=self.hparams.keep_checkpoints, + keep_recent=False, + ) + + elif stage == sb.Stage.TEST: + if self.hparams.compute_metrics: + stage_stats["BSSEval"] = self.bsseval_metric.summarize() + self.hparams.train_logger.log_stats( + stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, + test_stats=stage_stats, + ) + if if_main_process(): + # Save dWER + if self.hparams.compute_metrics: + self.bsseval_metric.write_stats(self.hparams.bsseval_file) + + +if __name__ == "__main__": + # Command-line interface + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # Filter warnings + warnings.filterwarnings("once") + warnings.filterwarnings("ignore", module="torch") + + # If --distributed_launch then create ddp_init_group with the right communication protocol + sb.utils.distributed.ddp_init_group(run_opts) + + # Create experiment directory + sb.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + # Dataset preparation + from musdb_prepare import prepare_musdb as prepare_data + + prepare_data_kwargs = { + "data_folder": hparams["data_folder"], + "save_folder": hparams["save_folder"], + "splits": hparams["splits"], + } + + run_on_main(prepare_data, kwargs=prepare_data_kwargs) + + # Create the datasets objects + from utils import dataio_prepare + + train_data, valid_data, test_data = dataio_prepare( + debug=run_opts.get("debug", False), **hparams + ) + + # Pretrain the specified modules + if "pretrainer" in hparams: + run_on_main(hparams["pretrainer"].collect_files) + run_on_main(hparams["pretrainer"].load_collected) + + # Use pretrained embeddings + if hparams["pretrain_embedding"]: + embs = hparams["codec"].vocabulary.reshape(-1, hparams["embedding_dim"]) + hparams["embedding"].embedding.weight.data.copy_(embs) + + # Log number of parameters/buffers + codec_params = sum( + [x.numel() for x in hparams["codec"].state_dict().values()] + ) + model_params = sum( + [ + x.numel() + for module in hparams["modules"].values() + for x in module.state_dict().values() + ] + ) + hparams["train_logger"].log_stats( + stats_meta={ + f"Codec parameters/buffers (M)": f"{codec_params / 1e6:.2f}", + "Model parameters/buffers (M)": f"{model_params / 1e6:.2f}", + }, + ) + + # Trainer initialization + brain = Separation( + modules=hparams["modules"], + opt_class=hparams["opt_class"], + hparams=hparams, + run_opts=run_opts, + checkpointer=hparams["checkpointer"], + ) + + # Train + brain.fit( + brain.hparams.epoch_counter, + train_data, + valid_data, + train_loader_kwargs=hparams["train_dataloader_kwargs"], + valid_loader_kwargs=hparams["valid_dataloader_kwargs"], + ) + + # Test + if hparams["testing"]: + # Testing + brain.hparams.bsseval_file = os.path.join( + hparams["output_folder"], "bsseval.txt" + ) + brain.evaluate( + test_data, test_loader_kwargs=hparams["test_dataloader_kwargs"], + ) diff --git a/benchmarks/DASB/MUSDB/separation/utils.py b/benchmarks/DASB/MUSDB/separation/utils.py new file mode 100644 index 000000000..e74ae22cc --- /dev/null +++ b/benchmarks/DASB/MUSDB/separation/utils.py @@ -0,0 +1,338 @@ +"""Common utilities. + +Authors + * Luca Della Libera 2024 +""" + +import os + +import speechbrain as sb +import torch +import torchaudio +from speechbrain.dataio.dataio import merge_csvs +from transformers.models.hubert.modeling_hubert import ( + HubertEncoderStableLayerNorm, +) +from transformers.models.wav2vec2.modeling_wav2vec2 import ( + Wav2Vec2EncoderStableLayerNorm, +) +from transformers.models.wavlm.modeling_wavlm import WavLMEncoderStableLayerNorm + + +__all__ = ["SBWav2Vec2ForwardWrapper", "dataio_prepare"] + +CHUNK = 5.0 + + +class EncodecHelper: + def __init__(self, codec, device): + self.codec = codec + self.device = device + + @torch.no_grad() + def sig_to_toks(self, sig, lens): + self.codec.to(self.device).eval() + toks, _ = self.codec.encode(sig.unsqueeze(1), lens) # [B, N, K] + return toks + + @torch.no_grad() + def toks_to_sig(self, toks): + self.codec.to(self.device).eval() + sig = self.codec.decode(toks)[:, 0] # [B, T] + return sig + + +class DacHelper: + def __init__(self, codec, device, num_codebooks): + self.codec = codec + self.device = device + self.num_codebooks = num_codebooks + + @torch.no_grad() + def sig_to_toks(self, sig, lens): + self.codec.to(self.device).eval() + toks, _ = self.codec( + sig[:, None], n_quantizers=self.num_codebooks + ) # [B, K, N] + toks = toks.movedim(-1, -2) # [B, N, K] + return toks + + @torch.no_grad() + def toks_to_sig(self, toks): + self.codec.to(self.device).eval() + qfeats, _, _ = self.codec.quantizer.from_codes( + toks.movedim(-1, -2) + ) # [B, K, N] -> [B, K, N] + sig = self.codec.decode(qfeats)[:, 0] # [B, T] + return sig + + +class SQCodecHelper: + def __init__(self, codec, device): + self.codec = codec + self.device = device + + @torch.no_grad() + def sig_to_toks(self, sig, lens): + # sig: [B, T] + self.codec.to(self.device).eval() + toks, _ = self.codec.encode(sig[:, None]) # [B, K * N] + K = self.codec.n_codebook + N = toks.shape[-1] // K + toks = self._unflatten_codebooks(toks, N, K) # [B, N, K] + return toks + + def _flatten_codebooks(self, arr): + assert ( + len(arr.shape) == 3 + ), "Input array must have 3 dimensions [B, N, K]" + N, B, K = arr.shape + arr = arr.clone() + flattened_arr = arr.permute(1, 2, 0).reshape(B, N * K) + return flattened_arr + + def _unflatten_codebooks(self, flat_arr, N, K): + # flat_arr: [B, N * K] + B = flat_arr.shape[0] + return flat_arr.reshape(B, N, K) + + @torch.no_grad() + def toks_to_sig(self, toks): + toks = toks.permute(2, 0, 1) # [B, N, K] -> [K, B, N] + flat_toks = self._flatten_codebooks(toks).to(torch.int32) + sig = self.codec.decode(flat_toks).squeeze(1) # [B, T] + return sig.to(toks.device) + + +class WavTokenizerHelper: + def __init__(self, codec, device): + self.codec = codec + self.device = device + + @torch.no_grad() + def sig_to_toks(self, sig, lens): + self.codec.to(self.device).eval() + toks, _ = self.codec.encode(sig) # [B, K, N] + toks = toks.permute(0, 2, 1) # [B, N, K] + return toks + + @torch.no_grad() + def toks_to_sig(self, toks): + self.codec.to(self.device).eval() + toks = toks.movedim(-1, -2) # [B, N, K] -> [B, K, N] + sig = self.codec.decode(toks) # [B, T] + return sig.clone() + + +class SBWav2Vec2ForwardWrapper(torch.nn.Module): + """SpeechBrain wav2vec 2.0 wrapper that returns the hidden representations from the specified layer IDs. + + Arguments + --------- + wav2vec2: + The SpeechBrain wav2vec 2.0 module. + layer_ids: + The layer IDs from which the hidden representations are extracted. + + Examples + -------- + >>> import torch + >>> from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE + >>> from speechbrain.lobes.models.huggingface_transformers.wavlm import WavLM + >>> + >>> encoder = WavLM(source="microsoft/wavlm-large", save_path=HUGGINGFACE_HUB_CACHE) + >>> encoder = SBWav2Vec2ForwardWrapper(encoder, layer_ids=[6, 7]) + >>> + >>> input = torch.rand([10, 16000]) + >>> length = torch.ones(10) + >>> output = encoder(input, length) + + """ + + def __init__(self, wav2vec2, layer_ids): + super().__init__() + self.wav2vec2 = wav2vec2 + # Workaround to deal with hardcoded class name in discrete SSL + # https://github.com/speechbrain/speechbrain/blob/60062c2536e8122253d6ad0e681208f554528950/speechbrain/lobes/models/huggingface_transformers/discrete_ssl.py#L88 + self.__class__.__name__ = self.wav2vec2.__class__.__name__ + self.layer_ids = sorted(layer_ids) + assert hasattr(self.wav2vec2, "model") + assert hasattr(self.wav2vec2.model, "encoder") + assert hasattr(self.wav2vec2.model.encoder, "layers") + # Workaround for early exiting to avoid the computational overhead of forwarding through the whole model + # NOTE: the model is modified in-place + self.wav2vec2.output_all_hiddens = True + self.wav2vec2.model.encoder.layers = self.wav2vec2.model.encoder.layers[ + : max(self.layer_ids) + ] + # NOTE: workaround to account for layer norm applied to the last hidden states when StableLayerNorm variant is used: + # https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/wavlm/modeling_wavlm.py#L816 + if isinstance( + self.wav2vec2.model.encoder, + ( + HubertEncoderStableLayerNorm, + Wav2Vec2EncoderStableLayerNorm, + WavLMEncoderStableLayerNorm, + ), + ): + self.wav2vec2.model.encoder.layer_norm = torch.nn.Identity() + + def extract_features(self, wav, length=None): + feats = self.wav2vec2(wav, length) # (K, B, N, H) + return feats + + def forward(self, wav, length=None): + return self.extract_features(wav, length) + + +def dataio_prepare( + data_folder, + train_csv, + valid_csv, + test_csv, + sample_rate=16000, + train_remove_if_longer=60.0, + valid_remove_if_longer=60.0, + test_remove_if_longer=60.0, + sorting="ascending", + debug=False, + **hparams, +): + """This function prepares the datasets to be used in the brain class. + It also defines the data processing pipeline through user-defined functions. + + """ + if isinstance(train_csv, (list, tuple)): + csvs = [os.path.basename(x) for x in train_csv] + save_folder = os.path.dirname(train_csv[0]) + merge_csvs( + save_folder, csvs, "train.csv", + ) + train_csv = os.path.join(save_folder, "train.csv") + + if isinstance(valid_csv, (list, tuple)): + csvs = [os.path.basename(x) for x in valid_csv] + save_folder = os.path.dirname(valid_csv[0]) + merge_csvs( + save_folder, csvs, "valid.csv", + ) + valid_csv = os.path.join(save_folder, "valid.csv") + + if isinstance(test_csv, (list, tuple)): + csvs = [os.path.basename(x) for x in test_csv] + save_folder = os.path.dirname(test_csv[0]) + merge_csvs( + save_folder, csvs, "test.csv", + ) + test_csv = os.path.join(save_folder, "test.csv") + + train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=train_csv, replacements={"DATA_ROOT": data_folder}, + ) + # Sort training data to speed up training + train_data = train_data.filtered_sorted( + sort_key="duration", + reverse=sorting == "descending", + key_max_value={"duration": train_remove_if_longer}, + ) + + valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=valid_csv, replacements={"DATA_ROOT": data_folder}, + ) + # Sort validation data to speed up validation + valid_data = valid_data.filtered_sorted( + sort_key="duration", + reverse=not debug, + key_max_value={"duration": valid_remove_if_longer}, + ) + + test_data = sb.dataio.dataset.DynamicItemDataset.from_csv( + csv_path=test_csv, replacements={"DATA_ROOT": data_folder}, + ) + # Sort the test data to speed up testing + test_data = test_data.filtered_sorted( + sort_key="duration", + reverse=not debug, + key_max_value={"duration": test_remove_if_longer}, + ) + + datasets = [train_data, valid_data, test_data] + + # Define audio pipeline + takes = ["mixture_wav", "bass_wav", "drums_wav", "other_wav", "vocals_wav"] + provides = ["in_sig", "out_sig"] + + def audio_pipeline(mix_wav, *src_wavs): + # Mixed signal + try: + original_sample_rate = sb.dataio.dataio.read_audio_info( + mix_wav + ).sample_rate + + # Source signals + src_sigs = [] + for src_wav in src_wavs: + assert ( + original_sample_rate + == sb.dataio.dataio.read_audio_info(src_wav).sample_rate + ) + src_sig = sb.dataio.dataio.read_audio( + dict(file=src_wav) + ) # ,start=start, stop=start + int(CHUNK * original_sample_rate))) + src_sigs.append(src_sig) + src_sigs = torch.stack(src_sigs) # [S, T] + + out_sig = torchaudio.functional.resample( + src_sigs, original_sample_rate, sample_rate, + ) + in_sig = out_sig.sum(0) # [T] + except Exception as e: + print(e) + yield in_sig + + # Flatten as SpeechBrain's dataloader does not support multichannel audio + out_sig = out_sig.flatten() # [S * T] + yield out_sig + + sb.dataio.dataset.add_dynamic_item( + [train_data, valid_data, test_data], audio_pipeline, takes, provides + ) + + # Set output + sb.dataio.dataset.set_output_keys(datasets, ["id"] + provides) + + return train_data, valid_data, test_data + + +if __name__ == "__main__": + from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE + from speechbrain.lobes.models.huggingface_transformers.wav2vec2 import ( + Wav2Vec2, + ) + + for source in [ + "facebook/wav2vec2-large-960h-lv60-self", + "facebook/hubert-large-ll60k", + "microsoft/wavlm-large", + ]: + layer_ids = [3, 7] + encoder1 = Wav2Vec2( + source=source, save_path=HUGGINGFACE_HUB_CACHE, output_norm=True, + ) + encoder1 = SBWav2Vec2ForwardWrapper( + encoder1, layer_ids=layer_ids + ).eval() + + encoder2 = Wav2Vec2( + source=source, + save_path=HUGGINGFACE_HUB_CACHE, + output_norm=True, + output_all_hiddens=True, + ).eval() + + input = torch.ones([1, 16000]) + with torch.no_grad(): + output1 = encoder1(input)[layer_ids] + output2 = encoder2(input)[layer_ids] + + print((output1 == output2).all()) diff --git a/benchmarks/DASB/extra_requirements.txt b/benchmarks/DASB/extra_requirements.txt index dffb3cd07..09e72e947 100644 --- a/benchmarks/DASB/extra_requirements.txt +++ b/benchmarks/DASB/extra_requirements.txt @@ -1,4 +1,5 @@ beartype +fast_bss_eval==0.1.4 jsonlines kaldiio librosa>=0.9.2 @@ -7,6 +8,7 @@ onnxruntime>=1.16.3 orion orion[profet] scikit-learn +soundfile==0.12.1 speechbrain>=1.0.0 speechtokenizer>=0.1.2 tensorboard diff --git a/benchmarks/DASB/model/custom_model.py b/benchmarks/DASB/model/custom_model.py index 972d35c66..4c67ecbeb 100644 --- a/benchmarks/DASB/model/custom_model.py +++ b/benchmarks/DASB/model/custom_model.py @@ -1,4 +1,12 @@ +import os +import sys import torch +from sq_codec import decimal_to_ternary_matrix + +base_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../../../model") +) +sys.path.append(base_dir) class AttentionMLP(torch.nn.Module): @@ -60,19 +68,23 @@ def __init__( init=False, freeze=False, hidden_dim=None, + scalar=False, ): super(Discrete_EmbeddingLayer, self).__init__() + self.scalar = scalar self.vocab_size = vocab_size + self.emb_dim = emb_dim self.num_codebooks = ( len(num_codebooks) if isinstance(num_codebooks, list) else num_codebooks ) self.freeze = freeze - self.embedding = torch.nn.Embedding( - self.num_codebooks * vocab_size, emb_dim - ).requires_grad_(not self.freeze) - self.init = init + if not self.scalar: + self.embedding = torch.nn.Embedding( + self.num_codebooks * vocab_size, emb_dim + ).requires_grad_(not self.freeze) + self.init = init # Add a linear layer to match dimensions if necessary if hidden_dim is not None and hidden_dim != emb_dim: @@ -96,16 +108,35 @@ def forward(self, in_tokens): ------- in_embs : torch.Tensor """ - with torch.set_grad_enabled(not self.freeze): - # Add unique token IDs across diffrent codebooks by adding num_codebooks * vocab_size - in_tokens += torch.arange( - 0, - self.num_codebooks * self.vocab_size, - self.vocab_size, - device=in_tokens.device, - ) - # Forward Pass to embedding and - in_embs = self.embedding(in_tokens) - if self.proj_layer is not None: - in_embs = self.proj_layer(in_embs) - return in_embs + + if self.scalar: + with torch.no_grad(): + in_tokens = in_tokens.permute(2, 0, 1) + in_embs = [] + for i in range(self.num_codebooks): + tmp_list = ( + decimal_to_ternary_matrix( + in_tokens[i, :, :], D=self.emb_dim + ) + - 1 + ) + in_embs.append(tmp_list) + in_embs = ( + torch.stack(in_embs, dim=0).float().to(in_tokens.device) + ) # Shape: (num_codebooks, B, D, T) + # Permute to match (B, T, num_codebook, D) + in_embs = in_embs.permute(1, 3, 0, 2) # Shape: (3, 150, 4, 9) + else: + with torch.set_grad_enabled(not self.freeze): + # Add unique token IDs across diffrent codebooks by adding num_codebooks * vocab_size + in_tokens += torch.arange( + 0, + self.num_codebooks * self.vocab_size, + self.vocab_size, + device=in_tokens.device, + ) + # Forward Pass to embedding and + in_embs = self.embedding(in_tokens) + if self.proj_layer is not None: + in_embs = self.proj_layer(in_embs) + return in_embs diff --git a/benchmarks/DASB/model/sq_codec.py b/benchmarks/DASB/model/sq_codec.py index 0e1ffe3f8..ef1283b05 100644 --- a/benchmarks/DASB/model/sq_codec.py +++ b/benchmarks/DASB/model/sq_codec.py @@ -124,7 +124,9 @@ def build_codec_model(self, config): exp_model_config = OmegaConf.load(config) scalar_codec = ScalarModel(**exp_model_config.generator.config) device = next(iter(scalar_codec.parameters())).device - parameter_dict = torch.load(self.ckpt_path, map_location=device) + parameter_dict = torch.load( + self.ckpt_path, map_location=device, weights_only=False + ) scalar_codec.load_state_dict(parameter_dict["codec_model"]) return scalar_codec