From f90a4df5caba729d6c1ffdfb0c5b639bca509cfd Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 3 Mar 2025 16:33:21 -0800 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- test/test_env.py | 76 +++++++++++ torchrl/data/map/tdstorage.py | 1 - torchrl/data/postprocs/postprocs.py | 5 +- torchrl/envs/common.py | 8 +- torchrl/envs/custom/llm.py | 176 ++++++++++++++++++++++---- torchrl/envs/libs/openspiel.py | 2 - torchrl/envs/libs/unity_mlagents.py | 1 - torchrl/envs/transforms/rlhf.py | 2 + torchrl/envs/transforms/transforms.py | 2 +- torchrl/envs/utils.py | 6 +- torchrl/objectives/common.py | 1 - 11 files changed, 243 insertions(+), 37 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index b2c26914ca6..6f962538f95 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -4798,6 +4798,82 @@ def policy(td): r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"] ).any() + @pytest.mark.parametrize( + "str2str,stack_method", + [ + [True, None], + [False, "as_padded_tensor"], + ], + ) + @pytest.mark.parametrize("batched", [True]) + @pytest.mark.parametrize("device", [None]) + @pytest.mark.parametrize("batch_size", [4]) + @pytest.mark.parametrize("repeats", [3]) + @pytest.mark.parametrize( + "assign_reward,assign_done", [[True, False], [True, True], [False, True]] + ) + def test_done_and_reward( + self, + str2str, + batched, + stack_method, + device, + batch_size, + repeats, + assign_reward, + assign_done, + ): + with pytest.raises( + ValueError, match="str2str" + ) if str2str else contextlib.nullcontext(): + if str2str: + kwargs = { + "dataloader": self.DummyDataLoader(batch_size=batch_size), + "data_keys": ["observation"], + "example_data": "a string!", + "repeats": repeats, + "assign_reward": assign_reward, + "assign_done": assign_done, + } + else: + if stack_method is None: + stack_method = as_padded_tensor + kwargs = { + "dataloader": self.DummyTensorDataLoader( + padding=True, batch_size=batch_size + ), + "data_keys": ["observation"], + "data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)], + "stack_method": stack_method, + "repeats": repeats, + "assign_reward": assign_reward, + "assign_done": assign_done, + } + kwargs.update({"str2str": str2str, "device": device}) + env = LLMEnv.from_dataloader(**kwargs) + # We want to make sure that transforms that rely on the done state work appropriately + env.append_transform(StepCounter(max_steps=10)) + + def policy(td): + td["action"] = torch.ones( + td.shape + (torch.randint(10, (1,)).item(),), dtype=torch.int64 + ) + return td + + if batched: + r = env.rollout( + 100, + policy, + tensordict=TensorDict(batch_size=[3]), + break_when_any_done=False, + ) + else: + r = env.rollout(100, policy, break_when_any_done=False) + if assign_done: + assert "terminated" in r + assert "done" in r + print(r) + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/data/map/tdstorage.py b/torchrl/data/map/tdstorage.py index f1464308144..1e8472260da 100644 --- a/torchrl/data/map/tdstorage.py +++ b/torchrl/data/map/tdstorage.py @@ -128,7 +128,6 @@ def __init__( self.in_keys = query_module.in_keys if out_keys is not None: self.out_keys = out_keys - assert not self._has_lazy_out_keys() self.query_module = query_module self.index_key = query_module.index_key diff --git a/torchrl/data/postprocs/postprocs.py b/torchrl/data/postprocs/postprocs.py index 1868deb2c12..53d283dbdad 100644 --- a/torchrl/data/postprocs/postprocs.py +++ b/torchrl/data/postprocs/postprocs.py @@ -11,7 +11,6 @@ from tensordict.utils import expand_right from torch import nn -from torchrl.objectives.value.functional import reward2go def _get_reward( @@ -367,6 +366,7 @@ def __init__( time_dim: int = 2, discount: float = 1.0, ): + from torchrl.objectives.value.functional import reward2go super().__init__() self.in_keys = [unravel_key(reward_key), unravel_key(done_key)] if reward_key_out is None: @@ -374,6 +374,7 @@ def __init__( self.out_keys = [unravel_key(reward_key_out)] self.time_dim = time_dim self.discount = discount + self.reward2go = reward2go def forward(self, tensordict): # Get done @@ -385,6 +386,6 @@ def forward(self, tensordict): f"reward and done state are expected to have the same shape. Got reard.shape={reward.shape} " f"and done.shape={done.shape}." ) - reward = reward2go(reward, done, time_dim=-2, gamma=self.discount) + reward = self.reward2go(reward, done, time_dim=-2, gamma=self.discount) tensordict.set(("next", self.out_keys[0]), reward) return tensordict diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index c7377a84d9e..cdfbe5c19e3 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2788,7 +2788,11 @@ def _reset_check_done(self, tensordict, tensordict_reset): if reset_value is not None: for done_key in done_key_group: done_val = tensordict_reset.get(done_key) - if done_val[reset_value].any() and not self._allow_done_after_reset: + if ( + done_val.any() + and done_val[reset_value].any() + and not self._allow_done_after_reset + ): raise RuntimeError( f"Env done entry '{done_key}' was (partially) True after reset on specified '_reset' dimensions. This is not allowed." ) @@ -3588,7 +3592,7 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase: """ any_done = self.any_done(tensordict) if any_done: - return self.reset(tensordict, select_reset_only=True) + tensordict = self.reset(tensordict, select_reset_only=True) return tensordict def empty_cache(self): diff --git a/torchrl/envs/custom/llm.py b/torchrl/envs/custom/llm.py index fc94db38216..7282a7f60b2 100644 --- a/torchrl/envs/custom/llm.py +++ b/torchrl/envs/custom/llm.py @@ -39,7 +39,7 @@ class LLMEnv(EnvBase): Prompts to the language model can be loaded when the environment is ``reset`` if the environment is created via :meth:`~from_dataloader` - Args: + Keyword Args: observation_key (NestedKey, optional): The key in the tensordict where the observation is stored. Defaults to ``"observation"``. action_key (NestedKey, optional): The key in the tensordict where the action is stored. Defaults to ``"action"``. @@ -47,6 +47,12 @@ class LLMEnv(EnvBase): device (torch.device | None, optional): The device on which the environment should run. Defaults to ``None``. vocab_size (int | None, optional): The size of the vocabulary. If None, the environment will assume an unbounded vocabulary. Defaults to ``None``. + no_stack (bool, optional): If ``False`` (default), the environment should stack the action with the past + observation, each action being a new, unseen part of a conversation. Otherwise, the action is assumed + to be the plain output of the LLM, including the input tokens / strings. + assign_reward (bool, optional): TODO + assign_done (bool, optional): TODO + reward_key (NestedKey, optional): TODO .. seealso:: :class:`~torchrl.envs.DataLoadingPrimer` for examples. @@ -58,39 +64,58 @@ class LLMEnv(EnvBase): def __init__( self, *, - observation_key: NestedKey = "observation", + token_key: NestedKey = "observation", + attention_key: NestedKey | None = None, action_key: NestedKey = "action", str2str: bool = False, device: torch.device | None = None, vocab_size: int | None = None, + no_stack: bool = False, + assign_reward: bool = False, + assign_done: bool = False, + reward_key: NestedKey = "reward", + batch_size: int | None = None, ) -> None: - super().__init__(device=device) - self._batch_locked = False + if batch_size is None: + self._batch_locked = False + else: + self._batch_locked = True + super().__init__(device=device, batch_size=() if batch_size is None else (batch_size,)) self.str2str = str2str self.vocab_size = vocab_size - self.observation_key = unravel_key(observation_key) + self.observation_key = unravel_key(token_key) + self.attention_key = unravel_key(attention_key) + self.no_stack = no_stack + self.assign_reward = assign_reward + self.assign_done = assign_done + # self.action_key = unravel_key(action_key) if str2str: - self.observation_spec = Composite( + self.full_observation_spec_unbatched = Composite( { - observation_key: NonTensor( + token_key: NonTensor( example_data="a string", batched=True, shape=() ) } ) - self.action_spec = Composite( + self.full_action_spec_unbatched = Composite( {action_key: NonTensor(example_data="a string", batched=True, shape=())} ) else: if vocab_size is None: - self.observation_spec = Composite( - { - observation_key: Unbounded( + observation_spec = { + token_key: Unbounded( shape=(-1,), dtype=torch.int64, device=device ) } + if attention_key is not None: + observation_spec[attention_key] = Unbounded( + shape=(-1,), dtype=torch.int64, device=device + ) + self.full_observation_spec_unbatched = Composite( + observation_spec ) - self.action_spec = Composite( + self.full_action_spec_unbatched = Composite( { action_key: Unbounded( shape=(-1,), dtype=torch.int64, device=device @@ -98,9 +123,9 @@ def __init__( } ) else: - self.observation_spec = Composite( + self.full_observation_spec_unbatched = Composite( { - observation_key: Bounded( + token_key: Bounded( shape=(-1,), dtype=torch.int64, low=0, @@ -109,7 +134,7 @@ def __init__( ) } ) - self.action_spec = Composite( + self.full_action_spec_unbatched = Composite( { action_key: Bounded( shape=(-1,), @@ -120,22 +145,54 @@ def __init__( ) } ) - self.full_done_spec = Composite( - done=Unbounded(shape=(1,), dtype=torch.bool), - truncated=Unbounded(shape=(1,), dtype=torch.bool), - terminated=Unbounded(shape=(1,), dtype=torch.bool), + STR2STR_ERR = ValueError( + "str2str cannot be True when either of assign_reward / assign_done are True. " + "Tokens are required to compute the reward shape." ) + if self.assign_reward: + if self.str2str: + raise STR2STR_ERR + self.full_reward_spec_unbatched = Composite( + {reward_key: Unbounded(shape=(-1,), device=device)} + ) + else: + self.full_reward_spec_unbatched = Composite(device=device) + + if not self.assign_done: + # Use single done + self.full_done_spec_unbatched = Composite( + done=Unbounded(shape=(-1,), dtype=torch.bool), + terminated=Unbounded(shape=(-1,), dtype=torch.bool), + ) + elif self.str2str: + raise STR2STR_ERR + else: + # Use single done + self.full_done_spec_unbatched = Composite( + tokens=Composite( + done=Unbounded(shape=(-1,), dtype=torch.bool), + terminated=Unbounded(shape=(-1,), dtype=torch.bool), + ), + done=Unbounded(shape=(1,), dtype=torch.bool), + terminated=Unbounded(shape=(1,), dtype=torch.bool), + ) @classmethod def from_dataloader( cls, dataloader: DataLoader, *, - observation_key: NestedKey = "observation", + token_key: NestedKey = "observation", + attention_key: NestedKey | None = None, action_key: NestedKey = "action", str2str: bool = False, device: torch.device | None = None, vocab_size: int | None = None, + no_stack: bool = False, + batch_size: int | None = None, + assign_reward: bool = False, + assign_done: bool = False, + reward_key: NestedKey = "reward", primers: Composite | None = None, data_keys: list[NestedKey] | None = None, data_specs: list[TensorSpec] | None = None, @@ -150,13 +207,21 @@ def from_dataloader( Args: dataloader (DataLoader): The dataloader to load data from. - observation_key (NestedKey, optional): The key in the tensordict where the observation is stored. Defaults + token_key (NestedKey, optional): The key in the tensordict where the observation is stored. Defaults to ``"observation"``. + attention_key: TODO action_key (NestedKey, optional): The key in the tensordict where the action is stored. Defaults to ``"action"``. str2str (bool, optional): Whether the environment should expect strings as input and output. Defaults to ``False``. device (torch.device | None, optional): The device on which the environment should run. Defaults to ``None``. vocab_size (int | None, optional): The size of the vocabulary. If None, the environment will assume an unbounded vocabulary. Defaults to ``None``. + no_stack (bool, optional): If ``False`` (default), the environment should stack the action with the past + observation, each action being a new, unseen part of a conversation. Otherwise, the action is assumed + to be the plain output of the LLM, including the input tokens / strings. + assign_reward (bool, optional): TODO + assign_done (bool, optional): TODO + reward_key (NestedKey, optional): TODO + batch_size (int, optional): TODO primers (Composite | None, optional): The primers to use for each key in the dataloader. Defaults to ``None``. data_keys (list[NestedKey] | None, optional): The keys to use for each item in the dataloader. If not passed ``observation_key`` will be populated with the data. @@ -178,7 +243,7 @@ def from_dataloader( primer = DataLoadingPrimer( dataloader=dataloader, primers=primers, - data_keys=data_keys if data_keys is not None else [observation_key], + data_keys=data_keys if data_keys is not None else [token_key], data_specs=data_specs, example_data=example_data, stack_method=stack_method, @@ -187,9 +252,15 @@ def from_dataloader( env = LLMEnv( str2str=str2str, device=device, - observation_key=observation_key, + token_key=token_key, + attention_key=attention_key, action_key=action_key, vocab_size=vocab_size, + no_stack=no_stack, + assign_reward=assign_reward, + assign_done=assign_done, + reward_key=reward_key, + batch_size=batch_size, ) return env.append_transform(primer) @@ -205,6 +276,59 @@ def _step( self, tensordict: TensorDictBase, ) -> TensorDictBase: + next_td = tensordict.empty() + self._make_next_obs(tensordict, next_td) + self._maybe_make_reward(tensordict, next_td) + self._maybe_make_done(tensordict, next_td) + return next_td + + def _maybe_make_reward( + self, tensordict: TensorDictBase, next_td: TensorDictBase + ) -> TensorDictBase: + if self.assign_reward: + next_td.set( + self.reward_key, + torch.zeros_like( + tensordict.get(self.action_key), dtype=self.reward_spec.dtype + ), + ) + return next_td + + def _maybe_make_done( + self, tensordict: TensorDictBase, next_td: TensorDictBase + ) -> TensorDictBase: + if self.assign_done: + action = tensordict.get(self.action_key) + if action is None: + done = torch.zeros( + tensordict.shape + (1,), dtype=torch.bool, device=self.device + ) + else: + done = torch.zeros_like(action, dtype=torch.bool) + next_td.set(("tokens", "terminated"), done) + next_td.set(("tokens", "done"), done.clone()) + next_td.set( + "terminated", next_td.get(("tokens", "done")).any(-1, keepdim=True) + ) + next_td.set( + "terminated", + next_td.get(("tokens", "terminated")).any(-1, keepdim=True), + ) + return next_td + + def _make_next_obs( + self, tensordict: TensorDictBase, nex_td: TensorDictBase + ) -> TensorDictBase: + if self.no_stack: + action = tensordict.get(self.action_key) + nex_td.set(self.observation_key, action) + if self.attention_key is not None: + attention_mask = tensordict.get(self.attention_key) + n = action.shape[-1] - attention_mask.shape[-1] + attention_mask = torch.cat([attention_mask, attention_mask.new_ones(attention_mask.shape[:-1] + (n,))], -1) + nex_td.set(self.attention_key, attention_mask) + return nex_td + # Cat action entry with prev obs if self.str2str: obs = tensordict[self.observation_key] @@ -256,10 +380,11 @@ def _step( "Failed to cat action and observation tensors. Check that str2str argument is correctly " f"set in {type(self).__name__}." ) - return tensordict.empty().set(self.observation_key, observation) + return nex_td.set(self.observation_key, observation) def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: # We should have an observation by this time, if not raise an exception + print('tensordict', tensordict) if tensordict is None or self.observation_key not in tensordict.keys( isinstance(self.observation_key, tuple) ): @@ -267,7 +392,8 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: f"Observation key {self.observation_key} is not defined. Make sure a TensorDictPrimer (eg, " f"torchrl.envs.DataLoadingPrimer) is appended to the env transforms." ) - return tensordict.copy() + td_reset = tensordict.copy() + return self._maybe_make_done(tensordict, td_reset) def _set_seed(self, seed: int | None): return seed diff --git a/torchrl/envs/libs/openspiel.py b/torchrl/envs/libs/openspiel.py index 7dbfeac07a8..d55059dc4fa 100644 --- a/torchrl/envs/libs/openspiel.py +++ b/torchrl/envs/libs/openspiel.py @@ -470,8 +470,6 @@ def _step_sequential(self, tensordict: TensorDictBase): agent_index_in_group = agents.index(agent) break - assert agent_group is not None - action_tensor = tensordict[agent_group, "action"][agent_index_in_group] action = self._get_action_from_tensor(action_tensor) self._env.apply_action(action) diff --git a/torchrl/envs/libs/unity_mlagents.py b/torchrl/envs/libs/unity_mlagents.py index 149c20606ac..ab3b568cd87 100644 --- a/torchrl/envs/libs/unity_mlagents.py +++ b/torchrl/envs/libs/unity_mlagents.py @@ -142,7 +142,6 @@ def _collect_agents(self, env): # Sometimes in an MLAgents environment, an agent may # show up in both the decision steps and the terminal # steps. When that happens, just skip the duplicate. - assert is_terminal continue agent_name_to_behavior_map[agent_name] = behavior agent_name_to_group_id_map[agent_name] = group_id diff --git a/torchrl/envs/transforms/rlhf.py b/torchrl/envs/transforms/rlhf.py index b4c8712299e..85bdd1f4bbc 100644 --- a/torchrl/envs/transforms/rlhf.py +++ b/torchrl/envs/transforms/rlhf.py @@ -412,6 +412,7 @@ def __init__( single_default_value=True, call_before_env_reset=True, ) + self._reset_key = "_reset" if self.use_buffer: self._queue = deque() @@ -460,6 +461,7 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None): raise ValueError( f"Unrecognized data type: {type(data)} with keys {self.data_keys}." ) + print('out', out) if self.use_buffer: if not out.ndim: out = out.unsqueeze(0) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 4930393f4b5..732fb203619 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -6134,7 +6134,7 @@ def __init__( @property def reset_key(self): - reset_key = self.__dict__.get("_reset_key", None) + reset_key = self.__dict__.get("_reset_key") if reset_key is None: if self.parent is None: raise RuntimeError( diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index b374a4594e8..e2d2d7bfe93 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -200,7 +200,7 @@ def _repr_key_list_as_tree(key_list): """Represents the keys as a tree to facilitate iteration.""" if not key_list: return {} - key_dict = {key: torch.zeros(()) for key in key_list} + key_dict = {key: torch.zeros((0,)) for key in key_list} td = TensorDict(key_dict, batch_size=torch.Size([])) return tree_map(lambda x: None, td.to_dict()) @@ -1390,6 +1390,7 @@ def _update_during_reset( if not reset_keys: return tensordict.update(tensordict_reset) roots = set() + print("reset_keys", reset_keys) for reset_key in reset_keys: # get the node of the reset key if isinstance(reset_key, tuple): @@ -1400,11 +1401,12 @@ def _update_during_reset( node = tensordict.get(node_key) reset_key_tuple = reset_key else: - node_reset = tensordict_reset + node_reset = tensordict_reset.exclude(reset_key) node = tensordict reset_key_tuple = (reset_key,) # get the reset signal reset = tensordict.pop(reset_key, None) + print("reset popped", reset) # check if this reset should be ignored -- this happens whenever the # root node has already been updated diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 50fc7ee7fba..2a74fac8f23 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -412,7 +412,6 @@ def _compare_and_expand(param): params.set(key, parameter.data) setattr(self, param_name, params) - assert getattr(self, param_name) is params, getattr(self, param_name) # Set the module in the __dict__ directly to avoid listing its params # A deepcopy with meta device could be used but that assumes that the model is copyable! From f16655fdcca70a1d87f2b9aad9caec84570ef987 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 6 Mar 2025 14:29:56 -0800 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- torchrl/envs/custom/llm.py | 10 +++++++++- torchrl/envs/transforms/transforms.py | 11 +++++++---- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/torchrl/envs/custom/llm.py b/torchrl/envs/custom/llm.py index 7282a7f60b2..d118ae32c98 100644 --- a/torchrl/envs/custom/llm.py +++ b/torchrl/envs/custom/llm.py @@ -325,7 +325,15 @@ def _make_next_obs( if self.attention_key is not None: attention_mask = tensordict.get(self.attention_key) n = action.shape[-1] - attention_mask.shape[-1] - attention_mask = torch.cat([attention_mask, attention_mask.new_ones(attention_mask.shape[:-1] + (n,))], -1) + if n > 0: + # It can happen that there's only one action (eg rand_action) + attention_mask = torch.cat( + [ + attention_mask, + attention_mask.new_ones(attention_mask.shape[:-1] + (n,)), + ], + -1, + ) nex_td.set(self.attention_key, attention_mask) return nex_td diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 732fb203619..0fffe6ccd7e 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -6361,10 +6361,13 @@ def _reset_func( def __repr__(self) -> str: class_name = self.__class__.__name__ - default_value = { - key: value if isinstance(value, float) else "Callable" - for key, value in self.default_value.items() - } + if callable(self.default_value): + default_value = self.default_value + else: + default_value = { + key: value if isinstance(value, float) else "Callable" + for key, value in self.default_value.items() + } return f"{class_name}(primers={self.primers}, default_value={default_value}, random={self.random})"