diff --git a/ding/torch_utils/optimizer_helper.py b/ding/torch_utils/optimizer_helper.py index b4fc89e208..ea3f7b0a73 100644 --- a/ding/torch_utils/optimizer_helper.py +++ b/ding/torch_utils/optimizer_helper.py @@ -495,7 +495,7 @@ def _state_init(self, p, momentum, centered): # wait torch upgrad to 1.4, 1.3.1 didn't support memory format state['step'] = 0 else: state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ - if self.defaults['capturable'] else torch.tensor(0.) + if ('capturable' in self.defaults and self.defaults['capturable']) else torch.tensor(0.) state['thre_square_avg'] = torch.zeros_like(p.data, device=p.data.device) state['square_avg'] = torch.zeros_like(p.data, device=p.data.device) if momentum: