Skip to content

Commit a0d650f

Browse files
committed
[Test] Test RB+Isaac+Ray
ghstack-source-id: ce70c80 Pull-Request: #3228
1 parent b599d9b commit a0d650f

File tree

8 files changed

+363
-70
lines changed

8 files changed

+363
-70
lines changed

examples/distributed/collectors/multi_nodes/ray_buffer_infra.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ def env_maker():
7575
# break at some point
7676
break
7777

78-
await distributed_collector.async_shutdown()
78+
await distributed_collector.async_shutdown(shutdown_ray=False)
79+
buffer.close()
7980

8081

8182
if __name__ == "__main__":

test/test_libs.py

Lines changed: 146 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import os
1313
import urllib.error
1414

15+
import torchrl.testing.env_helper
1516

1617
_has_isaac = importlib.util.find_spec("isaacgym") is not None
1718

@@ -50,11 +51,13 @@
5051

5152
from torchrl._utils import implement_for, logger as torchrl_logger
5253
from torchrl.collectors import SyncDataCollector
54+
from torchrl.collectors.distributed import RayCollector
5355
from torchrl.data import (
5456
Binary,
5557
Bounded,
5658
Categorical,
5759
Composite,
60+
LazyMemmapStorage,
5861
MultiCategorical,
5962
MultiOneHot,
6063
NonTensor,
@@ -134,6 +137,7 @@
134137
ValueOperator,
135138
)
136139

140+
_has_ray = importlib.util.find_spec("ray") is not None
137141
if os.getenv("PYTORCH_TEST_FBCODE"):
138142
from pytorch.rl.test._utils_internal import (
139143
_make_multithreaded_env,
@@ -5129,28 +5133,8 @@ def test_render(self, rollout_steps):
51295133
class TestIsaacLab:
51305134
@pytest.fixture(scope="class")
51315135
def env(self):
5132-
torch.manual_seed(0)
5133-
import argparse
5134-
5135-
# This code block ensures that the Isaac app is started in headless mode
5136-
from isaaclab.app import AppLauncher
5137-
5138-
parser = argparse.ArgumentParser(description="Train an RL agent with TorchRL.")
5139-
AppLauncher.add_app_launcher_args(parser)
5140-
args_cli, hydra_args = parser.parse_known_args(["--headless"])
5141-
AppLauncher(args_cli)
5142-
5143-
# Imports and env
5144-
import gymnasium as gym
5145-
import isaaclab_tasks # noqa: F401
5146-
from isaaclab_tasks.manager_based.classic.ant.ant_env_cfg import AntEnvCfg
5147-
from torchrl.envs.libs.isaac_lab import IsaacLabWrapper
5148-
5149-
torchrl_logger.info("Making IsaacLab env...")
5150-
env = gym.make("Isaac-Ant-v0", cfg=AntEnvCfg())
5151-
torchrl_logger.info("Wrapping IsaacLab env...")
5136+
env = torchrl.testing.env_helper.make_isaac_env()
51525137
try:
5153-
env = IsaacLabWrapper(env)
51545138
yield env
51555139
finally:
51565140
torchrl_logger.info("Closing IsaacLab env...")
@@ -5167,11 +5151,17 @@ def test_isaaclab(self, env):
51675151
def test_isaaclab_rb(self, env):
51685152
env = env.append_transform(StepCounter())
51695153
rb = ReplayBuffer(
5170-
storage=LazyTensorStorage(50, ndim=2), sampler=SliceSampler(num_slices=5)
5154+
storage=LazyTensorStorage(100_000, ndim=2),
5155+
sampler=SliceSampler(num_slices=5),
5156+
batch_size=20,
51715157
)
5172-
rb.extend(env.rollout(20))
5158+
r = env.rollout(20, break_when_any_done=False)
5159+
rb.extend(r)
51735160
# check that rb["step_count"].flatten() is made of sequences of 4 consecutive numbers
5174-
flat_ranges = rb["step_count"].flatten() % 4
5161+
flat_ranges = rb.sample()["step_count"]
5162+
flat_ranges = flat_ranges.view(-1, 4)
5163+
flat_ranges = flat_ranges - flat_ranges[:, :1] # substract baseline
5164+
flat_ranges = flat_ranges.flatten()
51755165
arange = torch.arange(flat_ranges.numel(), device=flat_ranges.device) % 4
51765166
assert (flat_ranges == arange).all()
51775167

@@ -5187,6 +5177,138 @@ def test_isaac_collector(self, env):
51875177
# We must do that, otherwise `__del__` calls `shutdown` and the next test will fail
51885178
col.shutdown(close_env=False)
51895179

5180+
@pytest.fixture(scope="function")
5181+
def clean_ray(self):
5182+
import ray
5183+
5184+
ray.shutdown()
5185+
ray.init(ignore_reinit_error=True)
5186+
yield
5187+
ray.shutdown()
5188+
5189+
@pytest.mark.skipif(not _has_ray, reason="Ray not found")
5190+
@pytest.mark.parametrize("use_rb", [False, True], ids=["rb_false", "rb_true"])
5191+
@pytest.mark.parametrize("num_collectors", [1, 4], ids=["1_col", "4_col"])
5192+
def test_isaaclab_ray_collector(self, env, use_rb, clean_ray, num_collectors):
5193+
from torchrl.data import RayReplayBuffer
5194+
5195+
# Create replay buffer if requested
5196+
replay_buffer = None
5197+
if use_rb:
5198+
replay_buffer = RayReplayBuffer(
5199+
# We place the storage on memmap to make it shareable
5200+
storage=partial(LazyMemmapStorage, 10_000, ndim=2),
5201+
ray_init_config={"num_cpus": 4},
5202+
)
5203+
5204+
col = RayCollector(
5205+
[torchrl.testing.env_helper.make_isaac_env] * num_collectors,
5206+
env.full_action_spec.rand_update,
5207+
frames_per_batch=8192,
5208+
total_frames=65536,
5209+
replay_buffer=replay_buffer,
5210+
num_collectors=num_collectors,
5211+
collector_kwargs={
5212+
"trust_policy": True,
5213+
"no_cuda_sync": True,
5214+
"extend_buffer": True,
5215+
},
5216+
)
5217+
5218+
try:
5219+
if use_rb:
5220+
# When replay buffer is provided, collector yields None and populates buffer
5221+
for i, data in enumerate(col):
5222+
# Data is None when using replay buffer
5223+
assert data is None, "Expected None when using replay buffer"
5224+
5225+
# Check replay buffer is being populated
5226+
if i >= 0:
5227+
# Wait for buffer to have enough data to sample
5228+
if len(replay_buffer) >= 32:
5229+
sample = replay_buffer.sample(32)
5230+
assert sample.batch_size == (32,)
5231+
# Check that we have meaningful data (not all zeros/nans)
5232+
assert sample["policy"].isfinite().any()
5233+
assert sample["action"].isfinite().any()
5234+
# Check shape is correct for Isaac Lab env (should have batch dim from env)
5235+
assert len(sample.shape) == 1
5236+
5237+
# Only collect a few batches for the test
5238+
if i >= 2:
5239+
break
5240+
5241+
# Verify replay buffer has data
5242+
assert len(replay_buffer) > 0, "Replay buffer should not be empty"
5243+
# Test that we can sample multiple times
5244+
for _ in range(5):
5245+
sample = replay_buffer.sample(16)
5246+
assert sample.batch_size == (16,)
5247+
assert sample["policy"].isfinite().any()
5248+
5249+
else:
5250+
# Without replay buffer, collector yields data normally
5251+
collected_frames = 0
5252+
for i, data in enumerate(col):
5253+
assert (
5254+
data is not None
5255+
), "Expected data when not using replay buffer"
5256+
# Check the data shape matches the batch size
5257+
assert (
5258+
data.numel() >= 1000
5259+
), f"Expected at least 1000 frames, got {data.numel()}"
5260+
collected_frames += data.numel()
5261+
5262+
# Only collect a few batches for the test
5263+
if i >= 2:
5264+
break
5265+
5266+
# Verify we collected some data
5267+
assert collected_frames > 0, "No frames were collected"
5268+
5269+
finally:
5270+
# Clean shutdown
5271+
col.shutdown()
5272+
if use_rb:
5273+
replay_buffer.close()
5274+
5275+
@pytest.mark.skipif(not _has_ray, reason="Ray not found")
5276+
@pytest.mark.parametrize("num_collectors", [1, 4], ids=["1_col", "4_col"])
5277+
def test_isaaclab_ray_collector_start(self, env, clean_ray, num_collectors):
5278+
5279+
from torchrl.data import LazyTensorStorage, RayReplayBuffer
5280+
5281+
rb = RayReplayBuffer(
5282+
storage=partial(LazyTensorStorage, 100_000, ndim=2),
5283+
ray_init_config={"num_cpus": 4},
5284+
)
5285+
col = RayCollector(
5286+
[torchrl.testing.env_helper.make_isaac_env] * num_collectors,
5287+
env.full_action_spec.rand_update,
5288+
frames_per_batch=8192,
5289+
total_frames=65536,
5290+
trust_policy=True,
5291+
replay_buffer=rb,
5292+
num_collectors=num_collectors,
5293+
)
5294+
col.start()
5295+
try:
5296+
time_waiting = 0
5297+
while time_waiting < 30:
5298+
if len(rb) >= 4096:
5299+
break
5300+
time.sleep(0.1)
5301+
time_waiting += 0.1
5302+
else:
5303+
raise RuntimeError("Timeout waiting for data")
5304+
sample = rb.sample(4096)
5305+
assert sample.batch_size == (4096,)
5306+
assert sample["policy"].isfinite().any()
5307+
assert sample["action"].isfinite().any()
5308+
finally:
5309+
col.shutdown()
5310+
rb.close()
5311+
51905312
def test_isaaclab_reset(self, env):
51915313
# Make a rollout that will stop as soon as a trajectory reaches a done state
51925314
r = env.rollout(1_000_000)

torchrl/collectors/collectors.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,11 +1028,16 @@ def __init__(
10281028
if not self.trust_policy:
10291029
self.policy = policy
10301030
env = getattr(self, "env", None)
1031-
wrapped_policy = _make_compatible_policy(
1032-
policy=policy,
1033-
observation_spec=getattr(env, "observation_spec", None),
1034-
env=self.env,
1035-
)
1031+
try:
1032+
wrapped_policy = _make_compatible_policy(
1033+
policy=policy,
1034+
observation_spec=getattr(env, "observation_spec", None),
1035+
env=self.env,
1036+
)
1037+
except (TypeError, AttributeError, ValueError) as err:
1038+
raise TypeError(
1039+
"Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True."
1040+
) from err
10361041
self._wrapped_policy = wrapped_policy
10371042
else:
10381043
self.policy = self._wrapped_policy = policy
@@ -1785,16 +1790,28 @@ def rollout(self) -> TensorDictBase:
17851790
next_data.clear_device_()
17861791
self._shuttle.set("next", next_data)
17871792

1793+
if self.verbose:
1794+
torchrl_logger.info(
1795+
f"Collector: Rollout step completed {self._iter=}."
1796+
)
17881797
if (
17891798
self.replay_buffer is not None
17901799
and not self._ignore_rb
17911800
and not self.extend_buffer
17921801
):
1802+
if self.verbose:
1803+
torchrl_logger.info(
1804+
f"Collector: Adding {env_output.numel()} frames to replay buffer using add()."
1805+
)
17931806
self.replay_buffer.add(self._shuttle)
17941807
if self._increment_frames(self._shuttle.numel()):
17951808
return
17961809
else:
17971810
if self.storing_device is not None:
1811+
if self.verbose:
1812+
torchrl_logger.info(
1813+
f"Collector: Moving to {self.storing_device} and adding to queue."
1814+
)
17981815
non_blocking = (
17991816
not self.no_cuda_sync or self.storing_device.type == "cuda"
18001817
)
@@ -1806,6 +1823,10 @@ def rollout(self) -> TensorDictBase:
18061823
if not self.no_cuda_sync:
18071824
self._sync_storage()
18081825
else:
1826+
if self.verbose:
1827+
torchrl_logger.info(
1828+
"Collector: Adding to queue (no device)."
1829+
)
18091830
tensordicts.append(self._shuttle)
18101831

18111832
# carry over collector data without messing up devices
@@ -1820,6 +1841,8 @@ def rollout(self) -> TensorDictBase:
18201841
self.interruptor is not None
18211842
and self.interruptor.collection_stopped()
18221843
):
1844+
if self.verbose:
1845+
torchrl_logger.info("Collector: Interruptor stopped.")
18231846
if (
18241847
self.replay_buffer is not None
18251848
and not self._ignore_rb
@@ -1846,6 +1869,7 @@ def rollout(self) -> TensorDictBase:
18461869
break
18471870
else:
18481871
if self._use_buffers:
1872+
torchrl_logger.info("Returning final rollout within buffer.")
18491873
result = self._final_rollout
18501874
try:
18511875
result = torch.stack(
@@ -1868,6 +1892,9 @@ def rollout(self) -> TensorDictBase:
18681892
):
18691893
return
18701894
else:
1895+
torchrl_logger.info(
1896+
"Returning final rollout with NO buffer (maybe_dense_stack)."
1897+
)
18711898
result = TensorDict.maybe_dense_stack(tensordicts, dim=-1)
18721899
result.refine_names(..., "time")
18731900

0 commit comments

Comments
 (0)