Skip to content

Commit 028d4e2

Browse files
committed
[Feature] DataLoadingPrimer.repeat
ghstack-source-id: c17a24a4594db737cae51e8897d215295aa52d03 Pull Request resolved: #2822
1 parent 73c7b0a commit 028d4e2

File tree

4 files changed

+129
-9
lines changed

4 files changed

+129
-9
lines changed

test/test_env.py

+98
Original file line numberDiff line numberDiff line change
@@ -4735,6 +4735,104 @@ def policy(td):
47354735
r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[]))
47364736
assert r.ndim == 1
47374737

4738+
@pytest.mark.parametrize(
4739+
"str2str,stack_method",
4740+
[
4741+
[True, None],
4742+
[False, "as_padded_tensor"],
4743+
# TODO: a bit experimental, fails with check_env_specs
4744+
# [False, "as_nested_tensor"],
4745+
[False, None],
4746+
],
4747+
)
4748+
@pytest.mark.parametrize("batched", [True, False])
4749+
@pytest.mark.parametrize("device", [None, "cpu"])
4750+
@pytest.mark.parametrize("batch_size", [0, 4])
4751+
@pytest.mark.parametrize("repeats", [3])
4752+
def test_llm_from_dataloader_repeats(
4753+
self, str2str, batched, stack_method, device, batch_size, repeats
4754+
):
4755+
if str2str:
4756+
kwargs = {
4757+
"dataloader": self.DummyDataLoader(batch_size=batch_size),
4758+
"data_keys": ["observation"],
4759+
"example_data": "a string!",
4760+
"repeats": repeats,
4761+
}
4762+
else:
4763+
if stack_method is None:
4764+
stack_method = as_padded_tensor
4765+
kwargs = {
4766+
"dataloader": self.DummyTensorDataLoader(
4767+
padding=True, batch_size=batch_size
4768+
),
4769+
"data_keys": ["observation"],
4770+
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
4771+
"stack_method": stack_method,
4772+
"repeats": repeats,
4773+
}
4774+
kwargs.update({"str2str": str2str, "device": device})
4775+
env = LLMEnv.from_dataloader(**kwargs)
4776+
assert env.transform.repeats == repeats
4777+
4778+
max_steps = 3
4779+
env.append_transform(StepCounter(max_steps=max_steps))
4780+
4781+
def policy(td):
4782+
if str2str:
4783+
if not td.shape:
4784+
td["action"] = "<nothing>"
4785+
else:
4786+
td["action"] = NonTensorStack(
4787+
*["<nothing>" for _ in range(td.shape[0])]
4788+
)
4789+
else:
4790+
td["action"] = torch.ones(td.shape + (1,), dtype=torch.int64)
4791+
return td
4792+
4793+
if batched:
4794+
r = env.rollout(
4795+
100,
4796+
policy,
4797+
tensordict=TensorDict(batch_size=[3]),
4798+
break_when_any_done=False,
4799+
)
4800+
else:
4801+
r = env.rollout(100, policy, break_when_any_done=False)
4802+
# check that r at reset is always the same
4803+
r_reset = r[..., ::max_steps]
4804+
if not batched:
4805+
if str2str:
4806+
assert r_reset[..., 0]["observation"] == r_reset[..., 1]["observation"]
4807+
assert r_reset[..., 0]["observation"] == r_reset[..., 2]["observation"]
4808+
assert r_reset[..., 0]["observation"] != r_reset[..., 3]["observation"]
4809+
else:
4810+
assert (
4811+
r_reset[..., 0]["observation"] == r_reset[..., 1]["observation"]
4812+
).all()
4813+
assert (
4814+
r_reset[..., 0]["observation"] == r_reset[..., 2]["observation"]
4815+
).all()
4816+
assert (
4817+
r_reset[..., 0]["observation"] != r_reset[..., 3]["observation"]
4818+
).any()
4819+
else:
4820+
# When batched, each block contains the 3 reset packs
4821+
if str2str:
4822+
assert r_reset[0, 0]["observation"] == r_reset[1, 0]["observation"]
4823+
assert r_reset[0, 0]["observation"] == r_reset[2, 0]["observation"]
4824+
assert r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"]
4825+
else:
4826+
assert (
4827+
r_reset[0, 0]["observation"] == r_reset[1, 0]["observation"]
4828+
).all()
4829+
assert (
4830+
r_reset[0, 0]["observation"] == r_reset[2, 0]["observation"]
4831+
).all()
4832+
assert (
4833+
r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"]
4834+
).any()
4835+
47384836

47394837
if __name__ == "__main__":
47404838
args, unknown = argparse.ArgumentParser().parse_known_args()

torchrl/envs/custom/llm.py

+5
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def from_dataloader(
142142
example_data: Any = None,
143143
stack_method: Callable[[Any], Any]
144144
| Literal["as_nested_tensor", "as_padded_tensor"] = None,
145+
repeats: int | None = None,
145146
) -> LLMEnv:
146147
"""Creates an LLMEnv instance from a dataloader.
147148
@@ -165,6 +166,9 @@ def from_dataloader(
165166
example_data (Any, optional): Example data to use for initializing the primer. Defaults to ``None``.
166167
stack_method (Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"], optional): The
167168
method to use for stacking the data. Defaults to ``None``.
169+
repeats (int, optional): How many times the same sample needs to appear successively. This can be useful in
170+
situations like GRPO where a single prompt is used multiple times to estimate the advantage using Monte-Carlo
171+
samples (rather than an advantage module).
168172
169173
Returns:
170174
LLMEnv: The created LLMEnv instance.
@@ -178,6 +182,7 @@ def from_dataloader(
178182
data_specs=data_specs,
179183
example_data=example_data,
180184
stack_method=stack_method,
185+
repeats=repeats,
181186
)
182187
env = LLMEnv(
183188
str2str=str2str,

torchrl/envs/transforms/llm.py

+22-7
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ class DataLoadingPrimer(TensorDictPrimer):
103103
auto_batch_size (bool, optional): If ``True`` (default if `dataloader.batch_size > 0`), the batch size of the
104104
tensordict returned by the transform will be automatically determined assuming that there is a single batch
105105
dimension.
106+
repeats (int, optional): How many times the same sample needs to appear successively. This can be useful in
107+
situations like GRPO where a single prompt is used multiple times to estimate the advantage using Monte-Carlo
108+
samples (rather than an advantage module).
106109
107110
Attributes:
108111
dataloader (Iterable[Any]): The dataloader to load data from.
@@ -359,15 +362,21 @@ def __init__(
359362
| Literal["as_nested_tensor", "as_padded_tensor"] = None,
360363
use_buffer: bool | None = None,
361364
auto_batch_size: bool = True,
365+
repeats: int | None = None,
362366
):
363367
self.dataloader = dataloader
364-
if getattr(dataloader, "batch_size", 1) > 1 and use_buffer is None:
368+
if repeats is None:
369+
repeats = 0
370+
self.repeats = repeats
371+
if (
372+
getattr(dataloader, "batch_size", 1) > 1 and use_buffer is None
373+
) or repeats > 0:
365374
use_buffer = True
366375

367376
self.use_buffer = use_buffer
368377
# No auto_batch_size if we know we have a single element
369378
self.auto_batch_size = auto_batch_size and (
370-
getattr(dataloader, "dataloader", 1) > 0
379+
getattr(dataloader, "batch_size", 1) > 0
371380
)
372381
self.endless_dataloader = self._endless_iter(self.dataloader)
373382
if primers is None:
@@ -420,11 +429,13 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None):
420429
if not reset.any():
421430
raise RuntimeError("reset must have at least one True value.")
422431
if reset.ndim > 0:
423-
return self.stack_method(
424-
[self._load_from_dataloader() for i in range(reset.sum())]
425-
)
432+
loaded = [self._load_from_dataloader() for i in range(reset.sum())]
433+
return self.stack_method(loaded)
434+
426435
if self.use_buffer and len(self._queue) > 0:
427-
return self._queue.popleft()
436+
result = self._queue.popleft()
437+
return result
438+
428439
data = next(self.endless_dataloader)
429440
# Some heuristic here:
430441
# if data is a map, assume its keys match the keys in spec
@@ -450,7 +461,11 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None):
450461
f"Unrecognized data type: {type(data)} with keys {self.data_keys}."
451462
)
452463
if self.use_buffer:
453-
self._queue.extend(out.unbind(0))
464+
if not out.ndim:
465+
out = out.unsqueeze(0)
466+
self._queue.extend(
467+
[d for d in out.unbind(0) for _ in range(max(1, self.repeats))]
468+
)
454469
return self._queue.popleft()
455470
return out
456471

torchrl/envs/transforms/transforms.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -7352,7 +7352,9 @@ def _reset(
73527352
else:
73537353
# It may be the case that reset did not provide a done state, in which case
73547354
# we fall back on the spec
7355-
done = self.parent.output_spec["full_done_spec", entry_name].zero()
7355+
done = self.parent.output_spec_unbatched[
7356+
"full_done_spec", entry_name
7357+
].zero(tensordict_reset.shape)
73567358
reset = torch.ones_like(done)
73577359

73587360
step_count = tensordict.get(step_count_key, default=None)
@@ -7362,7 +7364,7 @@ def _reset(
73627364
step_count = step_count.to(reset.device, non_blocking=True)
73637365

73647366
# zero the step count if reset is needed
7365-
step_count = torch.where(~expand_as_right(reset, step_count), step_count, 0)
7367+
step_count = torch.where(~reset, step_count.expand_as(reset), 0)
73667368
tensordict_reset.set(step_count_key, step_count)
73677369
if self.max_steps is not None:
73687370
truncated = step_count >= self.max_steps

0 commit comments

Comments
 (0)