Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] EnvBase.step_and_maybe_reset(td) modifies the ('next','observation') data too on partial reset withNonTensorStack #2257

Closed
3 tasks done
jkrude opened this issue Jul 1, 2024 · 1 comment · Fixed by #2260
Assignees
Labels
bug Something isn't working

Comments

@jkrude
Copy link

jkrude commented Jul 1, 2024

Describe the bug

For a custom environment with NonTensorData calling tensordict, tensordict_ = step_and_maybe_reset(tensordict) changes both the (next, observation) entry of the input tensordict (unexpected), as well as the observation entry of tensordict_ which is partially been reset (expected).

To Reproduce

This Environment is hard coded for batch_size = (2,).
The observation space is just a string for simplicity.
_step always returns ["B", "Z"] as next observation, with the first batch entry being in a done state but not the second.
_reset always returns ["A", "C"] as initial observation after reset.
(The action is ignored and only included to comply with the spec)

from typing import Optional

from torchrl.data import CompositeSpec, NonTensorSpec, BinaryDiscreteTensorSpec
from torchrl.envs import EnvBase
from tensordict import TensorDictBase, TensorDict, NonTensorData, NonTensorStack
import torch


class CustomEnv(EnvBase):
    # Custom environment

    def __init__(
        self,
        *,
        device=None,
        batch_size: Optional[torch.Size] = torch.Size([2]),
        run_type_checks: bool = False,
        allow_done_after_reset: bool = False,
    ):
        assert batch_size == (2,)  # hardcoded for minimal example
        super().__init__(
            device=device,
            batch_size=batch_size,
            run_type_checks=run_type_checks,
            allow_done_after_reset=allow_done_after_reset,
        )

        self.observation_spec = CompositeSpec(
            observation=NonTensorSpec(shape=batch_size), shape=batch_size
        )
        self.action_spec = NonTensorSpec(shape=batch_size)
        self.reward_spec: BinaryDiscreteTensorSpec = BinaryDiscreteTensorSpec(
            n=1, dtype=torch.int8, shape=torch.Size([2, 1])
        )

    def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
        done = torch.tensor([True, False], dtype=torch.bool)
        next_observation = NonTensorStack(
            NonTensorData("B"), NonTensorData("Z"), batch_size=(2,)
        )
        return TensorDict(
            {"observation": next_observation, "done": done, "reward": torch.ones((2,))},
            batch_size=(2,),
        )

    def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
        return TensorDict(
            {
                "observation": NonTensorStack(
                    NonTensorData("A"), NonTensorData("C"), batch_size=(2,)
                )
            },
            batch_size=(2,),
        )

    def rand_action(self, tensordict: Optional[TensorDictBase] = None):
        action = NonTensorStack(NonTensorData("+"), NonTensorData("+"), batch_size=(2,))
        if tensordict is None:
            tensordict = TensorDict({}, batch_size=self.batch_size)
        tensordict["action"] = action
        return tensordict

    def _set_seed(self, seed: Optional[int]):
        pass


env = CustomEnv()
td = env.reset()
env.rand_action(td)
out_td, reset_td = env.step_and_maybe_reset(td)
assert out_td is td
assert torch.equal(td["next", "done"], torch.tensor([[True], [False]]))
observation = "observation"
next_observation = ("next", observation)

print(f"{td[next_observation]=}")
print(f"{reset_td[observation]=}")
td[next_observation]=['A', 'Z']
reset_td[observation]=['A', 'Z']

Expected behavior

After taking one step, and executing out_td, reset_td = env.step_and_maybe_reset(td) we expect that td is unchanged, especially td["next","observation"] and reset_td having the observation being reset in the first dimension but not the second. Specifically, we expect td["next","observation"]=["B","Z"] and reset_td["observation"] = ["A","Z"].

However, both td["next","observation"] and reset_td["observation"] are both ["A", "Z"].

System info

The library was installed using pip requirements. We use the nightly-release.

tensordict-nightly>=2024.6.19
torch >= 2.4.0.dev
torchrl-nightly>=2024.6.23
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
2024.6.23 2.0.0 3.11.5 (main, Sep 11 2023, 13:54:46) [GCC 11.2.0] linux

Additional context

The problem occurs only for partial resets (not all batch entries are done) and is likely correlated with pytorch/tensordict#837.
Interestingly using the latest releases (0.4.0 1.26.4 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] linux) I get another wrong result:

td[next_observation]=['B', 'Z']
reset_td[observation]=['B', 'Z']

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@jkrude jkrude added the bug Something isn't working label Jul 1, 2024
@vmoens
Copy link
Contributor

vmoens commented Jul 1, 2024

On it! Thanks for reporting

@vmoens vmoens linked a pull request Jul 1, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants