diff --git a/.github/unittest/linux_sota/scripts/environment.yml b/.github/unittest/linux_sota/scripts/environment.yml index e1a8dadc1e8..6769658bfed 100644 --- a/.github/unittest/linux_sota/scripts/environment.yml +++ b/.github/unittest/linux_sota/scripts/environment.yml @@ -29,5 +29,3 @@ dependencies: - coverage - vmas - transformers - - gym[atari] - - gym[accept-rom-license] diff --git a/.github/unittest/linux_sota/scripts/run_all.sh b/.github/unittest/linux_sota/scripts/run_all.sh index 9c458ae4045..d7681e433f7 100755 --- a/.github/unittest/linux_sota/scripts/run_all.sh +++ b/.github/unittest/linux_sota/scripts/run_all.sh @@ -111,7 +111,6 @@ python -c """import gym;import d4rl""" # install ale-py: manylinux names are broken for CentOS so we need to manually download and # rename them -pip install "gymnasium[atari]>=1.1.0" # ============================================================================================ # # ================================ PyTorch & TorchRL ========================================= # @@ -128,6 +127,9 @@ version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" # submodules git submodule sync && git submodule update --init --recursive +pip3 install ale-py -U +pip3 install "gym[atari,accept-rom-license]" "gymnasium>=1.1.0" -U + printf "Installing PyTorch with %s\n" "${CU_VERSION}" if [[ "$TORCH_VERSION" == "nightly" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then diff --git a/.github/unittest/linux_sota/scripts/test_sota.py b/.github/unittest/linux_sota/scripts/test_sota.py index aa40efd1aef..f3513bcfca2 100644 --- a/.github/unittest/linux_sota/scripts/test_sota.py +++ b/.github/unittest/linux_sota/scripts/test_sota.py @@ -39,6 +39,7 @@ collector.frames_per_batch=20 \ collector.num_workers=1 \ logger.backend= \ + env.backend=gym \ logger.test_interval=10 """, "ppo_mujoco": """python sota-implementations/ppo/ppo_mujoco.py \ @@ -56,6 +57,7 @@ loss.mini_batch_size=20 \ loss.ppo_epochs=2 \ logger.backend= \ + env.backend=gym \ logger.test_interval=10 """, "ddpg": """python sota-implementations/ddpg/ddpg.py \ @@ -82,6 +84,7 @@ collector.frames_per_batch=20 \ loss.mini_batch_size=20 \ logger.backend= \ + env.backend=gym \ logger.test_interval=40 """, "dqn_atari": """python sota-implementations/dqn/dqn_atari.py \ @@ -91,6 +94,7 @@ buffer.batch_size=10 \ loss.num_updates=1 \ logger.backend= \ + env.backend=gym \ buffer.buffer_size=120 """, "discrete_cql_online": """python sota-implementations/cql/discrete_cql_online.py \ diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index 4d12a75ea0f..190a7518cab 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -47,7 +47,9 @@ def main(cfg: DictConfig): # noqa: F821 test_interval = cfg.logger.test_interval // frame_skip # Create models (check utils_atari.py) - actor, critic, critic_head = make_ppo_models(cfg.env.env_name, device=device) + actor, critic, critic_head = make_ppo_models( + cfg.env.env_name, device=device, gym_backend=cfg.env.backend + ) with from_module(actor).data.to("meta").to_module(actor): actor_eval = deepcopy(actor) actor_eval.eval() @@ -107,7 +109,13 @@ def main(cfg: DictConfig): # noqa: F821 ) # Create test environment - test_env = make_parallel_env(cfg.env.env_name, 1, device, is_test=True) + test_env = make_parallel_env( + cfg.env.env_name, + num_envs=1, + device=device, + gym_backend=cfg.env.backend, + is_test=True, + ) test_env.set_seed(0) if cfg.logger.video: test_env = test_env.insert_transform( @@ -162,7 +170,12 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm): # Create collector collector = SyncDataCollector( - create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device), + create_env_fn=make_parallel_env( + cfg.env.env_name, + num_envs=cfg.env.num_envs, + device=device, + gym_backend=cfg.env.backend, + ), policy=actor, frames_per_batch=frames_per_batch, total_frames=total_frames, diff --git a/sota-implementations/a2c/config_atari.yaml b/sota-implementations/a2c/config_atari.yaml index 59a0a621756..964704efc63 100644 --- a/sota-implementations/a2c/config_atari.yaml +++ b/sota-implementations/a2c/config_atari.yaml @@ -1,6 +1,7 @@ # Environment env: env_name: PongNoFrameskip-v4 + backend: gymnasium num_envs: 16 # collector diff --git a/sota-implementations/a2c/utils_atari.py b/sota-implementations/a2c/utils_atari.py index 6ff62bbe520..cf8372d7b86 100644 --- a/sota-implementations/a2c/utils_atari.py +++ b/sota-implementations/a2c/utils_atari.py @@ -21,6 +21,7 @@ ParallelEnv, Resize, RewardSum, + set_gym_backend, SignTransform, StepCounter, ToTensorImage, @@ -45,15 +46,20 @@ def make_base_env( - env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False + env_name="BreakoutNoFrameskip-v4", + gym_backend="gymnasium", + frame_skip=4, + device="cpu", + is_test=False, ): - env = GymEnv( - env_name, - frame_skip=frame_skip, - from_pixels=True, - pixels_only=False, - device=device, - ) + with set_gym_backend(gym_backend): + env = GymEnv( + env_name, + frame_skip=frame_skip, + from_pixels=True, + pixels_only=False, + device=device, + ) env = TransformedEnv(env) env.append_transform(NoopResetEnv(noops=30, random=True)) if not is_test: @@ -61,11 +67,14 @@ def make_base_env( return env -def make_parallel_env(env_name, num_envs, device, is_test=False): +def make_parallel_env(env_name, num_envs, device, gym_backend, is_test=False): env = ParallelEnv( num_envs, - EnvCreator(lambda: make_base_env(env_name)), + EnvCreator( + lambda: make_base_env(env_name, gym_backend=gym_backend, is_test=is_test), + ), serial_for_single=True, + gym_backend=gym_backend, device=device, ) env = TransformedEnv(env) @@ -175,9 +184,11 @@ def make_ppo_modules_pixels(proof_environment, device): return common_module, policy_module, value_module -def make_ppo_models(env_name, device): +def make_ppo_models(env_name, device, gym_backend): - proof_environment = make_parallel_env(env_name, 1, device="cpu") + proof_environment = make_parallel_env( + env_name, num_envs=1, device="cpu", gym_backend=gym_backend + ) common_module, policy_module, value_module = make_ppo_modules_pixels( proof_environment, device=device ) diff --git a/sota-implementations/dqn/config_atari.yaml b/sota-implementations/dqn/config_atari.yaml index 85d513fbb2c..cb5dcf9411c 100644 --- a/sota-implementations/dqn/config_atari.yaml +++ b/sota-implementations/dqn/config_atari.yaml @@ -3,6 +3,7 @@ device: null # Environment env: env_name: PongNoFrameskip-v4 + backend: gymnasium # collector collector: diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index c2bffd91869..255e32d9a31 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -49,7 +49,12 @@ def main(cfg: DictConfig): # noqa: F821 test_interval = cfg.logger.test_interval // frame_skip # Make the components - model = make_dqn_model(cfg.env.env_name, frame_skip, device=device) + model = make_dqn_model( + cfg.env.env_name, + gym_backend=cfg.env.backend, + frame_skip=frame_skip, + device=device, + ) greedy_module = EGreedyModule( annealing_num_steps=cfg.collector.annealing_frames, eps_init=cfg.collector.eps_start, @@ -114,7 +119,13 @@ def transform(td): ) # Create the test environment - test_env = make_env(cfg.env.env_name, frame_skip, device, is_test=True) + test_env = make_env( + cfg.env.env_name, + frame_skip, + device, + gym_backend=cfg.env.backend, + is_test=True, + ) if cfg.logger.video: test_env.insert_transform( 0, @@ -154,7 +165,9 @@ def update(sampled_tensordict): # Create the collector collector = SyncDataCollector( - create_env_fn=make_env(cfg.env.env_name, frame_skip, device), + create_env_fn=make_env( + cfg.env.env_name, frame_skip, device, gym_backend=cfg.env.backend + ), policy=model_explore, frames_per_batch=frames_per_batch, total_frames=total_frames, diff --git a/sota-implementations/dqn/utils_atari.py b/sota-implementations/dqn/utils_atari.py index 0956dfeb2ac..1a9e96fec8e 100644 --- a/sota-implementations/dqn/utils_atari.py +++ b/sota-implementations/dqn/utils_atari.py @@ -16,6 +16,7 @@ NoopResetEnv, Resize, RewardSum, + set_gym_backend, SignTransform, StepCounter, ToTensorImage, @@ -32,15 +33,16 @@ # -------------------------------------------------------------------- -def make_env(env_name, frame_skip, device, is_test=False): - env = GymEnv( - env_name, - frame_skip=frame_skip, - from_pixels=True, - pixels_only=False, - device=device, - categorical_action_encoding=True, - ) +def make_env(env_name, frame_skip, device, gym_backend, is_test=False): + with set_gym_backend(gym_backend): + env = GymEnv( + env_name, + frame_skip=frame_skip, + from_pixels=True, + pixels_only=False, + device=device, + categorical_action_encoding=True, + ) env = TransformedEnv(env) env.append_transform(NoopResetEnv(noops=30, random=True)) if not is_test: @@ -94,8 +96,10 @@ def make_dqn_modules_pixels(proof_environment, device): return qvalue_module -def make_dqn_model(env_name, frame_skip, device): - proof_environment = make_env(env_name, frame_skip, device=device) +def make_dqn_model(env_name, gym_backend, frame_skip, device): + proof_environment = make_env( + env_name, frame_skip, gym_backend=gym_backend, device=device + ) qvalue_module = make_dqn_modules_pixels(proof_environment, device=device) del proof_environment return qvalue_module diff --git a/sota-implementations/impala/config_multi_node_ray.yaml b/sota-implementations/impala/config_multi_node_ray.yaml index 549428a4725..5a4b8a79d4a 100644 --- a/sota-implementations/impala/config_multi_node_ray.yaml +++ b/sota-implementations/impala/config_multi_node_ray.yaml @@ -1,6 +1,7 @@ # Environment env: env_name: PongNoFrameskip-v4 + backend: gymnasium # Ray init kwargs - https://docs.ray.io/en/latest/ray-core/api/doc/ray.init.html ray_init_config: diff --git a/sota-implementations/impala/config_multi_node_submitit.yaml b/sota-implementations/impala/config_multi_node_submitit.yaml index 4d4332722aa..18e807bdf0b 100644 --- a/sota-implementations/impala/config_multi_node_submitit.yaml +++ b/sota-implementations/impala/config_multi_node_submitit.yaml @@ -1,6 +1,7 @@ # Environment env: env_name: PongNoFrameskip-v4 + backend: gymnasium # Device for the forward and backward passes local_device: diff --git a/sota-implementations/impala/config_single_node.yaml b/sota-implementations/impala/config_single_node.yaml index 655edaddc4e..c937698962f 100644 --- a/sota-implementations/impala/config_single_node.yaml +++ b/sota-implementations/impala/config_single_node.yaml @@ -1,6 +1,7 @@ # Environment env: env_name: PongNoFrameskip-v4 + backend: gymnasium # Device for the forward and backward passes device: diff --git a/sota-implementations/impala/impala_multi_node_ray.py b/sota-implementations/impala/impala_multi_node_ray.py index 5364c82c7b2..c6c4fcb02e8 100644 --- a/sota-implementations/impala/impala_multi_node_ray.py +++ b/sota-implementations/impala/impala_multi_node_ray.py @@ -59,7 +59,7 @@ def main(cfg: DictConfig): # noqa: F821 ) * cfg.loss.sgd_updates # Create models (check utils.py) - actor, critic = make_ppo_models(cfg.env.env_name) + actor, critic = make_ppo_models(cfg.env.env_name, cfg.env.backend) actor, critic = actor.to(device), critic.to(device) # Create collector @@ -91,7 +91,8 @@ def main(cfg: DictConfig): # noqa: F821 "memory": cfg.remote_worker_resources.memory, } collector = RayCollector( - create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers, + create_env_fn=[make_env(cfg.env.env_name, device, gym_backend=cfg.env.backend)] + * num_workers, policy=actor, collector_class=SyncDataCollector, frames_per_batch=frames_per_batch, @@ -154,7 +155,9 @@ def main(cfg: DictConfig): # noqa: F821 ) # Create test environment - test_env = make_env(cfg.env.env_name, device, is_test=True) + test_env = make_env( + cfg.env.env_name, device, gym_backend=cfg.env.backend, is_test=True + ) test_env.eval() # Main loop diff --git a/sota-implementations/impala/impala_multi_node_submitit.py b/sota-implementations/impala/impala_multi_node_submitit.py index 527821820ca..51d4ab8d27d 100644 --- a/sota-implementations/impala/impala_multi_node_submitit.py +++ b/sota-implementations/impala/impala_multi_node_submitit.py @@ -61,7 +61,7 @@ def main(cfg: DictConfig): # noqa: F821 ) * cfg.loss.sgd_updates # Create models (check utils.py) - actor, critic = make_ppo_models(cfg.env.env_name) + actor, critic = make_ppo_models(cfg.env.env_name, cfg.env.backend) actor, critic = actor.to(device), critic.to(device) slurm_kwargs = { @@ -81,7 +81,8 @@ def main(cfg: DictConfig): # noqa: F821 f"device assignment not implemented for backend {cfg.collector.backend}" ) collector = DistributedDataCollector( - create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers, + create_env_fn=[make_env(cfg.env.env_name, device, gym_backend=cfg.env.backend)] + * num_workers, policy=actor, num_workers_per_collector=1, frames_per_batch=frames_per_batch, @@ -146,7 +147,9 @@ def main(cfg: DictConfig): # noqa: F821 ) # Create test environment - test_env = make_env(cfg.env.env_name, device, is_test=True) + test_env = make_env( + cfg.env.env_name, device, gym_backend=cfg.env.backend, is_test=True + ) test_env.eval() # Main loop diff --git a/sota-implementations/impala/impala_single_node.py b/sota-implementations/impala/impala_single_node.py index b7af2adbc38..424bf65ed8f 100644 --- a/sota-implementations/impala/impala_single_node.py +++ b/sota-implementations/impala/impala_single_node.py @@ -58,11 +58,12 @@ def main(cfg: DictConfig): # noqa: F821 ) * cfg.loss.sgd_updates # Create models (check utils.py) - actor, critic = make_ppo_models(cfg.env.env_name) + actor, critic = make_ppo_models(cfg.env.env_name, cfg.env.backend) # Create collector collector = MultiaSyncDataCollector( - create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers, + create_env_fn=[make_env(cfg.env.env_name, device, gym_backend=cfg.env.backend)] + * num_workers, policy=actor, frames_per_batch=frames_per_batch, total_frames=total_frames, @@ -123,7 +124,9 @@ def main(cfg: DictConfig): # noqa: F821 ) # Create test environment - test_env = make_env(cfg.env.env_name, device, is_test=True) + test_env = make_env( + cfg.env.env_name, device, gym_backend=cfg.env.backend, is_test=True + ) test_env.eval() # Main loop diff --git a/sota-implementations/impala/utils.py b/sota-implementations/impala/utils.py index e174bc2e71c..268e63c6caf 100644 --- a/sota-implementations/impala/utils.py +++ b/sota-implementations/impala/utils.py @@ -17,6 +17,7 @@ NoopResetEnv, Resize, RewardSum, + set_gym_backend, SignTransform, StepCounter, ToTensorImage, @@ -38,10 +39,11 @@ # -------------------------------------------------------------------- -def make_env(env_name, device, is_test=False): - env = GymEnv( - env_name, frame_skip=4, from_pixels=True, pixels_only=False, device=device - ) +def make_env(env_name, device, gym_backend, is_test=False): + with set_gym_backend(gym_backend): + env = GymEnv( + env_name, frame_skip=4, from_pixels=True, pixels_only=False, device=device + ) env = TransformedEnv(env) env.append_transform(NoopResetEnv(noops=30, random=True)) if not is_test: @@ -139,9 +141,9 @@ def make_ppo_modules_pixels(proof_environment): return common_module, policy_module, value_module -def make_ppo_models(env_name): +def make_ppo_models(env_name, gym_backend): - proof_environment = make_env(env_name, device="cpu") + proof_environment = make_env(env_name, device="cpu", gym_backend=gym_backend) common_module, policy_module, value_module = make_ppo_modules_pixels( proof_environment ) diff --git a/sota-implementations/ppo/config_atari.yaml b/sota-implementations/ppo/config_atari.yaml index f7a340e3512..038bc9bc45f 100644 --- a/sota-implementations/ppo/config_atari.yaml +++ b/sota-implementations/ppo/config_atari.yaml @@ -2,6 +2,7 @@ env: env_name: PongNoFrameskip-v4 num_envs: 8 + backend: gymnasium # collector collector: diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index 25b6f63e893..918defc6f28 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -62,11 +62,18 @@ def main(cfg: DictConfig): # noqa: F821 compile_mode = "reduce-overhead" # Create models (check utils_atari.py) - actor, critic = make_ppo_models(cfg.env.env_name, device=device) + actor, critic = make_ppo_models( + cfg.env.env_name, device=device, gym_backend=cfg.env.backend + ) # Create collector collector = SyncDataCollector( - create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device), + create_env_fn=make_parallel_env( + cfg.env.env_name, + num_envs=cfg.env.num_envs, + device=device, + gym_backend=cfg.env.backend, + ), policy=actor, frames_per_batch=frames_per_batch, total_frames=total_frames, @@ -137,7 +144,9 @@ def main(cfg: DictConfig): # noqa: F821 logger_video = False # Create test environment - test_env = make_parallel_env(cfg.env.env_name, 1, device, is_test=True) + test_env = make_parallel_env( + cfg.env.env_name, 1, device, is_test=True, gym_backend=cfg.env.backend + ) if logger_video: test_env = test_env.append_transform( VideoRecorder(logger, tag="rendering/test", in_keys=["pixels_int"]) diff --git a/sota-implementations/ppo/utils_atari.py b/sota-implementations/ppo/utils_atari.py index fa9d4bb053e..e3eae253b48 100644 --- a/sota-implementations/ppo/utils_atari.py +++ b/sota-implementations/ppo/utils_atari.py @@ -21,6 +21,7 @@ RenameTransform, Resize, RewardSum, + set_gym_backend, SignTransform, StepCounter, ToTensorImage, @@ -43,15 +44,21 @@ # -------------------------------------------------------------------- -def make_base_env(env_name="BreakoutNoFrameskip-v4", frame_skip=4, is_test=False): - env = GymEnv( - env_name, - frame_skip=frame_skip, - from_pixels=True, - pixels_only=False, - device="cpu", - categorical_action_encoding=True, - ) +def make_base_env( + env_name="BreakoutNoFrameskip-v4", + frame_skip=4, + gym_backend="gymnasium", + is_test=False, +): + with set_gym_backend(gym_backend): + env = GymEnv( + env_name, + frame_skip=frame_skip, + from_pixels=True, + pixels_only=False, + device="cpu", + categorical_action_encoding=True, + ) env = TransformedEnv(env) env.append_transform(NoopResetEnv(noops=30, random=True)) if not is_test: @@ -59,10 +66,10 @@ def make_base_env(env_name="BreakoutNoFrameskip-v4", frame_skip=4, is_test=False return env -def make_parallel_env(env_name, num_envs, device, is_test=False): +def make_parallel_env(env_name, num_envs, device, gym_backend, is_test=False): env = ParallelEnv( num_envs, - EnvCreator(lambda: make_base_env(env_name)), + EnvCreator(lambda: make_base_env(env_name, gym_backend=gym_backend)), serial_for_single=True, device=device, ) @@ -174,9 +181,11 @@ def make_ppo_modules_pixels(proof_environment, device): return common_module, policy_module, value_module -def make_ppo_models(env_name, device): +def make_ppo_models(env_name, device, gym_backend): - proof_environment = make_parallel_env(env_name, 1, device=device) + proof_environment = make_parallel_env( + env_name, 1, device=device, gym_backend=gym_backend + ) common_module, policy_module, value_module = make_ppo_modules_pixels( proof_environment, device=device,