@@ -215,8 +215,8 @@ def decode(self,
215215 def _validate_and_expand_encode_input (self , x ):
216216 """Validates the input to encode and modifies it if necessary."""
217217 if x .shape .ndims not in [1 , 2 ]:
218- raise ValueError (
219- 'Number of dimensions must be 1 or 2. Shape of x: %s' % x .shape )
218+ raise ValueError ('Number of dimensions must be 1 or 2. Shape of x: %s' %
219+ x .shape )
220220 if x .shape .ndims == 1 :
221221 # The input to the fast_walsh_hadamard_transform must have 2 dimensions.
222222 x = tf .expand_dims (x , 0 )
@@ -262,7 +262,7 @@ class UniformQuantizationEncodingStage(encoding_stage.EncodingStageInterface):
262262 # otherwise be numerically unstable for float32 values.
263263 _ALLOWED_BITS_ARG = list (range (1 , 17 ))
264264
265- def __init__ (self , bits = 8 , min_max = None , stochastic = True ):
265+ def __init__ (self , bits = 8 , min_max = None , stochastic = True , ** kwargs ):
266266 """Initializer for the UniformQuantizationEncodingStage.
267267
268268 Args:
@@ -275,6 +275,7 @@ def __init__(self, bits=8, min_max=None, stochastic=True):
275275 stochastic: A Python bool, whether to use stochastic or deterministic
276276 rounding. If `True`, the encoding is randomized and on expectation
277277 unbiased. If `False`, the encoding is deterministic.
278+ **kwargs: Keyword arguments.
278279
279280 Raises:
280281 ValueError: The inputs do not satisfy the above constraints.
@@ -300,6 +301,8 @@ def __init__(self, bits=8, min_max=None, stochastic=True):
300301 if not isinstance (stochastic , bool ):
301302 raise TypeError ('The stochastic argument must be a bool.' )
302303 self ._stochastic = stochastic
304+ self ._force_random_op_after_clipping = kwargs .get (
305+ 'reduce_memory_use_by_forcing_random_op_after_clipping' , False )
303306
304307 @property
305308 def name (self ):
@@ -350,8 +353,18 @@ def encode(self, x, encode_params):
350353 x = tf .compat .v1 .div_no_nan (x - min_x , max_x - min_x ) * max_value
351354 if self ._stochastic : # Randomized rounding.
352355 floored_x = tf .floor (x )
353- bernoulli = tf .random .uniform (tf .shape (x ), dtype = x .dtype )
354- bernoulli = bernoulli < (x - floored_x )
356+ residuals_x = x - floored_x
357+ # Add graph dependencies to tensor `x` to ensure that the randomized
358+ # rounding variables are not created before `x` is scaled above. This
359+ # prevents TF from preallocating the tensor before it will actually be
360+ # used, reducing memory pressure (especially important for mobile
361+ # deployments).
362+ if self ._force_random_op_after_clipping :
363+ with tf .control_dependencies ([x ]):
364+ bernoulli = tf .random .uniform (tf .shape (x ), dtype = x .dtype )
365+ else :
366+ bernoulli = tf .random .uniform (tf .shape (x ), dtype = x .dtype )
367+ bernoulli = bernoulli < residuals_x
355368 quantized_x = floored_x + tf .cast (bernoulli , x .dtype )
356369 else : # Deterministic rounding.
357370 quantized_x = tf .round (x )
0 commit comments