Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 26, 2025
1 parent 9ad9fcc commit 5eb8086
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 37 deletions.
75 changes: 63 additions & 12 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -731,29 +731,80 @@ pixels or states etc).
Forward and inverse transforms
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Transforms also have an ``inv`` method that is called before
the action is applied in reverse order over the composed transform chain:
this allows to apply transforms to data in the environment before the action is taken
in the environment. The keys to be included in this inverse transform are passed through the
``"in_keys_inv"`` keyword argument:
Transforms also have an :meth:`~torchrl.envs.Transform.inv` method that is called before the action is applied in reverse
order over the composed transform chain. This allows applying transforms to data in the environment before the action is
taken in the environment. The keys to be included in this inverse transform are passed through the `"in_keys_inv"`
keyword argument, and the out-keys default to these values in most cases:

.. code-block::
:caption: Inverse transform
>>> env.append_transform(DoubleToFloat(in_keys_inv=["action"])) # will map the action from float32 to float64 before calling the base_env.step
The way ``in_keys`` relates to ``in_keys_inv`` can be understood by considering the base environment as the "inner" part
of the transform. In constrast, the user inputs and outputs to and from the transform are to be considered as the
outside world. The following figure shows what this means in practice for the :class:`~torchrl.envs.RenameTransform`
class: the input ``TensorDict`` of the ``step`` function must have the ``out_keys_inv`` listed in its entries as they
are part of the outside world. The transform changes these names to make them match the names of the inner, base
environment using the ``in_keys_inv``. The inverse process is executed with the output tensordict, where the ``in_keys``
are mapped to the corresponding ``out_keys``.
The following paragraphs detail how one can think about what is to be considered `in_` or `out_` features.

Understanding Transform Keys
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In transforms, `in_keys` and `out_keys` define the interaction between the base environment and the outside world
(e.g., your policy):

- `in_keys` refers to the base environment's perspective (inner = `base_env` of the
:class:`~torchrl.envs.TransformedEnv`).
- `out_keys` refers to the outside world (outer = `policy`, `agent`, etc.).

For example, with `in_keys=["obs"]` and `out_keys=["obs_standardized"]`, the policy will "see" a standardized
observation, while the base environment outputs a regular observation.

Similarly, for inverse keys:

- `in_keys_inv` refers to entries as seen by the base environment.
- `out_keys_inv` refers to entries as seen or produced by the policy.

The following figure illustrates this concept for the :class:`~torchrl.envs.RenameTransform` class: the input
`TensorDict` of the `step` function must include the `out_keys_inv` as they are part of the outside world. The
transform changes these names to match the names of the inner, base environment using the `in_keys_inv`.
The inverse process is executed with the output tensordict, where the `in_keys` are mapped to the corresponding
`out_keys`.

.. figure:: /_static/img/rename_transform.png

Rename transform logic

Transforming Tensors and Specs
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

When transforming actual tensors (coming from the policy), the process is schematically represented as:

>>> for t in reversed(self.transform):
... td = t.inv(td)

This starts with the outermost transform to the innermost transform, ensuring the action value exposed to the policy
is properly transformed.

For transforming the action spec, the process should go from innermost to outermost (similar to observation specs):

>>> def transform_action_spec(self, action_spec):
... for t in self.transform:
... action_spec = t.transform_action_spec(action_spec)
... return action_spec

A pseudocode for a single transform_action_spec could be:

>>> def transform_action_spec(self, action_spec):
... return spec_from_random_values(self._apply_transform(action_spec.rand()))

This approach ensures that the "outside" spec is inferred from the "inside" spec. Note that we did not call
`_inv_apply_transform` but `_apply_transform` on purpose!

Exposing Specs to the Outside World
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

`TransformedEnv` will expose the specs corresponding to the `out_keys_inv` for actions and states.
For example, with :class:`~torchrl.envs.ActionDiscretizer`, the environment's action (e.g., `"action"`) is a float-valued
tensor that should not be generated when using :meth:`~torchrl.envs.EnvBase.rand_action` with the transformed
environment. Instead, `"action_discrete"` should be generated, and its continuous counterpart obtained from the
transform. Therefore, the user should see the `"action_discrete"` entry being exposed, but not `"action"`.


Cloning transforms
Expand Down
25 changes: 11 additions & 14 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,8 @@ def test_transform_rb(self, rbclass):
ClipTransform(
in_keys=["observation", "reward"],
out_keys=["obs_clip", "reward_clip"],
in_keys_inv=["input"],
out_keys_inv=["input_clip"],
in_keys_inv=["input_clip"],
out_keys_inv=["input"],
low=-0.1,
high=0.1,
)
Expand Down Expand Up @@ -2509,20 +2509,17 @@ def test_transform_rb(self, rbclass):
assert ("next", "observation") in td.keys(True)

def test_transform_inverse(self):
return
env = CountingEnv()
env = env.append_transform(
Hash(
in_keys=[],
out_keys=[],
in_keys_inv=["action"],
out_keys_inv=["action_hash"],
with pytest.raises(TypeError):
env = env.append_transform(
Hash(
in_keys=[],
out_keys=[],
in_keys_inv=["action"],
out_keys_inv=["action_hash"],
)
)
)
assert "action_hash" in env.action_keys
r = env.rollout(3)
env.check_env_specs()
assert "action_hash" in r
assert isinstance(r[0]["action_hash"], torch.Tensor)


class TestTokenizer(TransformBase):
Expand Down
19 changes: 8 additions & 11 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,17 +146,23 @@ def new_fun(self, input_spec):
in_keys_inv = self.in_keys_inv
out_keys_inv = self.out_keys_inv
for in_key, out_key in _zip_strict(in_keys_inv, out_keys_inv):
in_key = unravel_key(in_key)
out_key = unravel_key(out_key)
# if in_key != out_key:
# # we only change the input spec if the key is the same
# continue
if in_key in action_spec.keys(True, True):
action_spec[out_key] = function(self, action_spec[in_key].clone())
if in_key != out_key:
del action_spec[in_key]
elif in_key in state_spec.keys(True, True):
state_spec[out_key] = function(self, state_spec[in_key].clone())
if in_key != out_key:
del state_spec[in_key]
elif in_key in input_spec.keys(False, True):
input_spec[out_key] = function(self, input_spec[in_key].clone())
# else:
# raise RuntimeError(f"Couldn't find key '{in_key}' in input spec {input_spec}")
if in_key != out_key:
del input_spec[in_key]
if skip:
return input_spec
return Composite(
Expand Down Expand Up @@ -4857,19 +4863,14 @@ class Hash(UnaryTransform):
[torchrl][INFO] check_env_specs succeeded!
"""

_repertoire: Dict[Tuple[int], Any]

def __init__(
self,
in_keys: Sequence[NestedKey],
out_keys: Sequence[NestedKey],
in_keys_inv: Sequence[NestedKey] | None = None,
out_keys_inv: Sequence[NestedKey] | None = None,
*,
hash_fn: Callable = None,
seed: Any | None = None,
use_raw_nontensor: bool = False,
repertoire: Dict[Tuple[int], Any] | None = None,
):
if hash_fn is None:
hash_fn = Hash.reproducible_hash
Expand All @@ -4879,13 +4880,9 @@ def __init__(
super().__init__(
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
out_keys_inv=out_keys_inv,
fn=self.call_hash_fn,
use_raw_nontensor=use_raw_nontensor,
)
if in_keys_inv is not None:
self._repertoire = repertoire if repertoire is not None else {}

def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
inputs = tensordict.select(*self.in_keys_inv).detach().cpu()
Expand Down

0 comments on commit 5eb8086

Please sign in to comment.