diff --git a/pruning_engine.py b/pruning_engine.py index d1497cc..1e5f18f 100755 --- a/pruning_engine.py +++ b/pruning_engine.py @@ -900,7 +900,7 @@ def set_momentum_zero_sgd(self, optimizer=None): if not self.prune_layers[layer]: continue for unit in range(len(self.pruning_gates[layer])): - if not self.pruning_gates[layer][unit]: + if self.pruning_gates[layer][unit]: continue if 'momentum_buffer' in optimizer.state[self.parameters[layer]].keys(): optimizer.state[self.parameters[layer]]['momentum_buffer'][unit] *= 0.0