2626from tensorflow .python .ops import control_flow_ops
2727from tensorflow .python .ops import math_ops
2828from tensorflow .python .ops import nn_ops
29+ from tensorflow .python .ops import state_ops
2930from tensorflow .python .ops import summary_ops_v2
3031from tensorflow .python .ops import variables
3132from tensorflow .python .summary import summary as summary_ops_v1
3233from tensorflow_model_optimization .python .core .sparsity .keras import pruning_utils
3334
34-
3535class Pruning (object ):
3636 """Implementation of magnitude-based weight pruning."""
3737
@@ -55,15 +55,9 @@ def __init__(self, training_step_fn, pruning_vars, pruning_schedule,
5555 self ._block_pooling_type = block_pooling_type
5656 self ._validate_block ()
5757
58- # List of tensorflow assignments ops for new masks and thresholds
59- self ._assign_ops = []
60-
6158 # Training step
6259 self ._step_fn = training_step_fn
6360
64- # List of tensorflow assignment ops for the weights
65- self ._weight_assign_ops = []
66-
6761 self ._validate_block ()
6862
6963 def _validate_block (self ):
@@ -73,9 +67,6 @@ def _validate_block(self):
7367 raise ValueError ('Block Sparsity can only be used for layers which '
7468 'have 2-dimensional weights.' )
7569
76- def get_weight_sparsity (self ):
77- return [math_ops .reduce_mean (weight ) for weight , _ , _ in self ._pruning_vars ]
78-
7970 def _update_mask (self , weights ):
8071 """Updates the mask for a given weight tensor.
8172
@@ -161,69 +152,99 @@ def _maybe_update_block_mask(self, weights):
161152 return new_threshold , array_ops .reshape (sliced_mask ,
162153 array_ops .shape (weights ))
163154
164- def _get_assign_ops (self ):
165- """Gather the assign ops for assigning updated masks and threshold."""
166- # Make sure the assignment ops have not already been added to the list
167- if self ._assign_ops :
168- raise ValueError (
169- 'Assign op list not empty. _get_assign_ops() called twice?' )
170-
171- for weight , mask , threshold in self ._pruning_vars :
172- is_partitioned = isinstance (weight , variables .PartitionedVariable )
173- weight_as_tensor = weight
174- if is_partitioned :
175- weight_as_tensor = weight .as_tensor ()
176-
177- new_threshold , new_mask = self ._maybe_update_block_mask (weight_as_tensor )
178- self ._assign_ops .append (
179- pruning_utils .variable_assign (threshold , new_threshold ))
180-
181- self ._assign_ops .append (
182- pruning_utils .partitioned_variable_assign (mask , new_mask )
183- if is_partitioned else pruning_utils .variable_assign (mask , new_mask ))
184-
185155 def _get_weight_assign_ops (self ):
186156 """Gather the assign ops for assigning weights<=weights*mask."""
187- if self ._weight_assign_ops :
188- raise ValueError (
189- 'Assign op list not empty. _get_weight_assign_ops() called twice?' )
190-
191- for weight , mask , _ in self ._pruning_vars :
192- is_partitioned = isinstance (weight , variables .PartitionedVariable )
193- masked_weight = math_ops .multiply (weight , mask )
194- self ._weight_assign_ops .append (
195- pruning_utils .partitioned_variable_assign (weight , masked_weight )
196- if is_partitioned else pruning_utils
197- .variable_assign (weight , masked_weight ))
198-
199- def weight_mask_op (self ):
200- if tf .executing_eagerly () or not self ._weight_assign_ops :
201- self ._weight_assign_ops = []
202- self ._get_weight_assign_ops ()
203-
204- with ops .control_dependencies (self ._weight_assign_ops ):
205- return control_flow_ops .no_op ('mask_weights' )
206157
207- def mask_update_op (self ):
208- self ._assign_ops = []
209- self ._get_assign_ops ()
158+ def update_fn (distribution , values_and_vars ):
159+ # TODO(yunluli): Need this ReduceOp because the weight is created by the
160+ # layer wrapped, so we don't have control of its aggregation policy. May
161+ # be able to optimize this when distribution strategy supports easier
162+ # update to mirrored variables in replica context.
163+ reduced_values = distribution .extended .batch_reduce_to (
164+ tf .distribute .ReduceOp .MEAN , values_and_vars )
165+ var_list = [v for _ , v in values_and_vars ]
166+ values_and_vars = zip (reduced_values , var_list )
167+
168+ def update_var (variable , reduced_value ):
169+ return state_ops .assign (variable , reduced_value )
170+
171+ update_ops = []
172+ for value , var in values_and_vars :
173+ update_ops .append (
174+ distribution .extended .update (var , update_var , args = (value ,)))
175+
176+ return control_flow_ops .group (update_ops )
177+
178+ assign_ops = []
179+
180+ if tf .distribute .get_replica_context ():
181+ values_and_vars = []
182+ for weight , mask , _ in self ._pruning_vars :
183+ masked_weight = math_ops .multiply (weight , mask )
184+ values_and_vars .append ((masked_weight , weight ))
185+ assign_ops .append (tf .distribute .get_replica_context ().merge_call (
186+ update_fn , args = (values_and_vars ,)))
187+ else :
188+ for weight , mask , _ in self ._pruning_vars :
189+ masked_weight = math_ops .multiply (weight , mask )
190+ assign_ops .append (state_ops .assign (weight , masked_weight ))
191+
192+ return assign_ops
210193
211- with ops . control_dependencies (self . _assign_ops ):
212- return control_flow_ops .no_op ( 'mask_update' )
194+ def weight_mask_op (self ):
195+ return control_flow_ops .group ( self . _get_weight_assign_ops () )
213196
214197 def conditional_mask_update (self ):
215198 """Returns an op to updates masks as per the pruning schedule."""
216199
217200 def maybe_update_masks ():
218201 return self ._pruning_schedule (self ._step_fn ())[0 ]
219202
220- def mask_update_op ():
221- return self .mask_update_op ()
222-
223- def no_op ():
203+ def no_update ():
224204 return control_flow_ops .no_op ()
225205
226- return control_flow_ops .cond (maybe_update_masks (), mask_update_op , no_op )
206+ def mask_update ():
207+ """Updates mask without distribution strategy."""
208+
209+ def update ():
210+ assign_ops = []
211+
212+ for weight , mask , threshold in self ._pruning_vars :
213+ new_threshold , new_mask = self ._maybe_update_block_mask (weight )
214+ assign_ops .append (state_ops .assign (threshold , new_threshold ))
215+ assign_ops .append (state_ops .assign (mask , new_mask ))
216+
217+ return control_flow_ops .group (assign_ops )
218+
219+ return control_flow_ops .cond (maybe_update_masks (), update , no_update )
220+
221+ def mask_update_distributed (distribution ):
222+ """Updates mask with distribution strategy."""
223+
224+ def update (var , value ):
225+ return state_ops .assign (var , value )
226+
227+ def update_distributed ():
228+ """Gather distributed update ops."""
229+ assign_ops = []
230+
231+ for weight , mask , threshold in self ._pruning_vars :
232+ new_threshold , new_mask = self ._maybe_update_block_mask (weight )
233+ assign_ops .append (
234+ distribution .extended .update (mask , update , (new_mask ,)))
235+ assign_ops .append (
236+ distribution .extended .update (threshold , update , (new_threshold ,)))
237+
238+ return control_flow_ops .group (assign_ops )
239+
240+ return control_flow_ops .cond (maybe_update_masks (), update_distributed ,
241+ no_update )
242+
243+ if tf .distribute .get_replica_context ():
244+ return tf .distribute .get_replica_context ().merge_call (
245+ mask_update_distributed )
246+ else :
247+ return mask_update ()
227248
228249 def add_pruning_summaries (self ):
229250 """Adds summaries of weight sparsities and thresholds."""
0 commit comments