1212import os
1313import urllib .error
1414
15+ import torchrl .testing .env_helper
1516
1617_has_isaac = importlib .util .find_spec ("isaacgym" ) is not None
1718
5051
5152from torchrl ._utils import implement_for , logger as torchrl_logger
5253from torchrl .collectors import SyncDataCollector
54+ from torchrl .collectors .distributed import RayCollector
5355from torchrl .data import (
5456 Binary ,
5557 Bounded ,
5658 Categorical ,
5759 Composite ,
60+ LazyMemmapStorage ,
5861 MultiCategorical ,
5962 MultiOneHot ,
6063 NonTensor ,
134137 ValueOperator ,
135138)
136139
140+ _has_ray = importlib .util .find_spec ("ray" ) is not None
137141if os .getenv ("PYTORCH_TEST_FBCODE" ):
138142 from pytorch .rl .test ._utils_internal import (
139143 _make_multithreaded_env ,
@@ -5129,28 +5133,8 @@ def test_render(self, rollout_steps):
51295133class TestIsaacLab :
51305134 @pytest .fixture (scope = "class" )
51315135 def env (self ):
5132- torch .manual_seed (0 )
5133- import argparse
5134-
5135- # This code block ensures that the Isaac app is started in headless mode
5136- from isaaclab .app import AppLauncher
5137-
5138- parser = argparse .ArgumentParser (description = "Train an RL agent with TorchRL." )
5139- AppLauncher .add_app_launcher_args (parser )
5140- args_cli , hydra_args = parser .parse_known_args (["--headless" ])
5141- AppLauncher (args_cli )
5142-
5143- # Imports and env
5144- import gymnasium as gym
5145- import isaaclab_tasks # noqa: F401
5146- from isaaclab_tasks .manager_based .classic .ant .ant_env_cfg import AntEnvCfg
5147- from torchrl .envs .libs .isaac_lab import IsaacLabWrapper
5148-
5149- torchrl_logger .info ("Making IsaacLab env..." )
5150- env = gym .make ("Isaac-Ant-v0" , cfg = AntEnvCfg ())
5151- torchrl_logger .info ("Wrapping IsaacLab env..." )
5136+ env = torchrl .testing .env_helper .make_isaac_env ()
51525137 try :
5153- env = IsaacLabWrapper (env )
51545138 yield env
51555139 finally :
51565140 torchrl_logger .info ("Closing IsaacLab env..." )
@@ -5167,11 +5151,17 @@ def test_isaaclab(self, env):
51675151 def test_isaaclab_rb (self , env ):
51685152 env = env .append_transform (StepCounter ())
51695153 rb = ReplayBuffer (
5170- storage = LazyTensorStorage (50 , ndim = 2 ), sampler = SliceSampler (num_slices = 5 )
5154+ storage = LazyTensorStorage (100_000 , ndim = 2 ),
5155+ sampler = SliceSampler (num_slices = 5 ),
5156+ batch_size = 20 ,
51715157 )
5172- rb .extend (env .rollout (20 ))
5158+ r = env .rollout (20 , break_when_any_done = False )
5159+ rb .extend (r )
51735160 # check that rb["step_count"].flatten() is made of sequences of 4 consecutive numbers
5174- flat_ranges = rb ["step_count" ].flatten () % 4
5161+ flat_ranges = rb .sample ()["step_count" ]
5162+ flat_ranges = flat_ranges .view (- 1 , 4 )
5163+ flat_ranges = flat_ranges - flat_ranges [:, :1 ] # substract baseline
5164+ flat_ranges = flat_ranges .flatten ()
51755165 arange = torch .arange (flat_ranges .numel (), device = flat_ranges .device ) % 4
51765166 assert (flat_ranges == arange ).all ()
51775167
@@ -5187,6 +5177,138 @@ def test_isaac_collector(self, env):
51875177 # We must do that, otherwise `__del__` calls `shutdown` and the next test will fail
51885178 col .shutdown (close_env = False )
51895179
5180+ @pytest .fixture (scope = "function" )
5181+ def clean_ray (self ):
5182+ import ray
5183+
5184+ ray .shutdown ()
5185+ ray .init (ignore_reinit_error = True )
5186+ yield
5187+ ray .shutdown ()
5188+
5189+ @pytest .mark .skipif (not _has_ray , reason = "Ray not found" )
5190+ @pytest .mark .parametrize ("use_rb" , [False , True ], ids = ["rb_false" , "rb_true" ])
5191+ @pytest .mark .parametrize ("num_collectors" , [1 , 4 ], ids = ["1_col" , "4_col" ])
5192+ def test_isaaclab_ray_collector (self , env , use_rb , clean_ray , num_collectors ):
5193+ from torchrl .data import RayReplayBuffer
5194+
5195+ # Create replay buffer if requested
5196+ replay_buffer = None
5197+ if use_rb :
5198+ replay_buffer = RayReplayBuffer (
5199+ # We place the storage on memmap to make it shareable
5200+ storage = partial (LazyMemmapStorage , 10_000 , ndim = 2 ),
5201+ ray_init_config = {"num_cpus" : 4 },
5202+ )
5203+
5204+ col = RayCollector (
5205+ [torchrl .testing .env_helper .make_isaac_env ] * num_collectors ,
5206+ env .full_action_spec .rand_update ,
5207+ frames_per_batch = 8192 ,
5208+ total_frames = 65536 ,
5209+ replay_buffer = replay_buffer ,
5210+ num_collectors = num_collectors ,
5211+ collector_kwargs = {
5212+ "trust_policy" : True ,
5213+ "no_cuda_sync" : True ,
5214+ "extend_buffer" : True ,
5215+ },
5216+ )
5217+
5218+ try :
5219+ if use_rb :
5220+ # When replay buffer is provided, collector yields None and populates buffer
5221+ for i , data in enumerate (col ):
5222+ # Data is None when using replay buffer
5223+ assert data is None , "Expected None when using replay buffer"
5224+
5225+ # Check replay buffer is being populated
5226+ if i >= 0 :
5227+ # Wait for buffer to have enough data to sample
5228+ if len (replay_buffer ) >= 32 :
5229+ sample = replay_buffer .sample (32 )
5230+ assert sample .batch_size == (32 ,)
5231+ # Check that we have meaningful data (not all zeros/nans)
5232+ assert sample ["policy" ].isfinite ().any ()
5233+ assert sample ["action" ].isfinite ().any ()
5234+ # Check shape is correct for Isaac Lab env (should have batch dim from env)
5235+ assert len (sample .shape ) == 1
5236+
5237+ # Only collect a few batches for the test
5238+ if i >= 2 :
5239+ break
5240+
5241+ # Verify replay buffer has data
5242+ assert len (replay_buffer ) > 0 , "Replay buffer should not be empty"
5243+ # Test that we can sample multiple times
5244+ for _ in range (5 ):
5245+ sample = replay_buffer .sample (16 )
5246+ assert sample .batch_size == (16 ,)
5247+ assert sample ["policy" ].isfinite ().any ()
5248+
5249+ else :
5250+ # Without replay buffer, collector yields data normally
5251+ collected_frames = 0
5252+ for i , data in enumerate (col ):
5253+ assert (
5254+ data is not None
5255+ ), "Expected data when not using replay buffer"
5256+ # Check the data shape matches the batch size
5257+ assert (
5258+ data .numel () >= 1000
5259+ ), f"Expected at least 1000 frames, got { data .numel ()} "
5260+ collected_frames += data .numel ()
5261+
5262+ # Only collect a few batches for the test
5263+ if i >= 2 :
5264+ break
5265+
5266+ # Verify we collected some data
5267+ assert collected_frames > 0 , "No frames were collected"
5268+
5269+ finally :
5270+ # Clean shutdown
5271+ col .shutdown ()
5272+ if use_rb :
5273+ replay_buffer .close ()
5274+
5275+ @pytest .mark .skipif (not _has_ray , reason = "Ray not found" )
5276+ @pytest .mark .parametrize ("num_collectors" , [1 , 4 ], ids = ["1_col" , "4_col" ])
5277+ def test_isaaclab_ray_collector_start (self , env , clean_ray , num_collectors ):
5278+
5279+ from torchrl .data import LazyTensorStorage , RayReplayBuffer
5280+
5281+ rb = RayReplayBuffer (
5282+ storage = partial (LazyTensorStorage , 100_000 , ndim = 2 ),
5283+ ray_init_config = {"num_cpus" : 4 },
5284+ )
5285+ col = RayCollector (
5286+ [torchrl .testing .env_helper .make_isaac_env ] * num_collectors ,
5287+ env .full_action_spec .rand_update ,
5288+ frames_per_batch = 8192 ,
5289+ total_frames = 65536 ,
5290+ trust_policy = True ,
5291+ replay_buffer = rb ,
5292+ num_collectors = num_collectors ,
5293+ )
5294+ col .start ()
5295+ try :
5296+ time_waiting = 0
5297+ while time_waiting < 30 :
5298+ if len (rb ) >= 4096 :
5299+ break
5300+ time .sleep (0.1 )
5301+ time_waiting += 0.1
5302+ else :
5303+ raise RuntimeError ("Timeout waiting for data" )
5304+ sample = rb .sample (4096 )
5305+ assert sample .batch_size == (4096 ,)
5306+ assert sample ["policy" ].isfinite ().any ()
5307+ assert sample ["action" ].isfinite ().any ()
5308+ finally :
5309+ col .shutdown ()
5310+ rb .close ()
5311+
51905312 def test_isaaclab_reset (self , env ):
51915313 # Make a rollout that will stop as soon as a trajectory reaches a done state
51925314 r = env .rollout (1_000_000 )
0 commit comments