@@ -5674,14 +5674,20 @@ class TensorDictPrimer(Transform):
5674
5674
Defaults to `False`.
5675
5675
default_value (float, Callable, Dict[NestedKey, float], Dict[NestedKey, Callable], optional): If non-random
5676
5676
filling is chosen, `default_value` will be used to populate the tensors. If `default_value` is a float,
5677
- all elements of the tensors will be set to that value. If it is a callable, this callable is expected to
5678
- return a tensor fitting the specs, and it will be used to generate the tensors. Finally, if `default_value`
5679
- is a dictionary of tensors or a dictionary of callables with keys matching those of the specs, these will
5680
- be used to generate the corresponding tensors. Defaults to `0.0`.
5677
+ all elements of the tensors will be set to that value.
5678
+ If it is a callable and `single_default_value=False` (default), this callable is expected to return a tensor
5679
+ fitting the specs (ie, ``default_value()`` will be called independently for each leaf spec). If it is a
5680
+ callable and ``single_default_value=True``, then the callable will be called just once and it is expected
5681
+ that the structure of its returned TensorDict instance or equivalent will match the provided specs.
5682
+ Finally, if `default_value` is a dictionary of tensors or a dictionary of callables with keys matching
5683
+ those of the specs, these will be used to generate the corresponding tensors. Defaults to `0.0`.
5681
5684
reset_key (NestedKey, optional): the reset key to be used as partial
5682
5685
reset indicator. Must be unique. If not provided, defaults to the
5683
5686
only reset key of the parent environment (if it has only one)
5684
5687
and raises an exception otherwise.
5688
+ single_default_value (bool, optional): if ``True`` and `default_value` is a callable, it will be expected that
5689
+ ``default_value`` returns a single tensordict matching the specs. If `False`, `default_value()` will be
5690
+ called independently for each leaf. Defaults to ``False``.
5685
5691
**kwargs: each keyword argument corresponds to a key in the tensordict.
5686
5692
The corresponding value has to be a TensorSpec instance indicating
5687
5693
what the value must be.
@@ -5781,6 +5787,7 @@ def __init__(
5781
5787
| Dict [NestedKey , Callable ] = None ,
5782
5788
reset_key : NestedKey | None = None ,
5783
5789
expand_specs : bool = None ,
5790
+ single_default_value : bool = False ,
5784
5791
** kwargs ,
5785
5792
):
5786
5793
self .device = kwargs .pop ("device" , None )
@@ -5821,10 +5828,13 @@ def __init__(
5821
5828
raise ValueError (
5822
5829
"If a default_value dictionary is provided, it must match the primers keys."
5823
5830
)
5831
+ elif single_default_value :
5832
+ pass
5824
5833
else :
5825
5834
default_value = {
5826
5835
key : default_value for key in self .primers .keys (True , True )
5827
5836
}
5837
+ self .single_default_value = single_default_value
5828
5838
self .default_value = default_value
5829
5839
self ._validated = False
5830
5840
self .reset_key = reset_key
@@ -5937,6 +5947,14 @@ def _validate_value_tensor(self, value, spec):
5937
5947
return True
5938
5948
5939
5949
def forward (self , tensordict : TensorDictBase ) -> TensorDictBase :
5950
+ if self .single_default_value and callable (self .default_value ):
5951
+ tensordict .update (self .default_value ())
5952
+ for key , spec in self .primers .items (True , True ):
5953
+ if not self ._validated :
5954
+ self ._validate_value_tensor (tensordict .get (key ), spec )
5955
+ if not self ._validated :
5956
+ self ._validated = True
5957
+ return tensordict
5940
5958
for key , spec in self .primers .items (True , True ):
5941
5959
if spec .shape [: len (tensordict .shape )] != tensordict .shape :
5942
5960
raise RuntimeError (
@@ -5991,6 +6009,14 @@ def _reset(
5991
6009
):
5992
6010
self .primers = self ._expand_shape (self .primers )
5993
6011
if _reset .any ():
6012
+ if self .single_default_value and callable (self .default_value ):
6013
+ tensordict_reset .update (self .default_value ())
6014
+ for key , spec in self .primers .items (True , True ):
6015
+ if not self ._validated :
6016
+ self ._validate_value_tensor (tensordict_reset .get (key ), spec )
6017
+ self ._validated = True
6018
+ return tensordict_reset
6019
+
5994
6020
for key , spec in self .primers .items (True , True ):
5995
6021
if self .random :
5996
6022
shape = (
0 commit comments