Skip to content

Commit 528986b

Browse files
committed
[Feature] batch_size, reward, done, attention_key in LLMEnv
ghstack-source-id: 90e0ff9 Pull Request resolved: #2824
1 parent 53065cf commit 528986b

File tree

11 files changed

+258
-51
lines changed

11 files changed

+258
-51
lines changed

test/test_env.py

+75
Original file line numberDiff line numberDiff line change
@@ -4861,6 +4861,81 @@ def policy(td):
48614861
r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"]
48624862
).any()
48634863

4864+
@pytest.mark.parametrize(
4865+
"str2str,stack_method",
4866+
[
4867+
[True, None],
4868+
[False, "as_padded_tensor"],
4869+
],
4870+
)
4871+
@pytest.mark.parametrize("batched", [True])
4872+
@pytest.mark.parametrize("device", [None])
4873+
@pytest.mark.parametrize("batch_size", [4])
4874+
@pytest.mark.parametrize("repeats", [3])
4875+
@pytest.mark.parametrize(
4876+
"assign_reward,assign_done", [[True, False], [True, True], [False, True]]
4877+
)
4878+
def test_done_and_reward(
4879+
self,
4880+
str2str,
4881+
batched,
4882+
stack_method,
4883+
device,
4884+
batch_size,
4885+
repeats,
4886+
assign_reward,
4887+
assign_done,
4888+
):
4889+
with pytest.raises(
4890+
ValueError, match="str2str"
4891+
) if str2str else contextlib.nullcontext():
4892+
if str2str:
4893+
kwargs = {
4894+
"dataloader": self.DummyDataLoader(batch_size=batch_size),
4895+
"data_keys": ["observation"],
4896+
"example_data": "a string!",
4897+
"repeats": repeats,
4898+
"assign_reward": assign_reward,
4899+
"assign_done": assign_done,
4900+
}
4901+
else:
4902+
if stack_method is None:
4903+
stack_method = as_padded_tensor
4904+
kwargs = {
4905+
"dataloader": self.DummyTensorDataLoader(
4906+
padding=True, batch_size=batch_size
4907+
),
4908+
"data_keys": ["observation"],
4909+
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
4910+
"stack_method": stack_method,
4911+
"repeats": repeats,
4912+
"assign_reward": assign_reward,
4913+
"assign_done": assign_done,
4914+
}
4915+
kwargs.update({"str2str": str2str, "device": device})
4916+
env = LLMEnv.from_dataloader(**kwargs)
4917+
# We want to make sure that transforms that rely on the done state work appropriately
4918+
env.append_transform(StepCounter(max_steps=10))
4919+
4920+
def policy(td):
4921+
td["action"] = torch.ones(
4922+
td.shape + (torch.randint(10, (1,)).item(),), dtype=torch.int64
4923+
)
4924+
return td
4925+
4926+
if batched:
4927+
r = env.rollout(
4928+
100,
4929+
policy,
4930+
tensordict=TensorDict(batch_size=[3]),
4931+
break_when_any_done=False,
4932+
)
4933+
else:
4934+
r = env.rollout(100, policy, break_when_any_done=False)
4935+
if assign_done:
4936+
assert "terminated" in r
4937+
assert "done" in r
4938+
48644939

48654940
if __name__ == "__main__":
48664941
args, unknown = argparse.ArgumentParser().parse_known_args()

torchrl/data/map/tdstorage.py

-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@ def __init__(
128128
self.in_keys = query_module.in_keys
129129
if out_keys is not None:
130130
self.out_keys = out_keys
131-
assert not self._has_lazy_out_keys()
132131

133132
self.query_module = query_module
134133
self.index_key = query_module.index_key

torchrl/data/postprocs/postprocs.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
from tensordict.utils import expand_right
1212
from torch import nn
1313

14-
from torchrl.objectives.value.functional import reward2go
15-
1614

1715
def _get_reward(
1816
gamma: float,
@@ -367,13 +365,16 @@ def __init__(
367365
time_dim: int = 2,
368366
discount: float = 1.0,
369367
):
368+
from torchrl.objectives.value.functional import reward2go
369+
370370
super().__init__()
371371
self.in_keys = [unravel_key(reward_key), unravel_key(done_key)]
372372
if reward_key_out is None:
373373
reward_key_out = reward_key
374374
self.out_keys = [unravel_key(reward_key_out)]
375375
self.time_dim = time_dim
376376
self.discount = discount
377+
self.reward2go = reward2go
377378

378379
def forward(self, tensordict):
379380
# Get done
@@ -385,6 +386,6 @@ def forward(self, tensordict):
385386
f"reward and done state are expected to have the same shape. Got reard.shape={reward.shape} "
386387
f"and done.shape={done.shape}."
387388
)
388-
reward = reward2go(reward, done, time_dim=-2, gamma=self.discount)
389+
reward = self.reward2go(reward, done, time_dim=-2, gamma=self.discount)
389390
tensordict.set(("next", self.out_keys[0]), reward)
390391
return tensordict

torchrl/envs/common.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -2788,7 +2788,11 @@ def _reset_check_done(self, tensordict, tensordict_reset):
27882788
if reset_value is not None:
27892789
for done_key in done_key_group:
27902790
done_val = tensordict_reset.get(done_key)
2791-
if done_val[reset_value].any() and not self._allow_done_after_reset:
2791+
if (
2792+
done_val.any()
2793+
and done_val[reset_value].any()
2794+
and not self._allow_done_after_reset
2795+
):
27922796
raise RuntimeError(
27932797
f"Env done entry '{done_key}' was (partially) True after reset on specified '_reset' dimensions. This is not allowed."
27942798
)
@@ -3588,7 +3592,7 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase:
35883592
"""
35893593
any_done = self.any_done(tensordict)
35903594
if any_done:
3591-
return self.reset(tensordict, select_reset_only=True)
3595+
tensordict = self.reset(tensordict, select_reset_only=True)
35923596
return tensordict
35933597

35943598
def empty_cache(self):

0 commit comments

Comments
 (0)