@@ -4112,17 +4112,21 @@ def test_parallel_partial_steps(
41124112 use_buffers = use_buffers ,
41134113 device = device ,
41144114 )
4115- td = penv .reset ()
4116- psteps = torch .zeros (4 , dtype = torch .bool )
4117- psteps [[1 , 3 ]] = True
4118- td .set ("_step" , psteps )
4119-
4120- td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4121- td = penv .step (td )
4122- assert (td [0 ].get ("next" ) == 0 ).all ()
4123- assert (td [1 ].get ("next" ) != 0 ).any ()
4124- assert (td [2 ].get ("next" ) == 0 ).all ()
4125- assert (td [3 ].get ("next" ) != 0 ).any ()
4115+ try :
4116+ td = penv .reset ()
4117+ psteps = torch .zeros (4 , dtype = torch .bool )
4118+ psteps [[1 , 3 ]] = True
4119+ td .set ("_step" , psteps )
4120+
4121+ td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4122+ td = penv .step (td )
4123+ assert_allclose_td (td [0 ].get ("next" ), td [0 ], intersection = True )
4124+ assert (td [1 ].get ("next" ) != 0 ).any ()
4125+ assert_allclose_td (td [2 ].get ("next" ), td [2 ], intersection = True )
4126+ assert (td [3 ].get ("next" ) != 0 ).any ()
4127+ finally :
4128+ penv .close ()
4129+ del penv
41264130
41274131 @pytest .mark .parametrize ("use_buffers" , [False , True ])
41284132 def test_parallel_partial_step_and_maybe_reset (
@@ -4135,17 +4139,21 @@ def test_parallel_partial_step_and_maybe_reset(
41354139 use_buffers = use_buffers ,
41364140 device = device ,
41374141 )
4138- td = penv .reset ()
4139- psteps = torch .zeros (4 , dtype = torch .bool )
4140- psteps [[1 , 3 ]] = True
4141- td .set ("_step" , psteps )
4142-
4143- td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4144- td , tdreset = penv .step_and_maybe_reset (td )
4145- assert (td [0 ].get ("next" ) == 0 ).all ()
4146- assert (td [1 ].get ("next" ) != 0 ).any ()
4147- assert (td [2 ].get ("next" ) == 0 ).all ()
4148- assert (td [3 ].get ("next" ) != 0 ).any ()
4142+ try :
4143+ td = penv .reset ()
4144+ psteps = torch .zeros (4 , dtype = torch .bool )
4145+ psteps [[1 , 3 ]] = True
4146+ td .set ("_step" , psteps )
4147+
4148+ td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4149+ td , tdreset = penv .step_and_maybe_reset (td )
4150+ assert_allclose_td (td [0 ].get ("next" ), td [0 ], intersection = True )
4151+ assert (td [1 ].get ("next" ) != 0 ).any ()
4152+ assert_allclose_td (td [2 ].get ("next" ), td [2 ], intersection = True )
4153+ assert (td [3 ].get ("next" ) != 0 ).any ()
4154+ finally :
4155+ penv .close ()
4156+ del penv
41494157
41504158 @pytest .mark .parametrize ("use_buffers" , [False , True ])
41514159 def test_serial_partial_steps (self , use_buffers , device , env_device ):
@@ -4156,17 +4164,21 @@ def test_serial_partial_steps(self, use_buffers, device, env_device):
41564164 use_buffers = use_buffers ,
41574165 device = device ,
41584166 )
4159- td = penv .reset ()
4160- psteps = torch .zeros (4 , dtype = torch .bool )
4161- psteps [[1 , 3 ]] = True
4162- td .set ("_step" , psteps )
4163-
4164- td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4165- td = penv .step (td )
4166- assert (td [0 ].get ("next" ) == 0 ).all ()
4167- assert (td [1 ].get ("next" ) != 0 ).any ()
4168- assert (td [2 ].get ("next" ) == 0 ).all ()
4169- assert (td [3 ].get ("next" ) != 0 ).any ()
4167+ try :
4168+ td = penv .reset ()
4169+ psteps = torch .zeros (4 , dtype = torch .bool )
4170+ psteps [[1 , 3 ]] = True
4171+ td .set ("_step" , psteps )
4172+
4173+ td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
4174+ td = penv .step (td )
4175+ assert_allclose_td (td [0 ].get ("next" ), td [0 ], intersection = True )
4176+ assert (td [1 ].get ("next" ) != 0 ).any ()
4177+ assert_allclose_td (td [2 ].get ("next" ), td [2 ], intersection = True )
4178+ assert (td [3 ].get ("next" ) != 0 ).any ()
4179+ finally :
4180+ penv .close ()
4181+ del penv
41704182
41714183 @pytest .mark .parametrize ("use_buffers" , [False , True ])
41724184 def test_serial_partial_step_and_maybe_reset (self , use_buffers , device , env_device ):
@@ -4184,9 +4196,9 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi
41844196
41854197 td .set ("action" , penv .full_action_spec [penv .action_key ].one ())
41864198 td = penv .step (td )
4187- assert (td [0 ].get ("next" ) == 0 ). all ( )
4199+ assert_allclose_td (td [0 ].get ("next" ), td [ 0 ], intersection = True )
41884200 assert (td [1 ].get ("next" ) != 0 ).any ()
4189- assert (td [2 ].get ("next" ) == 0 ). all ( )
4201+ assert_allclose_td (td [2 ].get ("next" ), td [ 2 ], intersection = True )
41904202 assert (td [3 ].get ("next" ) != 0 ).any ()
41914203
41924204
0 commit comments