Skip to content

Commit dbc8e2e

Browse files
committed
[Feature] Custom conversion tool for gym specs
ghstack-source-id: d38bb02 Pull Request resolved: #2726
1 parent 5fd5092 commit dbc8e2e

File tree

5 files changed

+330
-119
lines changed

5 files changed

+330
-119
lines changed

docs/source/reference/envs.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -1117,7 +1117,7 @@ in the relevant functions:
11171117
>>> print(env2._env.env.env)
11181118
<gym.envs.classic_control.pendulum.PendulumEnv at 0x1629916a0>
11191119

1120-
We can see that the two libraries modify the value returned by :func:`~.gym.gym_backend()`
1120+
We can see that the two libraries modify the value returned by :func:`~torchrl.envs.gym.gym_backend()`
11211121
which can be further used to indicate which library needs to be used for
11221122
the current computation. :class:`~.gym.set_gym_backend` is also a decorator:
11231123
we can use it to tell to a specific function what gym backend needs to be used
@@ -1188,3 +1188,4 @@ the following function will return ``1`` when queried:
11881188
VmasWrapper
11891189
gym_backend
11901190
set_gym_backend
1191+
register_gym_spec_conversion

test/test_libs.py

+35
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
Composite,
8080
MultiCategorical,
8181
MultiOneHot,
82+
NonTensor,
8283
OneHot,
8384
ReplayBuffer,
8485
ReplayBufferEnsemble,
@@ -119,6 +120,7 @@
119120
GymWrapper,
120121
MOGymEnv,
121122
MOGymWrapper,
123+
register_gym_spec_conversion,
122124
set_gym_backend,
123125
)
124126
from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv
@@ -337,6 +339,39 @@ def test_gym_spec_cast(self, categorical):
337339
assert spec == recon
338340
assert recon.shape == spec.shape
339341

342+
def test_gym_new_spec_reg(self):
343+
Space = gym_backend("spaces").Space
344+
345+
class MySpaceParent(Space):
346+
...
347+
348+
s_parent = MySpaceParent()
349+
350+
class MySpaceChild(MySpaceParent):
351+
...
352+
353+
# We intentionally register first the child then the parent
354+
@register_gym_spec_conversion(MySpaceChild)
355+
def convert_myspace_child(spec, **kwargs):
356+
return NonTensor((), example_data="child")
357+
358+
@register_gym_spec_conversion(MySpaceParent)
359+
def convert_myspace_parent(spec, **kwargs):
360+
return NonTensor((), example_data="parent")
361+
362+
s_child = MySpaceChild()
363+
assert _gym_to_torchrl_spec_transform(s_parent).example_data == "parent"
364+
assert _gym_to_torchrl_spec_transform(s_child).example_data == "child"
365+
366+
class NoConversionSpace(Space):
367+
...
368+
369+
s_no_conv = NoConversionSpace()
370+
with pytest.raises(
371+
KeyError, match="No conversion tool could be found with the gym space"
372+
):
373+
_gym_to_torchrl_spec_transform(s_no_conv)
374+
340375
@pytest.mark.parametrize("order", ["tuple_seq"])
341376
@implement_for("gym")
342377
def test_gym_spec_cast_tuple_sequential(self, order):

torchrl/envs/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
OpenSpielWrapper,
3333
PettingZooEnv,
3434
PettingZooWrapper,
35+
register_gym_spec_conversion,
3536
RoboHiveEnv,
3637
set_gym_backend,
3738
SMACv2Env,

torchrl/envs/libs/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
GymWrapper,
1313
MOGymEnv,
1414
MOGymWrapper,
15+
register_gym_spec_conversion,
1516
set_gym_backend,
1617
)
1718
from .habitat import HabitatEnv

0 commit comments

Comments
 (0)