Skip to content

Commit

Permalink
[Feature] MultiAction transform
Browse files Browse the repository at this point in the history
ghstack-source-id: f488730e7a03f3aa5f41dbeccc7fe28c0a2db8c5
Pull Request resolved: #2779
  • Loading branch information
vmoens committed Feb 12, 2025
1 parent eb6e105 commit be8ca50
Show file tree
Hide file tree
Showing 9 changed files with 431 additions and 45 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,7 @@ to be able to create this other composition:
InitTracker
KLRewardTransform
LineariseReward
MultiAction
NoopResetEnv
ObservationNorm
ObservationTransform
Expand Down
109 changes: 94 additions & 15 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,13 +358,11 @@ def _step(self, tensordict):
leading_batch_size = tensordict.shape if tensordict is not None else []
self.counter += 1
# We use tensordict.batch_size instead of self.batch_size since this method will also be used by MockBatchedUnLockedEnv
n = (
torch.full(
[*leading_batch_size, *self.observation_spec["observation"].shape],
self.counter,
)
.to(self.device)
.to(torch.get_default_dtype())
n = torch.full(
[*leading_batch_size, *self.observation_spec["observation"].shape],
self.counter,
device=self.device,
dtype=torch.get_default_dtype(),
)
done = self.counter >= self.max_val
done = torch.full(
Expand All @@ -391,13 +389,11 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
else:
leading_batch_size = tensordict.shape if tensordict is not None else []

n = (
torch.full(
[*leading_batch_size, *self.observation_spec["observation"].shape],
self.counter,
)
.to(self.device)
.to(torch.get_default_dtype())
n = torch.full(
[*leading_batch_size, *self.observation_spec["observation"].shape],
self.counter,
device=self.device,
dtype=torch.get_default_dtype(),
)
done = self.counter >= self.max_val
done = torch.full(
Expand All @@ -417,7 +413,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:


class MockBatchedUnLockedEnv(MockBatchedLockedEnv):
"""Mocks an env whose batch_size does not define the size of the output tensordict.
"""Mocks an env which batch_size does not define the size of the output tensordict.
The size of the output tensordict is defined by the input tensordict itself.
Expand All @@ -433,6 +429,89 @@ def __new__(cls, *args, **kwargs):
return super().__new__(cls, *args, _batch_locked=False, **kwargs)


class StateLessCountingEnv(EnvBase):
def __init__(self):
self.observation_spec = Composite(
count=Unbounded((1,), dtype=torch.int32),
max_count=Unbounded((1,), dtype=torch.int32),
)
self.full_action_spec = Composite(
action=Unbounded((1,), dtype=torch.int32),
)
self.full_done_spec = Composite(
done=Unbounded((1,), dtype=torch.bool),
termindated=Unbounded((1,), dtype=torch.bool),
truncated=Unbounded((1,), dtype=torch.bool),
)
self.reward_spec = Composite(reward=Unbounded((1,), dtype=torch.float))
super().__init__()
self._batch_locked = False

def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:

max_count = None
count = None
if tensordict is not None:
max_count = tensordict.get("max_count")
count = tensordict.get("count")
tensordict = TensorDict(
batch_size=tensordict.batch_size, device=tensordict.device
)
shape = tensordict.batch_size
else:
shape = ()
tensordict = TensorDict(device=self.device)
tensordict.update(
TensorDict(
count=torch.zeros(
(
*shape,
1,
),
dtype=torch.int32,
)
if count is None
else count,
max_count=torch.randint(
10,
20,
(
*shape,
1,
),
dtype=torch.int32,
)
if max_count is None
else max_count,
**self.done_spec.zero(shape),
**self.full_reward_spec.zero(shape),
)
)
return tensordict

def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
action = tensordict["action"]
count = tensordict["count"] + action
terminated = done = count >= tensordict["max_count"]
truncated = torch.zeros_like(done)
return TensorDict(
count=count,
max_count=tensordict["max_count"],
done=done,
terminated=terminated,
truncated=truncated,
reward=self.reward_spec.zero(tensordict.shape),
batch_size=tensordict.batch_size,
device=tensordict.device,
)

def _set_seed(self, seed: Optional[int]):
...


class DiscreteActionVecMockEnv(_MockEnv):
@classmethod
def __new__(
Expand Down
195 changes: 195 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
Hash,
InitTracker,
LineariseRewards,
MultiAction,
MultiStepTransform,
NoopResetEnv,
ObservationNorm,
Expand Down Expand Up @@ -156,6 +157,7 @@
MultiKeyCountingEnv,
MultiKeyCountingEnvPolicy,
NestedCountingEnv,
StateLessCountingEnv,
)
else:
from _utils_internal import ( # noqa
Expand Down Expand Up @@ -184,6 +186,7 @@
MultiKeyCountingEnv,
MultiKeyCountingEnvPolicy,
NestedCountingEnv,
StateLessCountingEnv,
)

IS_WIN = platform == "win32"
Expand Down Expand Up @@ -13664,6 +13667,198 @@ def test_transform_inverse(self):
return


class TestMultiAction(TransformBase):
@pytest.mark.parametrize("bwad", [False, True])
def test_single_trans_env_check(self, bwad):
base_env = CountingEnv(max_steps=10)
env = TransformedEnv(
base_env,
Compose(
StepCounter(step_count_key="before_count"),
MultiAction(),
StepCounter(step_count_key="after_count"),
),
)
env.check_env_specs()

def policy(td):
# 3 action per step
td["action"] = torch.ones(3, 1)
return td

r = env.rollout(10, policy)
assert r["action"].shape == (4, 3, 1)
assert r["next", "done"].any()
assert r["next", "done"][-1].all()
assert (r["observation"][0] == 0).all()
assert (r["next", "observation"][0] == 3).all()
assert (r["next", "observation"][-1] == 11).all()
# Check that before_count is incremented but not after_count
assert r["before_count"].max() == 9
assert r["after_count"].max() == 3

def _batched_trans_env_check(self, cls, bwad, within):
if within:

def make_env(i):
base_env = CountingEnv(max_steps=i)
env = TransformedEnv(
base_env,
Compose(
StepCounter(step_count_key="before_count"),
MultiAction(),
StepCounter(step_count_key="after_count"),
),
)
return env

env = cls(2, [partial(make_env, i=10), partial(make_env, i=20)])
else:
base_env = cls(
2,
[
partial(CountingEnv, max_steps=10),
partial(CountingEnv, max_steps=20),
],
)
env = TransformedEnv(
base_env,
Compose(
StepCounter(step_count_key="before_count"),
MultiAction(),
StepCounter(step_count_key="after_count"),
),
)

try:
env.check_env_specs()

def policy(td):
# 3 action per step
td["action"] = torch.ones(2, 3, 1)
return td

r = env.rollout(10, policy, break_when_any_done=bwad)
# r0
r0 = r[0]
if bwad:
assert r["action"].shape == (2, 4, 3, 1)
else:
assert r["action"].shape == (2, 10, 3, 1)
assert r0["next", "done"].any()
if bwad:
assert r0["next", "done"][-1].all()
else:
assert r0["next", "done"].sum() == 2

assert (r0["observation"][0] == 0).all()
assert (r0["next", "observation"][0] == 3).all()
if bwad:
assert (r0["next", "observation"][-1] == 11).all()
else:
assert (r0["next", "observation"][-1] == 6).all(), r0[
"next", "observation"
]
# Check that before_count is incremented but not after_count
assert r0["before_count"].max() == 9
assert r0["after_count"].max() == 3
# r1
r1 = r[1]
if bwad:
assert not r1["next", "done"].any()
else:
assert r1["next", "done"].any()
assert r1["next", "done"].sum() == 1
assert (r1["observation"][0] == 0).all()
assert (r1["next", "observation"][0] == 3).all()
if bwad:
# r0 cannot go above 11 but r1 can - so we see a 12 because one more step was done
assert (r1["next", "observation"][-1] == 12).all()
else:
assert (r1["next", "observation"][-1] == 9).all()
# Check that before_count is incremented but not after_count
if bwad:
assert r1["before_count"].max() == 9
assert r1["after_count"].max() == 3
else:
assert r1["before_count"].max() == 18
assert r1["after_count"].max() == 6
finally:
env.close()

@pytest.mark.parametrize("bwad", [False, True])
def test_serial_trans_env_check(self, bwad):
self._batched_trans_env_check(SerialEnv, bwad, within=True)

@pytest.mark.parametrize("bwad", [False, True])
def test_parallel_trans_env_check(self, bwad):
self._batched_trans_env_check(
partial(ParallelEnv, mp_start_method=mp_ctx), bwad, within=True
)

@pytest.mark.parametrize("bwad", [False, True])
def test_trans_serial_env_check(self, bwad):
self._batched_trans_env_check(SerialEnv, bwad, within=False)

@pytest.mark.parametrize("bwad", [True, False])
@pytest.mark.parametrize("buffers", [True, False])
def test_trans_parallel_env_check(self, bwad, buffers):
self._batched_trans_env_check(
partial(ParallelEnv, use_buffers=buffers, mp_start_method=mp_ctx),
bwad,
within=False,
)

def test_transform_no_env(self):
...

def test_transform_compose(self):
...

@pytest.mark.parametrize("bwad", [True, False])
def test_transform_env(self, bwad):
# tests stateless (batch-unlocked) envs
torch.manual_seed(0)
env = StateLessCountingEnv()

def policy(td):
td["action"] = torch.ones(td.shape + (1,))
return td

r = env.rollout(
10,
tensordict=env.reset().expand(4),
auto_reset=False,
break_when_any_done=False,
policy=policy,
)
assert (r["count"] == torch.arange(10).expand(4, 10).view(4, 10, 1)).all()
td_reset = env.reset().expand(4).clone()
td_reset["max_count"] = torch.arange(4, 8).view(4, 1)
env = TransformedEnv(env, MultiAction())

def policy(td):
td["action"] = torch.ones(td.shape + (3,) + (1,))
return td

r = env.rollout(
20,
policy=policy,
auto_reset=False,
tensordict=td_reset,
break_when_any_done=bwad,
)

def test_transform_model(self):
...

def test_transform_rb(self):
return

def test_transform_inverse(self):
return


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
Loading

0 comments on commit be8ca50

Please sign in to comment.