diff --git a/README.md b/README.md index 4241ae8..a8ad814 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,74 @@ This should produce a `sample.wav` file in your project root directory. _For repeated sampling we highly recommend using the gradio interface instead, as the minimal example needs to load the model every time it is run._ + +## Fine-Tuning / Training + +We provide a script, training.py, to demonstrate how you can fine-tune or adapt Zonos on your own data, or on a public dataset like Mozilla Common Voice. + + +### Requirements + +- A GPU with sufficient VRAM (6GB+ recommended). CPU training is possible but extremely slow. +- PyTorch and torchaudio +- The Hugging Face datasets library +- Enough disk space to download your chosen dataset (e.g., Common Voice can be quite large depending on language). +- If you plan to train the “hybrid” version, you must install CUDA-specific requirements (see Installation). + +### Usage + +#### 1. Clone or download the Zonos repository: +```bash +git clone https://github.com/Zyphra/Zonos.git +cd Zonos +``` +#### 2. Install dependencies (e.g., in a virtual environment): +```bash +uv sync +uv sync --extra compile # optional but needed to run the hybrid +uv pip install -e . +``` +#### 3. Edit training parameters if needed. By default, training.py uses: + +- **Model:** Zyphra/Zonos-v0.1-transformer +- **Dataset:** mozilla-foundation/common_voice_17_0 +- **Language:** uz +- `num_epochs=10`, `batch_size=8`, `learning_rate=1e-4` +- **Output directory:** `checkpoints` + +#### 4. Run the training script: +```bash +uv run training.py +``` + +#### The script will: + +- Load the specified dataset and the Zonos model +- Resample audio to 16kHz (adjust in code if necessary) +- Compute speaker embeddings for each sample +- Prepare text/language conditioning +- Forward through Zonos to predict codes +- Calculate cross-entropy loss +- Save periodic checkpoints to the checkpoints folder + +#### 5. Customizing training: +- **Change model:** Pass --model_path to specify another pretrained checkpoint (e.g., the hybrid). +- **Change dataset:** Pass --dataset_name with a Hugging Face dataset or your custom dataset. +- **Modify hyperparams:** e.g., `--learning_rate 1e-5`, `--num_epochs 5`, etc. +- **Output directory:** Use `--output_dir` to choose a different folder for checkpoints. + +#### Example command: +```bash +uv run training.py + --model_path Zyphra/Zonos-v0.1-hybrid + --dataset_name mozilla-foundation/common_voice_17_0 + --language uz + --output_dir checkpoints + --num_epochs 10 + --batch_size 8 + --learning_rate 1e-4 +``` + ## Features - Zero-shot TTS with voice cloning: Input desired text and a 10-30s speaker sample to generate high quality TTS output diff --git a/emotion_labels.json b/emotion_labels.json new file mode 100644 index 0000000..557f49d --- /dev/null +++ b/emotion_labels.json @@ -0,0 +1,13 @@ +{ + "0": [0.3077, 0.0256, 0.0256, 0.0256, 0.0256, 0.0256, 0.2564, 0.3077], + "1": [0.1000, 0.5000, 0.0500, 0.0500, 0.0500, 0.0500, 0.1000, 0.1000], + "2": [0.4000, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.1500, 0.2000], + "3": [0.1000, 0.0500, 0.5000, 0.0500, 0.0500, 0.0500, 0.1000, 0.1000], + "4": [0.1000, 0.4000, 0.1000, 0.1000, 0.0500, 0.0500, 0.1000, 0.1000], + "5": [0.3500, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.1500, 0.2500], + "6": [0.1000, 0.0500, 0.0500, 0.4000, 0.1000, 0.1000, 0.1000, 0.1000], + "7": [0.1000, 0.0500, 0.0500, 0.0500, 0.0500, 0.0500, 0.5000, 0.1500], + "8": [0.3000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000], + "9": [0.1000, 0.4500, 0.0500, 0.0500, 0.0500, 0.0500, 0.1000, 0.1500], + "10": [0.1500, 0.0500, 0.4000, 0.0500, 0.1000, 0.0500, 0.1000, 0.1000] +} \ No newline at end of file diff --git a/training.py b/training.py new file mode 100644 index 0000000..30f53fc --- /dev/null +++ b/training.py @@ -0,0 +1,334 @@ +import torch +import torchaudio +from torch.utils.data import Dataset, DataLoader +from torch.optim import AdamW +from zonos.model import Zonos +from zonos.conditioning import make_cond_dict +import os +from tqdm import tqdm +from datasets import load_dataset +import torch.nn as nn +import json +import numpy as np +from typing import Dict, List, Optional, Tuple + +class EmotionEncoder: + """Encodes emotions into embeddings using a predefined emotion set.""" + def __init__(self): + self.emotions = { + "neutral": 0, + "happy": 1, + "sad": 2, + "angry": 3, + "fearful": 4, + "disgust": 5, + "surprised": 6 + } + self.emotion_embeddings = torch.nn.Embedding(len(self.emotions), 256) # 256-dim emotion embedding + + def encode(self, emotion: str) -> torch.Tensor: + if emotion not in self.emotions: + emotion = "neutral" # default to neutral if emotion not found + emotion_idx = torch.tensor([self.emotions[emotion]]) + return self.emotion_embeddings(emotion_idx) + +class AudioPreprocessor: + """Handles audio preprocessing including normalization and augmentation.""" + def __init__(self, sample_rate: int = 16000, max_length_seconds: float = 30.0): + self.sample_rate = sample_rate + self.max_length = int(max_length_seconds * sample_rate) + + # Mel spectrogram parameters + self.n_fft = 1024 + self.win_length = int(0.025 * sample_rate) # 25ms window + self.hop_length = int(0.01 * sample_rate) # 10ms hop + self.n_mels = 80 + + # Calculate maximum mel length to ensure consistent sizes + self.max_mel_length = (self.max_length - self.n_fft) // self.hop_length + 3 + + self.mel_transform = torchaudio.transforms.MelSpectrogram( + sample_rate=sample_rate, + n_fft=self.n_fft, + win_length=self.win_length, + hop_length=self.hop_length, + n_mels=self.n_mels, + f_min=0, + f_max=8000, + window_fn=torch.hann_window, + normalized=True + ) + + def process(self, wav: torch.Tensor, augment: bool = True) -> torch.Tensor: + """Process audio and return processed audio.""" + # Ensure audio is 2D (channels, time) + if wav.dim() == 1: + wav = wav.unsqueeze(0) + + # Convert to mono if stereo + if wav.size(0) > 1: + wav = wav.mean(0, keepdim=True) + + # Ensure consistent length + if wav.size(-1) > self.max_length: + start = 0 # Always take from start for consistency + wav = wav[..., start:start + self.max_length] + else: + # Pad with zeros if too short + pad_length = self.max_length - wav.size(-1) + wav = torch.nn.functional.pad(wav, (0, pad_length)) + + # Normalize audio to [-1, 1] + wav = wav / (torch.max(torch.abs(wav)) + 1e-8) + + if augment: + # Random volume adjustment + wav = wav * (0.8 + 0.4 * torch.rand(1)) + + # Random noise addition (SNR between 20-30dB) + if torch.rand(1) < 0.5: + noise_level = 10 ** (-torch.rand(1) * 10 - 20) # -20 to -30 dB + noise = torch.randn_like(wav) * noise_level + wav = wav + noise + + return wav + +class ZonosHFDataset(Dataset): + def __init__( + self, + dataset_name="mozilla-foundation/common_voice_17_0", + language="uz", + split="train", + sampling_rate=16000, + emotion_labels_path: Optional[str] = None + ): + self.sampling_rate = sampling_rate + self.language = language + self.audio_processor = AudioPreprocessor(sampling_rate) + + print(f"Loading {dataset_name} dataset...") + self.dataset = load_dataset(dataset_name, language, split=split) + print(f"Dataset loaded with {len(self.dataset)} samples") + + # Load emotion labels if provided + self.emotion_labels = {} + if emotion_labels_path and os.path.exists(emotion_labels_path): + with open(emotion_labels_path, 'r') as f: + self.emotion_labels = json.load(f) + + def __len__(self): + return len(self.dataset) + + def get_emotion(self, idx: int) -> List[float]: + """Get emotion vector for a sample, defaulting to neutral if not found.""" + if str(idx) in self.emotion_labels: + return self.emotion_labels[str(idx)] + return [0.3077, 0.0256, 0.0256, 0.0256, 0.0256, 0.0256, 0.2564, 0.3077] # default neutral + + def __getitem__(self, idx): + item = self.dataset[idx] + + # Load and process audio + audio = item['audio'] + wav = torch.FloatTensor(audio['array']) + sr = audio['sampling_rate'] + + # Ensure 2D (channels, time) + if wav.dim() == 1: + wav = wav.unsqueeze(0) + + # Convert to mono if stereo + if wav.size(0) > 1: + wav = wav.mean(0, keepdim=True) + + # Resample if necessary + if sr != self.sampling_rate: + wav = torchaudio.functional.resample(wav, sr, self.sampling_rate) + + # Ensure fixed length (30 seconds) + target_length = 30 * self.sampling_rate + if wav.size(-1) > target_length: + wav = wav[..., :target_length] + else: + # Pad with zeros if too short + pad_length = target_length - wav.size(-1) + wav = torch.nn.functional.pad(wav, (0, pad_length)) + + # Process audio + wav = self.audio_processor.process(wav) + + # Get emotion vector + emotion = torch.tensor(self.get_emotion(idx), dtype=torch.float32) + + return { + 'audio': wav.squeeze(0), # Remove channel dimension for speaker model + 'text': item['sentence'], + 'language': self.language, + 'emotion': emotion + } + +def train_zonos( + model_path="Zyphra/Zonos-v0.1-transformer", + dataset_name="mozilla-foundation/common_voice_17_0", + language="uz", + output_dir="checkpoints", + batch_size=4, # Reduced batch size + learning_rate=1e-4, + num_epochs=10, + device="cuda" if torch.cuda.is_available() else "cpu", + emotion_labels_path: Optional[str] = None +): + # Initialize model + model = Zonos.from_pretrained(model_path, device=device) + model.train() + + # Create dataset and dataloader + dataset = ZonosHFDataset( + dataset_name, + language, + emotion_labels_path=emotion_labels_path + ) + + dataloader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + num_workers=2, # Reduced number of workers + collate_fn=lambda x: { + 'audio': torch.stack([s['audio'] for s in x]), # Audio is already properly shaped + 'text': [s['text'] for s in x], + 'language': [s['language'] for s in x], + 'emotion': torch.stack([s['emotion'] for s in x]) + } + ) + + # Initialize optimizer with warmup + optimizer = AdamW(model.parameters(), lr=learning_rate) + scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer, + max_lr=learning_rate, + epochs=num_epochs, + steps_per_epoch=len(dataloader), + pct_start=0.1 + ) + + # Training loop + for epoch in range(num_epochs): + print(f"Epoch {epoch+1}/{num_epochs}") + total_loss = 0 + + for batch in tqdm(dataloader): + optimizer.zero_grad() + + # Move tensors to device + audio = batch['audio'].to(device) # Shape: [batch_size, 480000] + emotions = batch['emotion'].to(device) + + # Process each sample in the batch + batch_loss = 0 + for i in range(len(audio)): + # Create speaker embedding from the audio + speaker = model.make_speaker_embedding(audio[i], dataset.sampling_rate) # Shape: [1, 256] + + # Prepare conditioning with emotion vector + cond_dict = make_cond_dict( + text=batch['text'][i], + speaker=speaker, + language=batch['language'][i], + emotion=emotions[i].tolist() + ) + conditioning = model.prepare_conditioning(cond_dict) + + # Get target codes using the autoencoder + # Ensure audio is properly shaped for the autoencoder + audio_input = audio[i].unsqueeze(0) # Add batch dimension: [1, 480000] + if audio_input.dim() == 1: + audio_input = audio_input.unsqueeze(0) # Add batch dimension if needed + if audio_input.dim() == 2: + audio_input = audio_input.unsqueeze(1) # Add channel dimension if needed + + # Resample to 44.1kHz for DAC + if dataset.sampling_rate != 44100: + audio_input = torchaudio.functional.resample( + audio_input, + dataset.sampling_rate, + 44100 + ) + + # Ensure proper padding for DAC + target_length = int(44100 * 30) # 30 seconds at 44.1kHz + if audio_input.size(-1) < target_length: + pad_length = target_length - audio_input.size(-1) + audio_input = torch.nn.functional.pad(audio_input, (0, pad_length)) + elif audio_input.size(-1) > target_length: + audio_input = audio_input[..., :target_length] + + with torch.no_grad(): + target_codes = model.autoencoder.encode(audio_input) # Shape: [1, num_codebooks, seq_len] + target_codes = target_codes.to(device) + + # Forward pass + output = model(conditioning) # Shape: [batch_size, num_codebooks, seq_len, vocab_size] + + # Calculate token prediction loss + # Ensure output and target_codes have compatible shapes + target_codes = target_codes.squeeze(0) # Remove batch dimension from target_codes + + # Get the minimum sequence length between output and target + seq_len = min(output.size(2), target_codes.size(-1)) + + # Truncate both tensors to the same sequence length + output = output[..., :seq_len, :] # [batch_size, num_codebooks, seq_len, vocab_size] + target_codes = target_codes[..., :seq_len] # [num_codebooks, seq_len] + + # Reshape for cross entropy - fixed dimensions for 4D tensor + output = output.squeeze(0) # Remove batch dimension: [num_codebooks, seq_len, vocab_size] + output = output.permute(1, 0, 2) # [seq_len, num_codebooks, vocab_size] + output = output.contiguous().view(-1, output.size(-1)) # [seq_len * num_codebooks, vocab_size] + target_codes = target_codes.transpose(0, 1) # [seq_len, num_codebooks] + target_codes = target_codes.reshape(-1) # [seq_len * num_codebooks] + + token_loss = torch.nn.functional.cross_entropy(output, target_codes) + + batch_loss += token_loss + + # Average loss over batch + batch_loss = batch_loss / batch_size + + # Backward pass + batch_loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + scheduler.step() + + total_loss += batch_loss.item() + + # End of epoch + avg_loss = total_loss / len(dataloader) + print(f"Average loss: {avg_loss:.4f}") + + # Save checkpoint + checkpoint_path = os.path.join(output_dir, f"zonos_checkpoint_epoch_{epoch+1}.pt") + torch.save({ + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'scheduler_state_dict': scheduler.state_dict(), + 'loss': avg_loss, + }, checkpoint_path) + print(f"Saved checkpoint to {checkpoint_path}") + +if __name__ == "__main__": + # Create output directory if it doesn't exist + os.makedirs("checkpoints", exist_ok=True) + + # Start training + train_zonos( + model_path="Zyphra/Zonos-v0.1-transformer", + dataset_name="mozilla-foundation/common_voice_17_0", + language="uz", + output_dir="checkpoints", + batch_size=4, + num_epochs=10, + emotion_labels_path="emotion_labels.json" # Optional path to emotion labels + ) \ No newline at end of file diff --git a/zonos/backbone/_torch.py b/zonos/backbone/_torch.py index 1b4287b..992b107 100644 --- a/zonos/backbone/_torch.py +++ b/zonos/backbone/_torch.py @@ -6,9 +6,11 @@ from zonos.config import BackboneConfig, InferenceParams -def precompute_freqs_cis(seq_len: int, n_elem: int, base: float = 10000) -> torch.Tensor: +def precompute_freqs_cis(seq_len: int, n_elem: int, base: float = 10000, device: torch.device = None) -> torch.Tensor: freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) - t = torch.arange(seq_len, device=freqs.device) + if device is not None: + freqs = freqs.to(device) + t = torch.arange(seq_len, device=device) freqs = torch.outer(t, freqs) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) @@ -36,6 +38,12 @@ def _update_kv_cache( """k/v: (batch_size, seqlen, nheads, head_dim) or (batch_size, 1, nheads, head_dim)""" assert layer_idx in inference_params.key_value_memory_dict kv_cache, _ = inference_params.key_value_memory_dict[layer_idx] + + # Ensure kv_cache is on the same device as k and v + if kv_cache.device != k.device: + kv_cache = kv_cache.to(k.device) + inference_params.key_value_memory_dict[layer_idx] = (kv_cache, None) + # Adjust key and value for inference batch_start = inference_params.batch_size_offset batch_end = batch_start + k.shape[0] @@ -64,17 +72,26 @@ def __init__(self, config: BackboneConfig): def allocate_inference_cache(self, batch_size: int, max_seqlen: int, dtype: torch.dtype = torch.bfloat16): # TODO: This function should be pure head_dim = self.config.d_model // self.config.attn_cfg["num_heads"] - self.freqs_cis = precompute_freqs_cis(16384, head_dim) + device = next(self.parameters()).device + self.freqs_cis = precompute_freqs_cis(16384, head_dim, device=device) return { i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype) for i, layer in enumerate(self.layers) } def forward(self, hidden_states: torch.Tensor, inference_params: InferenceParams) -> torch.Tensor: - input_pos = torch.arange(0, hidden_states.shape[1], device=hidden_states.device) + device = hidden_states.device + # Move all tensors to the same device at the start + if inference_params.lengths_per_sample.device != device: + inference_params.lengths_per_sample = inference_params.lengths_per_sample.to(device) + if self.freqs_cis.device != device: + self.freqs_cis = self.freqs_cis.to(device) + + input_pos = torch.arange(0, hidden_states.shape[1], device=device) input_pos = input_pos + inference_params.lengths_per_sample.unsqueeze(-1) freqs_cis = self.freqs_cis[input_pos].expand(hidden_states.shape[0], -1, -1, -1) + for i, layer in enumerate(self.layers): hidden_states = layer(hidden_states, inference_params, freqs_cis) return self.norm_f(hidden_states) @@ -94,7 +111,8 @@ def __init__(self, config: BackboneConfig, layer_idx: int) -> None: self.head_dim = config.d_model // config.attn_cfg["num_heads"] def allocate_inference_cache(self, batch_size: int, max_seqlen: int, dtype: torch.dtype = torch.bfloat16): - return torch.empty(batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim, dtype=dtype), None + device = next(self.parameters()).device + return torch.empty(batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim, dtype=dtype, device=device), None def forward(self, x: torch.Tensor, inference_params: InferenceParams, freqs_cis: torch.Tensor) -> torch.Tensor: x = x + self.mixer(self.norm(x), inference_params, freqs_cis) diff --git a/zonos/conditioning.py b/zonos/conditioning.py index 016cb53..85e6867 100644 --- a/zonos/conditioning.py +++ b/zonos/conditioning.py @@ -142,7 +142,7 @@ def normalize_numbers(text: str) -> str: PAD_ID, UNK_ID, BOS_ID, EOS_ID = 0, 1, 2, 3 SPECIAL_TOKEN_IDS = [PAD_ID, UNK_ID, BOS_ID, EOS_ID] -_punctuation = ';:,.!?¡¿—…"«»“”() *~-/\\&' +_punctuation = ';:,.!?¡¿—…"«»""() *~-/\\&' _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" _letters_ipa = ( "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" @@ -188,31 +188,59 @@ def clean(texts: list[str], languages: list[str]) -> list[str]: @cache def get_backend(language: str) -> "EspeakBackend": + """Get phonemizer backend for a given language.""" import logging - from phonemizer.backend import EspeakBackend logger = logging.getLogger("phonemizer") - backend = EspeakBackend( - language, - preserve_punctuation=True, - with_stress=True, - punctuation_marks=_punctuation, - logger=logger, - ) - logger.setLevel(logging.ERROR) - return backend + + # Map language codes to their espeak variants + language_mapping = { + "uz": "uzb", # Uzbek needs to use 'uzb' for espeak + } + + # Use mapped language code if available + espeak_language = language_mapping.get(language, language) + + try: + backend = EspeakBackend( + espeak_language, + preserve_punctuation=True, + with_stress=True, + punctuation_marks=_punctuation, + logger=logger, + ) + logger.setLevel(logging.ERROR) + return backend + except RuntimeError as e: + print(f"Warning: Language {language} (espeak: {espeak_language}) failed, falling back to en-us") + backend = EspeakBackend( + "en-us", + preserve_punctuation=True, + with_stress=True, + punctuation_marks=_punctuation, + logger=logger, + ) + logger.setLevel(logging.ERROR) + return backend def phonemize(texts: list[str], languages: list[str]) -> list[str]: + """Phonemize text using appropriate backend for each language.""" texts = clean(texts, languages) - + batch_phonemes = [] for text, language in zip(texts, languages): backend = get_backend(language) - phonemes = backend.phonemize([text], strip=True) - batch_phonemes.append(phonemes[0]) - + try: + phonemes = backend.phonemize([text], strip=True) + batch_phonemes.append(phonemes[0]) + except: + # If phonemization fails, fall back to English + fallback_backend = get_backend("en-us") + phonemes = fallback_backend.phonemize([text], strip=True) + batch_phonemes.append(phonemes[0]) + return batch_phonemes diff --git a/zonos/model.py b/zonos/model.py index ccb713b..1660613 100644 --- a/zonos/model.py +++ b/zonos/model.py @@ -198,7 +198,7 @@ def _prefill( def setup_cache(self, batch_size: int, max_seqlen: int, dtype: torch.dtype = torch.bfloat16) -> InferenceParams: max_seqlen = find_multiple(max_seqlen, 8) key_value_memory_dict = self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype) - lengths_per_sample = torch.full((batch_size,), 0, dtype=torch.int32) + lengths_per_sample = torch.full((batch_size,), 0, dtype=torch.int32, device=self.device) return InferenceParams(max_seqlen, batch_size, 0, 0, key_value_memory_dict, lengths_per_sample) def prepare_conditioning(self, cond_dict: dict, uncond_dict: dict | None = None) -> torch.Tensor: @@ -211,6 +211,27 @@ def prepare_conditioning(self, cond_dict: dict, uncond_dict: dict | None = None) ] ) + def forward(self, prefix_conditioning: torch.Tensor) -> torch.Tensor: + """Forward pass of the model. + + Args: + prefix_conditioning: Tensor of shape [batch_size, seq_len, d_model] containing the conditioning information + + Returns: + Tensor of logits for next token prediction + """ + # Setup inference parameters for the current batch + batch_size = prefix_conditioning.shape[0] // 2 # Divide by 2 because of CFG + inference_params = self.setup_cache(batch_size=batch_size * 2, max_seqlen=prefix_conditioning.shape[1]) + + # Pass through backbone + hidden_states = self.backbone(prefix_conditioning, inference_params) + + # Get logits from the final hidden states + logits = self.apply_heads(hidden_states) + + return logits + def can_use_cudagraphs(self) -> bool: # Only the mamba-ssm backbone supports CUDA Graphs at the moment return self.device.type == "cuda" and "_mamba_ssm" in str(self.backbone.__class__) diff --git a/zonos/speaker_cloning.py b/zonos/speaker_cloning.py index f76aa6d..9ce33f8 100644 --- a/zonos/speaker_cloning.py +++ b/zonos/speaker_cloning.py @@ -29,9 +29,19 @@ def __init__( ) def forward(self, x): + # Ensure input is 2D (batch, time) + if x.dim() == 1: + x = x.unsqueeze(0) + + # Calculate mel spectrogram out = self.fbankCal(x) + + # Apply log scale with small offset for numerical stability out = torch.log(out + 1e-6) - out = out - out.mean(axis=2).unsqueeze(dim=2) + + # Normalize along time dimension (last dimension) + out = out - out.mean(dim=-1, keepdim=True) + return out