Skip to content

Commit ee9e9d7

Browse files
committed
[Environment] Fix lib CI failures
ghstack-source-id: 0c77ddc6bae80baa866a9c35a8bc003491dfde43 Pull-Request-resolved: #2923
1 parent e4733b8 commit ee9e9d7

File tree

7 files changed

+32
-19
lines changed

7 files changed

+32
-19
lines changed

.github/unittest/linux_libs/scripts_gym/batch_scripts.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ conda activate ./env
1515
$DIR/install.sh
1616

1717
# Extracted from run_test.sh to run once.
18-
apt-get update && apt-get install -y git wget libglew-dev libx11-dev x11proto-dev g++ cmake
18+
apt-get update && apt-get install -y git wget libglew-dev libx11-dev x11proto-dev g++
1919

2020
# solves "'extras_require' must be a dictionary"
2121
pip install setuptools==65.3.0

.github/unittest/linux_libs/scripts_gym/install.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ unset PYTORCH_VERSION
44
# For unittest, nightly PyTorch is used as the following section,
55
# so no need to set PYTORCH_VERSION.
66
# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config.
7-
apt-get update && apt-get install -y git wget gcc g++ cmake
7+
apt-get update && apt-get install -y git wget gcc g++
88

99
set -e
1010
set -v

.github/unittest/linux_libs/scripts_gym/setup_env.sh

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ set -e
99

1010
this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
1111
# Avoid error: "fatal: unsafe repository"
12-
apt-get update && apt-get install -y git wget gcc g++ cmake
12+
apt-get update && apt-get install -y git wget gcc g++
1313

1414
git config --global --add safe.directory '*'
1515
root_dir="$(git rev-parse --show-toplevel)"
@@ -69,6 +69,8 @@ printf "* Installing dependencies (except PyTorch)\n"
6969
echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml"
7070
cat "${this_dir}/environment.yml"
7171

72+
conda install anaconda::cmake -y
73+
7274
export MUJOCO_GL=egl
7375
conda env config vars set \
7476
MAX_IDLE_COUNT=1000 \

sota-implementations/a2c/utils_atari.py

-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def make_parallel_env(env_name, num_envs, device, gym_backend, is_test=False):
7474
lambda: make_base_env(env_name, gym_backend=gym_backend, is_test=is_test),
7575
),
7676
serial_for_single=True,
77-
gym_backend=gym_backend,
7877
device=device,
7978
)
8079
env = TransformedEnv(env)

torchrl/data/datasets/atari_dqn.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,12 @@ def _download_and_preproc(self):
508508
if not os.listdir(tempdir):
509509
os.makedirs(tempdir, exist_ok=True)
510510
# get the list of runs
511+
try:
512+
subprocess.run(
513+
["gsutil", "version"], check=True, capture_output=True
514+
)
515+
except subprocess.CalledProcessError:
516+
raise RuntimeError("gsutil is not installed or not found in PATH.")
511517
command = f"gsutil -m ls -R gs://atari-replay-datasets/dqn/{self.dataset_id}/replay_logs"
512518
output = subprocess.run(
513519
command, shell=True, capture_output=True
@@ -520,9 +526,7 @@ def _download_and_preproc(self):
520526
self.remote_gz_files = self._list_runs(None, files)
521527
remote_gz_files = list(self.remote_gz_files)
522528
if not len(remote_gz_files):
523-
raise RuntimeError(
524-
"Could not load the file list. Did you install gsutil?"
525-
)
529+
raise RuntimeError("No files in file list.")
526530

527531
total_runs = remote_gz_files[-1]
528532
if self.num_procs == 0:

torchrl/envs/libs/smacv2.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def _make_specs(self, env: "smacv2.env.StarCraft2Env") -> None: # noqa: F821
228228
dtype=torch.bool,
229229
device=self.device,
230230
)
231-
self.action_spec = self._make_action_spec()
231+
self.full_action_spec = self._make_action_spec()
232232
self.observation_spec = self._make_observation_spec()
233233

234234
def _init_env(self) -> None:
@@ -356,7 +356,7 @@ def _reset(
356356
def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
357357
# perform actions
358358
action = tensordict.get(("agents", "action"))
359-
action_np = self.action_spec.to_numpy(action)
359+
action_np = self.full_action_spec[self.action_key].to_numpy(action)
360360

361361
# Actions are validated by the environment.
362362
try:

torchrl/envs/transforms/transforms.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -8798,29 +8798,37 @@ def __init__(
87988798
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
87998799
raise RuntimeError(FORWARD_NOT_IMPLEMENTED.format(type(self)))
88008800

8801+
@property
8802+
def action_spec(self):
8803+
action_spec = self.container.full_action_spec
8804+
keys = self.container.action_keys
8805+
if len(keys) == 1:
8806+
action_spec = action_spec[keys[0]]
8807+
else:
8808+
raise ValueError(
8809+
f"Too many action keys for {self.__class__.__name__}: {keys=}"
8810+
)
8811+
if not isinstance(action_spec, self.ACCEPTED_SPECS):
8812+
raise ValueError(
8813+
self.SPEC_TYPE_ERROR.format(self.ACCEPTED_SPECS, type(action_spec))
8814+
)
8815+
return action_spec
8816+
88018817
def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
88028818
parent = self.parent
88038819
if parent is None:
88048820
raise RuntimeError(
88058821
f"{type(self)}.parent cannot be None: make sure this transform is executed within an environment."
88068822
)
88078823
mask = next_tensordict.get(self.in_keys[1])
8808-
action_spec = self.container.action_spec
8809-
if not isinstance(action_spec, self.ACCEPTED_SPECS):
8810-
raise ValueError(
8811-
self.SPEC_TYPE_ERROR.format(self.ACCEPTED_SPECS, type(action_spec))
8812-
)
8824+
action_spec = self.action_spec
88138825
action_spec.update_mask(mask.to(action_spec.device))
88148826
return next_tensordict
88158827

88168828
def _reset(
88178829
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
88188830
) -> TensorDictBase:
8819-
action_spec = self.container.action_spec
8820-
if not isinstance(action_spec, self.ACCEPTED_SPECS):
8821-
raise ValueError(
8822-
self.SPEC_TYPE_ERROR.format(self.ACCEPTED_SPECS, type(action_spec))
8823-
)
8831+
action_spec = self.action_spec
88248832
mask = tensordict.get(self.in_keys[1], None)
88258833
if mask is not None:
88268834
mask = mask.to(action_spec.device)

0 commit comments

Comments
 (0)