Skip to content

Commit b599d9b

Browse files
committed
[Refactor] Make env creator optional for Ray
ghstack-source-id: 6444d2c Pull-Request: #3227
1 parent 13434eb commit b599d9b

File tree

1 file changed

+9
-3
lines changed
  • torchrl/collectors/distributed

1 file changed

+9
-3
lines changed

torchrl/collectors/distributed/ray.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,10 @@ class RayCollector(DataCollectorBase):
268268
:class:`~torchrl.weight_update.weight_sync_schemes.WeightSyncScheme` instances.
269269
This is the recommended way to configure weight synchronization. If not provided,
270270
defaults to ``{"policy": RayWeightSyncScheme()}``.
271+
use_env_creator (bool, optional): if ``True``, the environment constructor functions will be wrapped
272+
in :class:`~torchrl.envs.EnvCreator`. This is useful for multiprocessed settings where shared memory
273+
needs to be managed, but Ray has its own object storage mechanism, so this is typically not needed.
274+
Defaults to ``False``.
271275
272276
Examples:
273277
>>> from torch import nn
@@ -332,6 +336,7 @@ def __init__(
332336
| Callable[[], WeightUpdaterBase]
333337
| None = None,
334338
weight_sync_schemes: dict[str, WeightSyncScheme] | None = None,
339+
use_env_creator: bool = False,
335340
):
336341
self.frames_per_batch = frames_per_batch
337342
if remote_configs is None:
@@ -406,9 +411,10 @@ def check_list_length_consistency(*lists):
406411
create_env_fn, collector_kwargs, remote_configs = out_lists
407412
num_collectors = len(create_env_fn)
408413

409-
for i in range(len(create_env_fn)):
410-
if not isinstance(create_env_fn[i], (EnvBase, EnvCreator)):
411-
create_env_fn[i] = EnvCreator(create_env_fn[i])
414+
if use_env_creator:
415+
for i in range(len(create_env_fn)):
416+
if not isinstance(create_env_fn[i], (EnvBase, EnvCreator)):
417+
create_env_fn[i] = EnvCreator(create_env_fn[i])
412418

413419
# If ray available, try to connect to an existing Ray cluster or start one and connect to it.
414420
if not _has_ray:

0 commit comments

Comments
 (0)