From 4cdebb868824b1ffccab4cae22be8ccd23c192cc Mon Sep 17 00:00:00 2001 From: Erica Lin Date: Fri, 21 Mar 2025 22:37:04 -0400 Subject: [PATCH 1/3] bugfix --- torchrl/envs/transforms/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index f2d38c66121..53aa7b0c1ae 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -6673,8 +6673,8 @@ def _update(self, key, value, N) -> torch.Tensor: ) mean = _sum / _count - std = (_ssq / _count - mean.pow(2)).clamp_min(self.eps).sqrt() - return (value - mean) / std.clamp_min(self.eps) + std = (_ssq / _count - mean.pow(2)).sqrt().clamp_min(self.eps) + return (value - mean) / std def to_observation_norm(self) -> Compose | ObservationNorm: """Converts VecNorm into an ObservationNorm class that can be used at inference time. From 43d3cf5e7c015adc933b43c600f2aab7e9b4f9a4 Mon Sep 17 00:00:00 2001 From: Erica Lin Date: Fri, 21 Mar 2025 22:44:54 -0400 Subject: [PATCH 2/3] fix other one --- torchrl/envs/transforms/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 53aa7b0c1ae..8a6b647ba15 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -6725,7 +6725,7 @@ def _get_loc_scale(self, loc_only=False, scale_only=False): _ssq = self._td.get(_append_last(key, "_ssq")) _count = self._td.get(_append_last(key, "_count")) loc[key] = _sum / _count - scale[key] = (_ssq / _count - loc[key].pow(2)).clamp_min(self.eps).sqrt() + scale[key] = (_ssq / _count - loc[key].pow(2)).sqrt().clamp_min(self.eps) if not scale_only: loc = TensorDict(loc) else: From 3636bf92b9c53b00b71dc95c73d7b0785ac3b5c6 Mon Sep 17 00:00:00 2001 From: Erica Lin Date: Mon, 24 Mar 2025 11:19:20 -0400 Subject: [PATCH 3/3] change order for stability --- torchrl/envs/transforms/transforms.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 8a6b647ba15..cacb1bf3607 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -6673,7 +6673,7 @@ def _update(self, key, value, N) -> torch.Tensor: ) mean = _sum / _count - std = (_ssq / _count - mean.pow(2)).sqrt().clamp_min(self.eps) + std = (_ssq / _count - mean.pow(2)).clamp_min(self.eps**2).sqrt() return (value - mean) / std def to_observation_norm(self) -> Compose | ObservationNorm: @@ -6725,7 +6725,9 @@ def _get_loc_scale(self, loc_only=False, scale_only=False): _ssq = self._td.get(_append_last(key, "_ssq")) _count = self._td.get(_append_last(key, "_count")) loc[key] = _sum / _count - scale[key] = (_ssq / _count - loc[key].pow(2)).sqrt().clamp_min(self.eps) + scale[key] = ( + (_ssq / _count - loc[key].pow(2)).clamp_min(self.eps**2).sqrt() + ) if not scale_only: loc = TensorDict(loc) else: