From 6330b6cfd63f854b1fbac297e41e7316161374b5 Mon Sep 17 00:00:00 2001 From: johnpaulbin Date: Sat, 3 Dec 2022 20:16:06 -0800 Subject: [PATCH 01/15] new istft vocoder --- uberduck_ml_dev/vocoders/istftnet.py | 633 +++++++++++++++++++++++++++ 1 file changed, 633 insertions(+) create mode 100644 uberduck_ml_dev/vocoders/istftnet.py diff --git a/uberduck_ml_dev/vocoders/istftnet.py b/uberduck_ml_dev/vocoders/istftnet.py new file mode 100644 index 00000000..a934d5ee --- /dev/null +++ b/uberduck_ml_dev/vocoders/istftnet.py @@ -0,0 +1,633 @@ +__all__ = [ + "iSTFTNetGenerator", + "ResBlock1", + "ResBlock2", + "Generator", + "DiscriminatorP", + "MultiPeriodDiscriminator", + "DiscriminatorS", + "MultiScaleDiscriminator", + "feature_loss", + "discriminator_loss", + "generator_loss", + "LRELU_SLOPE", + "AttrDict", + "build_env", + "init_weights", + "apply_weight_norm", + "get_padding", +] + + + +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm +from utils import init_weights, get_padding +import numpy as np +from torch.autograd import Variable +from scipy.signal import get_window +from librosa.util import pad_center, tiny +import librosa.util as librosa_util +import glob +import os +import shutil +import matplotlib +import torch +from torch.nn.utils import weight_norm +matplotlib.use("Agg") +import matplotlib.pylab as plt + + +LRELU_SLOPE = 0.1 + + +""" +BSD 3-Clause License +Copyright (c) 2017, Prem Seetharaman +All rights reserved. +* Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: +* Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, this + list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from this + software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + + + +def window_sumsquare(window, n_frames, hop_length=200, win_length=800, + n_fft=800, dtype=np.float32, norm=None): + """ + # from librosa 0.6 + Compute the sum-square envelope of a window function at a given hop length. + This is used to estimate modulation effects induced by windowing + observations in short-time fourier transforms. + Parameters + ---------- + window : string, tuple, number, callable, or list-like + Window specification, as in `get_window` + n_frames : int > 0 + The number of analysis frames + hop_length : int > 0 + The number of samples to advance between frames + win_length : [optional] + The length of the window function. By default, this matches `n_fft`. + n_fft : int > 0 + The length of each analysis frame. + dtype : np.dtype + The data type of the output + Returns + ------- + wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` + The sum-squared envelope of the window function + """ + if win_length is None: + win_length = n_fft + + n = n_fft + hop_length * (n_frames - 1) + x = np.zeros(n, dtype=dtype) + + # Compute the squared window at the desired length + win_sq = get_window(window, win_length, fftbins=True) + win_sq = librosa_util.normalize(win_sq, norm=norm)**2 + win_sq = librosa_util.pad_center(win_sq, n_fft) + + # Fill the envelope + for i in range(n_frames): + sample = i * hop_length + x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] + return x + + +class STFT(torch.nn.Module): + """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" + def __init__(self, filter_length=800, hop_length=200, win_length=800, + window='hann'): + super(STFT, self).__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + self.window = window + self.forward_transform = None + scale = self.filter_length / self.hop_length + fourier_basis = np.fft.fft(np.eye(self.filter_length)) + + cutoff = int((self.filter_length / 2 + 1)) + fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), + np.imag(fourier_basis[:cutoff, :])]) + + forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) + inverse_basis = torch.FloatTensor( + np.linalg.pinv(scale * fourier_basis).T[:, None, :]) + + if window is not None: + assert(filter_length >= win_length) + # get window and zero center pad it to filter_length + fft_window = get_window(window, win_length, fftbins=True) + fft_window = pad_center(fft_window, filter_length) + fft_window = torch.from_numpy(fft_window).float() + + # window the bases + forward_basis *= fft_window + inverse_basis *= fft_window + + self.register_buffer('forward_basis', forward_basis.float()) + self.register_buffer('inverse_basis', inverse_basis.float()) + + def transform(self, input_data): + num_batches = input_data.size(0) + num_samples = input_data.size(1) + + self.num_samples = num_samples + + # similar to librosa, reflect-pad the input + input_data = input_data.view(num_batches, 1, num_samples) + input_data = F.pad( + input_data.unsqueeze(1), + (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), + mode='reflect') + input_data = input_data.squeeze(1) + + forward_transform = F.conv1d( + input_data, + Variable(self.forward_basis, requires_grad=False), + stride=self.hop_length, + padding=0) + + cutoff = int((self.filter_length / 2) + 1) + real_part = forward_transform[:, :cutoff, :] + imag_part = forward_transform[:, cutoff:, :] + + magnitude = torch.sqrt(real_part**2 + imag_part**2) + phase = torch.autograd.Variable( + torch.atan2(imag_part.data, real_part.data)) + + return magnitude, phase + + def inverse(self, magnitude, phase): + recombine_magnitude_phase = torch.cat( + [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1) + + inverse_transform = F.conv_transpose1d( + recombine_magnitude_phase, + Variable(self.inverse_basis, requires_grad=False), + stride=self.hop_length, + padding=0) + + if self.window is not None: + window_sum = window_sumsquare( + self.window, magnitude.size(-1), hop_length=self.hop_length, + win_length=self.win_length, n_fft=self.filter_length, + dtype=np.float32) + # remove modulation effects + approx_nonzero_indices = torch.from_numpy( + np.where(window_sum > tiny(window_sum))[0]) + window_sum = torch.autograd.Variable( + torch.from_numpy(window_sum), requires_grad=False) + window_sum = window_sum.to(inverse_transform.device()) if magnitude.is_cuda else window_sum + inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] + + # scale by hop ratio + inverse_transform *= float(self.filter_length) / self.hop_length + + inverse_transform = inverse_transform[:, :, int(self.filter_length/2):] + inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):] + + return inverse_transform + + def forward(self, input_data): + self.magnitude, self.phase = self.transform(input_data) + reconstruction = self.inverse(self.magnitude, self.phase) + return reconstruction + + +class TorchSTFT(torch.nn.Module): + def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'): + super().__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + self.window = torch.from_numpy(get_window(window, win_length, fftbins=True).astype(np.float32)) + + def transform(self, input_data): + forward_transform = torch.stft( + input_data, + self.filter_length, self.hop_length, self.win_length, window=self.window, + return_complex=True) + + return torch.abs(forward_transform), torch.angle(forward_transform) + + def inverse(self, magnitude, phase): + inverse_transform = torch.istft( + magnitude * torch.exp(phase * 1j), + self.filter_length, self.hop_length, self.win_length, window=self.window) + + return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation + + def forward(self, input_data): + self.magnitude, self.phase = self.transform(input_data) + reconstruction = self.inverse(self.magnitude, self.phase) + return reconstruction + + +class iSTFTNetGenerator(nn.Module): + def __init__(self, config, checkpoint, cudnn_enabled=False): + super().__init__() + self.config = config + self.checkpoint = checkpoint + self.stft = + self.device = "cuda" if torch.cuda.is_available() and cudnn_enabled else "cpu" + self.vocoder, self.stft = self.load_checkpoint().eval() + self.vocoder.remove_weight_norm() + + @torch.no_grad() + def load_checkpoint(self): + h = self.load_config() + vocoder = Generator(h) + stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft).to(device) + vocoder.load_state_dict( + torch.load( + self.checkpoint, + map_location="cuda" if self.device == "cuda" else "cpu", + )["generator"] + ) + if self.device == "cuda": + vocoder = vocoder.cuda() + return vocoder, stft + + @torch.no_grad() + def load_config(self): + with open(self.config) as f: + h = AttrDict(json.load(f)) + return h + + def forward(self, mel, max_wav_value=32768): + return self.infer(mel, max_wav_value=max_wav_value) + + @torch.no_grad() + def infer(self, mel, max_wav_value=32768): + spec, phase = self.vocoder.generator(x) + y_g_hat = self.stft.inverse(spec, phase) + audio = ( + y_g_hat.cpu().squeeze().clamp(-1, 1).numpy() + * max_wav_value + ).astype(np.int16) + return audio + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def build_env(config, config_name, path): + t_path = os.path.join(path, config_name) + if config != t_path: + os.makedirs(path, exist_ok=True) + shutil.copyfile(config, os.path.join(path, config_name)) + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", + interpolation='none') + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print("Loading '{}'".format(filepath)) + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def save_checkpoint(filepath, obj): + print("Saving checkpoint to {}".format(filepath)) + torch.save(obj, filepath) + print("Complete.") + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + '????????') + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return None + return sorted(cp_list)[-1] + + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.h = h + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.h = h + self.convs = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))) + ]) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Generator(torch.nn.Module): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)) + resblock = ResBlock1 if h.resblock == '1' else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append(weight_norm( + ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)), + k, u, padding=(k-u)//2))) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel//(2**(i+1)) + for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): + self.resblocks.append(resblock(h, ch, k, d)) + + self.post_n_fft = h.gen_istft_n_fft + self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + self.reflection_pad = torch.nn.ReflectionPad1d((1, 0)) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i*self.num_kernels+j](x) + else: + xs += self.resblocks[i*self.num_kernels+j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.reflection_pad(x) + x = self.conv_post(x) + spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :]) + phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :]) + + return spec, phase + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self): + super(MultiPeriodDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorP(2), + DiscriminatorP(3), + DiscriminatorP(5), + DiscriminatorP(7), + DiscriminatorP(11), + ]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self): + super(MultiScaleDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ]) + self.meanpools = nn.ModuleList([ + AvgPool1d(4, 2, padding=2), + AvgPool1d(4, 2, padding=2) + ]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i-1](y) + y_hat = self.meanpools[i-1](y_hat) + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss*2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1-dr)**2) + g_loss = torch.mean(dg**2) + loss += (r_loss + g_loss) + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1-dg)**2) + gen_losses.append(l) + loss += l + + return loss, gen_losses From 2f7d72b9e097d6718a3bae4ee5a77539db84b3c9 Mon Sep 17 00:00:00 2001 From: johnpaulbin Date: Sat, 3 Dec 2022 20:16:29 -0800 Subject: [PATCH 02/15] credit --- uberduck_ml_dev/vocoders/istftnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uberduck_ml_dev/vocoders/istftnet.py b/uberduck_ml_dev/vocoders/istftnet.py index a934d5ee..cd320b01 100644 --- a/uberduck_ml_dev/vocoders/istftnet.py +++ b/uberduck_ml_dev/vocoders/istftnet.py @@ -18,7 +18,7 @@ "get_padding", ] - +""" from https://github.com/rishikksh20/iSTFTNet-pytorch """ import torch import torch.nn.functional as F From 438620a2eb9188361f46af932fe12e73f326402b Mon Sep 17 00:00:00 2001 From: johnpaulbin Date: Sat, 3 Dec 2022 21:43:28 -0800 Subject: [PATCH 03/15] fix device --- uberduck_ml_dev/vocoders/istftnet.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/uberduck_ml_dev/vocoders/istftnet.py b/uberduck_ml_dev/vocoders/istftnet.py index cd320b01..25191e1d 100644 --- a/uberduck_ml_dev/vocoders/istftnet.py +++ b/uberduck_ml_dev/vocoders/istftnet.py @@ -219,7 +219,7 @@ def forward(self, input_data): class TorchSTFT(torch.nn.Module): - def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'): + def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann', device="cpu"): super().__init__() self.filter_length = filter_length self.hop_length = hop_length @@ -237,7 +237,7 @@ def transform(self, input_data): def inverse(self, magnitude, phase): inverse_transform = torch.istft( magnitude * torch.exp(phase * 1j), - self.filter_length, self.hop_length, self.win_length, window=self.window) + self.filter_length, self.hop_length, self.win_length, window=self.window.to(self.device)) return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation @@ -252,7 +252,6 @@ def __init__(self, config, checkpoint, cudnn_enabled=False): super().__init__() self.config = config self.checkpoint = checkpoint - self.stft = self.device = "cuda" if torch.cuda.is_available() and cudnn_enabled else "cpu" self.vocoder, self.stft = self.load_checkpoint().eval() self.vocoder.remove_weight_norm() @@ -261,7 +260,7 @@ def __init__(self, config, checkpoint, cudnn_enabled=False): def load_checkpoint(self): h = self.load_config() vocoder = Generator(h) - stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft).to(device) + stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft, device=device).to(device) vocoder.load_state_dict( torch.load( self.checkpoint, From f6ac618be9dff8ee44115c9e4178b5fccdba92e5 Mon Sep 17 00:00:00 2001 From: johnpaulbin Date: Sun, 4 Dec 2022 23:56:24 -0800 Subject: [PATCH 04/15] Update istftnet.py --- uberduck_ml_dev/vocoders/istftnet.py | 1 - 1 file changed, 1 deletion(-) diff --git a/uberduck_ml_dev/vocoders/istftnet.py b/uberduck_ml_dev/vocoders/istftnet.py index 25191e1d..bf6d1db9 100644 --- a/uberduck_ml_dev/vocoders/istftnet.py +++ b/uberduck_ml_dev/vocoders/istftnet.py @@ -25,7 +25,6 @@ import torch.nn as nn from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm -from utils import init_weights, get_padding import numpy as np from torch.autograd import Variable from scipy.signal import get_window From 52c682d3d663edf623709429593ee1435911a31e Mon Sep 17 00:00:00 2001 From: johnpaulbin Date: Sun, 4 Dec 2022 23:57:43 -0800 Subject: [PATCH 05/15] Update istftnet.py --- uberduck_ml_dev/vocoders/istftnet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/uberduck_ml_dev/vocoders/istftnet.py b/uberduck_ml_dev/vocoders/istftnet.py index bf6d1db9..eabcf8da 100644 --- a/uberduck_ml_dev/vocoders/istftnet.py +++ b/uberduck_ml_dev/vocoders/istftnet.py @@ -38,6 +38,7 @@ from torch.nn.utils import weight_norm matplotlib.use("Agg") import matplotlib.pylab as plt +import json LRELU_SLOPE = 0.1 From f9331704cec2d121e89f5bce1e4298dbdfe0ec7b Mon Sep 17 00:00:00 2001 From: johnpaulbin Date: Sun, 4 Dec 2022 23:58:53 -0800 Subject: [PATCH 06/15] Update istftnet.py --- uberduck_ml_dev/vocoders/istftnet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/uberduck_ml_dev/vocoders/istftnet.py b/uberduck_ml_dev/vocoders/istftnet.py index eabcf8da..ba2e19ae 100644 --- a/uberduck_ml_dev/vocoders/istftnet.py +++ b/uberduck_ml_dev/vocoders/istftnet.py @@ -253,14 +253,14 @@ def __init__(self, config, checkpoint, cudnn_enabled=False): self.config = config self.checkpoint = checkpoint self.device = "cuda" if torch.cuda.is_available() and cudnn_enabled else "cpu" - self.vocoder, self.stft = self.load_checkpoint().eval() + self.vocoder, self.stft = self.load_checkpoint() self.vocoder.remove_weight_norm() @torch.no_grad() def load_checkpoint(self): h = self.load_config() vocoder = Generator(h) - stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft, device=device).to(device) + stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft, device=self.device).to(self.device) vocoder.load_state_dict( torch.load( self.checkpoint, @@ -269,7 +269,7 @@ def load_checkpoint(self): ) if self.device == "cuda": vocoder = vocoder.cuda() - return vocoder, stft + return vocoder.eval(), stft @torch.no_grad() def load_config(self): From e1d83b7ae9437eb5cbb45e590af77a3904f93479 Mon Sep 17 00:00:00 2001 From: johnpaulbin Date: Mon, 5 Dec 2022 00:02:32 -0800 Subject: [PATCH 07/15] Update istftnet.py --- uberduck_ml_dev/vocoders/istftnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uberduck_ml_dev/vocoders/istftnet.py b/uberduck_ml_dev/vocoders/istftnet.py index ba2e19ae..3e669dba 100644 --- a/uberduck_ml_dev/vocoders/istftnet.py +++ b/uberduck_ml_dev/vocoders/istftnet.py @@ -282,7 +282,7 @@ def forward(self, mel, max_wav_value=32768): @torch.no_grad() def infer(self, mel, max_wav_value=32768): - spec, phase = self.vocoder.generator(x) + spec, phase = self.vocoder(x) y_g_hat = self.stft.inverse(spec, phase) audio = ( y_g_hat.cpu().squeeze().clamp(-1, 1).numpy() From 930dc7ab285effb26fb5f258c8030d18646829cf Mon Sep 17 00:00:00 2001 From: johnpaulbin Date: Mon, 5 Dec 2022 00:04:32 -0800 Subject: [PATCH 08/15] pls work --- uberduck_ml_dev/vocoders/istftnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uberduck_ml_dev/vocoders/istftnet.py b/uberduck_ml_dev/vocoders/istftnet.py index 3e669dba..06a8f5cd 100644 --- a/uberduck_ml_dev/vocoders/istftnet.py +++ b/uberduck_ml_dev/vocoders/istftnet.py @@ -282,7 +282,7 @@ def forward(self, mel, max_wav_value=32768): @torch.no_grad() def infer(self, mel, max_wav_value=32768): - spec, phase = self.vocoder(x) + spec, phase = self.vocoder(mel) y_g_hat = self.stft.inverse(spec, phase) audio = ( y_g_hat.cpu().squeeze().clamp(-1, 1).numpy() From 4bc5a65a9bfd5c235ddb262691c8429ebffaa606 Mon Sep 17 00:00:00 2001 From: johnpaulbin Date: Mon, 5 Dec 2022 00:06:45 -0800 Subject: [PATCH 09/15] Update istftnet.py --- uberduck_ml_dev/vocoders/istftnet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/uberduck_ml_dev/vocoders/istftnet.py b/uberduck_ml_dev/vocoders/istftnet.py index 06a8f5cd..af1041af 100644 --- a/uberduck_ml_dev/vocoders/istftnet.py +++ b/uberduck_ml_dev/vocoders/istftnet.py @@ -225,6 +225,7 @@ def __init__(self, filter_length=800, hop_length=200, win_length=800, window='ha self.hop_length = hop_length self.win_length = win_length self.window = torch.from_numpy(get_window(window, win_length, fftbins=True).astype(np.float32)) + self.device = device def transform(self, input_data): forward_transform = torch.stft( From c345aec0f76313ef52bcf0c8be392f59f98a0a3e Mon Sep 17 00:00:00 2001 From: johnpaulbin Date: Mon, 5 Dec 2022 00:18:25 -0800 Subject: [PATCH 10/15] Update denoiser.py --- uberduck_ml_dev/utils/denoiser.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/uberduck_ml_dev/utils/denoiser.py b/uberduck_ml_dev/utils/denoiser.py index 09939664..f31365fb 100644 --- a/uberduck_ml_dev/utils/denoiser.py +++ b/uberduck_ml_dev/utils/denoiser.py @@ -1,15 +1,12 @@ """ -Removes bias from HiFi-Gan and Avocodo (typically heard as noise in the audio) +Removes bias from vocoders (typically heard as metalic noise in the audio) Usage: from denoiser import Denoiser -denoiser = Denoiser(HIFIGANGENERATOR, mode="normal") # Experiment with modes "normal" and "zeros" +denoiser = Denoiser(VOCODERGENERATOR, mode="normal") # Experiment with modes "normal" and "zeros" # Inference Vocoder -audio = hifigan.vocoder.forward(output[1][:1]) - -audio = audio.squeeze() -audio = audio * 32768.0 +audio = VOCODERGENERATOR.vocoder.forward(output[1][:1]) # Denoise audio_denoised = denoiser(audio.view(1, -1), strength=15)[:, 0] # Change strength if needed @@ -22,6 +19,7 @@ import sys import torch from ..models.common import STFT +from ..vocoders.istftnet import iSTFTNetGenerator class Denoiser(torch.nn.Module): @@ -46,11 +44,18 @@ def __init__( raise Exception("Mode {} if not supported".format(mode)) with torch.no_grad(): - bias_audio = ( - hifigan.vocoder.forward(mel_input.to(hifigan.device)) - .view(1, -1) - .float() - ) + if isinstance(hifigan, iSTFTNetGenerator): + bias_audio = ( + hifigan(mel_input.to(hifigan.device)) + .view(1, -1) + .float() + ) + else: + bias_audio = ( + hifigan.vocoder.forward(mel_input.to(hifigan.device)) + .view(1, -1) + .float() + ) bias_spec, _ = self.stft.transform(bias_audio.cpu()) self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None]) From 53107b05911883c480ceb7b8fea53ea8ef5d458f Mon Sep 17 00:00:00 2001 From: johnpaulbin Date: Mon, 5 Dec 2022 00:21:52 -0800 Subject: [PATCH 11/15] Update denoiser.py --- uberduck_ml_dev/utils/denoiser.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/uberduck_ml_dev/utils/denoiser.py b/uberduck_ml_dev/utils/denoiser.py index f31365fb..66613dda 100644 --- a/uberduck_ml_dev/utils/denoiser.py +++ b/uberduck_ml_dev/utils/denoiser.py @@ -1,12 +1,15 @@ """ -Removes bias from vocoders (typically heard as metalic noise in the audio) +Removes bias from HiFi-Gan and Avocodo (typically heard as noise in the audio) Usage: from denoiser import Denoiser -denoiser = Denoiser(VOCODERGENERATOR, mode="normal") # Experiment with modes "normal" and "zeros" +denoiser = Denoiser(HIFIGANGENERATOR, mode="normal") # Experiment with modes "normal" and "zeros" # Inference Vocoder -audio = VOCODERGENERATOR.vocoder.forward(output[1][:1]) +audio = hifigan.vocoder.forward(output[1][:1]) + +audio = audio.squeeze() +audio = audio * 32768.0 # Denoise audio_denoised = denoiser(audio.view(1, -1), strength=15)[:, 0] # Change strength if needed @@ -45,8 +48,10 @@ def __init__( with torch.no_grad(): if isinstance(hifigan, iSTFTNetGenerator): + spec, phase = hifigan.vocoder(mel_input.to(hifigan.device)) + y_g_hat = self.stft.inverse(spec, phase) bias_audio = ( - hifigan(mel_input.to(hifigan.device)) + y_g_hat .view(1, -1) .float() ) From c9ac30705c2adb30ce47c1851455b1b411493398 Mon Sep 17 00:00:00 2001 From: johnpaulbin Date: Mon, 5 Dec 2022 00:28:03 -0800 Subject: [PATCH 12/15] Update denoiser.py --- uberduck_ml_dev/utils/denoiser.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/uberduck_ml_dev/utils/denoiser.py b/uberduck_ml_dev/utils/denoiser.py index 66613dda..8f634adf 100644 --- a/uberduck_ml_dev/utils/denoiser.py +++ b/uberduck_ml_dev/utils/denoiser.py @@ -22,7 +22,7 @@ import sys import torch from ..models.common import STFT -from ..vocoders.istftnet import iSTFTNetGenerator +from ..vocoders.istftnet import iSTFTNetGenerator, TorchSTFT class Denoiser(torch.nn.Module): @@ -48,6 +48,7 @@ def __init__( with torch.no_grad(): if isinstance(hifigan, iSTFTNetGenerator): + self.stft = TorchSTFT(filter_length=filter_length, hop_length=int(filter_length / n_overlap), win_length=win_length, device=hifigan.device).to(hifigan.device) spec, phase = hifigan.vocoder(mel_input.to(hifigan.device)) y_g_hat = self.stft.inverse(spec, phase) bias_audio = ( From 48415675982a8970d351eef533c0f05a55d51f5f Mon Sep 17 00:00:00 2001 From: johnpaulbin Date: Mon, 5 Dec 2022 07:30:18 -0800 Subject: [PATCH 13/15] Update denoiser.py --- uberduck_ml_dev/utils/denoiser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uberduck_ml_dev/utils/denoiser.py b/uberduck_ml_dev/utils/denoiser.py index 8f634adf..f546f86d 100644 --- a/uberduck_ml_dev/utils/denoiser.py +++ b/uberduck_ml_dev/utils/denoiser.py @@ -48,7 +48,7 @@ def __init__( with torch.no_grad(): if isinstance(hifigan, iSTFTNetGenerator): - self.stft = TorchSTFT(filter_length=filter_length, hop_length=int(filter_length / n_overlap), win_length=win_length, device=hifigan.device).to(hifigan.device) + self.stft = TorchSTFT(filter_length=16, hop_length=4, win_length=16, device=hifigan.device).to(hifigan.device) spec, phase = hifigan.vocoder(mel_input.to(hifigan.device)) y_g_hat = self.stft.inverse(spec, phase) bias_audio = ( From c82068b221ebc062b2a5db08aff30f4a4f13cb7c Mon Sep 17 00:00:00 2001 From: johnpaulbin Date: Mon, 5 Dec 2022 07:40:46 -0800 Subject: [PATCH 14/15] Update denoiser.py --- uberduck_ml_dev/utils/denoiser.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/uberduck_ml_dev/utils/denoiser.py b/uberduck_ml_dev/utils/denoiser.py index f546f86d..ef7cdda4 100644 --- a/uberduck_ml_dev/utils/denoiser.py +++ b/uberduck_ml_dev/utils/denoiser.py @@ -48,8 +48,8 @@ def __init__( with torch.no_grad(): if isinstance(hifigan, iSTFTNetGenerator): - self.stft = TorchSTFT(filter_length=16, hop_length=4, win_length=16, device=hifigan.device).to(hifigan.device) - spec, phase = hifigan.vocoder(mel_input.to(hifigan.device)) + self.stft = TorchSTFT(filter_length=16, hop_length=4, win_length=16, device="cpu").to("cpu") + spec, phase = hifigan.vocoder(mel_input.cpu()) y_g_hat = self.stft.inverse(spec, phase) bias_audio = ( y_g_hat From 5ebe30fbe728e21e4ebe5f045bc2c026d4c4878a Mon Sep 17 00:00:00 2001 From: johnpaulbin Date: Mon, 5 Dec 2022 07:42:33 -0800 Subject: [PATCH 15/15] Update denoiser.py --- uberduck_ml_dev/utils/denoiser.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/uberduck_ml_dev/utils/denoiser.py b/uberduck_ml_dev/utils/denoiser.py index ef7cdda4..48a14b18 100644 --- a/uberduck_ml_dev/utils/denoiser.py +++ b/uberduck_ml_dev/utils/denoiser.py @@ -49,8 +49,8 @@ def __init__( with torch.no_grad(): if isinstance(hifigan, iSTFTNetGenerator): self.stft = TorchSTFT(filter_length=16, hop_length=4, win_length=16, device="cpu").to("cpu") - spec, phase = hifigan.vocoder(mel_input.cpu()) - y_g_hat = self.stft.inverse(spec, phase) + spec, phase = hifigan.vocoder(mel_input.to(hifigan.device)) + y_g_hat = self.stft.inverse(spec.cpu(), phase.cpu()) bias_audio = ( y_g_hat .view(1, -1)