Skip to content

Commit 68999b9

Browse files
committed
[Feature] TensorDictPrimer with single default_value callable
ghstack-source-id: b9f7df7bf2abd312dc8de56cac757c4b2975c62c Pull Request resolved: #2732
1 parent 20a19fe commit 68999b9

File tree

2 files changed

+43
-8
lines changed

2 files changed

+43
-8
lines changed

torchrl/envs/custom/pendulum.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -269,11 +269,20 @@ def _reset(self, tensordict):
269269
batch_size = (
270270
tensordict.batch_size if tensordict is not None else self.batch_size
271271
)
272-
if tensordict is None or tensordict.is_empty():
272+
if tensordict is None or "params" not in tensordict:
273273
# if no ``tensordict`` is passed, we generate a single set of hyperparameters
274274
# Otherwise, we assume that the input ``tensordict`` contains all the relevant
275275
# parameters to get started.
276276
tensordict = self.gen_params(batch_size=batch_size, device=self.device)
277+
elif "th" in tensordict and "thdot" in tensordict:
278+
# we can hard-reset the env too
279+
return tensordict
280+
out = self._reset_random_data(
281+
tensordict.shape, batch_size, tensordict["params"]
282+
)
283+
return out
284+
285+
def _reset_random_data(self, shape, batch_size, params):
277286

278287
high_th = torch.tensor(self.DEFAULT_X, device=self.device)
279288
high_thdot = torch.tensor(self.DEFAULT_Y, device=self.device)
@@ -284,20 +293,20 @@ def _reset(self, tensordict):
284293
# of simulators run simultaneously. In other contexts, the initial
285294
# random state's shape will depend upon the environment batch-size instead.
286295
th = (
287-
torch.rand(tensordict.shape, generator=self.rng, device=self.device)
296+
torch.rand(shape, generator=self.rng, device=self.device)
288297
* (high_th - low_th)
289298
+ low_th
290299
)
291300
thdot = (
292-
torch.rand(tensordict.shape, generator=self.rng, device=self.device)
301+
torch.rand(shape, generator=self.rng, device=self.device)
293302
* (high_thdot - low_thdot)
294303
+ low_thdot
295304
)
296305
out = TensorDict(
297306
{
298307
"th": th,
299308
"thdot": thdot,
300-
"params": tensordict["params"],
309+
"params": params,
301310
},
302311
batch_size=batch_size,
303312
)

torchrl/envs/transforms/transforms.py

+30-4
Original file line numberDiff line numberDiff line change
@@ -5674,14 +5674,20 @@ class TensorDictPrimer(Transform):
56745674
Defaults to `False`.
56755675
default_value (float, Callable, Dict[NestedKey, float], Dict[NestedKey, Callable], optional): If non-random
56765676
filling is chosen, `default_value` will be used to populate the tensors. If `default_value` is a float,
5677-
all elements of the tensors will be set to that value. If it is a callable, this callable is expected to
5678-
return a tensor fitting the specs, and it will be used to generate the tensors. Finally, if `default_value`
5679-
is a dictionary of tensors or a dictionary of callables with keys matching those of the specs, these will
5680-
be used to generate the corresponding tensors. Defaults to `0.0`.
5677+
all elements of the tensors will be set to that value.
5678+
If it is a callable and `single_default_value=False` (default), this callable is expected to return a tensor
5679+
fitting the specs (ie, ``default_value()`` will be called independently for each leaf spec). If it is a
5680+
callable and ``single_default_value=True``, then the callable will be called just once and it is expected
5681+
that the structure of its returned TensorDict instance or equivalent will match the provided specs.
5682+
Finally, if `default_value` is a dictionary of tensors or a dictionary of callables with keys matching
5683+
those of the specs, these will be used to generate the corresponding tensors. Defaults to `0.0`.
56815684
reset_key (NestedKey, optional): the reset key to be used as partial
56825685
reset indicator. Must be unique. If not provided, defaults to the
56835686
only reset key of the parent environment (if it has only one)
56845687
and raises an exception otherwise.
5688+
single_default_value (bool, optional): if ``True`` and `default_value` is a callable, it will be expected that
5689+
``default_value`` returns a single tensordict matching the specs. If `False`, `default_value()` will be
5690+
called independently for each leaf. Defaults to ``False``.
56855691
**kwargs: each keyword argument corresponds to a key in the tensordict.
56865692
The corresponding value has to be a TensorSpec instance indicating
56875693
what the value must be.
@@ -5781,6 +5787,7 @@ def __init__(
57815787
| Dict[NestedKey, Callable] = None,
57825788
reset_key: NestedKey | None = None,
57835789
expand_specs: bool = None,
5790+
single_default_value: bool = False,
57845791
**kwargs,
57855792
):
57865793
self.device = kwargs.pop("device", None)
@@ -5821,10 +5828,13 @@ def __init__(
58215828
raise ValueError(
58225829
"If a default_value dictionary is provided, it must match the primers keys."
58235830
)
5831+
elif single_default_value:
5832+
pass
58245833
else:
58255834
default_value = {
58265835
key: default_value for key in self.primers.keys(True, True)
58275836
}
5837+
self.single_default_value = single_default_value
58285838
self.default_value = default_value
58295839
self._validated = False
58305840
self.reset_key = reset_key
@@ -5937,6 +5947,14 @@ def _validate_value_tensor(self, value, spec):
59375947
return True
59385948

59395949
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
5950+
if self.single_default_value and callable(self.default_value):
5951+
tensordict.update(self.default_value())
5952+
for key, spec in self.primers.items(True, True):
5953+
if not self._validated:
5954+
self._validate_value_tensor(tensordict.get(key), spec)
5955+
if not self._validated:
5956+
self._validated = True
5957+
return tensordict
59405958
for key, spec in self.primers.items(True, True):
59415959
if spec.shape[: len(tensordict.shape)] != tensordict.shape:
59425960
raise RuntimeError(
@@ -5991,6 +6009,14 @@ def _reset(
59916009
):
59926010
self.primers = self._expand_shape(self.primers)
59936011
if _reset.any():
6012+
if self.single_default_value and callable(self.default_value):
6013+
tensordict_reset.update(self.default_value())
6014+
for key, spec in self.primers.items(True, True):
6015+
if not self._validated:
6016+
self._validate_value_tensor(tensordict_reset.get(key), spec)
6017+
self._validated = True
6018+
return tensordict_reset
6019+
59946020
for key, spec in self.primers.items(True, True):
59956021
if self.random:
59966022
shape = (

0 commit comments

Comments
 (0)