diff --git a/alf/algorithms/algorithm.py b/alf/algorithms/algorithm.py index bf9ed5406..9362e61de 100644 --- a/alf/algorithms/algorithm.py +++ b/alf/algorithms/algorithm.py @@ -1050,7 +1050,8 @@ def _save_to_state_dict(self, destination, prefix, visited=None): visited.add(param) destination[prefix + name] = param.detach() for name, buf in self._buffers.items(): - if buf is not None and buf not in visited: + if (buf is not None and buf not in visited + and name not in self._non_persistent_buffers_set): visited.add(buf) destination[prefix + name] = buf.detach() @@ -1105,7 +1106,11 @@ def _load_from_state_dict(self, local_name_params = itertools.chain(self._parameters.items(), self._buffers.items()) - local_state = {k: v for k, v in local_name_params if v is not None} + local_state = { + k: v + for k, v in local_name_params + if v is not None and k not in self._non_persistent_buffers_set + } for name, param in local_state.items(): if param is not None and param not in visited: