Description
Describe the bug
For a custom environment with NonTensorData calling tensordict, tensordict_ = step_and_maybe_reset(tensordict)
changes both the (next, observation) entry of the input tensordict
(unexpected), as well as the observation entry of tensordict_
which is partially been reset (expected).
To Reproduce
This Environment is hard coded for batch_size = (2,).
The observation space is just a string for simplicity.
_step
always returns ["B", "Z"] as next observation, with the first batch entry being in a done state but not the second.
_reset
always returns ["A", "C"] as initial observation after reset.
(The action is ignored and only included to comply with the spec)
from typing import Optional
from torchrl.data import CompositeSpec, NonTensorSpec, BinaryDiscreteTensorSpec
from torchrl.envs import EnvBase
from tensordict import TensorDictBase, TensorDict, NonTensorData, NonTensorStack
import torch
class CustomEnv(EnvBase):
# Custom environment
def __init__(
self,
*,
device=None,
batch_size: Optional[torch.Size] = torch.Size([2]),
run_type_checks: bool = False,
allow_done_after_reset: bool = False,
):
assert batch_size == (2,) # hardcoded for minimal example
super().__init__(
device=device,
batch_size=batch_size,
run_type_checks=run_type_checks,
allow_done_after_reset=allow_done_after_reset,
)
self.observation_spec = CompositeSpec(
observation=NonTensorSpec(shape=batch_size), shape=batch_size
)
self.action_spec = NonTensorSpec(shape=batch_size)
self.reward_spec: BinaryDiscreteTensorSpec = BinaryDiscreteTensorSpec(
n=1, dtype=torch.int8, shape=torch.Size([2, 1])
)
def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
done = torch.tensor([True, False], dtype=torch.bool)
next_observation = NonTensorStack(
NonTensorData("B"), NonTensorData("Z"), batch_size=(2,)
)
return TensorDict(
{"observation": next_observation, "done": done, "reward": torch.ones((2,))},
batch_size=(2,),
)
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
return TensorDict(
{
"observation": NonTensorStack(
NonTensorData("A"), NonTensorData("C"), batch_size=(2,)
)
},
batch_size=(2,),
)
def rand_action(self, tensordict: Optional[TensorDictBase] = None):
action = NonTensorStack(NonTensorData("+"), NonTensorData("+"), batch_size=(2,))
if tensordict is None:
tensordict = TensorDict({}, batch_size=self.batch_size)
tensordict["action"] = action
return tensordict
def _set_seed(self, seed: Optional[int]):
pass
env = CustomEnv()
td = env.reset()
env.rand_action(td)
out_td, reset_td = env.step_and_maybe_reset(td)
assert out_td is td
assert torch.equal(td["next", "done"], torch.tensor([[True], [False]]))
observation = "observation"
next_observation = ("next", observation)
print(f"{td[next_observation]=}")
print(f"{reset_td[observation]=}")
td[next_observation]=['A', 'Z']
reset_td[observation]=['A', 'Z']
Expected behavior
After taking one step, and executing out_td, reset_td = env.step_and_maybe_reset(td)
we expect that td
is unchanged, especially td["next","observation"]
and reset_td
having the observation being reset in the first dimension but not the second. Specifically, we expect td["next","observation"]=["B","Z"]
and reset_td["observation"] = ["A","Z"]
.
However, both td["next","observation"]
and reset_td["observation"]
are both ["A", "Z"].
System info
The library was installed using pip requirements. We use the nightly-release.
tensordict-nightly>=2024.6.19
torch >= 2.4.0.dev
torchrl-nightly>=2024.6.23
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
2024.6.23 2.0.0 3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0] linux
Additional context
The problem occurs only for partial resets (not all batch entries are done) and is likely correlated with pytorch/tensordict#837.
Interestingly using the latest releases (0.4.0 1.26.4 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] linux) I get another wrong result:
td[next_observation]=['B', 'Z']
reset_td[observation]=['B', 'Z']
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)