diff --git a/notebooks/automatic_model_training.ipynb b/notebooks/automatic_model_training.ipynb index 83309f2..0df4361 100644 --- a/notebooks/automatic_model_training.ipynb +++ b/notebooks/automatic_model_training.ipynb @@ -2,503 +2,463 @@ "cells": [ { "cell_type": "markdown", - "id": "c1eab0b3", - "metadata": { - "id": "c1eab0b3" - }, + "metadata": {}, "source": [ - "# Introduction" - ] - }, - { - "cell_type": "markdown", - "id": "882058c5", - "metadata": { - "id": "882058c5" - }, - "source": [ - "This notebook demonstrates how to train custom openWakeWord models using pre-defined datasets and an automated process for dataset generation and training. While not guaranteed to always produce the best performing model, the methods shown in this notebook often produce baseline models with releatively strong performance.\n", + "# Hey Snowy — Entraînement Wake Word OpenWakeWord\n", "\n", - "Manual data preparation and model training (e.g., see the [training models](training_models.ipynb) notebook) remains an option for when full control over the model development process is needed.\n", + "Notebook one-shot corrigé pour fonctionner sur Google Colab (avril 2026).\n", "\n", - "At a high level, the automatic training process takes advantages of several techniques to try and produce a good model, including:\n", + "**Corrections intégrées par rapport au notebook officiel :**\n", + "- `torchaudio.set_audio_backend` deprecated → patché\n", + "- `torchaudio.info` supprimé → patché avec soundfile\n", + "- `generate_samples.py` manquant → créé avec resampling 16kHz auto\n", + "- `piper-sample-generator` installé depuis le bon repo (rhasspy)\n", + "- AudioSet (404) → remplacé par FMA dataset\n", + "- `onnxscript` manquant → installé explicitement\n", + "- Conversion tflite ignorée (non nécessaire pour Android)\n", "\n", - "- Early-stopping and checkpoint averaging (similar to [stochastic weight averaging](https://arxiv.org/abs/1803.05407)) to search for the best models found during training, according to the validation data\n", - "- Variable learning rates with cosine decay and multiple cycles\n", - "- Adaptive batch construction to focus on only high-loss examples when the model begins to converge, combined with gradient accumulation to ensure that batch sizes are still large enough for stable training\n", - "- Cycical weight schedules for negative examples to help the model reduce false-positive rates\n", + "⚠️ **Prérequis : activer GPU T4** → Exécution → Modifier le type d'exécution → T4 GPU\n", "\n", - "See the contents of the `train.py` file for more details." - ] - }, - { - "cell_type": "markdown", - "id": "e08d031b", - "metadata": { - "id": "e08d031b" - }, - "source": [ - "# Environment Setup" - ] - }, - { - "cell_type": "markdown", - "id": "aee78c37", - "metadata": { - "id": "aee78c37" - }, - "source": [ - "To begin, we'll need to install the requirements for training custom models. In particular, a relatively recent version of Pytorch and custom fork of the [piper-sample-generator](https://github.com/dscripka/piper-sample-generator) library for generating synthetic examples for the custom model.\n", - "\n", - "**Important Note!** Currently, automated model training is only supported on linux systems due to the requirements of the text to speech library used for synthetic sample generation (Piper). It may be possible to use Piper on Windows/Mac systems, but that has not (yet) been tested." + "⏱️ **Durée totale : ~1h30**" ] }, { "cell_type": "code", "execution_count": null, - "id": "4b1227eb", - "metadata": { - "id": "4b1227eb" - }, + "metadata": {}, "outputs": [], "source": [ - "## Environment setup\n", - "\n", - "# install piper-sample-generator (currently only supports linux systems)\n", - "!git clone https://github.com/rhasspy/piper-sample-generator\n", - "!wget -O piper-sample-generator/models/en_US-libritts_r-medium.pt 'https://github.com/rhasspy/piper-sample-generator/releases/download/v2.0.0/en_US-libritts_r-medium.pt'\n", - "!pip install piper-phonemize\n", - "!pip install webrtcvad\n", - "\n", - "# install openwakeword (full installation to support training)\n", - "!git clone https://github.com/dscripka/openwakeword\n", - "!pip install -e ./openwakeword\n", - "!cd openwakeword\n", - "\n", - "# install other dependencies\n", - "!pip install mutagen==1.47.0\n", - "!pip install torchinfo==1.8.0\n", - "!pip install torchmetrics==1.2.0\n", - "!pip install speechbrain==0.5.14\n", - "!pip install audiomentations==0.33.0\n", - "!pip install torch-audiomentations==0.11.0\n", - "!pip install acoustics==0.2.6\n", - "!pip install tensorflow-cpu==2.8.1\n", - "!pip install tensorflow_probability==0.16.0\n", - "!pip install onnx_tf==1.10.0\n", - "!pip install pronouncing==0.2.0\n", - "!pip install datasets==2.14.6\n", - "!pip install deep-phonemizer==0.0.19\n", - "\n", - "# Download required models (workaround for Colab)\n", - "import os\n", - "os.makedirs(\"./openwakeword/openwakeword/resources/models\")\n", - "!wget https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/embedding_model.onnx -O ./openwakeword/openwakeword/resources/models/embedding_model.onnx\n", - "!wget https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/embedding_model.tflite -O ./openwakeword/openwakeword/resources/models/embedding_model.tflite\n", - "!wget https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/melspectrogram.onnx -O ./openwakeword/openwakeword/resources/models/melspectrogram.onnx\n", - "!wget https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/melspectrogram.tflite -O ./openwakeword/openwakeword/resources/models/melspectrogram.tflite\n" + "# ============================================================\n", + "# CELLULE 0 — Configuration (MODIFIER ICI)\n", + "# ============================================================\n", + "\n", + "TARGET_PHRASE = 'hey snowy' # <-- MODIFIE ICI ('hey snowy' ou 'bye bye snowy')\n", + "N_SAMPLES = 5000 # Nombre de clips d'entraînement\n", + "N_SAMPLES_VAL = 1000 # Nombre de clips de validation\n", + "STEPS = 20000 # Nombre de steps d'entraînement\n", + "\n", + "print(f'Wake word cible : \"{TARGET_PHRASE}\"')\n", + "print(f'Samples : {N_SAMPLES} train / {N_SAMPLES_VAL} val')\n", + "print(f'Steps : {STEPS}')" ] }, { "cell_type": "code", "execution_count": null, - "id": "d4c1056e", - "metadata": { - "ExecuteTime": { - "end_time": "2023-09-04T13:42:01.183840Z", - "start_time": "2023-09-04T13:41:59.752153Z" - }, - "id": "d4c1056e" - }, + "metadata": {}, "outputs": [], "source": [ - "# Imports\n", - "\n", + "# ============================================================\n", + "# CELLULE 1 — Installation des dépendances (~5 min)\n", + "# ============================================================\n", "import os\n", - "import numpy as np\n", - "import torch\n", - "import sys\n", - "from pathlib import Path\n", - "import uuid\n", - "import yaml\n", - "import datasets\n", - "import scipy\n", - "from tqdm import tqdm\n" - ] - }, - { - "cell_type": "markdown", - "id": "e9d7a05a", - "metadata": { - "id": "e9d7a05a" - }, - "source": [ - "# Download Data" - ] - }, - { - "cell_type": "markdown", - "id": "c52f75cc", - "metadata": { - "id": "c52f75cc" - }, - "source": [ - "When training new openWakeWord models using the automated procedure, four specific types of data are required:\n", - "\n", - "1) Synthetic examples of the target word/phrase generated with text-to-speech models\n", "\n", - "2) Synthetic examples of adversarial words/phrases generated with text-to-speech models\n", - "\n", - "3) Room impulse reponses and noise/background audio data to augment the synthetic examples and make them more realistic\n", - "\n", - "4) Generic \"negative\" audio data that is very unlikely to contain examples of the target word/phrase in the context where the model should detect it. This data can be the original audio data, or precomputed openWakeWord features ready for model training.\n", - "\n", - "5) Validation data to use for early-stopping when training the model.\n", + "# Clone les repos\n", + "!git clone https://github.com/dscripka/openWakeWord\n", + "!git clone https://github.com/rhasspy/piper-sample-generator\n", "\n", - "For the purposes of this notebook, all five of these sources will either be generated manually or can be obtained from HuggingFace thanks to their excellent `datasets` library and extremely generous hosting policy. Also note that while only a portion of some datasets are downloaded, for the best possible performance it is recommended to download the entire dataset and keep a local copy for future training runs." + "# Télécharge le modèle TTS Piper\n", + "!mkdir -p piper-sample-generator/models\n", + "!wget -q -O piper-sample-generator/models/en_US-libritts_r-medium.pt \\\n", + " 'https://github.com/rhasspy/piper-sample-generator/releases/download/v2.0.0/en_US-libritts_r-medium.pt'\n", + "print('Modèle TTS téléchargé ✅')\n", + "\n", + "# Installe les packages\n", + "!pip install -q -e ./openWakeWord\n", + "!pip install -q -e ./piper-sample-generator\n", + "!pip install -q webrtcvad mutagen torchinfo torchmetrics==1.2.0 speechbrain==0.5.14\n", + "!pip install -q audiomentations==0.33.0 torch-audiomentations==0.11.0\n", + "!pip install -q acoustics pronouncing datasets==2.14.6 deep-phonemizer==0.0.19\n", + "!pip install -q librosa soundfile onnxscript onnx\n", + "\n", + "# Crée le dossier models s'il n'existe pas\n", + "os.makedirs('/content/openWakeWord/openwakeword/resources/models', exist_ok=True)\n", + "\n", + "# Télécharge les modèles OpenWakeWord\n", + "base_url = 'https://github.com/dscripka/openWakeWord/releases/download/v0.5.1'\n", + "models_dir = '/content/openWakeWord/openwakeword/resources/models'\n", + "for model_file in ['embedding_model.onnx', 'embedding_model.tflite', 'melspectrogram.onnx', 'melspectrogram.tflite']:\n", + " !wget -q {base_url}/{model_file} -O {models_dir}/{model_file}\n", + " print(f'{model_file} ✅')\n", + "\n", + "print('\\nInstallation terminée ✅')" ] }, { "cell_type": "code", "execution_count": null, - "id": "d25a93b1", - "metadata": { - "ExecuteTime": { - "end_time": "2023-09-04T01:07:17.746749Z", - "start_time": "2023-09-04T01:07:17.740846Z" - }, - "id": "d25a93b1" - }, + "metadata": {}, "outputs": [], "source": [ - "# Download room impulse responses collected by MIT\n", - "# https://mcdermottlab.mit.edu/Reverb/IR_Survey.html\n", - "\n", - "output_dir = \"./mit_rirs\"\n", - "if not os.path.exists(output_dir):\n", - " os.mkdir(output_dir)\n", - "rir_dataset = datasets.load_dataset(\"davidscripka/MIT_environmental_impulse_responses\", split=\"train\", streaming=True)\n", + "# ============================================================\n", + "# CELLULE 2 — Patches de compatibilité (CRITIQUE)\n", + "# ============================================================\n", + "import os\n", "\n", - "# Save clips to 16-bit PCM wav files\n", - "for row in tqdm(rir_dataset):\n", - " name = row['audio']['path'].split('/')[-1]\n", - " scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))" + "# --- Patch 1 : torchaudio.set_audio_backend deprecated ---\n", + "!sed -i 's/torchaudio.set_audio_backend(\"soundfile\")/#torchaudio.set_audio_backend(\"soundfile\") # patched/' \\\n", + " /usr/local/lib/python3.12/dist-packages/torch_audiomentations/utils/io.py\n", + "print('Patch 1 : torchaudio.set_audio_backend ✅')\n", + "\n", + "# --- Patch 2 : torchaudio.info supprimé dans nouvelles versions ---\n", + "io_path = '/usr/local/lib/python3.12/dist-packages/torch_audiomentations/utils/io.py'\n", + "patch_code = '''# === PATCH torchaudio.info ===\n", + "import torchaudio as _torchaudio\n", + "import soundfile as _sf\n", + "if not hasattr(_torchaudio, 'info'):\n", + " def _torchaudio_info(path):\n", + " class _AudioMetaData:\n", + " def __init__(self, path):\n", + " info = _sf.info(path)\n", + " self.sample_rate = info.samplerate\n", + " self.num_frames = info.frames\n", + " self.num_channels = info.channels\n", + " return _AudioMetaData(path)\n", + " _torchaudio.info = _torchaudio_info\n", + "# === FIN PATCH ===\n", + "'''\n", + "\n", + "with open(io_path, 'r') as f:\n", + " content = f.read()\n", + "\n", + "if '=== PATCH torchaudio.info ===' not in content:\n", + " with open(io_path, 'w') as f:\n", + " f.write(patch_code + content)\n", + " print('Patch 2 : torchaudio.info ✅')\n", + "else:\n", + " print('Patch 2 : déjà appliqué ✅')\n", + "\n", + "# --- Patch 3 : generate_samples.py manquant ---\n", + "generate_samples_code = '''import sys, os\n", + "import librosa\n", + "import soundfile as sf\n", + "sys.path.insert(0, \"/content/piper-sample-generator\")\n", + "from piper_sample_generator.__main__ import generate_samples as _generate_samples\n", + "\n", + "def generate_samples(text, max_samples, output_dir, batch_size=1,\n", + " noise_scales=None, noise_scale_ws=None,\n", + " length_scales=None, auto_reduce_batch_size=False,\n", + " file_names=None, **kwargs):\n", + " _generate_samples(\n", + " text=text,\n", + " max_samples=max_samples,\n", + " output_dir=output_dir,\n", + " model=\"/content/piper-sample-generator/models/en_US-libritts_r-medium.pt\",\n", + " batch_size=batch_size,\n", + " noise_scales=noise_scales or [0.667],\n", + " noise_scale_ws=noise_scale_ws or [0.8],\n", + " length_scales=length_scales or [0.75, 1.0, 1.25],\n", + " file_names=file_names,\n", + " )\n", + " # Resample tous les fichiers générés à 16kHz\n", + " for fname in os.listdir(output_dir):\n", + " if fname.endswith(\".wav\"):\n", + " path = os.path.join(output_dir, fname)\n", + " audio, sr = librosa.load(path, sr=16000)\n", + " sf.write(path, audio, 16000)\n", + "'''\n", + "\n", + "with open('/content/openWakeWord/openwakeword/generate_samples.py', 'w') as f:\n", + " f.write(generate_samples_code)\n", + "print('Patch 3 : generate_samples.py ✅')\n", + "\n", + "print('\\nTous les patches appliqués ✅')" ] }, { "cell_type": "code", "execution_count": null, - "id": "2c0e178b", - "metadata": { - "id": "2c0e178b" - }, + "metadata": {}, "outputs": [], "source": [ - "## Download noise and background audio\n", - "\n", - "# Audioset Dataset (https://research.google.com/audioset/dataset/index.html)\n", - "# Download one part of the audioset .tar files, extract, and convert to 16khz\n", - "# For full-scale training, it's recommended to download the entire dataset from\n", - "# https://huggingface.co/datasets/agkphysics/AudioSet, and\n", - "# even potentially combine it with other background noise datasets (e.g., FSD50k, Freesound, etc.)\n", - "\n", - "if not os.path.exists(\"audioset\"):\n", - " os.mkdir(\"audioset\")\n", + "# ============================================================\n", + "# CELLULE 3 — Téléchargement données de fond FMA (~1 min)\n", + "# ============================================================\n", + "import os, numpy as np, scipy.io.wavfile, datasets\n", + "from tqdm import tqdm\n", "\n", - "fname = \"bal_train09.tar\"\n", - "out_dir = f\"audioset/{fname}\"\n", - "link = \"https://huggingface.co/datasets/agkphysics/AudioSet/resolve/main/data/\" + fname\n", - "!wget -O {out_dir} {link}\n", - "!cd audioset && tar -xvf bal_train09.tar\n", - "\n", - "output_dir = \"./audioset_16k\"\n", + "output_dir = './fma'\n", "if not os.path.exists(output_dir):\n", " os.mkdir(output_dir)\n", "\n", - "# Convert audioset files to 16khz sample rate\n", - "audioset_dataset = datasets.Dataset.from_dict({\"audio\": [str(i) for i in Path(\"audioset/audio\").glob(\"**/*.flac\")]})\n", - "audioset_dataset = audioset_dataset.cast_column(\"audio\", datasets.Audio(sampling_rate=16000))\n", - "for row in tqdm(audioset_dataset):\n", - " name = row['audio']['path'].split('/')[-1].replace(\".flac\", \".wav\")\n", - " scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))\n", - "\n", - "# Free Music Archive dataset (https://github.com/mdeff/fma)\n", - "output_dir = \"./fma\"\n", - "if not os.path.exists(output_dir):\n", - " os.mkdir(output_dir)\n", - "fma_dataset = datasets.load_dataset(\"rudraml/fma\", name=\"small\", split=\"train\", streaming=True)\n", - "fma_dataset = iter(fma_dataset.cast_column(\"audio\", datasets.Audio(sampling_rate=16000)))\n", - "\n", - "n_hours = 1 # use only 1 hour of clips for this example notebook, recommend increasing for full-scale training\n", - "for i in tqdm(range(n_hours*3600//30)): # this works because the FMA dataset is all 30 second clips\n", - " row = next(fma_dataset)\n", - " name = row['audio']['path'].split('/')[-1].replace(\".mp3\", \".wav\")\n", - " scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))\n", - " i += 1\n", - " if i == n_hours*3600//30:\n", - " break\n" + "if len(os.listdir(output_dir)) < 100:\n", + " print('Téléchargement FMA dataset...')\n", + " fma_dataset = datasets.load_dataset('rudraml/fma', name='small', split='train', streaming=True)\n", + " fma_dataset = iter(fma_dataset.cast_column('audio', datasets.Audio(sampling_rate=16000)))\n", + " n_hours = 1\n", + " for i in tqdm(range(n_hours*3600//30)):\n", + " row = next(fma_dataset)\n", + " name = row['audio']['path'].split('/')[-1].replace('.mp3', '.wav')\n", + " scipy.io.wavfile.write(os.path.join(output_dir, name), 16000,\n", + " (row['audio']['array']*32767).astype(np.int16))\n", + " print(f'FMA téléchargé : {len(os.listdir(output_dir))} fichiers ✅')\n", + "else:\n", + " print(f'FMA déjà présent : {len(os.listdir(output_dir))} fichiers ✅')" ] }, { "cell_type": "code", "execution_count": null, - "id": "d01ec467", - "metadata": { - "id": "d01ec467" - }, + "metadata": {}, "outputs": [], "source": [ - "# Download pre-computed openWakeWord features for training and validation\n", - "\n", - "# training set (~2,000 hours from the ACAV100M Dataset)\n", - "# See https://huggingface.co/datasets/davidscripka/openwakeword_features for more information\n", - "!wget https://huggingface.co/datasets/davidscripka/openwakeword_features/resolve/main/openwakeword_features_ACAV100M_2000_hrs_16bit.npy\n", - "\n", - "# validation set for false positive rate estimation (~11 hours)\n", - "!wget https://huggingface.co/datasets/davidscripka/openwakeword_features/resolve/main/validation_set_features.npy" - ] - }, - { - "cell_type": "markdown", - "id": "cfe82647", - "metadata": { - "id": "cfe82647" - }, - "source": [ - "# Define Training Configuration" + "# ============================================================\n", + "# CELLULE 4 — Téléchargement MIT RIRs (~1 min)\n", + "# ============================================================\n", + "import os, numpy as np, scipy.io.wavfile, datasets\n", + "from tqdm import tqdm\n", + "\n", + "os.makedirs('./mit_rirs', exist_ok=True)\n", + "\n", + "if len(os.listdir('./mit_rirs')) < 50:\n", + " print('Téléchargement MIT RIRs...')\n", + " rir_dataset = datasets.load_dataset(\n", + " 'davidscripka/MIT_environmental_impulse_responses',\n", + " split='train', streaming=True\n", + " )\n", + " rir_dataset = iter(rir_dataset.cast_column('audio', datasets.Audio(sampling_rate=16000)))\n", + " for i in tqdm(range(100)):\n", + " try:\n", + " row = next(rir_dataset)\n", + " scipy.io.wavfile.write(\n", + " f'./mit_rirs/rir_{i:04d}.wav', 16000,\n", + " (row['audio']['array']*32767).astype(np.int16)\n", + " )\n", + " except StopIteration:\n", + " break\n", + " print(f'MIT RIRs : {len(os.listdir(\"./mit_rirs\"))} fichiers ✅')\n", + "else:\n", + " print(f'MIT RIRs déjà présents : {len(os.listdir(\"./mit_rirs\"))} fichiers ✅')" ] }, { - "cell_type": "markdown", - "id": "b2e71329", - "metadata": { - "id": "b2e71329" - }, + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ - "For automated model training openWakeWord uses a specially designed training script and a [YAML](https://yaml.org/) configuration file that defines all of the information required for training a new wake word/phrase detection model.\n", - "\n", - "It is strongly recommended that you review [the example config file](../examples/custom_model.yml), as each value is fully documented there. For the purposes of this notebook, we'll read in the YAML file to modify certain configuration parameters before saving a new YAML file for training our example model. Specifically:\n", - "\n", - "- We'll train a detection model for the phrase \"hey sebastian\"\n", - "- We'll only generate 5,000 positive and negative examples (to save on time for this example)\n", - "- We'll only generate 1,000 validation positive and negative examples for early stopping (again to save time)\n", - "- The model will only be trained for 10,000 steps (larger datasets will benefit from longer training)\n", - "- We'll reduce the target metrics to account for the small dataset size and limited training.\n", + "# ============================================================\n", + "# CELLULE 5 — Téléchargement features ACAV100M (~2 min, 16GB)\n", + "# ============================================================\n", + "import os\n", "\n", - "On the topic of target metrics, there are *not* specific guidelines about what these metrics should be in practice, and you will need to conduct testing in your target deployment environment to establish good thresholds. However, from very limited testing the default values in the config file (accuracy >= 0.7, recall >= 0.5, false-positive rate <= 0.2 per hour) seem to produce models with reasonable performance.\n" + "if not os.path.exists('openwakeword_features_ACAV100M_2000_hrs_16bit.npy'):\n", + " print('Téléchargement features ACAV100M (16GB)...')\n", + " !wget -q https://huggingface.co/datasets/davidscripka/openwakeword_features/resolve/main/openwakeword_features_ACAV100M_2000_hrs_16bit.npy\n", + " print('Features ACAV100M ✅')\n", + "else:\n", + " print('Features ACAV100M déjà présentes ✅')\n", + "\n", + "if not os.path.exists('validation_set_features.npy'):\n", + " print('Téléchargement features validation...')\n", + " !wget -q https://huggingface.co/datasets/davidscripka/openwakeword_features/resolve/main/validation_set_features.npy\n", + " print('Features validation ✅')\n", + "else:\n", + " print('Features validation déjà présentes ✅')" ] }, { "cell_type": "code", "execution_count": null, - "id": "fb0b6e4f", - "metadata": { - "ExecuteTime": { - "end_time": "2023-09-04T18:11:33.893397Z", - "start_time": "2023-09-04T18:11:33.878938Z" - }, - "id": "fb0b6e4f" - }, + "metadata": {}, "outputs": [], "source": [ - "# Load default YAML config file for training\n", - "config = yaml.load(open(\"openwakeword/examples/custom_model.yml\", 'r').read(), yaml.Loader)\n", - "config" + "# ============================================================\n", + "# CELLULE 6 — Configuration du modèle\n", + "# ============================================================\n", + "import yaml, sys\n", + "sys.path.insert(0, '/content/openWakeWord')\n", + "sys.path.insert(0, '/content/piper-sample-generator')\n", + "\n", + "MODEL_NAME = TARGET_PHRASE.replace(' ', '_')\n", + "\n", + "config = yaml.load(open('openWakeWord/examples/custom_model.yml', 'r').read(), yaml.Loader)\n", + "config['target_phrase'] = [TARGET_PHRASE]\n", + "config['model_name'] = MODEL_NAME\n", + "config['n_samples'] = N_SAMPLES\n", + "config['n_samples_val'] = N_SAMPLES_VAL\n", + "config['steps'] = STEPS\n", + "config['target_accuracy'] = 0.6\n", + "config['target_recall'] = 0.25\n", + "config['background_paths'] = ['./fma']\n", + "config['rir_paths'] = ['./mit_rirs']\n", + "config['feature_data_files'] = {'ACAV100M_sample': 'openwakeword_features_ACAV100M_2000_hrs_16bit.npy'}\n", + "config['false_positive_validation_data_path'] = 'validation_set_features.npy'\n", + "config['piper_model'] = '/content/piper-sample-generator/models/en_US-libritts_r-medium.pt'\n", + "\n", + "with open('my_model.yaml', 'w') as f:\n", + " yaml.dump(config, f)\n", + "\n", + "print(f'Config sauvegardée pour : \"{TARGET_PHRASE}\" ✅')" ] }, { "cell_type": "code", "execution_count": null, - "id": "482cf2d0", - "metadata": { - "ExecuteTime": { - "end_time": "2023-09-04T15:07:00.859210Z", - "start_time": "2023-09-04T15:07:00.841472Z" - }, - "id": "482cf2d0" - }, + "metadata": {}, "outputs": [], "source": [ - "# Modify values in the config and save a new version\n", - "\n", - "config[\"target_phrase\"] = [\"hey sebastian\"]\n", - "config[\"model_name\"] = config[\"target_phrase\"][0].replace(\" \", \"_\")\n", - "config[\"n_samples\"] = 1000\n", - "config[\"n_samples_val\"] = 1000\n", - "config[\"steps\"] = 10000\n", - "config[\"target_accuracy\"] = 0.6\n", - "config[\"target_recall\"] = 0.25\n", - "\n", - "config[\"background_paths\"] = ['./audioset_16k', './fma'] # multiple background datasets are supported\n", - "config[\"false_positive_validation_data_path\"] = \"validation_set_features.npy\"\n", - "config[\"feature_data_files\"] = {\"ACAV100M_sample\": \"openwakeword_features_ACAV100M_2000_hrs_16bit.npy\"}\n", - "\n", - "with open('my_model.yaml', 'w') as file:\n", - " documents = yaml.dump(config, file)" - ] - }, - { - "cell_type": "markdown", - "id": "aa6b2ab0", - "metadata": { - "id": "aa6b2ab0" - }, - "source": [ - "# Train the Model" - ] - }, - { - "cell_type": "markdown", - "id": "a51202c0", - "metadata": { - "id": "a51202c0" - }, - "source": [ - "With the data downloaded and training configuration set, we can now start training the model. We'll do this in parts to better illustrate the sequence, but you can also execute every step at once for a fully automated process." + "# ============================================================\n", + "# CELLULE 7 — Génération des clips (~15 min)\n", + "# ============================================================\n", + "import subprocess, sys\n", + "\n", + "print('Génération des clips positifs et négatifs...')\n", + "result = subprocess.run(\n", + " [sys.executable, 'openWakeWord/openwakeword/train.py',\n", + " '--training_config', 'my_model.yaml', '--generate_clips'],\n", + " stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True\n", + ")\n", + "# Affiche juste les lignes INFO\n", + "for line in result.stdout.split('\\n'):\n", + " if 'INFO' in line or 'Done' in line or 'ERROR' in line or 'Traceback' in line:\n", + " print(line)\n", + "\n", + "if result.returncode == 0:\n", + " print('\\n✅ Génération terminée !')\n", + "else:\n", + " print(f'\\n❌ Erreur code {result.returncode}')\n", + " print(result.stdout[-2000:])" ] }, { "cell_type": "code", "execution_count": null, - "id": "f01531fa", - "metadata": { - "ExecuteTime": { - "end_time": "2023-09-04T13:50:08.803326Z", - "start_time": "2023-09-04T13:50:06.790241Z" - }, - "id": "f01531fa" - }, + "metadata": {}, "outputs": [], "source": [ - "# Step 1: Generate synthetic clips\n", - "# For the number of clips we are using, this should take ~10 minutes on a free Google Colab instance with a T4 GPU\n", - "# If generation fails, you can simply run this command again as it will continue generating until the\n", - "# number of files meets the targets specified in the config file\n", - "\n", - "!{sys.executable} openwakeword/openwakeword/train.py --training_config my_model.yaml --generate_clips" + "# ============================================================\n", + "# CELLULE 8 — Augmentation + calcul features (~5 min)\n", + "# ============================================================\n", + "import subprocess, sys\n", + "\n", + "print('Augmentation et calcul des features...')\n", + "result = subprocess.run(\n", + " [sys.executable, 'openWakeWord/openwakeword/train.py',\n", + " '--training_config', 'my_model.yaml', '--augment_clips', '--overwrite'],\n", + " stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True\n", + ")\n", + "for line in result.stdout.split('\\n'):\n", + " if 'INFO' in line or 'WARNING' in line or 'ERROR' in line or 'Traceback' in line or '%' in line:\n", + " print(line)\n", + "\n", + "if result.returncode == 0:\n", + " import os\n", + " npy_files = [f for f in os.listdir(f'./my_custom_model/{MODEL_NAME}') if f.endswith('.npy')]\n", + " print(f'\\n✅ Augmentation terminée ! Features: {npy_files}')\n", + "else:\n", + " print(f'\\n❌ Erreur code {result.returncode}')\n", + " print(result.stdout[-2000:])" ] }, { "cell_type": "code", "execution_count": null, - "id": "afeedae4", - "metadata": { - "ExecuteTime": { - "end_time": "2023-09-04T13:56:08.781018Z", - "start_time": "2023-09-04T13:55:40.203515Z" - }, - "id": "afeedae4" - }, + "metadata": {}, "outputs": [], "source": [ - "# Step 2: Augment the generated clips\n", + "# ============================================================\n", + "# CELLULE 9 — Entraînement du modèle (~20-30 min)\n", + "# ============================================================\n", + "import subprocess, sys\n", + "\n", + "print('Entraînement en cours...')\n", + "result = subprocess.run(\n", + " [sys.executable, 'openWakeWord/openwakeword/train.py',\n", + " '--training_config', 'my_model.yaml', '--train_model'],\n", + " stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True\n", + ")\n", + "# Affiche résultats finaux\n", + "lines = result.stdout.split('\\n')\n", + "for line in lines:\n", + " if any(x in line for x in ['Final Model', 'Saving ONNX', 'INFO:root:####', 'Accuracy', 'Recall', 'False Pos']):\n", + " print(line)\n", "\n", - "!{sys.executable} openwakeword/openwakeword/train.py --training_config my_model.yaml --augment_clips" + "import os\n", + "onnx_path = f'./my_custom_model/{MODEL_NAME}.onnx'\n", + "if os.path.exists(onnx_path):\n", + " size = os.path.getsize(onnx_path)\n", + " print(f'\\n✅ Modèle ONNX généré : {onnx_path} ({size/1024:.1f} KB)')\n", + "elif result.returncode != 0:\n", + " # L'erreur tflite est normale et ignorée\n", + " if 'onnx_tf' in result.stdout:\n", + " print('\\n⚠️ Erreur tflite ignorée (non nécessaire pour Android)')\n", + " print('Vérification du fichier ONNX...')\n", + " all_files = os.listdir('./my_custom_model')\n", + " print('Fichiers:', all_files)\n", + " else:\n", + " print(f'\\n❌ Erreur code {result.returncode}')\n", + " print(result.stdout[-3000:])" ] }, { "cell_type": "code", "execution_count": null, - "id": "9ad81ea0", - "metadata": { - "ExecuteTime": { - "end_time": "2023-09-04T15:11:14.742260Z", - "start_time": "2023-09-04T15:07:03.755159Z" - }, - "id": "9ad81ea0" - }, + "metadata": {}, "outputs": [], "source": [ - "# Step 3: Train model\n", - "\n", - "!{sys.executable} openwakeword/openwakeword/train.py --training_config my_model.yaml --train_model" + "# ============================================================\n", + "# CELLULE 10 — Téléchargement du modèle .onnx\n", + "# ============================================================\n", + "import os\n", + "from google.colab import files\n", + "\n", + "onnx_path = f'./my_custom_model/{MODEL_NAME}.onnx'\n", + "\n", + "if os.path.exists(onnx_path):\n", + " size = os.path.getsize(onnx_path)\n", + " print(f'✅ Fichier trouvé : {onnx_path} ({size/1024:.1f} KB)')\n", + " files.download(onnx_path)\n", + " print('✅ Téléchargement lancé !')\n", + " print(f'\\n👉 Pour entraîner bye_bye_snowy :')\n", + " print(' 1. Modifie TARGET_PHRASE dans la cellule 0')\n", + " print(' 2. Relance les cellules 6, 7, 8, 9, 10')\n", + "else:\n", + " print(f'❌ Fichier non trouvé : {onnx_path}')\n", + " print('Fichiers disponibles :', os.listdir('./my_custom_model'))" ] }, { "cell_type": "code", "execution_count": null, - "id": "JSKWWLalnYzR", - "metadata": { - "id": "JSKWWLalnYzR" - }, + "metadata": {}, "outputs": [], "source": [ - "# Step 4 (Optional): On Google Colab, sometimes the .tflite model isn't saved correctly\n", - "# If so, run this cell to retry\n", - "\n", - "# Manually save to tflite as this doesn't work right in colab\n", - "def convert_onnx_to_tflite(onnx_model_path, output_path):\n", - " \"\"\"Converts an ONNX version of an openwakeword model to the Tensorflow tflite format.\"\"\"\n", - " # imports\n", - " import onnx\n", - " import logging\n", - " import tempfile\n", - " from onnx_tf.backend import prepare\n", - " import tensorflow as tf\n", - "\n", - " # Convert to tflite from onnx model\n", - " onnx_model = onnx.load(onnx_model_path)\n", - " tf_rep = prepare(onnx_model, device=\"CPU\")\n", - " with tempfile.TemporaryDirectory() as tmp_dir:\n", - " tf_rep.export_graph(os.path.join(tmp_dir, \"tf_model\"))\n", - " converter = tf.lite.TFLiteConverter.from_saved_model(os.path.join(tmp_dir, \"tf_model\"))\n", - " tflite_model = converter.convert()\n", - "\n", - " logging.info(f\"####\\nSaving tflite mode to '{output_path}'\")\n", - " with open(output_path, 'wb') as f:\n", - " f.write(tflite_model)\n", - "\n", - " return None\n", - "\n", - "convert_onnx_to_tflite(f\"my_custom_model/{config['model_name']}.onnx\", f\"my_custom_model/{config['model_name']}.tflite\")\n" - ] - }, - { - "cell_type": "markdown", - "id": "f9OyUW3ltOSs", - "metadata": { - "id": "f9OyUW3ltOSs" - }, - "source": [ - "After the model finishes training, the auto training script will automatically convert it to ONNX and tflite versions, saving them as `my_custom_model/.onnx/tflite` in the present working directory, where `` is defined in the YAML training config file. Either version can be used as normal with `openwakeword`. I recommend testing them with the [`detect_from_microphone.py`](https://github.com/dscripka/openWakeWord/blob/main/examples/detect_from_microphone.py) example script to see how the model performs!" + "# ============================================================\n", + "# CELLULE 11 — Téléchargement modèles de base OpenWakeWord\n", + "# (nécessaires pour l'intégration Android)\n", + "# ============================================================\n", + "from google.colab import files\n", + "import os\n", + "\n", + "base_models = [\n", + " '/content/openWakeWord/openwakeword/resources/models/melspectrogram.onnx',\n", + " '/content/openWakeWord/openwakeword/resources/models/embedding_model.onnx',\n", + "]\n", + "\n", + "for path in base_models:\n", + " if os.path.exists(path):\n", + " files.download(path)\n", + " print(f'✅ {os.path.basename(path)} téléchargé')\n", + " else:\n", + " print(f'❌ Non trouvé : {path}')" ] } ], "metadata": { + "accelerator": "GPU", "colab": { + "gpuType": "T4", "provenance": [] }, "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", + "display_name": "Python 3", "name": "python3" }, "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.12" - }, - "toc": { - "base_numbering": 1, - "nav_menu": {}, - "number_sections": true, - "sideBar": true, - "skip_h1_title": false, - "title_cell": "Table of Contents", - "title_sidebar": "Contents", - "toc_cell": false, - "toc_position": {}, - "toc_section_display": true, - "toc_window_display": false + "name": "python" } }, "nbformat": 4, - "nbformat_minor": 5 + "nbformat_minor": 0 }