Skip to content

[CI] Fix envnames in SOTA tests #2921

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Apr 25, 2025
2 changes: 0 additions & 2 deletions .github/unittest/linux_sota/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,3 @@ dependencies:
- coverage
- vmas
- transformers
- gym[atari]
- gym[accept-rom-license]
4 changes: 3 additions & 1 deletion .github/unittest/linux_sota/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ python -c """import gym;import d4rl"""

# install ale-py: manylinux names are broken for CentOS so we need to manually download and
# rename them
pip install "gymnasium[atari]>=1.1.0"

# ============================================================================================ #
# ================================ PyTorch & TorchRL ========================================= #
Expand All @@ -128,6 +127,9 @@ version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")"
# submodules
git submodule sync && git submodule update --init --recursive

pip3 install ale-py -U
pip3 install "gym[atari,accept-rom-license]" "gymnasium>=1.1.0" -U

printf "Installing PyTorch with %s\n" "${CU_VERSION}"
if [[ "$TORCH_VERSION" == "nightly" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
Expand Down
4 changes: 4 additions & 0 deletions .github/unittest/linux_sota/scripts/test_sota.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
collector.frames_per_batch=20 \
collector.num_workers=1 \
logger.backend= \
env.backend=gym \
logger.test_interval=10
""",
"ppo_mujoco": """python sota-implementations/ppo/ppo_mujoco.py \
Expand All @@ -56,6 +57,7 @@
loss.mini_batch_size=20 \
loss.ppo_epochs=2 \
logger.backend= \
env.backend=gym \
logger.test_interval=10
""",
"ddpg": """python sota-implementations/ddpg/ddpg.py \
Expand All @@ -82,6 +84,7 @@
collector.frames_per_batch=20 \
loss.mini_batch_size=20 \
logger.backend= \
env.backend=gym \
logger.test_interval=40
""",
"dqn_atari": """python sota-implementations/dqn/dqn_atari.py \
Expand All @@ -91,6 +94,7 @@
buffer.batch_size=10 \
loss.num_updates=1 \
logger.backend= \
env.backend=gym \
buffer.buffer_size=120
""",
"discrete_cql_online": """python sota-implementations/cql/discrete_cql_online.py \
Expand Down
19 changes: 16 additions & 3 deletions sota-implementations/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def main(cfg: DictConfig): # noqa: F821
test_interval = cfg.logger.test_interval // frame_skip

# Create models (check utils_atari.py)
actor, critic, critic_head = make_ppo_models(cfg.env.env_name, device=device)
actor, critic, critic_head = make_ppo_models(
cfg.env.env_name, device=device, gym_backend=cfg.env.backend
)
with from_module(actor).data.to("meta").to_module(actor):
actor_eval = deepcopy(actor)
actor_eval.eval()
Expand Down Expand Up @@ -107,7 +109,13 @@ def main(cfg: DictConfig): # noqa: F821
)

# Create test environment
test_env = make_parallel_env(cfg.env.env_name, 1, device, is_test=True)
test_env = make_parallel_env(
cfg.env.env_name,
num_envs=1,
device=device,
gym_backend=cfg.env.backend,
is_test=True,
)
test_env.set_seed(0)
if cfg.logger.video:
test_env = test_env.insert_transform(
Expand Down Expand Up @@ -162,7 +170,12 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):

# Create collector
collector = SyncDataCollector(
create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device),
create_env_fn=make_parallel_env(
cfg.env.env_name,
num_envs=cfg.env.num_envs,
device=device,
gym_backend=cfg.env.backend,
),
policy=actor,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/a2c/config_atari.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Environment
env:
env_name: PongNoFrameskip-v4
backend: gymnasium
num_envs: 16

# collector
Expand Down
35 changes: 23 additions & 12 deletions sota-implementations/a2c/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
ParallelEnv,
Resize,
RewardSum,
set_gym_backend,
SignTransform,
StepCounter,
ToTensorImage,
Expand All @@ -45,27 +46,35 @@


def make_base_env(
env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False
env_name="BreakoutNoFrameskip-v4",
gym_backend="gymnasium",
frame_skip=4,
device="cpu",
is_test=False,
):
env = GymEnv(
env_name,
frame_skip=frame_skip,
from_pixels=True,
pixels_only=False,
device=device,
)
with set_gym_backend(gym_backend):
env = GymEnv(
env_name,
frame_skip=frame_skip,
from_pixels=True,
pixels_only=False,
device=device,
)
env = TransformedEnv(env)
env.append_transform(NoopResetEnv(noops=30, random=True))
if not is_test:
env.append_transform(EndOfLifeTransform())
return env


def make_parallel_env(env_name, num_envs, device, is_test=False):
def make_parallel_env(env_name, num_envs, device, gym_backend, is_test=False):
env = ParallelEnv(
num_envs,
EnvCreator(lambda: make_base_env(env_name)),
EnvCreator(
lambda: make_base_env(env_name, gym_backend=gym_backend, is_test=is_test),
),
serial_for_single=True,
gym_backend=gym_backend,
device=device,
)
env = TransformedEnv(env)
Expand Down Expand Up @@ -175,9 +184,11 @@ def make_ppo_modules_pixels(proof_environment, device):
return common_module, policy_module, value_module


def make_ppo_models(env_name, device):
def make_ppo_models(env_name, device, gym_backend):

proof_environment = make_parallel_env(env_name, 1, device="cpu")
proof_environment = make_parallel_env(
env_name, num_envs=1, device="cpu", gym_backend=gym_backend
)
common_module, policy_module, value_module = make_ppo_modules_pixels(
proof_environment, device=device
)
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/dqn/config_atari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ device: null
# Environment
env:
env_name: PongNoFrameskip-v4
backend: gymnasium

# collector
collector:
Expand Down
19 changes: 16 additions & 3 deletions sota-implementations/dqn/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@ def main(cfg: DictConfig): # noqa: F821
test_interval = cfg.logger.test_interval // frame_skip

# Make the components
model = make_dqn_model(cfg.env.env_name, frame_skip, device=device)
model = make_dqn_model(
cfg.env.env_name,
gym_backend=cfg.env.backend,
frame_skip=frame_skip,
device=device,
)
greedy_module = EGreedyModule(
annealing_num_steps=cfg.collector.annealing_frames,
eps_init=cfg.collector.eps_start,
Expand Down Expand Up @@ -114,7 +119,13 @@ def transform(td):
)

# Create the test environment
test_env = make_env(cfg.env.env_name, frame_skip, device, is_test=True)
test_env = make_env(
cfg.env.env_name,
frame_skip,
device,
gym_backend=cfg.env.backend,
is_test=True,
)
if cfg.logger.video:
test_env.insert_transform(
0,
Expand Down Expand Up @@ -154,7 +165,9 @@ def update(sampled_tensordict):

# Create the collector
collector = SyncDataCollector(
create_env_fn=make_env(cfg.env.env_name, frame_skip, device),
create_env_fn=make_env(
cfg.env.env_name, frame_skip, device, gym_backend=cfg.env.backend
),
policy=model_explore,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
Expand Down
26 changes: 15 additions & 11 deletions sota-implementations/dqn/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
NoopResetEnv,
Resize,
RewardSum,
set_gym_backend,
SignTransform,
StepCounter,
ToTensorImage,
Expand All @@ -32,15 +33,16 @@
# --------------------------------------------------------------------


def make_env(env_name, frame_skip, device, is_test=False):
env = GymEnv(
env_name,
frame_skip=frame_skip,
from_pixels=True,
pixels_only=False,
device=device,
categorical_action_encoding=True,
)
def make_env(env_name, frame_skip, device, gym_backend, is_test=False):
with set_gym_backend(gym_backend):
env = GymEnv(
env_name,
frame_skip=frame_skip,
from_pixels=True,
pixels_only=False,
device=device,
categorical_action_encoding=True,
)
env = TransformedEnv(env)
env.append_transform(NoopResetEnv(noops=30, random=True))
if not is_test:
Expand Down Expand Up @@ -94,8 +96,10 @@ def make_dqn_modules_pixels(proof_environment, device):
return qvalue_module


def make_dqn_model(env_name, frame_skip, device):
proof_environment = make_env(env_name, frame_skip, device=device)
def make_dqn_model(env_name, gym_backend, frame_skip, device):
proof_environment = make_env(
env_name, frame_skip, gym_backend=gym_backend, device=device
)
qvalue_module = make_dqn_modules_pixels(proof_environment, device=device)
del proof_environment
return qvalue_module
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/impala/config_multi_node_ray.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Environment
env:
env_name: PongNoFrameskip-v4
backend: gymnasium

# Ray init kwargs - https://docs.ray.io/en/latest/ray-core/api/doc/ray.init.html
ray_init_config:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Environment
env:
env_name: PongNoFrameskip-v4
backend: gymnasium

# Device for the forward and backward passes
local_device:
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/impala/config_single_node.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Environment
env:
env_name: PongNoFrameskip-v4
backend: gymnasium

# Device for the forward and backward passes
device:
Expand Down
9 changes: 6 additions & 3 deletions sota-implementations/impala/impala_multi_node_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def main(cfg: DictConfig): # noqa: F821
) * cfg.loss.sgd_updates

# Create models (check utils.py)
actor, critic = make_ppo_models(cfg.env.env_name)
actor, critic = make_ppo_models(cfg.env.env_name, cfg.env.backend)
actor, critic = actor.to(device), critic.to(device)

# Create collector
Expand Down Expand Up @@ -91,7 +91,8 @@ def main(cfg: DictConfig): # noqa: F821
"memory": cfg.remote_worker_resources.memory,
}
collector = RayCollector(
create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers,
create_env_fn=[make_env(cfg.env.env_name, device, gym_backend=cfg.env.backend)]
* num_workers,
policy=actor,
collector_class=SyncDataCollector,
frames_per_batch=frames_per_batch,
Expand Down Expand Up @@ -154,7 +155,9 @@ def main(cfg: DictConfig): # noqa: F821
)

# Create test environment
test_env = make_env(cfg.env.env_name, device, is_test=True)
test_env = make_env(
cfg.env.env_name, device, gym_backend=cfg.env.backend, is_test=True
)
test_env.eval()

# Main loop
Expand Down
9 changes: 6 additions & 3 deletions sota-implementations/impala/impala_multi_node_submitit.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def main(cfg: DictConfig): # noqa: F821
) * cfg.loss.sgd_updates

# Create models (check utils.py)
actor, critic = make_ppo_models(cfg.env.env_name)
actor, critic = make_ppo_models(cfg.env.env_name, cfg.env.backend)
actor, critic = actor.to(device), critic.to(device)

slurm_kwargs = {
Expand All @@ -81,7 +81,8 @@ def main(cfg: DictConfig): # noqa: F821
f"device assignment not implemented for backend {cfg.collector.backend}"
)
collector = DistributedDataCollector(
create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers,
create_env_fn=[make_env(cfg.env.env_name, device, gym_backend=cfg.env.backend)]
* num_workers,
policy=actor,
num_workers_per_collector=1,
frames_per_batch=frames_per_batch,
Expand Down Expand Up @@ -146,7 +147,9 @@ def main(cfg: DictConfig): # noqa: F821
)

# Create test environment
test_env = make_env(cfg.env.env_name, device, is_test=True)
test_env = make_env(
cfg.env.env_name, device, gym_backend=cfg.env.backend, is_test=True
)
test_env.eval()

# Main loop
Expand Down
9 changes: 6 additions & 3 deletions sota-implementations/impala/impala_single_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,12 @@ def main(cfg: DictConfig): # noqa: F821
) * cfg.loss.sgd_updates

# Create models (check utils.py)
actor, critic = make_ppo_models(cfg.env.env_name)
actor, critic = make_ppo_models(cfg.env.env_name, cfg.env.backend)

# Create collector
collector = MultiaSyncDataCollector(
create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers,
create_env_fn=[make_env(cfg.env.env_name, device, gym_backend=cfg.env.backend)]
* num_workers,
policy=actor,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
Expand Down Expand Up @@ -123,7 +124,9 @@ def main(cfg: DictConfig): # noqa: F821
)

# Create test environment
test_env = make_env(cfg.env.env_name, device, is_test=True)
test_env = make_env(
cfg.env.env_name, device, gym_backend=cfg.env.backend, is_test=True
)
test_env.eval()

# Main loop
Expand Down
Loading
Loading