Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jun 26, 2024
1 parent 279bd60 commit d864cab
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2454,6 +2454,9 @@ class ObservationNorm(ObservationTransform):
as it is done for standardization. Default is `False`.
eps (float, optional): epsilon increment for the scale in the ``standard_normal`` case.
Defaults to ``1e-6`` if not recoverable directly from the scale dtype.
Examples:
>>> torch.set_default_tensor_type(torch.DoubleTensor)
>>> r = torch.randn(100, 3)*torch.randn(3) + torch.randn(3)
Expand Down Expand Up @@ -2495,6 +2498,7 @@ def __init__(
in_keys_inv: Sequence[NestedKey] | None = None,
out_keys_inv: Sequence[NestedKey] | None = None,
standard_normal: bool = False,
eps: float | None = None,
):
if in_keys is None:
raise RuntimeError(
Expand All @@ -2517,7 +2521,13 @@ def __init__(
if not isinstance(standard_normal, torch.Tensor):
standard_normal = torch.as_tensor(standard_normal)
self.register_buffer("standard_normal", standard_normal)
self.eps = 1e-6
self.eps = (
eps
if eps is not None
else torch.finfo(scale.dtype).eps
if isinstance(scale, torch.Tensor) and scale.dtype.is_floating_point
else 1e-6
)

if loc is not None and not isinstance(loc, torch.Tensor):
loc = torch.tensor(loc, dtype=torch.get_default_dtype())
Expand Down Expand Up @@ -2659,7 +2669,7 @@ def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor:
if self.standard_normal:
loc = self.loc
scale = self.scale
return (obs - loc) / (scale + torch.finfo(scale.dtype).eps)
return (obs - loc) / scale
else:
scale = self.scale
loc = self.loc
Expand Down

0 comments on commit d864cab

Please sign in to comment.