From 580ea65b8a439304d432b167bf6f64ab1bd04cbf Mon Sep 17 00:00:00 2001 From: PaParaZz1 Date: Fri, 13 Dec 2024 16:47:19 +0800 Subject: [PATCH] fix(nyz): fix rmsprop bug in torch 1.13.1 --- ding/torch_utils/optimizer_helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ding/torch_utils/optimizer_helper.py b/ding/torch_utils/optimizer_helper.py index b4fc89e208..ea3f7b0a73 100644 --- a/ding/torch_utils/optimizer_helper.py +++ b/ding/torch_utils/optimizer_helper.py @@ -495,7 +495,7 @@ def _state_init(self, p, momentum, centered): # wait torch upgrad to 1.4, 1.3.1 didn't support memory format state['step'] = 0 else: state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ - if self.defaults['capturable'] else torch.tensor(0.) + if ('capturable' in self.defaults and self.defaults['capturable']) else torch.tensor(0.) state['thre_square_avg'] = torch.zeros_like(p.data, device=p.data.device) state['square_avg'] = torch.zeros_like(p.data, device=p.data.device) if momentum: