Skip to content

Commit f41888d

Browse files
committed
[Environment] Fix lib CI failures
ghstack-source-id: 13a9f82b2d8fdbb1ec724743bbdd809f52018923 Pull-Request-resolved: #2923
1 parent 243fe3b commit f41888d

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

sota-implementations/impala/utils.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,7 @@ def make_ppo_modules_pixels(proof_environment):
143143

144144
def make_ppo_models(env_name, gym_backend):
145145

146-
proof_environment = make_env(
147-
env_name, device="cpu", gym_backend=gym_backend
148-
)
146+
proof_environment = make_env(env_name, device="cpu", gym_backend=gym_backend)
149147
common_module, policy_module, value_module = make_ppo_modules_pixels(
150148
proof_environment
151149
)

torchrl/data/datasets/atari_dqn.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,12 @@ def _download_and_preproc(self):
508508
if not os.listdir(tempdir):
509509
os.makedirs(tempdir, exist_ok=True)
510510
# get the list of runs
511+
try:
512+
subprocess.run(
513+
["gsutil", "version"], check=True, capture_output=True
514+
)
515+
except subprocess.CalledProcessError:
516+
raise RuntimeError("gsutil is not installed or not found in PATH.")
511517
command = f"gsutil -m ls -R gs://atari-replay-datasets/dqn/{self.dataset_id}/replay_logs"
512518
output = subprocess.run(
513519
command, shell=True, capture_output=True
@@ -520,9 +526,7 @@ def _download_and_preproc(self):
520526
self.remote_gz_files = self._list_runs(None, files)
521527
remote_gz_files = list(self.remote_gz_files)
522528
if not len(remote_gz_files):
523-
raise RuntimeError(
524-
"Could not load the file list. Did you install gsutil?"
525-
)
529+
raise RuntimeError("No files in file list.")
526530

527531
total_runs = remote_gz_files[-1]
528532
if self.num_procs == 0:

torchrl/envs/libs/smacv2.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def _make_specs(self, env: "smacv2.env.StarCraft2Env") -> None: # noqa: F821
228228
dtype=torch.bool,
229229
device=self.device,
230230
)
231-
self.action_spec = self._make_action_spec()
231+
self.full_action_spec = self._make_action_spec()
232232
self.observation_spec = self._make_observation_spec()
233233

234234
def _init_env(self) -> None:
@@ -356,7 +356,7 @@ def _reset(
356356
def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
357357
# perform actions
358358
action = tensordict.get(("agents", "action"))
359-
action_np = self.action_spec.to_numpy(action)
359+
action_np = self.full_action_spec[self.action_key].to_numpy(action)
360360

361361
# Actions are validated by the environment.
362362
try:

0 commit comments

Comments
 (0)