@@ -8798,29 +8798,37 @@ def __init__(
87988798 def forward (self , tensordict : TensorDictBase ) -> TensorDictBase :
87998799 raise RuntimeError (FORWARD_NOT_IMPLEMENTED .format (type (self )))
88008800
8801+ @property
8802+ def action_spec (self ):
8803+ action_spec = self .container .full_action_spec
8804+ keys = self .container .action_keys
8805+ if len (keys ) == 1 :
8806+ action_spec = action_spec [keys [0 ]]
8807+ else :
8808+ raise ValueError (
8809+ f"Too many action keys for { self .__class__ .__name__ } : { keys = } "
8810+ )
8811+ if not isinstance (action_spec , self .ACCEPTED_SPECS ):
8812+ raise ValueError (
8813+ self .SPEC_TYPE_ERROR .format (self .ACCEPTED_SPECS , type (action_spec ))
8814+ )
8815+ return action_spec
8816+
88018817 def _call (self , next_tensordict : TensorDictBase ) -> TensorDictBase :
88028818 parent = self .parent
88038819 if parent is None :
88048820 raise RuntimeError (
88058821 f"{ type (self )} .parent cannot be None: make sure this transform is executed within an environment."
88068822 )
88078823 mask = next_tensordict .get (self .in_keys [1 ])
8808- action_spec = self .container .action_spec
8809- if not isinstance (action_spec , self .ACCEPTED_SPECS ):
8810- raise ValueError (
8811- self .SPEC_TYPE_ERROR .format (self .ACCEPTED_SPECS , type (action_spec ))
8812- )
8824+ action_spec = self .action_spec
88138825 action_spec .update_mask (mask .to (action_spec .device ))
88148826 return next_tensordict
88158827
88168828 def _reset (
88178829 self , tensordict : TensorDictBase , tensordict_reset : TensorDictBase
88188830 ) -> TensorDictBase :
8819- action_spec = self .container .action_spec
8820- if not isinstance (action_spec , self .ACCEPTED_SPECS ):
8821- raise ValueError (
8822- self .SPEC_TYPE_ERROR .format (self .ACCEPTED_SPECS , type (action_spec ))
8823- )
8831+ action_spec = self .action_spec
88248832 mask = tensordict .get (self .in_keys [1 ], None )
88258833 if mask is not None :
88268834 mask = mask .to (action_spec .device )
0 commit comments