Skip to content

Commit f8d1574

Browse files
sueannthunterdb
authored andcommitted
Add ResNet50 to DeepImageFeaturizer (#57)
* resnet50 works * clean up * clearer comments
1 parent 3f668d9 commit f8d1574

File tree

4 files changed

+75
-4
lines changed

4 files changed

+75
-4
lines changed

python/sparkdl/transformers/keras_applications.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from abc import ABCMeta, abstractmethod
1717

1818
import keras.backend as K
19-
from keras.applications import inception_v3, xception
19+
from keras.applications import inception_v3, xception, resnet50
20+
import numpy as np
2021
import tensorflow as tf
2122

2223
from sparkdl.transformers.utils import (imageInputPlaceholder, InceptionV3Constants)
@@ -104,9 +105,41 @@ def inputShape(self):
104105
def _testKerasModel(self, include_top):
105106
return xception.Xception(weights="imagenet", include_top=include_top)
106107

108+
class ResNet50Model(KerasApplicationModel):
109+
def preprocess(self, inputImage):
110+
return _imagenet_preprocess_input(inputImage, self.inputShape())
111+
112+
def model(self, preprocessed, featurize):
113+
return resnet50.ResNet50(input_tensor=preprocessed, weights="imagenet",
114+
include_top=(not featurize))
115+
116+
def inputShape(self):
117+
return (224, 224)
118+
119+
def _testKerasModel(self, include_top):
120+
return resnet50.ResNet50(weights="imagenet", include_top=include_top)
121+
122+
def _imagenet_preprocess_input(x, input_shape):
123+
"""
124+
For ResNet50, VGG models. For InceptionV3 and Xception it's okay to use the
125+
keras version (e.g. InceptionV3.preprocess_input) as the code path they hit
126+
works okay with tf.Tensor inputs. The following was translated to tf ops from
127+
https://github.com/fchollet/keras/blob/fb4a0849cf4dc2965af86510f02ec46abab1a6a4/keras/applications/imagenet_utils.py#L52
128+
It's a possibility to change the implementation in keras to look like the
129+
following, but not doing it for now.
130+
"""
131+
# 'RGB'->'BGR'
132+
x = x[..., ::-1]
133+
# Zero-center by mean pixel
134+
mean = np.ones(input_shape + (3,), dtype=np.float32)
135+
mean[..., 0] = 103.939
136+
mean[..., 1] = 116.779
137+
mean[..., 2] = 123.68
138+
return x - mean
107139

108140
KERAS_APPLICATION_MODELS = {
109141
"InceptionV3": InceptionV3Model,
110-
"Xception": XceptionModel
142+
"Xception": XceptionModel,
143+
"ResNet50": ResNet50Model,
111144
}
112145

python/sparkdl/transformers/named_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from sparkdl.transformers.tf_image import TFImageTransformer
3030

3131

32-
SUPPORTED_MODELS = ["InceptionV3", "Xception"]
32+
SUPPORTED_MODELS = ["InceptionV3", "Xception", "ResNet50"]
3333

3434

3535
class DeepImagePredictor(Transformer, HasInputCol, HasOutputCol):
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright 2017 Databricks, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
from .named_image_test import NamedImageTransformerBaseTestCase
17+
18+
class NamedImageTransformerResNet50Test(NamedImageTransformerBaseTestCase):
19+
20+
__test__ = True
21+
name = "ResNet50"

python/tests/transformers/named_image_test.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#
1515

1616
import numpy as np
17+
from keras.applications import resnet50
1718
import tensorflow as tf
1819

1920
from pyspark.ml import Pipeline
@@ -29,10 +30,26 @@
2930
from .image_utils import getSampleImageDF, getSampleImageList
3031

3132

32-
class GetKerasApplicationModelTestCase(SparkDLTestCase):
33+
class KerasApplicationModelTestCase(SparkDLTestCase):
3334
def test_getKerasApplicationModelError(self):
3435
self.assertRaises(ValueError, keras_apps.getKerasApplicationModel, "NotAModelABC")
3536

37+
def test_imagenet_preprocess_input(self):
38+
# compare our tf implementation to the np implementation in keras
39+
image = np.zeros((256, 256, 3))
40+
41+
sess = tf.Session()
42+
with sess.as_default():
43+
x = tf.placeholder(tf.float32, shape=[256, 256, 3])
44+
processed = keras_apps._imagenet_preprocess_input(x, (256, 256)),
45+
sparkdl_preprocessed_input = sess.run(processed, {x: image})
46+
47+
keras_preprocessed_input = resnet50.preprocess_input(np.expand_dims(image, axis=0))
48+
49+
# NOTE: precision errors occur for decimal > 5
50+
np.testing.assert_array_almost_equal(sparkdl_preprocessed_input, keras_preprocessed_input,
51+
decimal=5)
52+
3653

3754
class NamedImageTransformerBaseTestCase(SparkDLTestCase):
3855
"""

0 commit comments

Comments
 (0)