From f5c22ac8c1f7fd038ec12551637734dd26e55225 Mon Sep 17 00:00:00 2001 From: The MorphNet Team Date: Tue, 7 Jan 2020 14:06:00 -0800 Subject: [PATCH] Internal build changes. PiperOrigin-RevId: 288568069 --- morph_net/tools/configurable_ops.py | 45 +++++++++++++++-------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/morph_net/tools/configurable_ops.py b/morph_net/tools/configurable_ops.py index 17d33dd..78a750f 100644 --- a/morph_net/tools/configurable_ops.py +++ b/morph_net/tools/configurable_ops.py @@ -50,6 +50,9 @@ import tensorflow as tf +from tensorflow.contrib import framework as contrib_framework +from tensorflow.contrib import layers as contrib_layers + gfile = tf.gfile # Aliase needed for mock. VANISHED = 0.0 @@ -80,14 +83,14 @@ class FallbackRule(Enum): DEFAULT_FUNCTION_DICT = { - 'fully_connected': tf.contrib.layers.fully_connected, - 'conv2d': tf.contrib.layers.conv2d, - 'separable_conv2d': tf.contrib.layers.separable_conv2d, + 'fully_connected': contrib_layers.fully_connected, + 'conv2d': contrib_layers.conv2d, + 'separable_conv2d': contrib_layers.separable_conv2d, 'concat': tf.concat, 'add_n': tf.add_n, - 'avg_pool2d': tf.contrib.layers.avg_pool2d, - 'max_pool2d': tf.contrib.layers.max_pool2d, - 'batch_norm': tf.contrib.layers.batch_norm, + 'avg_pool2d': contrib_layers.avg_pool2d, + 'max_pool2d': contrib_layers.max_pool2d, + 'batch_norm': contrib_layers.batch_norm, } # Maps function names to the suffix of the name of the regularized ops. @@ -164,13 +167,13 @@ def parameterization(self): """Returns the parameterization dict mapping op names to num_outputs.""" return self._parameterization - @tf.contrib.framework.add_arg_scope + @contrib_framework.add_arg_scope def conv2d(self, *args, **kwargs): """Masks num_outputs from the function pointed to by 'conv2d'. The object's parameterization has precedence over the given NUM_OUTPUTS argument. The resolution of the op names uses - tf.contrib.framework.get_name_scope() and kwargs['scope']. + contrib_framework.get_name_scope() and kwargs['scope']. Args: *args: Arguments for the operation. @@ -187,13 +190,13 @@ def conv2d(self, *args, **kwargs): fn, suffix = self._get_function_and_suffix('conv2d') return self._mask(fn, suffix, *args, **kwargs) - @tf.contrib.framework.add_arg_scope + @contrib_framework.add_arg_scope def fully_connected(self, *args, **kwargs): """Masks NUM_OUTPUTS from the function pointed to by 'fully_connected'. The object's parameterization has precedence over the given NUM_OUTPUTS argument. The resolution of the op names uses - tf.contrib.framework.get_name_scope() and kwargs['scope']. + contrib_framework.get_name_scope() and kwargs['scope']. Args: *args: Arguments for the operation. @@ -214,13 +217,13 @@ def fully_connected(self, *args, **kwargs): fn, suffix = self._get_function_and_suffix('fully_connected') return self._mask(fn, suffix, *args, **kwargs) - @tf.contrib.framework.add_arg_scope + @contrib_framework.add_arg_scope def separable_conv2d(self, *args, **kwargs): """Masks NUM_OUTPUTS from the function pointed to by 'separable_conv2d'. The object's parameterization has precedence over the given NUM_OUTPUTS argument. The resolution of the op names uses - tf.contrib.framework.get_name_scope() and kwargs['scope']. + contrib_framework.get_name_scope() and kwargs['scope']. Args: *args: Arguments for the operation. @@ -251,7 +254,7 @@ def _mask(self, function, suffix, *args, **kwargs): The object's parameterization has precedence over the given NUM_OUTPUTS argument. The resolution of the op names uses - `tf.contrib.framework.get_name_scope()` and `kwargs['scope']`. + `contrib_framework.get_name_scope()` and `kwargs['scope']`. The NUM_OUTPUTS argument is assumed to be either in **kwargs or held in *args[1]. @@ -284,7 +287,7 @@ def _mask(self, function, suffix, *args, **kwargs): # Support for tf.contrib.layers and tf.layers API. op_scope = kwargs.get('scope') or kwargs.get('name') - current_scope = tf.contrib.framework.get_name_scope() or '' + current_scope = contrib_framework.get_name_scope() or '' if current_scope and not current_scope.endswith('/'): current_scope += '/' op_name = ''.join([current_scope, op_scope, '/', suffix]) @@ -320,17 +323,17 @@ def concat(self, *args, **kwargs): def add_n(self, *args, **kwargs): return self._pass_through_mask_list('add_n', 'inputs', *args, **kwargs) - @tf.contrib.framework.add_arg_scope + @contrib_framework.add_arg_scope def avg_pool2d(self, *args, **kwargs): return self._pass_through_mask( self._function_dict['avg_pool2d'], *args, **kwargs) - @tf.contrib.framework.add_arg_scope + @contrib_framework.add_arg_scope def max_pool2d(self, *args, **kwargs): return self._pass_through_mask( self._function_dict['max_pool2d'], *args, **kwargs) - @tf.contrib.framework.add_arg_scope + @contrib_framework.add_arg_scope def batch_norm(self, *args, **kwargs): return self._pass_through_mask( self._function_dict['batch_norm'], *args, **kwargs) @@ -432,8 +435,8 @@ def hijack_module_functions(configurable_ops, module): example_module.py ``` - conv2d = tr.contrib.layers.conv2d - fully_connected = tr.contrib.layers.fully_connected + conv2d = tr.contrib_layers.conv2d + fully_connected = tr.contrib_layers.fully_connected def build_layer(inputs): return conv2d(inputs, 64, 3, scope='demo') @@ -444,7 +447,7 @@ def build_layer(inputs): So after a call to `hijack_module_functions(configurable_ops, example_module)` the call `example_module.build_layer(net)` will under the hood use - `configurable_ops.conv2d` rather than `tf.contrib.layers.conv2d`. + `configurable_ops.conv2d` rather than `contrib_layers.conv2d`. Note: This function could be unsafe as it depends on aliases defined in a possibly external module. In addition, a function in that module that calls @@ -452,7 +455,7 @@ def build_layer(inputs): ``` def build_layer_not_affected(inputs): - return tf.contrib.layers.conv2d(inputs, 64, 3, scope='bad') + return contrib_layers.conv2d(inputs, 64, 3, scope='bad') ``` Args: