diff --git a/.ci/ignore_treon_docker.txt b/.ci/ignore_treon_docker.txt index 6803c637d09..40537e09838 100644 --- a/.ci/ignore_treon_docker.txt +++ b/.ci/ignore_treon_docker.txt @@ -90,4 +90,5 @@ notebooks/kokoro/kokoro.ipynb notebooks/qwen2.5-omni-chatbot/qwen2.5-omni-chatbot.ipynb notebooks/intern-video2-classiciation/intern-video2-classification.ipynb notebooks/flex.2-image-generation/flex.2-image-generation.ipynb -notebooks/wan2.1-text-to-video/wan2.1-text-to-video.ipynb \ No newline at end of file +notebooks/wan2.1-text-to-video/wan2.1-text-to-video.ipynb +notebooks/ace-step-music-generation/ace-step-music-generation.ipynb \ No newline at end of file diff --git a/.ci/skipped_notebooks.yml b/.ci/skipped_notebooks.yml index 88ae13df2f9..f72a44859e6 100644 --- a/.ci/skipped_notebooks.yml +++ b/.ci/skipped_notebooks.yml @@ -574,3 +574,9 @@ skips: - os: - macos-13 +- notebook: notebooks/ace-step-music-generation/ace-step-music-generation.ipynb + skips: + - os: + - macos-13 + - ubuntu-22.04 + - windows-2022 diff --git a/.ci/spellcheck/.pyspelling.wordlist.txt b/.ci/spellcheck/.pyspelling.wordlist.txt index 317159d286c..ed4252687a6 100644 --- a/.ci/spellcheck/.pyspelling.wordlist.txt +++ b/.ci/spellcheck/.pyspelling.wordlist.txt @@ -52,6 +52,8 @@ autogenerated AutoModelForXxx autoregressive autoregressively +AutoEncoder +AutoEncoders AutoTokenizer AWQ awq @@ -201,6 +203,7 @@ denoises denoising denormalization denormalized +demucs depainting deployable DepthAnything @@ -231,6 +234,7 @@ DIT DiT DiT’s DiT’s +DiTs DL DocLayNet docling @@ -291,6 +295,8 @@ FastDraft FastSAM FC feedforward +FeedForward +FFN FFmpeg FIL FEIL @@ -608,6 +614,7 @@ MRPC mRoPE msi MTVQA +mT multiarchitecture Multiclass multiclass @@ -705,6 +712,7 @@ opset optimizable Orca otsl +OSNet OTSL OuteTTS outpainting @@ -780,6 +788,7 @@ PowerShell PPYOLOv PR Prateek +PLR pre Precisions precomputed @@ -945,6 +954,7 @@ SmolVLM softmax softvc SoftVC +SongGen SOTA SoTA soundfile @@ -1125,6 +1135,7 @@ Vladlen VOC Vocoder vocoder +vocoding VQ VQA VQGAN diff --git a/notebooks/ace-step-music-generation/README.md b/notebooks/ace-step-music-generation/README.md new file mode 100644 index 00000000000..c21adb0534e --- /dev/null +++ b/notebooks/ace-step-music-generation/README.md @@ -0,0 +1,33 @@ +# Music generation using ACE Step and OpenVINO + +[ACE-Step](https://ace-step.github.io/) is a novel open-source foundation model for music generation that overcomes key limitations of existing approaches and achieves state-of-the-art performance through a holistic architectural design. Current methods face inherent trade-offs between generation speed, musical coherence, and controllability. ACE-Step bridges this gap by integrating diffusion-based generation with Sana’s Deep Compression AutoEncoder (DCAE) and a lightweight linear transformer. The model achieving superior musical coherence and lyric alignment across melody, harmony, and rhythm metrics. Moreover, ACE-Step preserves fine-grained acoustic details, enabling advanced control mechanisms such as voice cloning, lyric editing, remixing, and track generation (e.g., lyric2vocal, singing2accompaniment). + +ACE-Step adapts a text-to-image diffusion framework for music generation. The core generative model is a diffusion model operating on a compressed mel spectrogram latent representation. This process is guided by conditioning information from three specialized encoders: a text prompt encoder, a lyric encoder, and a speaker encoder. Embeddings from these encoders are concatenated and integrated into the diffusion model via cross-attention mechanisms + +ACE-Step can be used for generating original music from text descriptions, music remixing and style transfer, edit song lyrics. The model offers a set of controllable features that allow users to precisely control the generation process and enable targeted modifications to existing audio material, as well as perform specialized generation tasks through fine-tuning. + + + +More details about the model can be found using the following resources: [project page](https://ace-step.github.io/), [paper](https://arxiv.org/abs/2506.00045), [original repository](https://github.com/ace-step/ACE-Step). + + +## Notebook Contents + +This notebook demonstrates how to convert and run music generation or editing with ACE Step using OpenVINO. + +The tutorial consists of the following steps: + +- Install prerequisites +- Download and run inference of ACE Step model +- Convert the model to IR format and run inference with OpenVINO +- Download, apply and generate audio with LoRA +- Interactive demo + + +## Installation Instructions + +This is a self-contained example that relies solely on its own code.
+We recommend running the notebook in a virtual environment. You only need a Jupyter server to start. +For details, please refer to [Installation Guide](../../README.md). + + diff --git a/notebooks/ace-step-music-generation/ace-step-music-generation.ipynb b/notebooks/ace-step-music-generation/ace-step-music-generation.ipynb new file mode 100644 index 00000000000..9d9b9f185ba --- /dev/null +++ b/notebooks/ace-step-music-generation/ace-step-music-generation.ipynb @@ -0,0 +1,881 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "b464bcb6-cff9-46e7-bd86-fa40bdd3ad18", + "metadata": {}, + "source": [ + "# Music generation using ACE Step and OpenVINO\n", + "\n", + "
Important note: This notebook requires python >= 3.10. Please make sure that your environment fulfill to this requirement before running it
\n", + "\n", + "[ACE-Step](https://ace-step.github.io/) is a novel open-source foundation model for music generation that overcomes key limitations of existing approaches and achieves state-of-the-art performance through a holistic architectural design. Current methods face inherent trade-offs between generation speed, musical coherence, and controllability. ACE-Step bridges this gap by integrating diffusion-based generation with Sana’s Deep Compression AutoEncoder (DCAE) and a lightweight linear transformer. The model achieving superior musical coherence and lyric alignment across melody, harmony, and rhythm metrics. Moreover, ACE-Step preserves fine-grained acoustic details, enabling advanced control mechanisms such as voice cloning, lyric editing, remixing, and track generation (e.g., lyric2vocal, singing2accompaniment). \n", + "\n", + "\n", + "ACE-Step adapts a text-to-image diffusion framework for music generation. The core generative model is a diffusion model operating on a compressed mel spectrogram latent representation. This process is guided by conditioning information from three specialized encoders: a text prompt encoder, a lyric encoder, and a speaker encoder. Embeddings from these encoders are concatenated and integrated into the diffusion model via cross-attention mechanisms\n", + "\n", + "\n", + "ACE-Step can be used for generating original music from text descriptions, music remixing and style transfer, edit song lyrics. The model offers a set of controllable features that allow users to precisely control the generation process and enable targeted modifications to existing audio material, as well as perform specialized generation tasks through fine-tuning.\n", + "\n", + "\n", + "\n", + "\n", + "More details about the model can be found using the following resources: [project page](https://ace-step.github.io/), [paper](https://arxiv.org/abs/2506.00045), [original repository](https://github.com/ace-step/ACE-Step).\n", + "\n", + "\n", + "\n", + "\n", + "### Installation Instructions\n", + "\n", + "This is a self-contained example that relies solely on its own code.\n", + "\n", + "We recommend running the notebook in a virtual environment. You only need a Jupyter server to start.\n", + "For details, please refer to [Installation Guide](https://github.com/openvinotoolkit/openvino_notebooks/blob/latest/README.md#-installation-guide).\n", + "\n", + "\n", + "\n", + "#### Table of contents:\n", + "\n", + "- [Prerequisites](#Prerequisites)\n", + "- [Music generation with ACE Step Pipeline via PyTorch](#Music-generation-with-ACE-Step-Pipeline-via-PyTorch)\n", + " - [Download checkpoints and load PyTorch models](#Download-checkpoints-and-load-PyTorch-models)\n", + " - [Configure parameters and generate audio](#Configure-parameters-and-generate-audio)\n", + " - [Update audio](#Update-audio)\n", + "- [Music Generation with ACE Step via OpenVINO](#Music-Generation-with-ACE-Step-via-OpenVINO)\n", + " - [Convert model to OpenVINO](#Convert-model-to-OpenVINO)\n", + " - [Select inference device](#Select-inference-device)\n", + " - [Create pipeline, read and compile models](#Create-pipeline,-read-and-compile-models)\n", + " - [Generate audio](#Generate-audio)\n", + " - [Update audio](#Update-audio)\n", + "- [ACE Step with LoRA](#ACE-Step-with-LoRA)\n", + " - [Load LoRA and apply for Transformer models for ACE Step pipeline](#Load-LoRA-and-apply-for-Transformer-models-for-ACE-Step-pipeline)\n", + " - [Convert models with LoRA and load to OpenVINO pipeline](#Convert-models-with-LoRA-and-load-to-OpenVINO-pipeline)\n", + " - [Run Inference with LoRA](#Run-Inference-with-LoRA)\n", + " - [Deactivate LoRA](#Deactivate-LoRA)\n", + "- [Interactive demo](#Interactive-demo)\n", + "\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "915b561b-3a13-4deb-8a61-9bab9abb6faa", + "metadata": {}, + "source": [ + "## Prerequisites\n", + "[back to top ⬆️](#Table-of-contents:)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16f5d8e3-87e7-4396-97c1-147b6b82d0f8", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import requests\n", + "import platform\n", + "from pathlib import Path\n", + "\n", + "if not Path(\"ov_ace_helper.py\").exists():\n", + " r = requests.get(url=\"https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/notebooks/ace-step-music-generation/ov_ace_helper.py\")\n", + " open(\"ov_ace_helper.py\", \"w\").write(r.text)\n", + "\n", + "if not Path(\"gradio_helper.py\").exists():\n", + " r = requests.get(url=\"https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/notebooks/ace-step-music-generation/gradio_helper.py\")\n", + " open(\"gradio_helper.py\", \"w\").write(r.text)\n", + "\n", + "if not Path(\"notebook_utils.py\").exists():\n", + " r = requests.get(url=\"https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py\")\n", + " open(\"notebook_utils.py\", \"w\").write(r.text)\n", + "\n", + "if not Path(\"pip_helper.py\").exists():\n", + " r = requests.get(\n", + " url=\"https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/pip_helper.py\",\n", + " )\n", + " open(\"pip_helper.py\", \"w\").write(r.text)\n", + "\n", + "from pip_helper import pip_install\n", + "\n", + "pip_install(\"gradio>=4.19\")\n", + "if platform.system() == \"Darwin\":\n", + " pip_install(\"numpy<2.0\")\n", + "\n", + "pip_install(\n", + " \"git+https://github.com/ace-step/ACE-Step.git@6ae0852b1388de6dc0cca26b31a86d711f723cb3\", \"--extra-index-url\", \"https://download.pytorch.org/whl/cpu\"\n", + ")\n", + "\n", + "pip_install(\"openvino>=2025.1.0\", \"openvino-tokenizers>=2025.1.0\", \"nncf>=2.16.0\")\n", + "\n", + "# Read more about telemetry collection at https://github.com/openvinotoolkit/openvino_notebooks?tab=readme-ov-file#-telemetry\n", + "from notebook_utils import collect_telemetry\n", + "\n", + "collect_telemetry(\"ace-step-music-generation.ipynb\")" + ] + }, + { + "cell_type": "markdown", + "id": "358dd910", + "metadata": {}, + "source": [ + "## Music generation with ACE Step Pipeline via PyTorch\n", + "[back to top ⬆️](#Table-of-contents:)" + ] + }, + { + "cell_type": "markdown", + "id": "cc3880f0", + "metadata": {}, + "source": [ + "### Download checkpoints and load PyTorch models\n", + "[back to top ⬆️](#Table-of-contents:)\n", + "\n", + "The architecture of ACE Step consists of following components.\n", + "\n", + "Linear Diffusion Transformer.\n", + "\n", + "Core Diffusion Model is a Linear Diffusion Transformer (DiT). DiT architecture adapts the Transformer to serve as the backbone for diffusion models, primarily replacing U-Net structures common in image generation. DiTs treat latent representations (e.g., image patches or segments of audio features) as sequences of tokens. These tokens, along with embeddings for the diffusion timestep and any conditioning information, are processed by a series of Transformer blocks. ACE-Step adapt the Linear DiT structure from Sana with several modifications. This significantly reduces model size and memory consumption; And added 1D Convolutional FeedForward Layers (FFN).\n", + "\n", + "Conditioning Encoders.\n", + "\n", + "To guide the generation process, the Linear DiT is conditioned on embeddings from the following encoders:\n", + "- Text Encoder: It is a mT5-base model, which generate 768-dimensional embeddings from textual prompts.\n", + "- Lyric Encoder: The lyric encoder architecture and hyperparameters are adopted from SongGen.\n", + "- Speaker Encoder: The speaker encoder processes a 10 - second unaccompanied vocal segment, which is separated by demucs, into a 512 - dimensional embedding. For full songs with vocals, embeddings from multiple such segments are averaged. A zero vector is used as the speaker embedding for instrumental tracks. The encoder, pre-trained on a large and diverse singing voice corpus, draws architectural inspiration from\n", + "PLR-OSNet, originally designed for face recognition. The model was tuned to prevents the model from over-relying on timbre information for stylistic interpretation, thereby enabling reasonable timbre generation even without explicit speaker input.\n", + "\n", + "Deep Compression AutoEncoders.\n", + "\n", + "For efficient latent space modeling, ACE-Step uses Deep Compression AutoEncoder (DCAE). A DCAE is an encoder to map high-dimensional input (e.g., mel-spectrograms) to a much lower-dimensional latent representation, and a decoder to reconstruct the original input from this latent code. The \"deep compression\" aspect implies a focus on achieving a highly compact latent space while minimizing reconstruction error. For audio, this means capturing salient acoustic features essential for perception and quality within a small number of latent variables. This not only reduces the computational burden for subsequent generative models operating in this latent space but also encourages the generative model to focus on higher-level structural and semantic aspects rather than low-level waveform details. The specific architecture of DCAE (e.g., convolutional layers, quantization if used) is optimized for this trade-off. For converting the generated mel-spectrograms back to waveform (vocoding), ACE-Step utilizes a pre-trained universal music vocoder from Fish Audio." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8b542df6", + "metadata": {}, + "outputs": [], + "source": [ + "from acestep.pipeline_ace_step import ACEStepPipeline\n", + "\n", + "checkpoint_dir = \"\"\n", + "pipeline = ACEStepPipeline(checkpoint_dir=checkpoint_dir, dtype=\"float32\", cpu_offload=False)\n", + "pipeline.load_checkpoint(checkpoint_dir)" + ] + }, + { + "cell_type": "markdown", + "id": "e23dd34c", + "metadata": {}, + "source": [ + "### Configure parameters and generate audio\n", + "[back to top ⬆️](#Table-of-contents:)\n", + "\n", + "We will specify some parameters: `prompt`, `lyrics`, `infer_step` `save_path`, `audio_duration` in second, `task` and parameters related to the Entropy Rectifying Guidance `use_erg_tag`, `use_erg_lyric`, `use_erg_diffusion`, if these parameters are enabled, the temperature will be multiply to the attention to make a weaker, for example with `use_erg_lyric`, lyric condition and make better diversity.\n", + "\n", + "Some other options which can be specified: `guidance_scale`, `guidance_scale_text`, `guidance_scale_lyric`, `guidance_interval` - guidance interval for the generation, `min_guidance_scale`, `guidance_interval_decay` - guidance interval decay for the generation, guidance scale will decay from guidance_scale to min_guidance_scale in the interval, `omega_scale` - Granularity scale for the generation. Higher values can reduce artifacts, `oss_steps` - optimal Steps for the generation, `audio2audio_enable` - enable Audio-to-Audio generation using a reference audio, `ref_audio_input` - reference audio for audio2audio task, `ref_audio_strength`.\n", + "\n", + "More information about parameters can be found in [ACE Step repo](https://github.com/ace-step/ACE-Step)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "359fabcb", + "metadata": { + "test_replace": { + " \"audio_duration\": 15.0,\n": " \"audio_duration\": 5.0,\n", + " \"infer_step\": 25,\n": " \"infer_step\": 10,\n", + " \"lyrics\": \"[verse]\\nWoke up to the sunrise glow\\nTook my heart and hit the road[inst]\",\n": " \"lyrics\": \"[verse]\\nWoke up \\n[inst]\",\n" + } + }, + "outputs": [], + "source": [ + "import os\n", + "\n", + "inputs = {\n", + " \"prompt\": \"country rock, folk rock, southern rock, bluegrass, country pop\",\n", + " \"lyrics\": \"[verse]\\nWoke up to the sunrise glow\\nTook my heart and hit the road[inst]\",\n", + " \"audio_duration\": 15.0,\n", + " \"infer_step\": 25,\n", + " \"use_erg_tag\": False,\n", + " \"use_erg_lyric\": True,\n", + " \"use_erg_diffusion\": True,\n", + " \"save_path\": Path(\"outputs\").absolute().as_posix(),\n", + " \"task\": \"text2music\",\n", + "}\n", + "\n", + "if not Path(inputs[\"save_path\"]).exists():\n", + " os.mkdir(inputs[\"save_path\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "f559c85c", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2025-08-18 21:08:31.631\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36macestep.pipeline_ace_step\u001b[0m:\u001b[36m__call__\u001b[0m:\u001b[36m1488\u001b[0m - \u001b[1mModel loaded in 0.00 seconds.\u001b[0m\n", + "\u001b[32m2025-08-18 21:08:31.682\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36macestep.pipeline_ace_step\u001b[0m:\u001b[36mtext2music_diffusion_process\u001b[0m:\u001b[36m847\u001b[0m - \u001b[1mcfg_type: apg, guidance_scale: 15.0, omega_scale: 10.0\u001b[0m\n", + "\u001b[32m2025-08-18 21:08:31.684\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36macestep.pipeline_ace_step\u001b[0m:\u001b[36mtext2music_diffusion_process\u001b[0m:\u001b[36m1072\u001b[0m - \u001b[1mstart_idx: 6, end_idx: 18, num_inference_steps: 25\u001b[0m\n", + "100%|███████████████████████████████████████████| 25/25 [00:33<00:00, 1.35s/it]\n", + " 0%| | 0/1 [00:00\n", + " \n", + " Your browser does not support the audio element.\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import IPython.display as ipd\n", + "\n", + "result = pipeline(**inputs)\n", + "\n", + "output_path = result[0]\n", + "display(ipd.Audio(output_path))" + ] + }, + { + "cell_type": "markdown", + "id": "216b725c", + "metadata": {}, + "source": [ + "### Update audio\n", + "[back to top ⬆️](#Table-of-contents:)\n", + "\n", + "ACE Step provides functions for updating audio. The following tasks are available in ACE Step for that: `retake`, `repaint`, `edit` and `extend`.\n", + "For retaking `retake_variance` and `retake_seeds` can be specified. We can edit prompt with `edit_target_prompt` or lyrics with `edit_target_lyrics`, also we can set `edit_n_min`,`edit_n_max` and `edit_n_avg`. For repainting it is possible to setup `retake_variance`, `retake_seeds`, `repaint_start` and `repaint_end` times. And extend can be configured with `repaint_start` and `repaint_end` options in seconds. Also source audio should be provided via `src_audio_path`.\n", + "\n", + "Let's try to edit style of generated audio." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "02c814a4", + "metadata": {}, + "outputs": [], + "source": [ + "import copy\n", + "\n", + "inputs_update_audio = copy.deepcopy(inputs)\n", + "inputs_update_audio.update(\n", + " {\n", + " \"edit_target_prompt\": \"classical, orchestral, strings, piano, 60 bpm, elegant, emotive, timeless, instrumental\",\n", + " \"edit_target_lyrics\": inputs[\"lyrics\"],\n", + " \"edit_n_min\": 0.2,\n", + " \"edit_n_max\": 0.4,\n", + " \"task\": \"edit\",\n", + " \"src_audio_path\": output_path,\n", + " }\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "34d73ee0", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2025-08-18 21:14:15.309\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36macestep.pipeline_ace_step\u001b[0m:\u001b[36m__call__\u001b[0m:\u001b[36m1488\u001b[0m - \u001b[1mModel loaded in 0.00 seconds.\u001b[0m\n", + "\u001b[32m2025-08-18 21:14:16.593\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36macestep.pipeline_ace_step\u001b[0m:\u001b[36mtokenize_lyrics\u001b[0m:\u001b[36m465\u001b[0m - \u001b[1mdebbug [verse] --> zh --> ['[en]', '[verse]']\u001b[0m\n", + "\u001b[32m2025-08-18 21:14:16.594\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36macestep.pipeline_ace_step\u001b[0m:\u001b[36mtokenize_lyrics\u001b[0m:\u001b[36m465\u001b[0m - \u001b[1mdebbug Woke up to the sunrise glow --> en --> ['[en]', 'w', 'ok', 'e', ' ', 'up', ' ', 'to', ' ', 'the', ' ', 'sun', 'ris', 'e', ' ', 'gl', 'ow']\u001b[0m\n", + "\u001b[32m2025-08-18 21:14:16.596\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36macestep.pipeline_ace_step\u001b[0m:\u001b[36mtokenize_lyrics\u001b[0m:\u001b[36m465\u001b[0m - \u001b[1mdebbug Took my heart and hit the road[inst] --> en --> ['[en]', 'to', 'ok', ' ', 'my', ' ', 'he', 'a', 'rt', ' ', 'and', ' ', 'h', 'it', ' ', 'the', ' ', 'ro', 'ad', '[inst]']\u001b[0m\n", + "\u001b[32m2025-08-18 21:14:16.598\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36macestep.pipeline_ace_step\u001b[0m:\u001b[36mflowedit_diffusion_process\u001b[0m:\u001b[36m658\u001b[0m - \u001b[1mflowedit start from 5 to 10\u001b[0m\n", + "100%|███████████████████████████████████████████| 25/25 [00:44<00:00, 1.79s/it]\n", + " 0%| | 0/1 [00:00\n", + " \n", + " Your browser does not support the audio element.\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "result = pipeline(**inputs_update_audio)\n", + "\n", + "display(ipd.Audio(result[0]))" + ] + }, + { + "cell_type": "markdown", + "id": "448566eb", + "metadata": {}, + "source": [ + "## Music Generation with ACE Step via OpenVINO\n", + "[back to top ⬆️](#Table-of-contents:)" + ] + }, + { + "cell_type": "markdown", + "id": "ff826412", + "metadata": {}, + "source": [ + "### Convert model to OpenVINO\n", + "[back to top ⬆️](#Table-of-contents:)\n", + "\n", + "OpenVINO supports PyTorch models via conversion to OpenVINO Intermediate Representation (IR). [OpenVINO model conversion API](https://docs.openvino.ai/2024/openvino-workflow/model-preparation.html#convert-a-model-with-python-convert-model) can be used for these purposes. `ov.convert_model` function accepts original PyTorch model instance and example input for tracing and returns `ov.Model` representing this model in OpenVINO framework. Converted model can be saved on disk using `ov.save_model` function or loading on device using `core.complie_model`.\n", + "\n", + "`ov_ace_helper.py` script contains helper function for model conversion, please check its content if you interested in conversion details. \n", + "\n", + "Let's convert models to IR format. In output folder you will find tokenizer - `openvino_tokenizer.xml`, text encoder model - `ov_text_encoder_model.xml`. Lyric encoder and speaker encoder will be part of transformer models. DCAE decoder and encoder models: `ov_dcae_encoder_model.xml`, `ov_dcae_decoder_model.xml`, vocoder decoder and mel_transform models: `ov_vocoder_decode_model.xml`, `ov_vocoder_mel_transform_model.xml`. And transformer models: `ov_transformer_decoder_model.xml`, `ov_transformer_encoder_model.xml`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14c01092", + "metadata": {}, + "outputs": [], + "source": [ + "from gradio_helper import get_model_compression_format_widgets\n", + "\n", + "model_format = get_model_compression_format_widgets()\n", + "model_format" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f9db0d8a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "⌛ Conversion started. Be patient, it may takes some time.\n", + "⌛ Convert Tokenizer\n", + "✅ Tokenizer is converted\n", + "⌛ Convert UMT5 Encoder model\n", + "✅ UMT5 Encoder model converted\n", + "⌛ Convert Sana's Deep Compression AutoEncoder model\n", + "✅ Sana's Deep Compression AutoEncoder model converted\n", + "⌛ Convert Sana's Deep Compression AutoEncoder Decoder model\n", + "✅ Sana's Deep Compression AutoEncoder Decoder model converted\n", + "⌛ Convert Vocoder Mel Tranform model\n", + "✅ Vocoder Mel Tranform model converted\n", + "⌛ Convert Vocoder Decoder model\n", + "✅ Vocoder Decoder model converted\n", + "⌛ Convert Transformer Encoder with Entropy Rectifying Guidance model\n", + "✅ Transformer Encoder with Entropy Rectifying Guidance model converted\n", + "⌛ Convert Transformer Decoder with Entropy Rectifying Guidance model\n", + "✅ Transformer Decoder with Entropy Rectifying Guidance model converted\n" + ] + } + ], + "source": [ + "import nncf\n", + "from ov_ace_helper import convert_models\n", + "\n", + "ov_converted_model_dir = \"ov_models\"\n", + "if model_format.value == \"INT4\":\n", + " weights_compression_config = {\"mode\": nncf.CompressWeightsMode.INT4_ASYM, \"group_size\": 128, \"ratio\": 0.8}\n", + " ov_converted_model_dir += \"_int4\"\n", + "elif model_format.value == \"INT8\":\n", + " weights_compression_config = {\"mode\": nncf.CompressWeightsMode.INT8_ASYM}\n", + " ov_converted_model_dir += \"_int8\"\n", + "else:\n", + " weights_compression_config = None\n", + "\n", + "convert_models(pipeline, model_dir=ov_converted_model_dir, orig_checkpoint_path=checkpoint_dir, quantization_config=weights_compression_config)" + ] + }, + { + "cell_type": "markdown", + "id": "3154daae", + "metadata": {}, + "source": [ + "### Select inference device\n", + "[back to top ⬆️](#Table-of-contents:)\n", + "\n", + "Select device from dropdown list for running inference using OpenVINO" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "97f24bb1", + "metadata": {}, + "outputs": [], + "source": [ + "from notebook_utils import device_widget\n", + "\n", + "device = device_widget()\n", + "\n", + "device" + ] + }, + { + "cell_type": "markdown", + "id": "b3b1d996", + "metadata": {}, + "source": [ + "### Create pipeline, read and compile models\n", + "[back to top ⬆️](#Table-of-contents:)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "a3d15fff", + "metadata": {}, + "outputs": [], + "source": [ + "from ov_ace_helper import OVACEStepPipeline\n", + "\n", + "ov_pipeline = OVACEStepPipeline()\n", + "ov_pipeline.load_models(ov_models_path=ov_converted_model_dir, device=device.value)" + ] + }, + { + "cell_type": "markdown", + "id": "c2e4182b", + "metadata": {}, + "source": [ + "### Generate audio\n", + "[back to top ⬆️](#Table-of-contents:)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "ae0c6914", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2025-08-18 21:18:00.763\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36macestep.pipeline_ace_step\u001b[0m:\u001b[36m__call__\u001b[0m:\u001b[36m1488\u001b[0m - \u001b[1mModel loaded in 0.00 seconds.\u001b[0m\n", + "\u001b[32m2025-08-18 21:18:01.178\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mov_ace_helper\u001b[0m:\u001b[36mtext2music_diffusion_process\u001b[0m:\u001b[36m571\u001b[0m - \u001b[1mcfg_type: apg, guidance_scale: 15.0, omega_scale: 10.0\u001b[0m\n", + "\u001b[32m2025-08-18 21:18:01.180\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mov_ace_helper\u001b[0m:\u001b[36mtext2music_diffusion_process\u001b[0m:\u001b[36m796\u001b[0m - \u001b[1mstart_idx: 6, end_idx: 18, num_inference_steps: 25\u001b[0m\n", + "100%|███████████████████████████████████████████| 25/25 [00:20<00:00, 1.20it/s]\n", + " 0%| | 0/1 [00:00\n", + " \n", + " Your browser does not support the audio element.\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import IPython.display as ipd\n", + "\n", + "ov_result = ov_pipeline(**inputs)\n", + "\n", + "ov_out_audio_path = ov_result[0]\n", + "display(ipd.Audio(ov_out_audio_path))" + ] + }, + { + "cell_type": "markdown", + "id": "2e2bed54", + "metadata": {}, + "source": [ + "### Update audio\n", + "[back to top ⬆️](#Table-of-contents:)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2724e9d1", + "metadata": {}, + "outputs": [], + "source": [ + "from gradio_helper import update_audio_widget\n", + "\n", + "update_audio_type = update_audio_widget()\n", + "\n", + "update_audio_type" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2c7f7c79", + "metadata": {}, + "outputs": [], + "source": [ + "from gradio_helper import setup_update_audio_widgets\n", + "\n", + "setup_update_audio_options, update_task = setup_update_audio_widgets(update_audio_type)\n", + "\n", + "setup_update_audio_options" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "477887d3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'prompt': 'country rock, folk rock, southern rock, bluegrass, country pop', 'lyrics': '[verse]\\nWoke up to the sunrise glow\\nTook my heart and hit the road[inst]', 'audio_duration': 15.0, 'infer_step': 25, 'use_erg_tag': False, 'use_erg_lyric': True, 'use_erg_diffusion': True, 'save_path': './outputs', 'task': 'repaint', 'src_audio_path': './outputs/output_20250818211825_0.wav', 'retake_variance': 0.2, 'repaint_start': 0.0, 'repaint_end': 5.000000000000001}\n" + ] + } + ], + "source": [ + "from gradio_helper import get_inputs_base_on_setup_widget\n", + "\n", + "inputs_update_audio = get_inputs_base_on_setup_widget(\n", + " base_inputs=inputs, source_audio_path=ov_out_audio_path, setup_widgets=setup_update_audio_options, task=update_task\n", + ")\n", + "print(inputs_update_audio)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "e583c146", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2025-08-18 21:19:03.391\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36macestep.pipeline_ace_step\u001b[0m:\u001b[36m__call__\u001b[0m:\u001b[36m1488\u001b[0m - \u001b[1mModel loaded in 0.00 seconds.\u001b[0m\n", + "\u001b[32m2025-08-18 21:19:04.374\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mov_ace_helper\u001b[0m:\u001b[36mtext2music_diffusion_process\u001b[0m:\u001b[36m571\u001b[0m - \u001b[1mcfg_type: apg, guidance_scale: 15.0, omega_scale: 10.0\u001b[0m\n", + "\u001b[32m2025-08-18 21:19:04.376\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mov_ace_helper\u001b[0m:\u001b[36mtext2music_diffusion_process\u001b[0m:\u001b[36m796\u001b[0m - \u001b[1mstart_idx: 6, end_idx: 18, num_inference_steps: 25\u001b[0m\n", + " 0%| | 0/25 [00:00\n", + " \n", + " Your browser does not support the audio element.\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "ov_update_result = ov_pipeline(**inputs_update_audio)\n", + "\n", + "display(ipd.Audio(ov_update_result[0]))" + ] + }, + { + "cell_type": "markdown", + "id": "48cbba57", + "metadata": {}, + "source": [ + "## ACE Step with LoRA\n", + "[back to top ⬆️](#Table-of-contents:)\n", + "\n", + "\n", + "LoRA is a technique that allows to fine-tune large models with a small number of parameters. ACE Step support LoRA, more information about it can be find [here](https://github.com/ace-step/ACE-Step?tab=readme-ov-file#-applications).\n", + "\n", + "Let's try LoRA. To use LoRA for ACE Step and OpenVINO, LoRA should be applied for the model and model should be converted to IR format." + ] + }, + { + "cell_type": "markdown", + "id": "296f5a17", + "metadata": {}, + "source": [ + "### Load LoRA and apply for Transformer models for ACE Step pipeline\n", + "[back to top ⬆️](#Table-of-contents:)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "53af1865", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d60850be7409415fab974f25221e98d0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fetching 4 files: 0%| | 0/4 [00:00\n", + " \n", + " Your browser does not support the audio element.\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import IPython.display as ipd\n", + "\n", + "inputs[\"prompt\"] = \"chopper rap, female, rap flow, groovy, singing, smooth Rhythm, synthesizer lead, heavy bassline\"\n", + "result = ov_pipeline(**inputs)\n", + "\n", + "orig_path = result[0]\n", + "display(ipd.Audio(result[0]))" + ] + }, + { + "cell_type": "markdown", + "id": "d4f17ee4", + "metadata": {}, + "source": [ + "### Deactivate LoRA\n", + "[back to top ⬆️](#Table-of-contents:)\n", + "\n", + "To unload LoRA and use original IRs we will run `load_lora` and provide `none` instead of path." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "a2aac62e", + "metadata": {}, + "outputs": [], + "source": [ + "ov_pipeline.load_lora(\"none\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "1c114a00-2529-4a43-833c-bffbabf44ba5", + "metadata": {}, + "source": [ + "## Interactive demo\n", + "[back to top ⬆️](#Table-of-contents:)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "add3ee7a-9395-4cdd-b89b-a3efe5df474b", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from gradio_helper import make_demo\n", + "from acestep.data_sampler import DataSampler\n", + "\n", + "data_sampler = DataSampler()\n", + "\n", + "demo = make_demo(pipeline=ov_pipeline, data_sampler=data_sampler)\n", + "\n", + "try:\n", + " demo.queue().launch(debug=True, height=800)\n", + "except Exception:\n", + " demo.queue().launch(debug=True, share=True, height=800)\n", + "# If you are launching remotely, specify server_name and server_port\n", + "# EXAMPLE: `demo.launch(server_name='your server name', server_port='server port in int')`\n", + "# To learn more please refer to the Gradio docs: https://gradio.app/docs/" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ace", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "openvino_notebooks": { + "imageUrl": "https://raw.githubusercontent.com/ACE-Step/ACE-Step/main/assets/ACE-Step_framework.png", + "tags": { + "categories": [ + "Model Demos" + ], + "libraries": [], + "other": [ + "Stable Diffusion" + ], + "tasks": [ + "Audio Generation", + "Text-to-Audio", + "Audio-to-Audio" + ] + } + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/ace-step-music-generation/gradio_helper.py b/notebooks/ace-step-music-generation/gradio_helper.py new file mode 100644 index 00000000000..2129a39b233 --- /dev/null +++ b/notebooks/ace-step-music-generation/gradio_helper.py @@ -0,0 +1,851 @@ +import gradio as gr +import librosa +import copy +from acestep.ui.components import GENRE_PRESETS, TAG_DEFAULT, update_tags_from_preset, create_output_ui +import ipywidgets as widgets + + +LYRIC_DEFAULT = """[verse] +Neon lights they flicker bright +City hums in dead of night +Rhythms pulse through concrete veins +Lost in echoes of refrains +""" + + +def create_text2music_ui(gr, ov_pipeline, sample_data_func=None, lora_path=None): + with gr.Row(): + with gr.Column(): + with gr.Row(equal_height=True): + audio_duration = gr.Slider( + 0, + 90.0, + step=0.01, + value=10, + label="Audio Duration", + interactive=True, + info="The length of the generated audio in sec. Longer duration takes more time.", + scale=9, + ) + + with gr.Row(equal_height=True): + lora_enable = gr.Checkbox(label="Enable LoRA", value=False, info="Enable generation with Rap LoRA.", elem_id="lora_checkbox") + + def toggle_lora_in_pipeline(is_checked): + if is_checked: + ov_pipeline.load_lora(lora_path) + else: + ov_pipeline.load_lora("none") + + lora_enable.change(fn=toggle_lora_in_pipeline, inputs=[lora_enable]) + + with gr.Row(equal_height=True): + audio2audio_enable = gr.Checkbox( + label="Enable Audio2Audio", value=False, info="Enable Audio-to-Audio generation using a reference audio.", elem_id="audio2audio_checkbox" + ) + + ref_audio_input = gr.Audio( + type="filepath", label="Reference Audio (for Audio2Audio task)", visible=False, elem_id="ref_audio_input", show_download_button=True + ) + ref_audio_strength = gr.Slider( + label="Refer audio strength", + minimum=0.0, + maximum=1.0, + step=0.01, + value=0.5, + elem_id="ref_audio_strength", + visible=False, + interactive=True, + ) + + def toggle_ref_audio_visibility(is_checked): + return ( + gr.update(visible=is_checked, elem_id="ref_audio_input"), + gr.update(visible=is_checked, elem_id="ref_audio_strength"), + ) + + audio2audio_enable.change( + fn=toggle_ref_audio_visibility, + inputs=[audio2audio_enable], + outputs=[ref_audio_input, ref_audio_strength], + ) + + with gr.Column(scale=2): + with gr.Group(): + gr.Markdown( + """
Support tags, descriptions, and scene. Use commas to separate different tags.
Tags and lyrics examples are from AI music generation community.
""" + ) + with gr.Row(): + genre_preset = gr.Dropdown( + choices=["Custom"] + list(GENRE_PRESETS.keys()), + value="Custom", + label="Preset", + scale=1, + ) + prompt = gr.Textbox( + lines=1, + label="Tags", + max_lines=4, + value=TAG_DEFAULT, + scale=9, + ) + + # Add the change event for the preset dropdown + genre_preset.change(fn=update_tags_from_preset, inputs=[genre_preset], outputs=[prompt]) + with gr.Group(): + gr.Markdown( + """
Support lyric structure tags like [verse], [chorus], and [bridge] to separate different parts of the lyrics.
Use [instrumental] or [inst] to generate instrumental music. Not support genre structure tag in lyrics
""" + ) + lyrics = gr.Textbox( + lines=9, + label="Lyrics", + max_lines=13, + value=LYRIC_DEFAULT, + ) + + with gr.Accordion("Basic Settings", open=False): + infer_step = gr.Slider( + minimum=1, + maximum=200, + step=1, + value=20, + label="Infer Steps", + interactive=True, + ) + guidance_scale = gr.Slider( + minimum=0.0, + maximum=30.0, + step=0.1, + value=15.0, + label="Guidance Scale", + interactive=True, + info="When guidance_scale_lyric > 1 and guidance_scale_text > 1, the guidance scale will not be applied.", + ) + guidance_scale_text = gr.Slider( + minimum=0.0, + maximum=10.0, + step=0.1, + value=0.0, + label="Guidance Scale Text", + interactive=True, + info="Guidance scale for text condition. It can only apply to cfg. set guidance_scale_text=5.0, guidance_scale_lyric=1.5 for start", + ) + guidance_scale_lyric = gr.Slider( + minimum=0.0, + maximum=10.0, + step=0.1, + value=0.0, + label="Guidance Scale Lyric", + interactive=True, + ) + + manual_seeds = gr.Textbox( + label="manual seeds (default None)", + placeholder="1,2,3,4", + value=None, + info="Seed for the generation", + ) + + with gr.Accordion("Advanced Settings", open=False): + scheduler_type = gr.Radio( + ["euler", "heun"], + value="euler", + label="Scheduler Type", + elem_id="scheduler_type", + info="Scheduler type for the generation. euler is recommended. heun will take more time.", + ) + cfg_type = gr.Radio( + ["cfg", "apg", "cfg_star"], + value="apg", + label="CFG Type", + elem_id="cfg_type", + info="CFG type for the generation. apg is recommended. cfg and cfg_star are almost the same.", + ) + use_erg_lyric = gr.Checkbox( + label="use ERG for lyric", + value=True, + info="Use Entropy Rectifying Guidance for lyric. It will multiple a temperature to the attention to make a weaker lyric condition and make better diversity.", + ) + use_erg_diffusion = gr.Checkbox( + label="use ERG for diffusion", + value=True, + info="The same but apply to diffusion model's attention.", + ) + + omega_scale = gr.Slider( + minimum=-100.0, + maximum=100.0, + step=0.1, + value=10.0, + label="Granularity Scale", + interactive=True, + info="Granularity scale for the generation. Higher values can reduce artifacts", + ) + + guidance_interval = gr.Slider( + minimum=0.0, + maximum=1.0, + step=0.01, + value=0.5, + label="Guidance Interval", + interactive=True, + info="Guidance interval for the generation. 0.5 means only apply guidance in the middle steps (0.25 * infer_steps to 0.75 * infer_steps)", + ) + guidance_interval_decay = gr.Slider( + minimum=0.0, + maximum=1.0, + step=0.01, + value=0.0, + label="Guidance Interval Decay", + interactive=True, + info="Guidance interval decay for the generation. Guidance scale will decay from guidance_scale to min_guidance_scale in the interval. 0.0 means no decay.", + ) + min_guidance_scale = gr.Slider( + minimum=0.0, + maximum=200.0, + step=0.1, + value=3.0, + label="Min Guidance Scale", + interactive=True, + info="Min guidance scale for guidance interval decay's end scale", + ) + oss_steps = gr.Textbox( + label="OSS Steps", + placeholder="16, 29, 52, 96, 129, 158, 172, 183, 189, 200", + value=None, + info="Optimal Steps for the generation.", + ) + + text2music_bnt = gr.Button("Generate", variant="primary") + + with gr.Column(): + outputs, input_params_json = create_output_ui() + with gr.Tab("retake"): + retake_variance = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.2, label="variance") + retake_seeds = gr.Textbox(label="retake seeds (default None)", placeholder="", value=None) + retake_bnt = gr.Button("Retake", variant="primary") + retake_outputs, retake_input_params_json = create_output_ui("Retake") + + def retake_process_func(json_data, retake_variance, retake_seeds): + return ov_pipeline( + audio_duration=json_data["audio_duration"], + prompt=json_data["prompt"], + lyrics=json_data["lyrics"], + infer_step=json_data["infer_step"], + guidance_scale=json_data["guidance_scale"], + scheduler_type=json_data["scheduler_type"], + cfg_type=json_data["cfg_type"], + omega_scale=json_data["omega_scale"], + manual_seeds=json_data["actual_seeds"], + guidance_interval=json_data["guidance_interval"], + guidance_interval_decay=json_data["guidance_interval_decay"], + min_guidance_scale=json_data["min_guidance_scale"], + use_erg_tag=False, + use_erg_lyric=json_data["use_erg_lyric"], + use_erg_diffusion=json_data["use_erg_diffusion"], + oss_steps=", ".join(map(str, json_data["oss_steps"])), + guidance_scale_text=json_data["guidance_scale_text"] if "guidance_scale_text" in json_data else 0.0, + guidance_scale_lyric=json_data["guidance_scale_lyric"] if "guidance_scale_lyric" in json_data else 0.0, + retake_seeds=retake_seeds, + retake_variance=retake_variance, + task="retake", + ) + + retake_bnt.click( + fn=retake_process_func, + inputs=[ + input_params_json, + retake_variance, + retake_seeds, + ], + outputs=retake_outputs + [retake_input_params_json], + ) + with gr.Tab("repainting"): + retake_variance = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.2, label="variance") + retake_seeds = gr.Textbox(label="repaint seeds (default None)", placeholder="", value=None) + repaint_start = gr.Slider( + minimum=0.0, + maximum=90.0, + step=0.01, + value=0.0, + label="Repaint Start Time", + interactive=True, + ) + repaint_end = gr.Slider( + minimum=0.0, + maximum=90.0, + step=0.01, + value=10.0, + label="Repaint End Time", + interactive=True, + ) + repaint_source = gr.Radio( + ["text2music", "last_repaint", "upload"], + value="text2music", + label="Repaint Source", + elem_id="repaint_source", + ) + + repaint_source_audio_upload = gr.Audio( + label="Upload Audio", + type="filepath", + visible=False, + elem_id="repaint_source_audio_upload", + show_download_button=True, + ) + repaint_source.change( + fn=lambda x: gr.update(visible=x == "upload", elem_id="repaint_source_audio_upload"), + inputs=[repaint_source], + outputs=[repaint_source_audio_upload], + ) + + repaint_bnt = gr.Button("Repaint", variant="primary") + repaint_outputs, repaint_input_params_json = create_output_ui("Repaint") + + def repaint_process_func( + text2music_json_data, + repaint_json_data, + retake_variance, + retake_seeds, + repaint_start, + repaint_end, + repaint_source, + repaint_source_audio_upload, + prompt, + lyrics, + infer_step, + guidance_scale, + scheduler_type, + cfg_type, + omega_scale, + manual_seeds, + guidance_interval, + guidance_interval_decay, + min_guidance_scale, + use_erg_lyric, + use_erg_diffusion, + oss_steps, + guidance_scale_text, + guidance_scale_lyric, + ): + if repaint_source == "upload": + src_audio_path = repaint_source_audio_upload + audio_duration = librosa.get_duration(filename=src_audio_path) + json_data = {"audio_duration": audio_duration} + elif repaint_source == "text2music": + json_data = text2music_json_data + src_audio_path = json_data["audio_path"] + elif repaint_source == "last_repaint": + json_data = repaint_json_data + src_audio_path = json_data["audio_path"] + + return ov_pipeline( + audio_duration=json_data["audio_duration"], + prompt=prompt, + lyrics=lyrics, + infer_step=infer_step, + guidance_scale=guidance_scale, + scheduler_type=scheduler_type, + cfg_type=cfg_type, + omega_scale=omega_scale, + manual_seeds=manual_seeds, + guidance_interval=guidance_interval, + guidance_interval_decay=guidance_interval_decay, + min_guidance_scale=min_guidance_scale, + use_erg_tag=False, + use_erg_lyric=use_erg_lyric, + use_erg_diffusion=use_erg_diffusion, + oss_steps=oss_steps, + guidance_scale_text=guidance_scale_text, + guidance_scale_lyric=guidance_scale_lyric, + retake_seeds=retake_seeds, + retake_variance=retake_variance, + task="repaint", + repaint_start=repaint_start, + repaint_end=repaint_end, + src_audio_path=src_audio_path, + ) + + repaint_bnt.click( + fn=repaint_process_func, + inputs=[ + input_params_json, + repaint_input_params_json, + retake_variance, + retake_seeds, + repaint_start, + repaint_end, + repaint_source, + repaint_source_audio_upload, + prompt, + lyrics, + infer_step, + guidance_scale, + scheduler_type, + cfg_type, + omega_scale, + manual_seeds, + guidance_interval, + guidance_interval_decay, + min_guidance_scale, + use_erg_lyric, + use_erg_diffusion, + oss_steps, + guidance_scale_text, + guidance_scale_lyric, + ], + outputs=repaint_outputs + [repaint_input_params_json], + ) + with gr.Tab("edit"): + edit_prompt = gr.Textbox(lines=2, label="Edit Tags", max_lines=4) + edit_lyrics = gr.Textbox(lines=9, label="Edit Lyrics", max_lines=13) + retake_seeds = gr.Textbox(label="edit seeds (default None)", placeholder="", value=None) + + edit_type = gr.Radio( + ["only_lyrics", "remix"], + value="only_lyrics", + label="Edit Type", + elem_id="edit_type", + info="`only_lyrics` will keep the whole song the same except lyrics difference. Make your diffrence smaller, e.g. one lyrc line change.\nremix can change the song melody and genre", + ) + edit_n_min = gr.Slider( + minimum=0.0, + maximum=1.0, + step=0.01, + value=0.6, + label="edit_n_min", + interactive=True, + ) + edit_n_max = gr.Slider( + minimum=0.0, + maximum=1.0, + step=0.01, + value=1.0, + label="edit_n_max", + interactive=True, + ) + + def edit_type_change_func(edit_type): + if edit_type == "only_lyrics": + n_min = 0.6 + n_max = 1.0 + elif edit_type == "remix": + n_min = 0.2 + n_max = 0.4 + return n_min, n_max + + edit_type.change( + edit_type_change_func, + inputs=[edit_type], + outputs=[edit_n_min, edit_n_max], + ) + + edit_source = gr.Radio( + ["text2music", "last_edit", "upload"], + value="text2music", + label="Edit Source", + elem_id="edit_source", + ) + edit_source_audio_upload = gr.Audio( + label="Upload Audio", + type="filepath", + visible=False, + elem_id="edit_source_audio_upload", + show_download_button=True, + ) + edit_source.change( + fn=lambda x: gr.update(visible=x == "upload", elem_id="edit_source_audio_upload"), + inputs=[edit_source], + outputs=[edit_source_audio_upload], + ) + + edit_bnt = gr.Button("Edit", variant="primary") + edit_outputs, edit_input_params_json = create_output_ui("Edit") + + def edit_process_func( + text2music_json_data, + edit_input_params_json, + edit_source, + edit_source_audio_upload, + prompt, + lyrics, + edit_prompt, + edit_lyrics, + edit_n_min, + edit_n_max, + infer_step, + guidance_scale, + scheduler_type, + cfg_type, + omega_scale, + manual_seeds, + guidance_interval, + guidance_interval_decay, + min_guidance_scale, + use_erg_lyric, + use_erg_diffusion, + oss_steps, + guidance_scale_text, + guidance_scale_lyric, + retake_seeds, + ): + if edit_source == "upload": + src_audio_path = edit_source_audio_upload + audio_duration = librosa.get_duration(filename=src_audio_path) + json_data = {"audio_duration": audio_duration} + elif edit_source == "text2music": + json_data = text2music_json_data + src_audio_path = json_data["audio_path"] + elif edit_source == "last_edit": + json_data = edit_input_params_json + src_audio_path = json_data["audio_path"] + + if not edit_prompt: + edit_prompt = prompt + if not edit_lyrics: + edit_lyrics = lyrics + + return ov_pipeline( + audio_duration=json_data["audio_duration"], + prompt=prompt, + lyrics=lyrics, + infer_step=infer_step, + guidance_scale=guidance_scale, + scheduler_type=scheduler_type, + cfg_type=cfg_type, + omega_scale=omega_scale, + manual_seeds=manual_seeds, + guidance_interval=guidance_interval, + guidance_interval_decay=guidance_interval_decay, + min_guidance_scale=min_guidance_scale, + use_erg_tag=False, + use_erg_lyric=use_erg_lyric, + use_erg_diffusion=use_erg_diffusion, + oss_steps=oss_steps, + guidance_scale_text=guidance_scale_text, + guidance_scale_lyric=guidance_scale_lyric, + task="edit", + src_audio_path=src_audio_path, + edit_target_prompt=edit_prompt, + edit_target_lyrics=edit_lyrics, + edit_n_min=edit_n_min, + edit_n_max=edit_n_max, + retake_seeds=retake_seeds, + ) + + edit_bnt.click( + fn=edit_process_func, + inputs=[ + input_params_json, + edit_input_params_json, + edit_source, + edit_source_audio_upload, + prompt, + lyrics, + edit_prompt, + edit_lyrics, + edit_n_min, + edit_n_max, + infer_step, + guidance_scale, + scheduler_type, + cfg_type, + omega_scale, + manual_seeds, + guidance_interval, + guidance_interval_decay, + min_guidance_scale, + use_erg_lyric, + use_erg_diffusion, + oss_steps, + guidance_scale_text, + guidance_scale_lyric, + retake_seeds, + ], + outputs=edit_outputs + [edit_input_params_json], + ) + with gr.Tab("extend"): + extend_seeds = gr.Textbox(label="extend seeds (default None)", placeholder="", value=None) + left_extend_length = gr.Slider( + minimum=0.0, + maximum=20.0, + step=0.01, + value=0.0, + label="Left Extend Length", + interactive=True, + ) + right_extend_length = gr.Slider( + minimum=0.0, + maximum=20.0, + step=0.01, + value=5.0, + label="Right Extend Length", + interactive=True, + ) + extend_source = gr.Radio( + ["text2music", "last_extend", "upload"], + value="text2music", + label="Extend Source", + elem_id="extend_source", + ) + + extend_source_audio_upload = gr.Audio( + label="Upload Audio", + type="filepath", + visible=False, + elem_id="extend_source_audio_upload", + show_download_button=True, + ) + extend_source.change( + fn=lambda x: gr.update(visible=x == "upload", elem_id="extend_source_audio_upload"), + inputs=[extend_source], + outputs=[extend_source_audio_upload], + ) + + extend_bnt = gr.Button("Extend", variant="primary") + extend_outputs, extend_input_params_json = create_output_ui("Extend") + + def extend_process_func( + text2music_json_data, + extend_input_params_json, + extend_seeds, + left_extend_length, + right_extend_length, + extend_source, + extend_source_audio_upload, + prompt, + lyrics, + infer_step, + guidance_scale, + scheduler_type, + cfg_type, + omega_scale, + manual_seeds, + guidance_interval, + guidance_interval_decay, + min_guidance_scale, + use_erg_lyric, + use_erg_diffusion, + oss_steps, + guidance_scale_text, + guidance_scale_lyric, + ): + if extend_source == "upload": + src_audio_path = extend_source_audio_upload + # get audio duration + audio_duration = librosa.get_duration(filename=src_audio_path) + json_data = {"audio_duration": audio_duration} + elif extend_source == "text2music": + json_data = text2music_json_data + src_audio_path = json_data["audio_path"] + elif extend_source == "last_extend": + json_data = extend_input_params_json + src_audio_path = json_data["audio_path"] + + repaint_start = -left_extend_length + repaint_end = json_data["audio_duration"] + right_extend_length + return ov_pipeline( + audio_duration=json_data["audio_duration"], + prompt=prompt, + lyrics=lyrics, + infer_step=infer_step, + guidance_scale=guidance_scale, + scheduler_type=scheduler_type, + cfg_type=cfg_type, + omega_scale=omega_scale, + manual_seeds=manual_seeds, + guidance_interval=guidance_interval, + guidance_interval_decay=guidance_interval_decay, + min_guidance_scale=min_guidance_scale, + use_erg_tag=False, + use_erg_lyric=use_erg_lyric, + use_erg_diffusion=use_erg_diffusion, + oss_steps=oss_steps, + guidance_scale_text=guidance_scale_text, + guidance_scale_lyric=guidance_scale_lyric, + retake_seeds=extend_seeds, + retake_variance=1.0, + task="extend", + repaint_start=repaint_start, + repaint_end=repaint_end, + src_audio_path=src_audio_path, + ) + + extend_bnt.click( + fn=extend_process_func, + inputs=[ + input_params_json, + extend_input_params_json, + extend_seeds, + left_extend_length, + right_extend_length, + extend_source, + extend_source_audio_upload, + prompt, + lyrics, + infer_step, + guidance_scale, + scheduler_type, + cfg_type, + omega_scale, + manual_seeds, + guidance_interval, + guidance_interval_decay, + min_guidance_scale, + use_erg_lyric, + use_erg_diffusion, + oss_steps, + guidance_scale_text, + guidance_scale_lyric, + ], + outputs=extend_outputs + [extend_input_params_json], + ) + + def ov_pipeline_wrap( + audio_duration, + prompt, + lyrics, + infer_step, + guidance_scale, + scheduler_type, + cfg_type, + omega_scale, + manual_seeds, + guidance_interval, + guidance_interval_decay, + min_guidance_scale, + use_erg_lyric, + use_erg_diffusion, + oss_steps, + guidance_scale_text, + guidance_scale_lyric, + audio2audio_enable, + ref_audio_strength, + ref_audio_input, + ): + return ov_pipeline( + audio_duration=audio_duration, + prompt=prompt, + lyrics=lyrics, + infer_step=infer_step, + guidance_scale=guidance_scale, + scheduler_type=scheduler_type, + cfg_type=cfg_type, + omega_scale=omega_scale, + manual_seeds=manual_seeds, + guidance_interval=guidance_interval, + guidance_interval_decay=guidance_interval_decay, + min_guidance_scale=min_guidance_scale, + use_erg_tag=False, + use_erg_lyric=use_erg_lyric, + use_erg_diffusion=use_erg_diffusion, + oss_steps=oss_steps, + guidance_scale_text=guidance_scale_text, + guidance_scale_lyric=guidance_scale_lyric, + audio2audio_enable=audio2audio_enable, + ref_audio_strength=ref_audio_strength, + ref_audio_input=ref_audio_input, + ) + + text2music_bnt.click( + fn=ov_pipeline_wrap, + inputs=[ + audio_duration, + prompt, + lyrics, + infer_step, + guidance_scale, + scheduler_type, + cfg_type, + omega_scale, + manual_seeds, + guidance_interval, + guidance_interval_decay, + min_guidance_scale, + use_erg_lyric, + use_erg_diffusion, + oss_steps, + guidance_scale_text, + guidance_scale_lyric, + audio2audio_enable, + ref_audio_strength, + ref_audio_input, + ], + outputs=outputs + [input_params_json], + ) + + +def make_demo(pipeline, data_sampler): + with gr.Blocks( + title="ACE-Step Model with OpenVINO DEMO", + ) as demo: + gr.Markdown( + """ +

Music generation with ACE-Step model and OpenVINO

+ """ + ) + with gr.Tab("text2music"): + create_text2music_ui(gr=gr, ov_pipeline=pipeline, sample_data_func=data_sampler.sample) + return demo + + +def update_audio_widget(): + options = ["Retake", "Repainting", "Edit", "Extend"] + return widgets.ToggleButtons(options=options, description="Choose next operation with audio:", disabled=False, button_style="info", value="Repainting") + + +def setup_update_audio_widgets(update_audio_widget): + if update_audio_widget.value == "Retake": + task = "retake" + w = widgets.FloatSlider(value=0.2, min=0.0, max=1.0, step=0.01, description="Variance") + vbox = widgets.VBox([w]) + elif update_audio_widget.value == "Repainting": + task = "repaint" + w1 = widgets.FloatSlider(value=0.2, min=0.0, max=1.0, step=0.01, description="Variance") + w2 = widgets.FloatSlider(value=0, min=0, max=15, step=0.01, description="Repaint Start") + w3 = widgets.FloatSlider(value=4, min=0, max=15, step=0.01, description="Repaint End") + vbox = widgets.VBox([w1, w2, w3]) + elif update_audio_widget.value == "Edit": + task = "edit" + w1 = widgets.Textarea(value="", description="Edit lyric") + w2 = widgets.Textarea(value="classical, orchestral, strings, piano, 60 bpm, elegant, emotive, timeless, instrumental", description="Edit tags") + w3 = widgets.FloatSlider(value=0.2, min=0, max=1, step=0.01, description="edit_n_min") + w4 = widgets.FloatSlider(value=0.4, min=0, max=1, step=0.01, description="edit_n_max") + vbox = widgets.VBox([w1, w2, w3, w4]) + elif update_audio_widget.value == "Extend": + task = "extend" + w1 = widgets.FloatSlider(value=0.0, min=0, max=10, step=0.01, description="Left Extend Length") + w2 = widgets.FloatSlider(value=5.0, min=0, max=10, step=0.01, description="Right Extend Length") + vbox = widgets.VBox([w1, w2]) + + return vbox, task + + +def get_inputs_base_on_setup_widget(base_inputs, source_audio_path, setup_widgets, task): + extra_inputs = copy.deepcopy(base_inputs) + extra_inputs.update({"task": task, "src_audio_path": source_audio_path}) + if task == "retake": + extra_inputs["retake_variance"] = setup_widgets.children[0].value + del extra_inputs["src_audio_path"] + elif task == "repaint": + extra_inputs["retake_variance"] = setup_widgets.children[0].value + extra_inputs["repaint_start"] = setup_widgets.children[1].value + extra_inputs["repaint_end"] = setup_widgets.children[2].value + elif task == "edit": + extra_inputs["edit_target_lyrics"] = setup_widgets.children[0].value if setup_widgets.children[0].value else extra_inputs["lyrics"] + extra_inputs["edit_target_prompt"] = setup_widgets.children[1].value if setup_widgets.children[1].value else extra_inputs["prompt"] + extra_inputs["edit_n_min"] = setup_widgets.children[2].value + extra_inputs["edit_n_max"] = setup_widgets.children[3].value + elif task == "extend": + extra_inputs["repaint_start"] = -setup_widgets.children[0].value + extra_inputs["repaint_end"] = base_inputs["audio_duration"] + setup_widgets.children[1].value + + return extra_inputs + + +def get_model_compression_format_widgets(): + return widgets.Dropdown( + options=["FP16", "INT8", "INT4"], + value="FP16", + description="Model format:", + ) diff --git a/notebooks/ace-step-music-generation/ov_ace_helper.py b/notebooks/ace-step-music-generation/ov_ace_helper.py new file mode 100644 index 00000000000..c9aecba7144 --- /dev/null +++ b/notebooks/ace-step-music-generation/ov_ace_helper.py @@ -0,0 +1,978 @@ +import os +import gc +import math +import torch +import types +import torchaudio +import torchvision.transforms as transforms + +from tqdm import tqdm +from pathlib import Path +from loguru import logger +from diffusers.utils.torch_utils import randn_tensor +from typing import Dict, Optional, List, Union, Type +from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import retrieve_timesteps + +import nncf +import openvino as ov +from openvino.tools.ovc import convert_model +from openvino_tokenizers import convert_tokenizer +from openvino.frontend.pytorch.patch_model import __make_16bit_traceable + +from acestep.language_segmentation import LangSegment, language_filters +from acestep.models.lyrics_utils.lyric_tokenizer import VoiceBpeTokenizer + +from acestep.pipeline_ace_step import ACEStepPipeline +from acestep.models.ace_step_transformer import Transformer2DModelOutput +from acestep.music_dcae.music_dcae_pipeline import MusicDCAE +from acestep.schedulers.scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler +from acestep.schedulers.scheduling_flow_match_pingpong import FlowMatchPingPongScheduler +from acestep.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler +from acestep.apg_guidance import ( + apg_forward, + MomentumBuffer, + cfg_forward, + cfg_zero_star, + cfg_double_condition_forward, +) + +torch.set_float32_matmul_precision("high") +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +TOKENIZER_MODEL_NAME = "openvino_tokenizer.xml" +TEXT_ENCODER_MODEL_NAME = "ov_text_encoder_model.xml" +DCAE_ENCODER_MODEL_NAME = "ov_dcae_encoder_model.xml" +DCAE_DECODER_MODEL_NAME = "ov_dcae_decoder_model.xml" +VOCODER_DECODE_MODEL_NAME = "ov_vocoder_decode_model.xml" +VOCODER_MEL_TRANSFORM_MODEL_NAME = "ov_vocoder_mel_transform_model.xml" +TRANSFORMER_DECODER_MODEL_NAME = "ov_transformer_decoder_model.xml" +TRANSFORMER_ENCODER_MODEL_NAME = "ov_transformer_encoder_model.xml" + + +def cleanup_torchscript_cache(): + """ + Helper for removing cached model representation + """ + torch._C._jit_clear_class_registry() + torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore() + torch.jit._state._clear_class_state() + + +def ov_convert( + model_dir_path: str, + ov_model_name: str, + inputs: Dict, + orig_model: Type[torch.nn.Module], + model_name: str, + quantization_config: Dict = None, + force_convertion: bool = False, +): + try: + ov_model_path = Path(model_dir_path, ov_model_name) + if not ov_model_path.exists() or force_convertion: + print(f"⌛ Convert {model_name} model") + orig_model.eval() + __make_16bit_traceable(orig_model) + ov_model = convert_model(orig_model, example_input=inputs) + if quantization_config is not None: + print(f"⌛ Weights compression with {quantization_config['mode']} mode started") + ov_model = nncf.compress_weights(ov_model, **quantization_config) + print("✅ Weights compression finished") + ov.save_model(ov_model, ov_model_path) + + del ov_model + cleanup_torchscript_cache() + gc.collect() + print(f"✅ {model_name} model converted") + except Exception as e: + print(f"❌{model_name} model is not converted. Error: {e}") + + +def convert_transformer_models(pipeline: ACEStepPipeline, model_dir: str = "ov_converted", orig_checkpoint_path: str = "", quantization_config: Dict = None): + # Transformer Encoder model + def encode_with_temperature_wrap( + self, + encoder_text_hidden_states: torch.Tensor = None, + text_attention_mask: torch.LongTensor = None, + speaker_embeds: torch.FloatTensor = None, + lyric_token_idx: torch.LongTensor = None, + lyric_mask: torch.LongTensor = None, + tau: torch.FloatTensor = torch.Tensor([0.01]), + ): + handlers = [] + + def hook(module, input, output): + output[:] *= tau[0] + return output + + l_min = 4 + l_max = 6 + for i in range(l_min, l_max): + handler = self.lyric_encoder.encoders[i].self_attn.linear_q.register_forward_hook(hook) + handlers.append(handler) + + encoder_hidden_states, encoder_hidden_mask = self.encode( + encoder_text_hidden_states=encoder_text_hidden_states, + text_attention_mask=text_attention_mask, + speaker_embeds=speaker_embeds, + lyric_token_idx=lyric_token_idx, + lyric_mask=lyric_mask, + ) + + for hook in handlers: + hook.remove() + + return encoder_hidden_states, encoder_hidden_mask + + inputs = { + "encoder_text_hidden_states": torch.randn(size=(1, 15, 768), dtype=torch.float), + "text_attention_mask": torch.ones([1, 15], dtype=torch.int64), + "speaker_embeds": torch.zeros(size=(1, 512), dtype=torch.float), + "lyric_token_idx": torch.randint(10000, [1, 543], dtype=torch.int64), + "lyric_mask": torch.ones([1, 543], dtype=torch.int64), + "tau": torch.Tensor([0.01]), + } + transformer_encoder_model = pipeline.ace_step_transformer + transformer_encoder_erg_model = pipeline.ace_step_transformer + transformer_encoder_erg_model.forward = types.MethodType(encode_with_temperature_wrap, transformer_encoder_model) + ov_convert( + model_dir, + TRANSFORMER_ENCODER_MODEL_NAME, + inputs, + transformer_encoder_erg_model, + "Transformer Encoder with Entropy Rectifying Guidance", + quantization_config=quantization_config, + ) + + # Transformer Decoder model + def decode_with_temperature_wrap( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_mask: torch.Tensor, + timestep: torch.Tensor = None, + # ssl_hidden_states: List[torch.Tensor] = None, + output_length: int = 0, + # block_controlnet_hidden_states: Union[List[torch.Tensor], torch.Tensor] = None, + # controlnet_scale: Union[float, torch.Tensor] = 1.0, + tau: torch.FloatTensor = torch.Tensor([0.01]), + ): + handlers = [] + + def hook(module, input, output): + output[:] *= tau[0] + return output + + l_min = 5 + l_max = 10 + for i in range(l_min, l_max): + handler = self.transformer_blocks[i].attn.to_q.register_forward_hook(hook) + handlers.append(handler) + handler = self.transformer_blocks[i].cross_attn.to_q.register_forward_hook(hook) + handlers.append(handler) + + sample = self.decode( + hidden_states=hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_mask=encoder_hidden_mask, + output_length=output_length, + timestep=timestep, + ).sample + + for hook in handlers: + hook.remove() + + return sample + + inputs = { + "hidden_states": torch.randn(size=(1, 8, 16, 151), dtype=torch.float), + "attention_mask": torch.ones([1, 151], dtype=torch.int64), + "encoder_hidden_states": torch.randn(size=(1, 559, 2560), dtype=torch.float), + "encoder_hidden_mask": torch.ones([1, 559], dtype=torch.float), + "output_length": torch.tensor(151), + "timestep": torch.randn([1], dtype=torch.float), + "tau": torch.Tensor([0.01]), + } + transformer_decoder_erg_model = pipeline.ace_step_transformer + transformer_decoder_erg_model.forward = types.MethodType(decode_with_temperature_wrap, transformer_decoder_erg_model) + ov_convert( + model_dir, + TRANSFORMER_DECODER_MODEL_NAME, + inputs, + transformer_decoder_erg_model, + "Transformer Decoder with Entropy Rectifying Guidance", + quantization_config=quantization_config, + ) + + +def convert_models(pipeline: ACEStepPipeline, model_dir: str = "ov_converted_new", orig_checkpoint_path: str = "", quantization_config: Dict = None): + print(f"⌛ Conversion started. Be patient, it may takes some time.") + + if not pipeline.loaded or (orig_checkpoint_path and not Path(orig_checkpoint_path).exists()): + print("⌛ Load Original model checkpoints") + pipeline.load_checkpoint(orig_checkpoint_path) + print("✅ Original model checkpoints successfully loaded") + + # Tokenizer + ov_tokenizer_path = Path(model_dir, TOKENIZER_MODEL_NAME) + if not ov_tokenizer_path.exists(): + print(f"⌛ Convert Tokenizer") + if not ov_tokenizer_path.exists(): + ov_tokenizer = convert_tokenizer(pipeline.text_tokenizer, with_detokenizer=False) + ov.save_model(ov_tokenizer, Path(model_dir, TOKENIZER_MODEL_NAME)) + print(f"✅ Tokenizer is converted") + + # Text Encoder Model + inputs = { + "input_ids": torch.randint(1000, size=(1, 15), dtype=torch.int64), + "attention_mask": torch.ones([1, 15], dtype=torch.int64), + } + ov_convert(model_dir, TEXT_ENCODER_MODEL_NAME, inputs, pipeline.text_encoder_model, "UMT5 Encoder") + + # DCAE Encoder model + inputs = {"hidden_states": torch.randn([1, 2, 128, 1208], dtype=torch.float)} + ov_convert(model_dir, DCAE_ENCODER_MODEL_NAME, inputs, pipeline.music_dcae.dcae.encoder, "Sana's Deep Compression AutoEncoder") + + # DCAE Decoder model + inputs = {"hidden_states": torch.randn([1, 8, 16, 151], dtype=torch.float)} + ov_convert(model_dir, DCAE_DECODER_MODEL_NAME, inputs, pipeline.music_dcae.dcae.decoder, "Sana's Deep Compression AutoEncoder Decoder") + + # Vocoder Mel Transform model + inputs = {"x": torch.randn([2, 618496], dtype=torch.float)} + ov_convert(model_dir, VOCODER_MEL_TRANSFORM_MODEL_NAME, inputs, pipeline.music_dcae.vocoder.mel_transform, "Vocoder Mel Transform") + + # Vocoder Decoder model + inputs = {"mel": torch.randn([1, 128, 856], dtype=torch.float)} + ov_convert(model_dir, VOCODER_DECODE_MODEL_NAME, inputs, pipeline.music_dcae.vocoder, "Vocoder Decoder") + + # DiT + convert_transformer_models(pipeline, model_dir, orig_checkpoint_path, quantization_config) + + +class MusicDCAEWrapper(MusicDCAE): + def __init__(self, source_sample_rate=None): + torch.nn.Module.__init__(self) + self.dcae = None + self.vocoder = None + + if source_sample_rate is None: + source_sample_rate = 48000 + + self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100) + + self.transform = transforms.Compose( + [ + transforms.Normalize(0.5, 0.5), + ] + ) + self.min_mel_value = -11.0 + self.max_mel_value = 3.0 + self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000))) + self.mel_chunk_size = 1024 + self.time_dimention_multiple = 8 + self.latent_chunk_size = self.mel_chunk_size // self.time_dimention_multiple + self.scale_factor = 0.1786 + self.shift_factor = -1.9091 + + +class OVDCAECompiledModels(torch.nn.Module): + def __init__(self, compiled_model): + self.compiled_model = compiled_model + + def __call__(self, inputs): + if not self.compiled_model: + logger.error("OVDCAECompiledModels: compiled model is not defined") + + output = self.compiled_model({"hidden_states": inputs.to(dtype=torch.float32)}) + return torch.from_numpy(output[0]) + + @classmethod + def from_pretrained(cls, ov_model_path, device, ov_core): + ov_dcae_model = ov_core.read_model(ov_model_path) + compiled_model = ov_core.compile_model(ov_dcae_model, device) + return cls(compiled_model) + + +class OVWrapperAutoencoderDC(torch.nn.Module): + def __init__(self, encoder, decoder): + super().__init__() + self.encoder = encoder + self.decoder = decoder + + @classmethod + def from_pretrained(cls, ov_core, ov_models_path, device="CPU"): + encoder = OVDCAECompiledModels.from_pretrained(Path(ov_models_path, DCAE_ENCODER_MODEL_NAME), device, ov_core) + decoder = OVDCAECompiledModels.from_pretrained(Path(ov_models_path, DCAE_DECODER_MODEL_NAME), device, ov_core) + return cls(encoder, decoder) + + +class OVWrapperADaMoSHiFiGANV1(torch.nn.Module): + def __init__(self, encoder_compiled_model, mel_trnasform_compiled_model): + super().__init__() + self.decoder = encoder_compiled_model + self.mel_trnasform = mel_trnasform_compiled_model + + @classmethod + def from_pretrained(cls, ov_core, ov_models_path, device="CPU"): + ov_vocoder_decoder_model = ov_core.read_model(Path(ov_models_path, VOCODER_DECODE_MODEL_NAME)) + decoder = ov_core.compile_model(ov_vocoder_decoder_model, device) + ov_vocoder_mel_transform_model = ov_core.read_model(Path(ov_models_path, VOCODER_MEL_TRANSFORM_MODEL_NAME)) + mel_trnasform = ov_core.compile_model(ov_vocoder_mel_transform_model, device) + return cls(decoder, mel_trnasform) + + def decode(self, inputs): + output = self.decoder({"mel": inputs.to(dtype=torch.float32)}) + return torch.from_numpy(output[0]) + + def mel_transform(self, inputs): + output = self.mel_trnasform({"x": inputs.to(dtype=torch.float32)}) + return torch.from_numpy(output[0]) + + def forward(self, inputs): + return self.decode(inputs) + + +class OvWrapperACEStepTransformer2DModel(torch.nn.Module): + def __init__(self, encoder_model, decoder_model): + super().__init__() + self.ov_lyric_encoder_compiled = encoder_model + self.ov_decoder_compiled_model = decoder_model + + @classmethod + def from_pretrained(cls, ov_core, ov_models_path, device="CPU"): + ov_model_encoder = ov_core.read_model(Path(ov_models_path, TRANSFORMER_ENCODER_MODEL_NAME)) + compiled_model_encoder = ov_core.compile_model(ov_model_encoder, device) + + ov_model_decoder = ov_core.read_model(Path(ov_models_path, TRANSFORMER_DECODER_MODEL_NAME)) + compiled_model_decoder = ov_core.compile_model(ov_model_decoder, device) + return cls(compiled_model_encoder, compiled_model_decoder) + + def encode_with_temperature( + self, + encoder_text_hidden_states: Optional[torch.Tensor] = None, + text_attention_mask: Optional[torch.LongTensor] = None, + speaker_embeds: Optional[torch.FloatTensor] = None, + lyric_token_idx: Optional[torch.LongTensor] = None, + lyric_mask: Optional[torch.LongTensor] = None, + tau: Optional[torch.FloatTensor] = torch.Tensor([0.01]), + ): + output = None + if self.ov_lyric_encoder_compiled: + output = self.ov_lyric_encoder_compiled( + { + "encoder_text_hidden_states": encoder_text_hidden_states, + "text_attention_mask": text_attention_mask, + "speaker_embeds": speaker_embeds, + "lyric_token_idx": lyric_token_idx, + "lyric_mask": lyric_mask, + "tau": tau, + } + ) + return torch.from_numpy(output[0]), torch.from_numpy(output[1]) + + def decode_with_temperature( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_mask: torch.Tensor, + timestep: Optional[torch.Tensor], + ssl_hidden_states: Optional[List[torch.Tensor]] = None, + output_length: int = 0, + block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + controlnet_scale: Union[float, torch.Tensor] = 1.0, + return_dict: bool = True, + tau: Optional[torch.FloatTensor] = torch.Tensor([0.01]), + ): + output = None + if self.ov_decoder_compiled_model: + output = self.ov_decoder_compiled_model( + { + "hidden_states": hidden_states, + "attention_mask": attention_mask, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_mask": encoder_hidden_mask, + "output_length": output_length, + "timestep": timestep, + "tau": tau, + } + ) + + sample = torch.from_numpy(output[0]) if output is not None else None + return sample + + def encode( + self, + encoder_text_hidden_states: Optional[torch.Tensor] = None, + text_attention_mask: Optional[torch.LongTensor] = None, + speaker_embeds: Optional[torch.FloatTensor] = None, + lyric_token_idx: Optional[torch.LongTensor] = None, + lyric_mask: Optional[torch.LongTensor] = None, + ): + return self.encode_with_temperature( + encoder_text_hidden_states=encoder_text_hidden_states, + text_attention_mask=text_attention_mask, + speaker_embeds=speaker_embeds, + lyric_token_idx=lyric_token_idx, + lyric_mask=lyric_mask, + tau=torch.Tensor([1]), + ) + + def decode( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_mask: torch.Tensor, + timestep: Optional[torch.Tensor], + ssl_hidden_states: Optional[List[torch.Tensor]] = None, + output_length: int = 0, + block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + controlnet_scale: Union[float, torch.Tensor] = 1.0, + return_dict: bool = True, + ): + sample = self.decode_with_temperature( + hidden_states=hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_mask=encoder_hidden_mask, + timestep=timestep, + ssl_hidden_states=ssl_hidden_states, + output_length=output_length, + block_controlnet_hidden_states=block_controlnet_hidden_states, + controlnet_scale=controlnet_scale, + return_dict=return_dict, + tau=torch.Tensor([1]), + ) + + return Transformer2DModelOutput(sample, None) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + encoder_text_hidden_states: Optional[torch.Tensor] = None, + text_attention_mask: Optional[torch.LongTensor] = None, + speaker_embeds: Optional[torch.FloatTensor] = None, + lyric_token_idx: Optional[torch.LongTensor] = None, + lyric_mask: Optional[torch.LongTensor] = None, + timestep: Optional[torch.Tensor] = None, + ssl_hidden_states: Optional[List[torch.Tensor]] = None, + block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + controlnet_scale: Union[float, torch.Tensor] = 1.0, + return_dict: bool = True, + ): + encoder_hidden_states, encoder_hidden_mask = self.encode( + encoder_text_hidden_states=encoder_text_hidden_states, + text_attention_mask=text_attention_mask, + speaker_embeds=speaker_embeds, + lyric_token_idx=lyric_token_idx, + lyric_mask=lyric_mask, + ) + + output_length = hidden_states.shape[-1] + + output = self.decode( + hidden_states=hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_mask=encoder_hidden_mask, + timestep=timestep, + ssl_hidden_states=ssl_hidden_states, + output_length=output_length, + block_controlnet_hidden_states=block_controlnet_hidden_states, + controlnet_scale=controlnet_scale, + return_dict=return_dict, + ) + + return output + + +class OVACEStepPipeline(ACEStepPipeline): + def __init__(self): + super().__init__(checkpoint_dir="", dtype="float32") + self.core = ov.Core() + + self.dcae_decoder = None + self.vocoder_encode = None + self.vocoder_decoder = None + self.transformer_encode = None + self.transformer_encode_with_temperature = None + self.transformer_decode = None + self.transformer_decode_with_temperature = None + + self.ace_step_transformer_origin = None + self.ace_step_transformer = None + self.music_dcae = None + self.text_tokenizer = None + self.text_encoder_model = None + + def get_checkpoint_path(self, checkpoint_dir, repo): + pass + + def load_checkpoint(self, checkpoint_dir=None, export_quantized_weights=False): + pass + + def load_models(self, ov_models_path: str = None, device: str = "CPU"): + self.loaded = True + if ov_models_path and Path(ov_models_path).exists: + ov_text_encoder_model = self.core.read_model(Path(ov_models_path, TEXT_ENCODER_MODEL_NAME)) + self.text_encoder_model = self.core.compile_model(ov_text_encoder_model, device) + + ov_text_tokenizer_path = self.core.read_model(Path(ov_models_path, TOKENIZER_MODEL_NAME)) + self.text_tokenizer = self.core.compile_model(ov_text_tokenizer_path, device) + + self.music_dcae = MusicDCAEWrapper() + self.music_dcae.dcae = OVWrapperAutoencoderDC.from_pretrained(self.core, ov_models_path, device) + self.music_dcae.vocoder = OVWrapperADaMoSHiFiGANV1.from_pretrained(self.core, ov_models_path, device) + + self.ace_step_transformer = OvWrapperACEStepTransformer2DModel.from_pretrained(self.core, ov_models_path, device) + else: + logger.error(f"Path is not exists: {ov_models_path}") + + lang_segment = LangSegment() + lang_segment.setfilters(language_filters.default) + self.lang_segment = lang_segment + self.lyric_tokenizer = VoiceBpeTokenizer() + + def load_quantized_checkpoint(self, checkpoint_dir=None): + pass + + def get_text_embeddings(self, texts, text_max_length=256): + inputs = self.text_tokenizer(texts) + inputs = {"attention_mask": inputs["attention_mask"], "input_ids": inputs["input_ids"]} + + last_hidden_states = self.text_encoder_model(inputs) + attention_mask = inputs["attention_mask"] + return torch.from_numpy(last_hidden_states[0]), torch.from_numpy(attention_mask) + + def get_text_embeddings_null(self, texts, text_max_length=256, tau=0.01, l_min=8, l_max=10): + inputs = self.text_tokenizer(texts) + inputs = {"attention_mask": inputs["attention_mask"], "input_ids": inputs["input_ids"]} + last_hidden_states = self.text_encoder_model(inputs) + return torch.from_numpy(last_hidden_states[0]) + + def text2music_diffusion_process( + self, + duration, + encoder_text_hidden_states, + text_attention_mask, + speaker_embds, + lyric_token_ids, + lyric_mask, + random_generators=None, + infer_steps=60, + guidance_scale=15.0, + omega_scale=10.0, + scheduler_type="euler", + cfg_type="apg", + zero_steps=1, + use_zero_init=True, + guidance_interval=0.5, + guidance_interval_decay=1.0, + min_guidance_scale=3.0, + oss_steps=[], + encoder_text_hidden_states_null=None, + use_erg_lyric=False, + use_erg_diffusion=False, + retake_random_generators=None, + retake_variance=0.5, + add_retake_noise=False, + guidance_scale_text=0.0, + guidance_scale_lyric=0.0, + repaint_start=0, + repaint_end=0, + src_latents=None, + audio2audio_enable=False, + ref_audio_strength=0.5, + ref_latents=None, + ): + logger.info("cfg_type: {}, guidance_scale: {}, omega_scale: {}".format(cfg_type, guidance_scale, omega_scale)) + do_classifier_free_guidance = True + if guidance_scale == 0.0 or guidance_scale == 1.0: + do_classifier_free_guidance = False + + do_double_condition_guidance = False + if guidance_scale_text is not None and guidance_scale_text > 1.0 and guidance_scale_lyric is not None and guidance_scale_lyric > 1.0: + do_double_condition_guidance = True + logger.info( + "do_double_condition_guidance: {}, guidance_scale_text: {}, guidance_scale_lyric: {}".format( + do_double_condition_guidance, + guidance_scale_text, + guidance_scale_lyric, + ) + ) + + bsz = encoder_text_hidden_states.shape[0] + + if scheduler_type == "euler": + scheduler = FlowMatchEulerDiscreteScheduler( + num_train_timesteps=1000, + shift=3.0, + ) + elif scheduler_type == "heun": + scheduler = FlowMatchHeunDiscreteScheduler( + num_train_timesteps=1000, + shift=3.0, + ) + elif scheduler_type == "pingpong": + scheduler = FlowMatchPingPongScheduler( + num_train_timesteps=1000, + shift=3.0, + ) + + frame_length = int(duration * 44100 / 512 / 8) + if src_latents is not None: + frame_length = src_latents.shape[-1] + + if ref_latents is not None: + frame_length = ref_latents.shape[-1] + + if len(oss_steps) > 0: + infer_steps = max(oss_steps) + scheduler.set_timesteps + timesteps, num_inference_steps = retrieve_timesteps( + scheduler, + num_inference_steps=infer_steps, + device=self.device, + timesteps=None, + ) + new_timesteps = torch.zeros(len(oss_steps), dtype=self.dtype, device=self.device) + for idx in range(len(oss_steps)): + new_timesteps[idx] = timesteps[oss_steps[idx] - 1] + num_inference_steps = len(oss_steps) + sigmas = (new_timesteps / 1000).float().cpu().numpy() + timesteps, num_inference_steps = retrieve_timesteps( + scheduler, + num_inference_steps=num_inference_steps, + device=self.device, + sigmas=sigmas, + ) + logger.info(f"oss_steps: {oss_steps}, num_inference_steps: {num_inference_steps} after remapping to timesteps {timesteps}") + else: + timesteps, num_inference_steps = retrieve_timesteps( + scheduler, + num_inference_steps=infer_steps, + device=self.device, + timesteps=None, + ) + + target_latents = randn_tensor( + shape=(bsz, 8, 16, frame_length), + generator=random_generators, + device=self.device, + dtype=self.dtype, + ) + + is_repaint = False + is_extend = False + + if add_retake_noise: + n_min = int(infer_steps * (1 - retake_variance)) + retake_variance = torch.tensor(retake_variance * math.pi / 2).to(self.device).to(self.dtype) + retake_latents = randn_tensor( + shape=(bsz, 8, 16, frame_length), + generator=retake_random_generators, + device=self.device, + dtype=self.dtype, + ) + repaint_start_frame = int(repaint_start * 44100 / 512 / 8) + repaint_end_frame = int(repaint_end * 44100 / 512 / 8) + x0 = src_latents + # retake + is_repaint = repaint_end_frame - repaint_start_frame != frame_length + + is_extend = (repaint_start_frame < 0) or (repaint_end_frame > frame_length) + if is_extend: + is_repaint = True + + # TODO: train a mask aware repainting controlnet + # to make sure mean = 0, std = 1 + if not is_repaint: + target_latents = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents + elif not is_extend: + # if repaint_end_frame + repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=self.device, dtype=self.dtype) + repaint_mask[:, :, :, repaint_start_frame:repaint_end_frame] = 1.0 + repaint_noise = torch.cos(retake_variance) * target_latents + torch.sin(retake_variance) * retake_latents + repaint_noise = torch.where(repaint_mask == 1.0, repaint_noise, target_latents) + zt_edit = x0.clone() + z0 = repaint_noise + elif is_extend: + to_right_pad_gt_latents = None + to_left_pad_gt_latents = None + gt_latents = src_latents + src_latents_length = gt_latents.shape[-1] + max_infer_fame_length = int(240 * 44100 / 512 / 8) + left_pad_frame_length = 0 + right_pad_frame_length = 0 + right_trim_length = 0 + left_trim_length = 0 + if repaint_start_frame < 0: + left_pad_frame_length = abs(repaint_start_frame) + frame_length = left_pad_frame_length + gt_latents.shape[-1] + extend_gt_latents = torch.nn.functional.pad(gt_latents, (left_pad_frame_length, 0), "constant", 0) + if frame_length > max_infer_fame_length: + right_trim_length = frame_length - max_infer_fame_length + extend_gt_latents = extend_gt_latents[:, :, :, :max_infer_fame_length] + to_right_pad_gt_latents = extend_gt_latents[:, :, :, -right_trim_length:] + frame_length = max_infer_fame_length + repaint_start_frame = 0 + gt_latents = extend_gt_latents + + if repaint_end_frame > src_latents_length: + right_pad_frame_length = repaint_end_frame - gt_latents.shape[-1] + frame_length = gt_latents.shape[-1] + right_pad_frame_length + extend_gt_latents = torch.nn.functional.pad(gt_latents, (0, right_pad_frame_length), "constant", 0) + if frame_length > max_infer_fame_length: + left_trim_length = frame_length - max_infer_fame_length + extend_gt_latents = extend_gt_latents[:, :, :, -max_infer_fame_length:] + to_left_pad_gt_latents = extend_gt_latents[:, :, :, :left_trim_length] + frame_length = max_infer_fame_length + repaint_end_frame = frame_length + gt_latents = extend_gt_latents + + repaint_mask = torch.zeros((bsz, 8, 16, frame_length), device=self.device, dtype=self.dtype) + if left_pad_frame_length > 0: + repaint_mask[:, :, :, :left_pad_frame_length] = 1.0 + if right_pad_frame_length > 0: + repaint_mask[:, :, :, -right_pad_frame_length:] = 1.0 + x0 = gt_latents + padd_list = [] + if left_pad_frame_length > 0: + padd_list.append(retake_latents[:, :, :, :left_pad_frame_length]) + padd_list.append( + target_latents[ + :, + :, + :, + left_trim_length : target_latents.shape[-1] - right_trim_length, + ] + ) + if right_pad_frame_length > 0: + padd_list.append(retake_latents[:, :, :, -right_pad_frame_length:]) + target_latents = torch.cat(padd_list, dim=-1) + assert target_latents.shape[-1] == x0.shape[-1], f"{target_latents.shape=} {x0.shape=}" + zt_edit = x0.clone() + z0 = target_latents + + if audio2audio_enable and ref_latents is not None: + logger.info(f"audio2audio_enable: {audio2audio_enable}, ref_latents: {ref_latents.shape}") + target_latents, timesteps, scheduler, num_inference_steps = self.add_latents_noise( + gt_latents=ref_latents, + sigma_max=(1 - ref_audio_strength), + noise=target_latents, + scheduler_type=scheduler_type, + infer_steps=infer_steps, + ) + + attention_mask = torch.ones(bsz, frame_length, device=self.device, dtype=self.dtype) + + # guidance interval + start_idx = int(num_inference_steps * ((1 - guidance_interval) / 2)) + end_idx = int(num_inference_steps * (guidance_interval / 2 + 0.5)) + logger.info(f"start_idx: {start_idx}, end_idx: {end_idx}, num_inference_steps: {num_inference_steps}") + + momentum_buffer = MomentumBuffer() + + # P(speaker, text, lyric) + encoder_hidden_states, encoder_hidden_mask = self.ace_step_transformer.encode( + encoder_text_hidden_states, + text_attention_mask, + speaker_embds, + lyric_token_ids, + lyric_mask, + ) + + if use_erg_lyric: + # P(null_speaker, text_weaker, lyric_weaker) + encoder_hidden_states_null, _ = self.ace_step_transformer.encode_with_temperature( + encoder_text_hidden_states=( + encoder_text_hidden_states_null if encoder_text_hidden_states_null is not None else torch.zeros_like(encoder_text_hidden_states) + ), + text_attention_mask=text_attention_mask, + speaker_embeds=torch.zeros_like(speaker_embds), + lyric_token_idx=lyric_token_ids, + lyric_mask=lyric_mask, + ) + else: + # P(null_speaker, null_text, null_lyric) + encoder_hidden_states_null, _ = self.ace_step_transformer.encode( + torch.zeros_like(encoder_text_hidden_states), + text_attention_mask, + torch.zeros_like(speaker_embds), + torch.zeros_like(lyric_token_ids), + lyric_mask, + ) + + encoder_hidden_states_no_lyric = None + if do_double_condition_guidance: + # P(null_speaker, text, lyric_weaker) + if use_erg_lyric: + encoder_hidden_states_no_lyric, _ = self.ace_step_transformer.encode_with_temperature( + encoder_text_hidden_states=encoder_text_hidden_states, + text_attention_mask=text_attention_mask, + speaker_embeds=torch.zeros_like(speaker_embds), + lyric_token_idx=lyric_token_ids, + lyric_mask=lyric_mask, + ) + # P(null_speaker, text, no_lyric) + else: + encoder_hidden_states_no_lyric, _ = self.ace_step_transformer.encode( + encoder_text_hidden_states, + text_attention_mask, + torch.zeros_like(speaker_embds), + torch.zeros_like(lyric_token_ids), + lyric_mask, + ) + + for i, t in tqdm(enumerate(timesteps), total=num_inference_steps): + if is_repaint: + if i < n_min: + continue + elif i == n_min: + t_i = t / 1000 + zt_src = (1 - t_i) * x0 + (t_i) * z0 + target_latents = zt_edit + zt_src - x0 + logger.info(f"repaint start from {n_min} add {t_i} level of noise") + + # expand the latents if we are doing classifier free guidance + latents = target_latents + + is_in_guidance_interval = start_idx <= i < end_idx + if is_in_guidance_interval and do_classifier_free_guidance: + # compute current guidance scale + if guidance_interval_decay > 0: + # Linearly interpolate to calculate the current guidance scale + progress = (i - start_idx) / (end_idx - start_idx - 1) # 归一化到[0,1] + current_guidance_scale = guidance_scale - (guidance_scale - min_guidance_scale) * progress * guidance_interval_decay + else: + current_guidance_scale = guidance_scale + + latent_model_input = latents + timestep = t.expand(latent_model_input.shape[0]) + output_length = latent_model_input.shape[-1] + # P(x|speaker, text, lyric) + noise_pred_with_cond = self.ace_step_transformer.decode( + hidden_states=latent_model_input, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_mask=encoder_hidden_mask, + output_length=output_length, + timestep=timestep, + ).sample + + noise_pred_with_only_text_cond = None + if do_double_condition_guidance and encoder_hidden_states_no_lyric is not None: + noise_pred_with_only_text_cond = self.ace_step_transformer.decode( + hidden_states=latent_model_input, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states_no_lyric, + encoder_hidden_mask=encoder_hidden_mask, + output_length=output_length, + timestep=timestep, + ).sample + + if use_erg_diffusion: + noise_pred_uncond = self.ace_step_transformer.decode_with_temperature( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states_null, + encoder_hidden_mask=encoder_hidden_mask, + output_length=output_length, + attention_mask=attention_mask, + ) + else: + noise_pred_uncond = self.ace_step_transformer.decode( + hidden_states=latent_model_input, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states_null, + encoder_hidden_mask=encoder_hidden_mask, + output_length=output_length, + timestep=timestep, + ).sample + + if do_double_condition_guidance and noise_pred_with_only_text_cond is not None: + noise_pred = cfg_double_condition_forward( + cond_output=noise_pred_with_cond, + uncond_output=noise_pred_uncond, + only_text_cond_output=noise_pred_with_only_text_cond, + guidance_scale_text=guidance_scale_text, + guidance_scale_lyric=guidance_scale_lyric, + ) + + elif cfg_type == "apg": + noise_pred = apg_forward( + pred_cond=noise_pred_with_cond, + pred_uncond=noise_pred_uncond, + guidance_scale=current_guidance_scale, + momentum_buffer=momentum_buffer, + ) + elif cfg_type == "cfg": + noise_pred = cfg_forward( + cond_output=noise_pred_with_cond, + uncond_output=noise_pred_uncond, + cfg_strength=current_guidance_scale, + ) + elif cfg_type == "cfg_star": + noise_pred = cfg_zero_star( + noise_pred_with_cond=noise_pred_with_cond, + noise_pred_uncond=noise_pred_uncond, + guidance_scale=current_guidance_scale, + i=i, + zero_steps=zero_steps, + use_zero_init=use_zero_init, + ) + else: + latent_model_input = latents + timestep = t.expand(latent_model_input.shape[0]) + noise_pred = self.ace_step_transformer.decode( + hidden_states=latent_model_input, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_mask=encoder_hidden_mask, + output_length=latent_model_input.shape[-1], + timestep=timestep, + ).sample + + if is_repaint and i >= n_min: + t_i = t / 1000 + if i + 1 < len(timesteps): + t_im1 = (timesteps[i + 1]) / 1000 + else: + t_im1 = torch.zeros_like(t_i).to(self.device) + target_latents = target_latents.to(torch.float32) + prev_sample = target_latents + (t_im1 - t_i) * noise_pred + prev_sample = prev_sample.to(self.dtype) + target_latents = prev_sample + zt_src = (1 - t_im1) * x0 + (t_im1) * z0 + target_latents = torch.where(repaint_mask == 1.0, target_latents, zt_src) + else: + target_latents = scheduler.step( + model_output=noise_pred, + timestep=t, + sample=target_latents, + return_dict=False, + omega=omega_scale, + generator=random_generators[0], + )[0] + + if is_extend: + if to_right_pad_gt_latents is not None: + target_latents = torch.cat([target_latents, to_right_pad_gt_latents], dim=-1) + if to_left_pad_gt_latents is not None: + target_latents = torch.cat([to_right_pad_gt_latents, target_latents], dim=0) + return target_latents + + def load_lora(self, model_with_lora_path, device="CPU"): + if model_with_lora_path == "none": + if self.ace_step_transformer_origin: + self.ace_step_transformer = self.ace_step_transformer_origin + else: + self.ace_step_transformer_origin = self.ace_step_transformer + self.update_transformer_model(model_with_lora_path, device) + + def update_transformer_model(self, new_transformer_path, device="CPU"): + self.ace_step_transformer = OvWrapperACEStepTransformer2DModel.from_pretrained(self.core, new_transformer_path, device)