Skip to content

[BugFix] Fix batch_locked check in check_env_specs + error message callable #2817

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 3, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 22 additions & 17 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import re
import warnings
from enum import Enum
from typing import Any, Dict, List
from typing import Any

import torch

Expand Down Expand Up @@ -329,9 +329,9 @@ def step_mdp(
exclude_reward: bool = True,
exclude_done: bool = False,
exclude_action: bool = True,
reward_keys: NestedKey | List[NestedKey] = "reward",
done_keys: NestedKey | List[NestedKey] = "done",
action_keys: NestedKey | List[NestedKey] = "action",
reward_keys: NestedKey | list[NestedKey] = "reward",
done_keys: NestedKey | list[NestedKey] = "done",
action_keys: NestedKey | list[NestedKey] = "action",
) -> TensorDictBase:
"""Creates a new tensordict that reflects a step in time of the input tensordict.

Expand Down Expand Up @@ -680,8 +680,8 @@ def _per_level_env_check(data0, data1, check_dtype):


def check_env_specs(
env,
return_contiguous=True,
env: torchrl.envs.EnvBase, # noqa
return_contiguous: bool | None = None,
check_dtype=True,
seed: int | None = None,
tensordict: TensorDictBase | None = None,
Expand All @@ -700,7 +700,7 @@ def check_env_specs(
env (EnvBase): the env for which the specs have to be checked against data.
return_contiguous (bool, optional): if ``True``, the random rollout will be called with
return_contiguous=True. This will fail in some cases (e.g. heterogeneous shapes
of inputs/outputs). Defaults to True.
of inputs/outputs). Defaults to ``None`` (determined by the presence of dynamic specs).
check_dtype (bool, optional): if False, dtype checks will be skipped.
Defaults to True.
seed (int, optional): for reproducibility, a seed can be set.
Expand All @@ -718,6 +718,8 @@ def check_env_specs(
of an experiment and as such should be kept out of training scripts.

"""
if return_contiguous is None:
return_contiguous = not env._has_dynamic_specs
if break_when_any_done == "both":
check_env_specs(
env,
Expand Down Expand Up @@ -746,7 +748,7 @@ def check_env_specs(
)

fake_tensordict = env.fake_tensordict()
if not env._batch_locked and tensordict is not None:
if not env.batch_locked and tensordict is not None:
shape = torch.broadcast_shapes(fake_tensordict.shape, tensordict.shape)
fake_tensordict = fake_tensordict.expand(shape)
tensordict = tensordict.expand(shape)
Expand Down Expand Up @@ -786,10 +788,13 @@ def check_env_specs(
- List of keys present in fake but not in real: {fake_tensordict_keys-real_tensordict_keys}.
"""
)
zeroing_err_msg = (
"zeroing the two tensordicts did not make them identical. "
f"Check for discrepancies:\nFake=\n{fake_tensordict}\nReal=\n{real_tensordict}"
)

def zeroing_err_msg():
return (
"zeroing the two tensordicts did not make them identical. "
f"Check for discrepancies:\nFake=\n{fake_tensordict}\nReal=\n{real_tensordict}"
)

from torchrl.envs.common import _has_dynamic_specs

if _has_dynamic_specs(env.specs):
Expand All @@ -799,7 +804,7 @@ def check_env_specs(
):
fake = fake.apply(lambda x, y: x.expand_as(y), real)
if (torch.zeros_like(real) != torch.zeros_like(fake)).any():
raise AssertionError(zeroing_err_msg)
raise AssertionError(zeroing_err_msg())

# Checks shapes and eventually dtypes of keys at all nesting levels
_per_level_env_check(fake, real, check_dtype=check_dtype)
Expand All @@ -809,7 +814,7 @@ def check_env_specs(
torch.zeros_like(fake_tensordict_select)
!= torch.zeros_like(real_tensordict_select)
).any():
raise AssertionError(zeroing_err_msg)
raise AssertionError(zeroing_err_msg())

# Checks shapes and eventually dtypes of keys at all nesting levels
_per_level_env_check(
Expand Down Expand Up @@ -1028,14 +1033,14 @@ class MarlGroupMapType(Enum):
ALL_IN_ONE_GROUP = 1
ONE_GROUP_PER_AGENT = 2

def get_group_map(self, agent_names: List[str]):
def get_group_map(self, agent_names: list[str]):
if self == MarlGroupMapType.ALL_IN_ONE_GROUP:
return {"agents": agent_names}
elif self == MarlGroupMapType.ONE_GROUP_PER_AGENT:
return {agent_name: [agent_name] for agent_name in agent_names}


def check_marl_grouping(group_map: Dict[str, List[str]], agent_names: List[str]):
def check_marl_grouping(group_map: dict[str, list[str]], agent_names: list[str]):
"""Check MARL group map.

Performs checks on the group map of a marl environment to assess its validity.
Expand Down Expand Up @@ -1379,7 +1384,7 @@ def skim_through(td, reset=reset):
def _update_during_reset(
tensordict_reset: TensorDictBase,
tensordict: TensorDictBase,
reset_keys: List[NestedKey],
reset_keys: list[NestedKey],
):
"""Updates the input tensordict with the reset data, based on the reset keys."""
if not reset_keys:
Expand Down
Loading