Skip to content

Commit 27c190b

Browse files
committed
[Feature] DataLoadingPrimer.repeat
ghstack-source-id: df7ee5caf2850303068d073e0c7cf09d8941c5d3 Pull Request resolved: #2822
1 parent 2e74593 commit 27c190b

File tree

4 files changed

+129
-9
lines changed

4 files changed

+129
-9
lines changed

test/test_env.py

+98
Original file line numberDiff line numberDiff line change
@@ -4700,6 +4700,104 @@ def policy(td):
47004700
r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[]))
47014701
assert r.ndim == 1
47024702

4703+
@pytest.mark.parametrize(
4704+
"str2str,stack_method",
4705+
[
4706+
[True, None],
4707+
[False, "as_padded_tensor"],
4708+
# TODO: a bit experimental, fails with check_env_specs
4709+
# [False, "as_nested_tensor"],
4710+
[False, None],
4711+
],
4712+
)
4713+
@pytest.mark.parametrize("batched", [True, False])
4714+
@pytest.mark.parametrize("device", [None, "cpu"])
4715+
@pytest.mark.parametrize("batch_size", [0, 4])
4716+
@pytest.mark.parametrize("repeats", [3])
4717+
def test_llm_from_dataloader_repeats(
4718+
self, str2str, batched, stack_method, device, batch_size, repeats
4719+
):
4720+
if str2str:
4721+
kwargs = {
4722+
"dataloader": self.DummyDataLoader(batch_size=batch_size),
4723+
"data_keys": ["observation"],
4724+
"example_data": "a string!",
4725+
"repeats": repeats,
4726+
}
4727+
else:
4728+
if stack_method is None:
4729+
stack_method = as_padded_tensor
4730+
kwargs = {
4731+
"dataloader": self.DummyTensorDataLoader(
4732+
padding=True, batch_size=batch_size
4733+
),
4734+
"data_keys": ["observation"],
4735+
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
4736+
"stack_method": stack_method,
4737+
"repeats": repeats,
4738+
}
4739+
kwargs.update({"str2str": str2str, "device": device})
4740+
env = LLMEnv.from_dataloader(**kwargs)
4741+
assert env.transform.repeats == repeats
4742+
4743+
max_steps = 3
4744+
env.append_transform(StepCounter(max_steps=max_steps))
4745+
4746+
def policy(td):
4747+
if str2str:
4748+
if not td.shape:
4749+
td["action"] = "<nothing>"
4750+
else:
4751+
td["action"] = NonTensorStack(
4752+
*["<nothing>" for _ in range(td.shape[0])]
4753+
)
4754+
else:
4755+
td["action"] = torch.ones(td.shape + (1,), dtype=torch.int64)
4756+
return td
4757+
4758+
if batched:
4759+
r = env.rollout(
4760+
100,
4761+
policy,
4762+
tensordict=TensorDict(batch_size=[3]),
4763+
break_when_any_done=False,
4764+
)
4765+
else:
4766+
r = env.rollout(100, policy, break_when_any_done=False)
4767+
# check that r at reset is always the same
4768+
r_reset = r[..., ::max_steps]
4769+
if not batched:
4770+
if str2str:
4771+
assert r_reset[..., 0]["observation"] == r_reset[..., 1]["observation"]
4772+
assert r_reset[..., 0]["observation"] == r_reset[..., 2]["observation"]
4773+
assert r_reset[..., 0]["observation"] != r_reset[..., 3]["observation"]
4774+
else:
4775+
assert (
4776+
r_reset[..., 0]["observation"] == r_reset[..., 1]["observation"]
4777+
).all()
4778+
assert (
4779+
r_reset[..., 0]["observation"] == r_reset[..., 2]["observation"]
4780+
).all()
4781+
assert (
4782+
r_reset[..., 0]["observation"] != r_reset[..., 3]["observation"]
4783+
).any()
4784+
else:
4785+
# When batched, each block contains the 3 reset packs
4786+
if str2str:
4787+
assert r_reset[0, 0]["observation"] == r_reset[1, 0]["observation"]
4788+
assert r_reset[0, 0]["observation"] == r_reset[2, 0]["observation"]
4789+
assert r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"]
4790+
else:
4791+
assert (
4792+
r_reset[0, 0]["observation"] == r_reset[1, 0]["observation"]
4793+
).all()
4794+
assert (
4795+
r_reset[0, 0]["observation"] == r_reset[2, 0]["observation"]
4796+
).all()
4797+
assert (
4798+
r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"]
4799+
).any()
4800+
47034801

47044802
if __name__ == "__main__":
47054803
args, unknown = argparse.ArgumentParser().parse_known_args()

torchrl/envs/custom/llm.py

+5
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def from_dataloader(
142142
example_data: Any = None,
143143
stack_method: Callable[[Any], Any]
144144
| Literal["as_nested_tensor", "as_padded_tensor"] = None,
145+
repeats: int | None = None,
145146
) -> LLMEnv:
146147
"""Creates an LLMEnv instance from a dataloader.
147148
@@ -165,6 +166,9 @@ def from_dataloader(
165166
example_data (Any, optional): Example data to use for initializing the primer. Defaults to ``None``.
166167
stack_method (Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"], optional): The
167168
method to use for stacking the data. Defaults to ``None``.
169+
repeats (int, optional): How many times the same sample needs to appear successively. This can be useful in
170+
situations like GRPO where a single prompt is used multiple times to estimate the advantage using Monte-Carlo
171+
samples (rather than an advantage module).
168172
169173
Returns:
170174
LLMEnv: The created LLMEnv instance.
@@ -178,6 +182,7 @@ def from_dataloader(
178182
data_specs=data_specs,
179183
example_data=example_data,
180184
stack_method=stack_method,
185+
repeats=repeats,
181186
)
182187
env = LLMEnv(
183188
str2str=str2str,

torchrl/envs/transforms/rlhf.py

+22-7
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ class DataLoadingPrimer(TensorDictPrimer):
103103
auto_batch_size (bool, optional): If ``True`` (default if `dataloader.batch_size > 0`), the batch size of the
104104
tensordict returned by the transform will be automatically determined assuming that there is a single batch
105105
dimension.
106+
repeats (int, optional): How many times the same sample needs to appear successively. This can be useful in
107+
situations like GRPO where a single prompt is used multiple times to estimate the advantage using Monte-Carlo
108+
samples (rather than an advantage module).
106109
107110
Attributes:
108111
dataloader (Iterable[Any]): The dataloader to load data from.
@@ -359,15 +362,21 @@ def __init__(
359362
| Literal["as_nested_tensor", "as_padded_tensor"] = None,
360363
use_buffer: bool | None = None,
361364
auto_batch_size: bool = True,
365+
repeats: int | None = None,
362366
):
363367
self.dataloader = dataloader
364-
if getattr(dataloader, "batch_size", 1) > 1 and use_buffer is None:
368+
if repeats is None:
369+
repeats = 0
370+
self.repeats = repeats
371+
if (
372+
getattr(dataloader, "batch_size", 1) > 1 and use_buffer is None
373+
) or repeats > 0:
365374
use_buffer = True
366375

367376
self.use_buffer = use_buffer
368377
# No auto_batch_size if we know we have a single element
369378
self.auto_batch_size = auto_batch_size and (
370-
getattr(dataloader, "dataloader", 1) > 0
379+
getattr(dataloader, "batch_size", 1) > 0
371380
)
372381
self.endless_dataloader = self._endless_iter(self.dataloader)
373382
if primers is None:
@@ -420,11 +429,13 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None):
420429
if not reset.any():
421430
raise RuntimeError("reset must have at least one True value.")
422431
if reset.ndim > 0:
423-
return self.stack_method(
424-
[self._load_from_dataloader() for i in range(reset.sum())]
425-
)
432+
loaded = [self._load_from_dataloader() for i in range(reset.sum())]
433+
return self.stack_method(loaded)
434+
426435
if self.use_buffer and len(self._queue) > 0:
427-
return self._queue.popleft()
436+
result = self._queue.popleft()
437+
return result
438+
428439
data = next(self.endless_dataloader)
429440
# Some heuristic here:
430441
# if data is a map, assume its keys match the keys in spec
@@ -450,7 +461,11 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None):
450461
f"Unrecognized data type: {type(data)} with keys {self.data_keys}."
451462
)
452463
if self.use_buffer:
453-
self._queue.extend(out.unbind(0))
464+
if not out.ndim:
465+
out = out.unsqueeze(0)
466+
self._queue.extend(
467+
[d for d in out.unbind(0) for _ in range(max(1, self.repeats))]
468+
)
454469
return self._queue.popleft()
455470
return out
456471

torchrl/envs/transforms/transforms.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -7352,7 +7352,9 @@ def _reset(
73527352
else:
73537353
# It may be the case that reset did not provide a done state, in which case
73547354
# we fall back on the spec
7355-
done = self.parent.output_spec["full_done_spec", entry_name].zero()
7355+
done = self.parent.output_spec_unbatched[
7356+
"full_done_spec", entry_name
7357+
].zero(tensordict_reset.shape)
73567358
reset = torch.ones_like(done)
73577359

73587360
step_count = tensordict.get(step_count_key, default=None)
@@ -7362,7 +7364,7 @@ def _reset(
73627364
step_count = step_count.to(reset.device, non_blocking=True)
73637365

73647366
# zero the step count if reset is needed
7365-
step_count = torch.where(~expand_as_right(reset, step_count), step_count, 0)
7367+
step_count = torch.where(~reset, step_count.expand_as(reset), 0)
73667368
tensordict_reset.set(step_count_key, step_count)
73677369
if self.max_steps is not None:
73687370
truncated = step_count >= self.max_steps

0 commit comments

Comments
 (0)