Skip to content

TensorFlow v2 & linting #91

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ EfficientNets rely on AutoML and compound scaling to achieve superior performanc
<table border="0">
<tr>
<td>
<img src="https://raw.githubusercontent.com/tensorflow/tpu/master/models/official/efficientnet/g3doc/params.png" width="100%" />
<img src="https://raw.githubusercontent.com/tensorflow/tpu/master/models/official/efficientnet/g3doc/params.png" alt="Params" width="100%" />
</td>
<td>
<img src="https://raw.githubusercontent.com/tensorflow/tpu/master/models/official/efficientnet/g3doc/flops.png", width="90%" />
<img src="https://raw.githubusercontent.com/tensorflow/tpu/master/models/official/efficientnet/g3doc/flops.png" alt="Flops" width="90%" />
</td>
</tr>
</table>
Expand Down
2 changes: 2 additions & 0 deletions efficientnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def get_submodules_from_kwargs(kwargs):

def inject_keras_modules(func):
import keras

@functools.wraps(func)
def wrapper(*args, **kwargs):
kwargs['backend'] = keras.backend
Expand All @@ -48,6 +49,7 @@ def wrapper(*args, **kwargs):

def inject_tfkeras_modules(func):
import tensorflow.keras as tfkeras

@functools.wraps(func)
def wrapper(*args, **kwargs):
kwargs['backend'] = tfkeras.backend
Expand Down
6 changes: 2 additions & 4 deletions efficientnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,12 @@
from __future__ import print_function

import os
import json
import math
import string
import collections
import numpy as np

from six.moves import xrange
from keras_applications.imagenet_utils import _obtain_input_shape
from keras_applications.imagenet_utils import decode_predictions
from keras_applications.imagenet_utils import preprocess_input as _preprocess_input

from . import get_submodules_from_kwargs
Expand Down Expand Up @@ -139,6 +136,7 @@ def preprocess_input(x, **kwargs):

def get_swish(**kwargs):
backend, layers, models, keras_utils = get_submodules_from_kwargs(kwargs)

def swish(x):
"""Swish activation function: x * sigmoid(x).
Reference: [Searching for Activation Functions](https://arxiv.org/abs/1710.05941)
Expand All @@ -153,7 +151,7 @@ def swish(x):
pass

return x * backend.sigmoid(x)
return swish
return swish


def get_dropout(**kwargs):
Expand Down
2 changes: 1 addition & 1 deletion efficientnet/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np

from skimage.transform import resize

MAP_INTERPOLATION_TO_ORDER = {
Expand Down
20 changes: 9 additions & 11 deletions scripts/load_efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import argparse
import sys

import numpy as np

import tensorflow as tf
import efficientnet.keras
from keras.layers import BatchNormalization, Conv2D, Dense
Expand Down Expand Up @@ -93,14 +91,14 @@ def convert_tensorflow_model(
""" Loads and saves a TensorFlow model. """
image_files = [example_img]
eval_ckpt_driver = eval_ckpt_main.EvalCkptDriver(model_name)
with tf.Graph().as_default(), tf.Session() as sess:
with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
images, _ = eval_ckpt_driver.build_dataset(
image_files, [0] * len(image_files), False
)
eval_ckpt_driver.build_model(images, is_training=False)
sess.run(tf.global_variables_initializer())
sess.run(tf.compat.v1.global_variables_initializer())
eval_ckpt_driver.restore_model(sess, model_ckpt)
global_variables = tf.global_variables()
global_variables = tf.compat.v1.global_variables()
weights = dict()
for variable in global_variables:
try:
Expand Down Expand Up @@ -149,15 +147,15 @@ def convert_tensorflow_model(
default="true",
help="Whether to include metadata in the serialized Keras model",
)
args = parser.parse_args()
cli_args = parser.parse_args()

sys.path.append(args.source)
sys.path.append(cli_args.source)
import eval_ckpt_main

true_values = ("yes", "true", "t", "1", "y")
convert_tensorflow_model(
model_name=args.model_name,
model_ckpt=args.tf_checkpoint,
output_file=args.output_file,
weights_only=args.weights_only in true_values,
model_name=cli_args.model_name,
model_ckpt=cli_args.tf_checkpoint,
output_file=cli_args.output_file,
weights_only=cli_args.weights_only in true_values,
)
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright 2019 The TensorFlow Authors, Pavel Yakubovskiy. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -12,8 +15,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Note: To use the 'upload' functionality of this file, you must:
# $ pip install twine
Expand Down
2 changes: 1 addition & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@ def test_models_result(args):


if __name__ == "__main__":
pytest.main([__file__])
pytest.main([__file__])