Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Further updates to make it work with TensorFlow>=2.0 and librosa>=0.7 #411

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@
import os

import librosa
import soundfile
import numpy as np
import tensorflow as tf
#import tensorflow as tf
import tensorflow.compat.v1 as tf

from wavenet import WaveNetModel, mu_law_decode, mu_law_encode, audio_reader

tf.disable_v2_behavior()


SAMPLES = 16000
TEMPERATURE = 1.0
LOGDIR = './logdir'
Expand Down Expand Up @@ -112,7 +117,7 @@ def _ensure_positive_float(f):

def write_wav(waveform, sample_rate, filename):
y = np.array(waveform)
librosa.output.write_wav(filename, y, sample_rate)
soundfile.write(filename, y, sample_rate)
print('Updated wav file at {}'.format(filename))


Expand Down
3 changes: 2 additions & 1 deletion test/test_causal_conv.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Unit tests for the causal_conv op."""

import numpy as np
import tensorflow as tf
#import tensorflow as tf
import tensorflow.compat.v1 as tf

from wavenet import time_to_batch, batch_to_time, causal_conv

Expand Down
6 changes: 5 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@
import sys
import time

import tensorflow as tf
#import tensorflow as tf
import tensorflow.compat.v1 as tf
from tensorflow.python.client import timeline

from wavenet import WaveNetModel, AudioReader, optimizer_factory

tf.disable_v2_behavior()


BATCH_SIZE = 1
DATA_DIRECTORY = './VCTK-Corpus'
LOGDIR_ROOT = './logdir'
Expand Down
5 changes: 3 additions & 2 deletions wavenet/audio_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

import librosa
import numpy as np
import tensorflow as tf
#import tensorflow as tf
import tensorflow.compat.v1 as tf

FILE_PATTERN = r'p([0-9]+)_([0-9]+)\.wav'

Expand Down Expand Up @@ -65,7 +66,7 @@ def trim_silence(audio, threshold, frame_length=2048):
'''Removes silence at the beginning and end of a sample.'''
if audio.size < frame_length:
frame_length = audio.size
energy = librosa.feature.rmse(audio, frame_length=frame_length)
energy = librosa.feature.rms(y=audio, frame_length=frame_length)
frames = np.nonzero(energy > threshold)
indices = librosa.core.frames_to_samples(frames)[1]

Expand Down
6 changes: 4 additions & 2 deletions wavenet/model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import numpy as np
import tensorflow as tf
#import tensorflow as tf
import tensorflow.compat.v1 as tf
import tensorflow.keras as keras

from .ops import causal_conv, mu_law_encode


def create_variable(name, shape):
'''Create a convolution filter variable with the specified name and shape,
and initialize it using Xavier initialition.'''
initializer = tf.contrib.layers.xavier_initializer_conv2d()
initializer = keras.initializers.glorot_normal()
variable = tf.Variable(initializer(shape=shape), name=name)
return variable

Expand Down
3 changes: 2 additions & 1 deletion wavenet/ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import division

import tensorflow as tf
#import tensorflow as tf
import tensorflow.compat.v1 as tf


def create_adam_optimizer(learning_rate, momentum):
Expand Down