Skip to content

Commit d7e6fce

Browse files
committed
[CI] Fix envnames in SOTA tests
ghstack-source-id: 809d6a5 Pull-Request-resolved: #2921
1 parent fb9c628 commit d7e6fce

19 files changed

+110
-52
lines changed

.github/unittest/linux_sota/scripts/environment.yml

-2
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,3 @@ dependencies:
2929
- coverage
3030
- vmas
3131
- transformers
32-
- gym[atari]
33-
- gym[accept-rom-license]

.github/unittest/linux_sota/scripts/run_all.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,6 @@ python -c """import gym;import d4rl"""
111111

112112
# install ale-py: manylinux names are broken for CentOS so we need to manually download and
113113
# rename them
114-
pip install "gymnasium[atari]>=1.1.0"
115114

116115
# ============================================================================================ #
117116
# ================================ PyTorch & TorchRL ========================================= #
@@ -128,6 +127,8 @@ version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")"
128127
# submodules
129128
git submodule sync && git submodule update --init --recursive
130129

130+
pip3 install "gym[atari,accept-rom-license]" "gymnasium[atari,ale-py]>=1.1.0" -U
131+
131132
printf "Installing PyTorch with %s\n" "${CU_VERSION}"
132133
if [[ "$TORCH_VERSION" == "nightly" ]]; then
133134
if [ "${CU_VERSION:-}" == cpu ] ; then

.github/unittest/linux_sota/scripts/test_sota.py

+4
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
collector.frames_per_batch=20 \
4040
collector.num_workers=1 \
4141
logger.backend= \
42+
env.env_name=ALE/Pong-v5 \
4243
logger.test_interval=10
4344
""",
4445
"ppo_mujoco": """python sota-implementations/ppo/ppo_mujoco.py \
@@ -56,6 +57,7 @@
5657
loss.mini_batch_size=20 \
5758
loss.ppo_epochs=2 \
5859
logger.backend= \
60+
env.env_name=ALE/Pong-v5 \
5961
logger.test_interval=10
6062
""",
6163
"ddpg": """python sota-implementations/ddpg/ddpg.py \
@@ -82,6 +84,7 @@
8284
collector.frames_per_batch=20 \
8385
loss.mini_batch_size=20 \
8486
logger.backend= \
87+
env.env_name=ALE/Pong-v5 \
8588
logger.test_interval=40
8689
""",
8790
"dqn_atari": """python sota-implementations/dqn/dqn_atari.py \
@@ -91,6 +94,7 @@
9194
buffer.batch_size=10 \
9295
loss.num_updates=1 \
9396
logger.backend= \
97+
env.env_name=ALE/Pong-v5 \
9498
buffer.buffer_size=120
9599
""",
96100
"discrete_cql_online": """python sota-implementations/cql/discrete_cql_online.py \

sota-implementations/a2c/a2c_atari.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ def main(cfg: DictConfig): # noqa: F821
107107
)
108108

109109
# Create test environment
110-
test_env = make_parallel_env(cfg.env.env_name, 1, device, is_test=True)
110+
test_env = make_parallel_env(
111+
cfg.env.env_name, 1, device, gym_backend=cfg.env.gym_backend, is_test=True
112+
)
111113
test_env.set_seed(0)
112114
if cfg.logger.video:
113115
test_env = test_env.insert_transform(

sota-implementations/a2c/config_atari.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Environment
22
env:
33
env_name: PongNoFrameskip-v4
4+
backend: gymnasium
45
num_envs: 16
56

67
# collector

sota-implementations/a2c/utils_atari.py

+16-9
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
ParallelEnv,
2222
Resize,
2323
RewardSum,
24+
set_gym_backend,
2425
SignTransform,
2526
StepCounter,
2627
ToTensorImage,
@@ -45,27 +46,33 @@
4546

4647

4748
def make_base_env(
48-
env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False
49+
env_name="BreakoutNoFrameskip-v4",
50+
gym_backend="gymnasium",
51+
frame_skip=4,
52+
device="cpu",
53+
is_test=False,
4954
):
50-
env = GymEnv(
51-
env_name,
52-
frame_skip=frame_skip,
53-
from_pixels=True,
54-
pixels_only=False,
55-
device=device,
56-
)
55+
with set_gym_backend(gym_backend):
56+
env = GymEnv(
57+
env_name,
58+
frame_skip=frame_skip,
59+
from_pixels=True,
60+
pixels_only=False,
61+
device=device,
62+
)
5763
env = TransformedEnv(env)
5864
env.append_transform(NoopResetEnv(noops=30, random=True))
5965
if not is_test:
6066
env.append_transform(EndOfLifeTransform())
6167
return env
6268

6369

64-
def make_parallel_env(env_name, num_envs, device, is_test=False):
70+
def make_parallel_env(env_name, num_envs, device, gym_backend, is_test=False):
6571
env = ParallelEnv(
6672
num_envs,
6773
EnvCreator(lambda: make_base_env(env_name)),
6874
serial_for_single=True,
75+
gym_backend=gym_backend,
6976
device=device,
7077
)
7178
env = TransformedEnv(env)

sota-implementations/dqn/config_atari.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ device: null
33
# Environment
44
env:
55
env_name: PongNoFrameskip-v4
6+
backend: gymnasium
67

78
# collector
89
collector:

sota-implementations/dqn/dqn_atari.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,13 @@ def transform(td):
114114
)
115115

116116
# Create the test environment
117-
test_env = make_env(cfg.env.env_name, frame_skip, device, is_test=True)
117+
test_env = make_env(
118+
cfg.env.env_name,
119+
frame_skip,
120+
device,
121+
gym_backend=cfg.env.gym_backend,
122+
is_test=True,
123+
)
118124
if cfg.logger.video:
119125
test_env.insert_transform(
120126
0,
@@ -154,7 +160,9 @@ def update(sampled_tensordict):
154160

155161
# Create the collector
156162
collector = SyncDataCollector(
157-
create_env_fn=make_env(cfg.env.env_name, frame_skip, device),
163+
create_env_fn=make_env(
164+
cfg.env.env_name, frame_skip, device, gym_backend=cfg.env.gym_backend
165+
),
158166
policy=model_explore,
159167
frames_per_batch=frames_per_batch,
160168
total_frames=total_frames,

sota-implementations/dqn/utils_atari.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
NoopResetEnv,
1717
Resize,
1818
RewardSum,
19+
set_gym_backend,
1920
SignTransform,
2021
StepCounter,
2122
ToTensorImage,
@@ -32,15 +33,16 @@
3233
# --------------------------------------------------------------------
3334

3435

35-
def make_env(env_name, frame_skip, device, is_test=False):
36-
env = GymEnv(
37-
env_name,
38-
frame_skip=frame_skip,
39-
from_pixels=True,
40-
pixels_only=False,
41-
device=device,
42-
categorical_action_encoding=True,
43-
)
36+
def make_env(env_name, frame_skip, device, gym_backend, is_test=False):
37+
with set_gym_backend(gym_backend):
38+
env = GymEnv(
39+
env_name,
40+
frame_skip=frame_skip,
41+
from_pixels=True,
42+
pixels_only=False,
43+
device=device,
44+
categorical_action_encoding=True,
45+
)
4446
env = TransformedEnv(env)
4547
env.append_transform(NoopResetEnv(noops=30, random=True))
4648
if not is_test:

sota-implementations/impala/config_multi_node_ray.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Environment
22
env:
33
env_name: PongNoFrameskip-v4
4+
backend: gymnasium
45

56
# Ray init kwargs - https://docs.ray.io/en/latest/ray-core/api/doc/ray.init.html
67
ray_init_config:

sota-implementations/impala/config_multi_node_submitit.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Environment
22
env:
33
env_name: PongNoFrameskip-v4
4+
backend: gymnasium
45

56
# Device for the forward and backward passes
67
local_device:

sota-implementations/impala/config_single_node.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Environment
22
env:
33
env_name: PongNoFrameskip-v4
4+
backend: gymnasium
45

56
# Device for the forward and backward passes
67
device:

sota-implementations/impala/impala_multi_node_ray.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def main(cfg: DictConfig): # noqa: F821
5959
) * cfg.loss.sgd_updates
6060

6161
# Create models (check utils.py)
62-
actor, critic = make_ppo_models(cfg.env.env_name)
62+
actor, critic = make_ppo_models(cfg.env.env_name, cfg.env.gym_backend)
6363
actor, critic = actor.to(device), critic.to(device)
6464

6565
# Create collector
@@ -91,7 +91,10 @@ def main(cfg: DictConfig): # noqa: F821
9191
"memory": cfg.remote_worker_resources.memory,
9292
}
9393
collector = RayCollector(
94-
create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers,
94+
create_env_fn=[
95+
make_env(cfg.env.env_name, device, gym_backend=cfg.env.gym_backend)
96+
]
97+
* num_workers,
9598
policy=actor,
9699
collector_class=SyncDataCollector,
97100
frames_per_batch=frames_per_batch,
@@ -154,7 +157,9 @@ def main(cfg: DictConfig): # noqa: F821
154157
)
155158

156159
# Create test environment
157-
test_env = make_env(cfg.env.env_name, device, is_test=True)
160+
test_env = make_env(
161+
cfg.env.env_name, device, gym_backend=cfg.env.gym_backend, is_test=True
162+
)
158163
test_env.eval()
159164

160165
# Main loop

sota-implementations/impala/impala_multi_node_submitit.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def main(cfg: DictConfig): # noqa: F821
6161
) * cfg.loss.sgd_updates
6262

6363
# Create models (check utils.py)
64-
actor, critic = make_ppo_models(cfg.env.env_name)
64+
actor, critic = make_ppo_models(cfg.env.env_name, cfg.env.gym_backend)
6565
actor, critic = actor.to(device), critic.to(device)
6666

6767
slurm_kwargs = {
@@ -81,7 +81,10 @@ def main(cfg: DictConfig): # noqa: F821
8181
f"device assignment not implemented for backend {cfg.collector.backend}"
8282
)
8383
collector = DistributedDataCollector(
84-
create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers,
84+
create_env_fn=[
85+
make_env(cfg.env.env_name, device, gym_backend=cfg.env.gym_backend)
86+
]
87+
* num_workers,
8588
policy=actor,
8689
num_workers_per_collector=1,
8790
frames_per_batch=frames_per_batch,
@@ -146,7 +149,9 @@ def main(cfg: DictConfig): # noqa: F821
146149
)
147150

148151
# Create test environment
149-
test_env = make_env(cfg.env.env_name, device, is_test=True)
152+
test_env = make_env(
153+
cfg.env.env_name, device, gym_backend=cfg.env.gym_backend, is_test=True
154+
)
150155
test_env.eval()
151156

152157
# Main loop

sota-implementations/impala/impala_single_node.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,14 @@ def main(cfg: DictConfig): # noqa: F821
5858
) * cfg.loss.sgd_updates
5959

6060
# Create models (check utils.py)
61-
actor, critic = make_ppo_models(cfg.env.env_name)
61+
actor, critic = make_ppo_models(cfg.env.env_name, cfg.env.gym_backend)
6262

6363
# Create collector
6464
collector = MultiaSyncDataCollector(
65-
create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers,
65+
create_env_fn=[
66+
make_env(cfg.env.env_name, device, gym_backend=cfg.env.gym_backend)
67+
]
68+
* num_workers,
6669
policy=actor,
6770
frames_per_batch=frames_per_batch,
6871
total_frames=total_frames,
@@ -123,7 +126,9 @@ def main(cfg: DictConfig): # noqa: F821
123126
)
124127

125128
# Create test environment
126-
test_env = make_env(cfg.env.env_name, device, is_test=True)
129+
test_env = make_env(
130+
cfg.env.env_name, device, gym_backend=cfg.env.gym_backend, is_test=True
131+
)
127132
test_env.eval()
128133

129134
# Main loop

sota-implementations/impala/utils.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
NoopResetEnv,
1818
Resize,
1919
RewardSum,
20+
set_gym_backend,
2021
SignTransform,
2122
StepCounter,
2223
ToTensorImage,
@@ -38,10 +39,11 @@
3839
# --------------------------------------------------------------------
3940

4041

41-
def make_env(env_name, device, is_test=False):
42-
env = GymEnv(
43-
env_name, frame_skip=4, from_pixels=True, pixels_only=False, device=device
44-
)
42+
def make_env(env_name, device, gym_backend, is_test=False):
43+
with set_gym_backend(gym_backend):
44+
env = GymEnv(
45+
env_name, frame_skip=4, from_pixels=True, pixels_only=False, device=device
46+
)
4547
env = TransformedEnv(env)
4648
env.append_transform(NoopResetEnv(noops=30, random=True))
4749
if not is_test:
@@ -139,9 +141,11 @@ def make_ppo_modules_pixels(proof_environment):
139141
return common_module, policy_module, value_module
140142

141143

142-
def make_ppo_models(env_name):
144+
def make_ppo_models(env_name, gym_backend):
143145

144-
proof_environment = make_env(env_name, device="cpu")
146+
proof_environment = make_env(
147+
env_name, device="cpu", gym_backend=gym_backend
148+
)
145149
common_module, policy_module, value_module = make_ppo_modules_pixels(
146150
proof_environment
147151
)

sota-implementations/ppo/config_atari.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
env:
33
env_name: PongNoFrameskip-v4
44
num_envs: 8
5+
backend: gymnasium
56

67
# collector
78
collector:

sota-implementations/ppo/ppo_atari.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ def main(cfg: DictConfig): # noqa: F821
6666

6767
# Create collector
6868
collector = SyncDataCollector(
69-
create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device),
69+
create_env_fn=make_parallel_env(
70+
cfg.env.env_name, cfg.env.num_envs, device, gym_backend=cfg.env.gym_backend
71+
),
7072
policy=actor,
7173
frames_per_batch=frames_per_batch,
7274
total_frames=total_frames,
@@ -137,7 +139,9 @@ def main(cfg: DictConfig): # noqa: F821
137139
logger_video = False
138140

139141
# Create test environment
140-
test_env = make_parallel_env(cfg.env.env_name, 1, device, is_test=True)
142+
test_env = make_parallel_env(
143+
cfg.env.env_name, 1, device, is_test=True, gym_backend=cfg.env.gym_backend
144+
)
141145
if logger_video:
142146
test_env = test_env.append_transform(
143147
VideoRecorder(logger, tag="rendering/test", in_keys=["pixels_int"])

0 commit comments

Comments
 (0)