diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index f2d38c66121..cacb1bf3607 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -6673,8 +6673,8 @@ def _update(self, key, value, N) -> torch.Tensor: ) mean = _sum / _count - std = (_ssq / _count - mean.pow(2)).clamp_min(self.eps).sqrt() - return (value - mean) / std.clamp_min(self.eps) + std = (_ssq / _count - mean.pow(2)).clamp_min(self.eps**2).sqrt() + return (value - mean) / std def to_observation_norm(self) -> Compose | ObservationNorm: """Converts VecNorm into an ObservationNorm class that can be used at inference time. @@ -6725,7 +6725,9 @@ def _get_loc_scale(self, loc_only=False, scale_only=False): _ssq = self._td.get(_append_last(key, "_ssq")) _count = self._td.get(_append_last(key, "_count")) loc[key] = _sum / _count - scale[key] = (_ssq / _count - loc[key].pow(2)).clamp_min(self.eps).sqrt() + scale[key] = ( + (_ssq / _count - loc[key].pow(2)).clamp_min(self.eps**2).sqrt() + ) if not scale_only: loc = TensorDict(loc) else: