@@ -8798,29 +8798,37 @@ def __init__(
8798
8798
def forward (self , tensordict : TensorDictBase ) -> TensorDictBase :
8799
8799
raise RuntimeError (FORWARD_NOT_IMPLEMENTED .format (type (self )))
8800
8800
8801
+ @property
8802
+ def action_spec (self ):
8803
+ action_spec = self .container .full_action_spec
8804
+ keys = self .container .action_keys
8805
+ if len (keys ) == 1 :
8806
+ action_spec = action_spec [keys [0 ]]
8807
+ else :
8808
+ raise ValueError (
8809
+ f"Too many action keys for { self .__class__ .__name__ } : { keys = } "
8810
+ )
8811
+ if not isinstance (action_spec , self .ACCEPTED_SPECS ):
8812
+ raise ValueError (
8813
+ self .SPEC_TYPE_ERROR .format (self .ACCEPTED_SPECS , type (action_spec ))
8814
+ )
8815
+ return action_spec
8816
+
8801
8817
def _call (self , next_tensordict : TensorDictBase ) -> TensorDictBase :
8802
8818
parent = self .parent
8803
8819
if parent is None :
8804
8820
raise RuntimeError (
8805
8821
f"{ type (self )} .parent cannot be None: make sure this transform is executed within an environment."
8806
8822
)
8807
8823
mask = next_tensordict .get (self .in_keys [1 ])
8808
- action_spec = self .container .action_spec
8809
- if not isinstance (action_spec , self .ACCEPTED_SPECS ):
8810
- raise ValueError (
8811
- self .SPEC_TYPE_ERROR .format (self .ACCEPTED_SPECS , type (action_spec ))
8812
- )
8824
+ action_spec = self .action_spec
8813
8825
action_spec .update_mask (mask .to (action_spec .device ))
8814
8826
return next_tensordict
8815
8827
8816
8828
def _reset (
8817
8829
self , tensordict : TensorDictBase , tensordict_reset : TensorDictBase
8818
8830
) -> TensorDictBase :
8819
- action_spec = self .container .action_spec
8820
- if not isinstance (action_spec , self .ACCEPTED_SPECS ):
8821
- raise ValueError (
8822
- self .SPEC_TYPE_ERROR .format (self .ACCEPTED_SPECS , type (action_spec ))
8823
- )
8831
+ action_spec = self .action_spec
8824
8832
mask = tensordict .get (self .in_keys [1 ], None )
8825
8833
if mask is not None :
8826
8834
mask = mask .to (action_spec .device )
0 commit comments