Skip to content

[Feature] DataLoadingPrimer handling of dataloader with batch-size > 0 #2821

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 65 additions & 6 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
TensorDictBase,
)
from tensordict.nn import TensorDictModuleBase
from tensordict.tensorclass import NonTensorStack
from tensordict.utils import _unravel_key_to_tuple
from torch import nn

Expand Down Expand Up @@ -4577,20 +4578,23 @@ def __next__(self):
],
)
@pytest.mark.parametrize("batched", [True, False])
@pytest.mark.parametrize("batch_size", [0, 4])
@pytest.mark.parametrize("device", [None, "cpu"])
def test_llm_env(self, str2str, batched, stack_method, device):
def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
env = LLMEnv(str2str=str2str, device=device)
if str2str:
primer = DataLoadingPrimer(
dataloader=self.DummyDataLoader(),
dataloader=self.DummyDataLoader(batch_size=batch_size),
data_keys=["observation"],
example_data="a string!",
)
else:
if stack_method is None:
stack_method = as_padded_tensor
primer = DataLoadingPrimer(
dataloader=self.DummyTensorDataLoader(padding=True),
dataloader=self.DummyTensorDataLoader(
batch_size=batch_size, padding=True
),
data_keys=["observation"],
data_specs=[Unbounded(shape=(-1,), dtype=torch.int64)],
stack_method=stack_method,
Expand All @@ -4601,6 +4605,7 @@ def test_llm_env(self, str2str, batched, stack_method, device):
if batched:
td = env.reset(TensorDict(batch_size=[3]))
env.check_env_specs(break_when_any_done="both", tensordict=td)
r = env.rollout(10, tensordict=TensorDict(batch_size=[3]))
else:
env.check_env_specs(break_when_any_done="both")

Expand All @@ -4616,18 +4621,23 @@ def test_llm_env(self, str2str, batched, stack_method, device):
)
@pytest.mark.parametrize("batched", [True, False])
@pytest.mark.parametrize("device", [None, "cpu"])
def test_llm_from_dataloader(self, str2str, batched, stack_method, device):
@pytest.mark.parametrize("batch_size", [0, 4])
def test_llm_from_dataloader(
self, str2str, batched, stack_method, device, batch_size
):
if str2str:
kwargs = {
"dataloader": self.DummyDataLoader(),
"dataloader": self.DummyDataLoader(batch_size=batch_size),
"data_keys": ["observation"],
"example_data": "a string!",
}
else:
if stack_method is None:
stack_method = as_padded_tensor
kwargs = {
"dataloader": self.DummyTensorDataLoader(padding=True),
"dataloader": self.DummyTensorDataLoader(
padding=True, batch_size=batch_size
),
"data_keys": ["observation"],
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
"stack_method": stack_method,
Expand All @@ -4640,6 +4650,55 @@ def test_llm_from_dataloader(self, str2str, batched, stack_method, device):
env.check_env_specs(break_when_any_done="both", tensordict=td)
else:
env.check_env_specs(break_when_any_done="both")
if batch_size > 0:

def policy(td):
if str2str:
if not td.shape:
td["action"] = "<nothing>"
else:
td["action"] = NonTensorStack(
*["<nothing>" for _ in range(td.shape[0])]
)
else:
td["action"] = torch.ones(td.shape + (1,), dtype=torch.int64)
return td

if batched:
# Tell the env that we want 3 sub-envs
r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[3]))
assert r.ndim == 2
if str2str:
assert isinstance(r[0, 0]["observation"], str)
assert isinstance(r[0, 1]["observation"], str)
assert (
r[0, 0]["observation"]
== r[0, 1]["observation"][: -len(r[0, 0]["action"])]
)
assert (
r[0, 1]["observation"]
== r[0, 2]["observation"][: -len(r[0, 1]["action"])]
)
assert (
r[-1, 0]["observation"]
== r[-1, 1]["observation"][: -len(r[-1, 0]["action"])]
)
assert (
r[-1, 1]["observation"]
== r[-1, 2]["observation"][: -len(r[-1, 1]["action"])]
)
else:
assert (r[0, 0]["observation"] == r[0, 1]["observation"][:-1]).all()
assert (r[0, 1]["observation"] == r[0, 2]["observation"][:-1]).all()
assert (
r[-1, 0]["observation"] == r[-1, 1]["observation"][:-1]
).all()
assert (
r[-1, 1]["observation"] == r[-1, 2]["observation"][:-1]
).all()
else:
r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[]))
assert r.ndim == 1


if __name__ == "__main__":
Expand Down
19 changes: 15 additions & 4 deletions torchrl/envs/custom/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,14 @@ def from_dataloader(
)
return env.append_transform(primer)

@staticmethod
def _check_obs_act_and_cat(obs, action):
if not isinstance(obs, str):
raise TypeError(f"Observation must be a string, got {type(obs)}.")
if not isinstance(action, str):
raise TypeError(f"Action must be a string, got {type(action)}.")
return obs + action

def _step(
self,
tensordict: TensorDictBase,
Expand All @@ -202,11 +210,14 @@ def _step(
"The tensordict is batchless, yet the action and/or observations are not "
f"strings but {type(action)} and {type(obs)}, respectivly."
)
observation = obs + action
observation = self._check_obs_act_and_cat(obs, action)
else:
observation = [
_obs + _action for (_obs, _action) in _zip_strict(obs, action)
]
observation = NonTensorStack(
*[
self._check_obs_act_and_cat(_obs, _action)
for (_obs, _action) in _zip_strict(obs, action)
]
)
else:
try:
obs: torch.Tensor = tensordict.get(self.observation_key)
Expand Down
49 changes: 46 additions & 3 deletions torchrl/envs/transforms/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from collections import deque
from collections.abc import Mapping
from copy import copy, deepcopy
from typing import Any, Callable, Iterable, Literal
Expand Down Expand Up @@ -87,11 +88,21 @@ class DataLoadingPrimer(TensorDictPrimer):

Args:
dataloader (Iterable[Any]): The dataloader to load data from.

Keyword Args:
primers (Composite | None, optional): The primers to use for each key in the dataloader. Defaults to None.
data_keys (List[NestedKey] | None, optional): The keys to use for each item in the dataloader. Defaults to None.
data_specs (List[TensorSpec] | None, optional): The specs to use for each item in the dataloader. Defaults to None.
example_data (Any, optional): Example data to use for initializing the primer. Defaults to None.
stack_method (Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"], optional): The method to use for stacking the data. Defaults to ``maybe_dense_stack``.
use_buffer (bool, optional): Whether to use a buffer to load the batches. When an environment has a batch-size
that differs from the dataloader's, or when partial resets are to be expected, using a buffer to store data
ensures that `next()` is called on the dataloader only when necessary, and that elements of the dataset
are loaded in order.
Defaults to ``True`` whenever the batch-size of the dataloader is greater than 1.
auto_batch_size (bool, optional): If ``True`` (default if `dataloader.batch_size > 0`), the batch size of the
tensordict returned by the transform will be automatically determined assuming that there is a single batch
dimension.

Attributes:
dataloader (Iterable[Any]): The dataloader to load data from.
Expand Down Expand Up @@ -339,14 +350,25 @@ class DataLoadingPrimer(TensorDictPrimer):
def __init__(
self,
dataloader: Iterable[Any],
*,
primers: Composite | None = None,
data_keys: list[NestedKey] | None = None,
data_specs: list[TensorSpec] | None = None,
example_data: Any = None,
stack_method: Callable[[Any], Any]
| Literal["as_nested_tensor", "as_padded_tensor"] = None,
use_buffer: bool | None = None,
auto_batch_size: bool = True,
):
self.dataloader = dataloader
if getattr(dataloader, "batch_size", 1) > 1 and use_buffer is None:
use_buffer = True

self.use_buffer = use_buffer
# No auto_batch_size if we know we have a single element
self.auto_batch_size = auto_batch_size and (
getattr(dataloader, "dataloader", 1) > 0
)
self.endless_dataloader = self._endless_iter(self.dataloader)
if primers is None:
if data_keys is None:
Expand Down Expand Up @@ -381,34 +403,55 @@ def __init__(
single_default_value=True,
call_before_env_reset=True,
)
if self.use_buffer:
self._queue = deque()

@classmethod
def _endless_iter(self, obj):
while True:
yield from obj

def _load_from_dataloader(self, reset: torch.Tensor | None = None):
"""Loads a single element from the dataloader, or alternatively from the buffer.

If `reset` is passed, the one element per reset will be loaded.
"""
if reset is not None:
if not reset.any():
raise RuntimeError("reset must have at least one True value.")
if reset.ndim > 0:
return self.stack_method(
[self._load_from_dataloader() for i in range(reset.sum())]
)
if self.use_buffer and len(self._queue) > 0:
return self._queue.popleft()
data = next(self.endless_dataloader)
# Some heuristic here:
# if data is a map, assume its keys match the keys in spec
# TODO: one could rename the keys too
if isinstance(data, Mapping):
out = TensorDict(data)
out = TensorDict.from_dict(
data, auto_batch_size=self.auto_batch_size, batch_dims=1
)
elif len(self.data_keys) > 1 and isinstance(data, (list, tuple)):
out = TensorDict({k: val for k, val in _zip_strict(self.data_keys, data)})
out = TensorDict.from_dict(
{k: val for k, val in _zip_strict(self.data_keys, data)},
auto_batch_size=self.auto_batch_size,
batch_dims=1,
)
elif len(self.data_keys) == 1:
out = TensorDict({self.data_keys[0]: data})
out = TensorDict.from_dict(
{self.data_keys[0]: data},
auto_batch_size=self.auto_batch_size,
batch_dims=1,
)
else:
raise ValueError(
f"Unrecognized data type: {type(data)} with keys {self.data_keys}."
)
if self.use_buffer:
self._queue.extend(out.unbind(0))
return self._queue.popleft()
return out


Expand Down
Loading