diff --git a/.github/workflows/commit_checks.yml b/.github/workflows/commit_checks.yml index ae07b02d..51979354 100644 --- a/.github/workflows/commit_checks.yml +++ b/.github/workflows/commit_checks.yml @@ -7,7 +7,6 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: psf/black@stable - uses: chartboost/ruff-action@v1 - uses: isort/isort-action@v1 diff --git a/docs/source/release_notes.rst b/docs/source/release_notes.rst index 8e110b61..c6cc7a68 100644 --- a/docs/source/release_notes.rst +++ b/docs/source/release_notes.rst @@ -49,6 +49,8 @@ Development - |version| * Fix a bug where a satellite's initial data was never added to the rewarder. * Fix a bug where using multiple of the same rewarder would cause some settings to be overwritten. +* Add the ability to define metaagents that concatenate satellite action and observation + spaces in the environment. Version 1.1.0 diff --git a/src/bsk_rl/gym.py b/src/bsk_rl/gym.py index c2f34a37..5cc00715 100644 --- a/src/bsk_rl/gym.py +++ b/src/bsk_rl/gym.py @@ -32,6 +32,24 @@ NO_ACTION = int(2**31) - 1 +def is_no_action(action): + """Check if the action is a no-action placeholder.""" + if action is None: + return True + if isinstance(action, (int, np.integer)) and action == NO_ACTION: + return True + if isinstance(action, (np.ndarray, list, tuple)) and np.allclose(action, NO_ACTION): + return True + return False + + +def no_action_like(action): + """Generate an action that is the same type and shape as the no-action placeholder.""" + if isinstance(action, (int, np.integer)): + return NO_ACTION + return action * 0 + NO_ACTION # Aggressively try to convert while retaining type + + class GeneralSatelliteTasking(Env, Generic[SatObs, SatAct]): def __init__( self, @@ -122,6 +140,7 @@ def __init__( satellites = [satellites] self.satellites = deepcopy(satellites) + self.dtype = dtype while True: for satellite in self.satellites: if [sat.name for sat in self.satellites].count(satellite.name) > 1: @@ -135,8 +154,8 @@ def __init__( sat_rename.name = new_name # Update satellite observation dtypes - if dtype is not None: - satellite.observation_builder.dtype = dtype + if self.dtype is not None: + satellite.observation_builder.dtype = self.dtype # Check if all satellite names are unique sat_names = [sat.name for sat in self.satellites] @@ -212,6 +231,20 @@ class MinimumEnv(*types): return MinimumEnv + def get_satellite(self, name: str) -> "Satellite": + """Get a satellite by name. + + Args: + name: Name of the satellite to retrieve. + + Returns: + The satellite object with the specified name. + """ + for sat in self.satellites: + if sat.name == name: + return sat + raise ValueError(f"Satellite with name '{name}' not found.") + def _configure_logging(self, log_level, log_dir=None): if isinstance(log_level, str): log_level = log_level.upper() @@ -446,10 +479,8 @@ def _step(self, actions: MultiSatAct) -> None: raise ValueError("There must be the same number of actions and satellites") for satellite, action in zip(self.satellites, actions): satellite.info = [] # reset satellite info log - if action is not None and ( - not isinstance(action, int) - or action != NO_ACTION # TODO improve for non-discrete actions - ): + + if not is_no_action(action): satellite.requires_retasking = False satellite.set_action(action) if not satellite.is_alive(): @@ -569,15 +600,56 @@ def _get_info(self) -> dict[str, Any]: class ConstellationTasking( GeneralSatelliteTasking, ParallelEnv, Generic[SatObs, SatAct, AgentID] ): - def __init__(self, *args, **kwargs) -> None: + def __init__( + self, + *args, + meta_agent_groupings: Optional[dict[AgentID, list[str]]] = None, + only_retask_idle_meta_agent_members: bool = False, + **kwargs, + ) -> None: """Implements the `PettingZoo `_ parallel API for the :class:`GeneralSatelliteTasking` environment. Args: *args: Passed to :class:`GeneralSatelliteTasking`. + meta_agent_groupings: A dictionary mapping agent names to lists of satellite names. + only_retask_idle_meta_agent_members: If True, only satellites in a meta agent + that require retasking will receive actions. Other actions in the meta + agent output will be ignored. This may also be useful to control in the + training pipeline. **kwargs: Passed to :class:`GeneralSatelliteTasking`. """ super().__init__(*args, **kwargs) + self.only_retask_idle_meta_agent_members = only_retask_idle_meta_agent_members + + if meta_agent_groupings is None: + meta_agent_groupings = {} + + sats_in_meta_agents = sum(meta_agent_groupings.values(), []) + for sat in self.satellites: + if sat.name not in sats_in_meta_agents: + meta_agent_groupings[sat.name] = [sat.name] + + self.meta_agent_groupings: dict[AgentID, list[Satellite]] = { + name: [self.get_satellite(member) for member in members] + for name, members in meta_agent_groupings.items() + } + + def _validate_meta_agent_groupings(self): + """Validate that meta agent groupings consist of similar action spaces.""" + for name, members in self.meta_agent_groupings.items(): + if len(members) == 0: + raise ValueError(f"Meta agent '{name}' has no members.") + action_space_type = type(members[0].action_space) + for member in members: + assert isinstance(member.action_space, action_space_type), ( + f"Meta agent '{name}' has members with different action space types." + ) + assert isinstance(member.observation_space, spaces.Box), ( + f"Only Box observation spaces are supported for meta agents, " + f"but member '{member.name}' has {type(member.observation_space)}." + ) + def reset( self, seed: int | None = None, options=None ) -> tuple[MultiSatObs, dict[str, Any]]: @@ -595,9 +667,10 @@ def agents(self) -> list[AgentID]: ): truncated = super()._get_truncated() agents = [ - satellite.name - for satellite in self.satellites - if (satellite.is_alive() and not truncated) + agent + for agent, satellites in self.meta_agent_groupings.items() + if all(satellite.is_alive() for satellite in satellites) + and not truncated ] self._agents_last_compute_time = self.simulator.sim_time self._agents_cache = agents @@ -613,7 +686,7 @@ def num_agents(self) -> int: @property def possible_agents(self) -> list[AgentID]: """Return the list of all possible agents.""" - return [satellite.name for satellite in self.satellites] + return list(self.meta_agent_groupings.keys()) @property def max_num_agents(self) -> int: @@ -628,10 +701,28 @@ def previously_dead(self) -> list[AgentID]: @property def observation_spaces(self) -> dict[AgentID, spaces.Box]: """Return the observation space for each agent.""" - return { - agent: obs_space - for agent, obs_space in zip(self.possible_agents, super().observation_space) - } + super().observation_space + self._validate_meta_agent_groupings() + + obs_spaces = {} + for agent, satellites in self.meta_agent_groupings.items(): + if len(satellites) == 1: + obs_spaces[agent] = satellites[0].observation_space + else: + obs_spaces[agent] = spaces.Box( + low=np.concatenate( + [sat.observation_space.low for sat in satellites] + ), + high=np.concatenate( + [sat.observation_space.high for sat in satellites] + ), + dtype=( + self.dtype + if self.dtype + else satellites[0].observation_space.dtype + ), + ) + return obs_spaces @functools.lru_cache(maxsize=None) def observation_space(self, agent: AgentID) -> spaces.Space[SatObs]: @@ -641,44 +732,80 @@ def observation_space(self, agent: AgentID) -> spaces.Space[SatObs]: @property def action_spaces(self) -> dict[AgentID, spaces.Space[SatAct]]: """Return the action space for each agent.""" - return { - agent: act_space - for agent, act_space in zip(self.possible_agents, super().action_space) - } + act_spaces = {} + for agent, satellites in self.meta_agent_groupings.items(): + if len(satellites) == 1: + act_spaces[agent] = satellites[0].action_space + else: + if isinstance(satellites[0].action_space, spaces.Discrete): + act_spaces[agent] = spaces.MultiDiscrete( + [sat.action_space.n for sat in satellites] + ) + elif isinstance(satellites[0].action_space, spaces.Box): + low = np.concatenate([sat.action_space.low for sat in satellites]) + high = np.concatenate([sat.action_space.high for sat in satellites]) + act_spaces[agent] = spaces.Box( + low=low, + high=high, + dtype=( + self.dtype + if self.dtype + else satellites[0].action_space.dtype + ), + ) + return act_spaces @functools.lru_cache(maxsize=None) def action_space(self, agent: AgentID) -> spaces.Space[SatAct]: """Return the action space for a certain agent.""" return self.action_spaces[agent] + def _requires_retasking(self, agent: AgentID) -> bool: + """Check if the agent requires retasking.""" + return any( + satellite.requires_retasking + for satellite in self.meta_agent_groupings[agent] + ) + def _get_obs(self) -> dict[AgentID, SatObs]: """Format the observation per the PettingZoo Parallel API.""" - if self.generate_obs_retasking_only: - return { - agent: ( - satellite.get_obs() - if satellite.requires_retasking - else self.observation_space(agent).sample() * 0 - ) - for agent, satellite in zip(self.possible_agents, self.satellites) - if agent not in self.previously_dead - } - else: - return { - agent: satellite.get_obs() - for agent, satellite in zip(self.possible_agents, self.satellites) - if agent not in self.previously_dead - } + obs = {} + for agent, satellites in self.meta_agent_groupings.items(): + # Don't generate observations for agents that are dead + if agent in self.previously_dead: + continue + + if self.generate_obs_retasking_only and not self._requires_retasking(agent): + agent_obs = [ + satellite.observation_space.sample() * 0 for satellite in satellites + ] + else: + agent_obs = [satellite.get_obs() for satellite in satellites] + + if len(agent_obs) == 1: + obs[agent] = agent_obs[0] + else: + obs[agent] = np.concatenate(agent_obs) + + return obs def _get_reward(self) -> dict[AgentID, float]: """Format the reward per the PettingZoo Parallel API.""" - reward = deepcopy(self.reward_dict) - for agent, satellite in zip(self.possible_agents, self.satellites): + satellite_rewards = { + self.get_satellite(name): reward + for name, reward in self.reward_dict.items() + } + for satellite in self.satellites: if not satellite.is_alive(): - if agent in reward: - reward[agent] += self.failure_penalty + if satellite in satellite_rewards: + satellite_rewards[satellite] += self.failure_penalty else: - reward[agent] = self.failure_penalty + satellite_rewards[satellite] = self.failure_penalty + + reward = { + agent: sum(satellite_rewards[sat] for sat in sats) + for agent, sats in self.meta_agent_groupings.items() + } reward_keys = list(reward.keys()) for agent in reward_keys: @@ -697,9 +824,11 @@ def _get_terminated(self) -> dict[AgentID, bool]: } else: return { - agent: not satellite.is_alive() - or self.rewarder.is_terminated(satellite) - for agent, satellite in zip(self.possible_agents, self.satellites) + agent: any( + not sat.is_alive() or self.rewarder.is_terminated(sat) + for sat in satellites + ) + for agent, satellites in self.meta_agent_groupings.items() if agent not in self.previously_dead } @@ -707,22 +836,38 @@ def _get_truncated(self) -> dict[AgentID, bool]: """Format truncations per the PettingZoo Parallel API.""" truncated = super()._get_truncated() return { - agent: truncated or self.rewarder.is_truncated(satellite) - for agent, satellite in zip(self.possible_agents, self.satellites) + agent: truncated + or any(self.rewarder.is_truncated(sat) for sat in satellites) + for agent, satellites in self.meta_agent_groupings.items() if agent not in self.previously_dead } def _get_info(self) -> dict[AgentID, dict]: """Format info per the PettingZoo Parallel API.""" - info = super()._get_info() - for agent in self.possible_agents: - if agent in self.previously_dead: - del info[agent] - - common = {k: v for k, v in info.items() if k not in self.possible_agents} - for k in common.keys(): - del info[k] + info_per_sat = super()._get_info() + + # Group info by agent + info = {} + for agent, satellites in self.meta_agent_groupings.items(): + if agent not in self.previously_dead: + info[agent] = { + "requires_retasking": any( + info_per_sat[sat.name]["requires_retasking"] + for sat in satellites + ) + } + if len(satellites) > 1: + for satellite in satellites: + info[agent][satellite.name] = info_per_sat[satellite.name] + + # Identify common info + common = { + k: v + for k, v in info_per_sat.items() + if k not in [sat.name for sat in self.satellites] + } + # Pass common info to all agents and to __common__ for agent in info.keys(): for k, v in common.items(): info[agent][k] = v @@ -730,6 +875,36 @@ def _get_info(self) -> dict[AgentID, dict]: return info + def _decompose_meta_action( + self, agent: AgentID, action: SatAct + ) -> dict[Satellite, SatAct]: + """Decompose a meta agent action into satellite actions.""" + sat_to_action_map = {} + i = 0 + for satellite in self.meta_agent_groupings[agent]: + action_len = satellite.action_space.shape + if len(action_len) == 0: + action_len = 1 + else: + action_len = action_len[0] + + if isinstance(action, (list, tuple, np.ndarray)): + if ( + not self.only_retask_idle_meta_agent_members + or satellite.requires_retasking + ): + if action_len == 1: + sat_to_action_map[satellite] = action[i] + else: + sat_to_action_map[satellite] = action[i : i + action_len] + else: + sat_to_action_map[satellite] = None + i += action_len + else: + sat_to_action_map[satellite] = action + + return sat_to_action_map + def step( self, actions: dict[AgentID, SatAct], @@ -745,24 +920,31 @@ def step( previous_alive = self.agents + sat_to_action_map = {} + for agent, action in actions.items(): + if len(self.meta_agent_groupings[agent]) > 1: + logger.info(f"Decomposing action for meta agent {agent}") + sat_to_action_map.update(self._decompose_meta_action(agent, action)) + action_vector = [] - for agent in self.possible_agents: - if agent in actions.keys(): - action_vector.append(actions[agent]) + for satellite in self.satellites: + if satellite in sat_to_action_map: + action_vector.append(sat_to_action_map[satellite]) else: action_vector.append(None) self._step(action_vector) self.newly_dead = list(set(previous_alive) - set(self.agents)) - for satellite in self.newly_dead: - for attr in [ - "_timed_terminal_event_name", - "_image_event_name", - ]: - event_name = getattr(satellite, attr, None) - if event_name is not None: - self.simulator.delete_event(event_name) + for agent in self.newly_dead: + for satellite in self.meta_agent_groupings[agent]: + for attr in [ + "_timed_terminal_event_name", + "_image_event_name", + ]: + event_name = getattr(satellite, attr, None) + if event_name is not None: + self.simulator.delete_event(event_name) observation = self._get_obs() reward = self._get_reward() diff --git a/src/bsk_rl/utils/rllib/__init__.py b/src/bsk_rl/utils/rllib/__init__.py index 1ebb5dc4..2e440dec 100644 --- a/src/bsk_rl/utils/rllib/__init__.py +++ b/src/bsk_rl/utils/rllib/__init__.py @@ -12,66 +12,61 @@ that are arguments to :class:`EpisodeDataWrapper` can be set in the ``env_config`` dictionary. """ -import json from pathlib import Path +from typing import Callable import numpy as np import torch -from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog -from ray.rllib.core.models.base import ACTOR, ENCODER_OUT +from ray.rllib.core.columns import Columns +from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.spaces.space_utils import unsquash_action from ray.tune.registry import register_env -from scipy.special import softmax from bsk_rl import ConstellationTasking, GeneralSatelliteTasking, SatelliteTasking from bsk_rl.utils.rllib.callbacks import EpisodeDataParallelWrapper, EpisodeDataWrapper -def load_torch_mlp_policy(policy_path: str, env: GeneralSatelliteTasking): +def load_torch_mlp_policy( + policy_path: Path, policy_name: str = "default_agent" +) -> Callable: """Load a PyTorch policy from a saved model. Args: - policy_path: The path to the saved model. - env: The environment to load the policy for. + policy_path: Path to the directory containing the policy checkpoint. + policy_name: Name of the policy to load from the checkpoint. """ - policy_path = Path(policy_path) - state_dict = torch.load(policy_path / "module_state_dir" / "module_state.pt") - with open(policy_path / "rl_module_metadata.json") as f: - module_config = json.load(f)["module_spec_dict"]["module_config"] - model_config_dict = module_config["model_config_dict"] - - cat = PPOCatalog( - env.satellites[0].observation_space, # TODO do this by agent ID - env.satellites[0].action_space, - model_config_dict, + rl_module = RLModule.from_checkpoint( + Path(policy_path) / "learner_group" / "learner" / "rl_module" / policy_name ) - encoder = cat.build_actor_critic_encoder("torch") - pi_head = cat.build_pi_head("torch") - - encoder_state_dict = { - ".".join(k.split(".")[1:]): torch.from_numpy(v) - for k, v in state_dict.items() - if k.split(".")[0] == "encoder" - } - encoder.load_state_dict(encoder_state_dict) - pi_state_dict = { - ".".join(k.split(".")[1:]): torch.from_numpy(v) - for k, v in state_dict.items() - if k.split(".")[0] == "pi" - } - pi_head.load_state_dict(pi_state_dict) - - def policy(obs, deterministic=True): - action_logits = pi_head( - encoder(dict(obs=torch.from_numpy(obs[None, :])))[ENCODER_OUT][ACTOR] + cat = rl_module.config.get_catalog() + action_dist_cls = cat.get_action_dist_cls(framework="torch") + + def policy(obs: list[float], deterministic: bool = True) -> np.ndarray: + """Policy function that takes observations and returns actions. + + Args: + obs: Observation vector. + deterministic: If True, use loc for action selection; otherwise, sample from the action distribution. + """ + obs = np.array(obs, dtype=np.float32) + action_dist_params = rl_module.forward_inference( + dict(obs=torch.tensor(obs)[None, :]) ) + action_dist = action_dist_cls.from_logits( + action_dist_params[Columns.ACTION_DIST_INPUTS] + ) + if deterministic: - return action_logits.argmax().item() + action_squashed = action_dist.to_deterministic().sample() else: - return np.random.choice( - np.arange(0, len(action_logits[0])), - p=softmax(action_logits.detach())[0, :], - ) + action_squashed = action_dist.sample() + + action_squashed = convert_to_numpy(action_squashed[0]) + action = unsquash_action(action_squashed, rl_module.config.action_space) + + return action return policy diff --git a/src/bsk_rl/utils/rllib/discounting.py b/src/bsk_rl/utils/rllib/discounting.py index c6169922..b1542891 100644 --- a/src/bsk_rl/utils/rllib/discounting.py +++ b/src/bsk_rl/utils/rllib/discounting.py @@ -9,6 +9,7 @@ :class:`MakeAddedStepActionValid`, and :class:`CondenseMultiStepActions`). """ +from logging import getLogger from typing import Any, List, Literal, Optional import numpy as np @@ -28,6 +29,9 @@ from ray.rllib.utils.typing import EpisodeType from bsk_rl import NO_ACTION +from bsk_rl.gym import is_no_action, no_action_like + +logger = getLogger(__name__) class ContinuePreviousAction(ConnectorV2): @@ -61,10 +65,58 @@ def __call__( if id_tuple[1] == sa_episode.agent_id ] if len(id_tuples) == 0: - return data + continue else: id_tuple = id_tuples[0] - data[Columns.ACTIONS][id_tuple][0] = NO_ACTION + + data[Columns.ACTIONS][id_tuple][0] = ( + data[Columns.ACTIONS][id_tuple][0] * 0 + NO_ACTION + ) + + return data + + +class ContinuePreviousActionAppended(ConnectorV2): + def __init__(self, *args, **kwargs): + """Override actions with ``NO_ACTION`` on connector pass if the agent does not require retasking. + + This additional connector must be appended after NormalizeAndClipActions in continuous + action spaces. For example: + + .. code-block:: python + + # Append extra connector to the pipeline + old_connector_builder = config.build_module_to_env_connector + + def new_connector_builder(env): + pipeline = old_connector_builder(env) + after = pipeline.insert_after( + "NormalizeAndClipActions", ContinuePreviousActionAppended() + ) + return pipeline + + config.build_module_to_env_connector = new_connector_builder + + """ + super().__init__(*args, **kwargs) + + def __call__( + self, + *, + data: Optional[Any], + episodes: List[EpisodeType], + **_, + ) -> Any: + """Override actions with ``NO_ACTION`` on connector pass. + + :meta private: + """ + for id_tuple in data[Columns.ACTIONS_FOR_ENV].keys(): + if is_no_action(data[Columns.ACTIONS][id_tuple][0]): + data[Columns.ACTIONS_FOR_ENV][id_tuple][0] = no_action_like( + data[Columns.ACTIONS_FOR_ENV][id_tuple][0] + ) + return data @@ -114,16 +166,9 @@ def __call__( for episode in self.single_agent_episode_iterator( episodes, agents_that_stepped_only=False ): - last_action = NO_ACTION - for action in reversed(episode.actions): - if isinstance(action, int) and last_action == NO_ACTION: - last_action = action - else: - break - if last_action == NO_ACTION: - last_action = 0 - episode.actions[-1] = last_action - + # Arbitrary filler, will be deleted by remove_last_ts_from_episodes_and_restore_truncateds + if is_no_action(episode.actions[-1]): + episode.actions[-1] = episode._action_space.sample() episode.validate() return data @@ -148,12 +193,9 @@ def __call__( for episode in self.single_agent_episode_iterator( episodes, agents_that_stepped_only=False ): - if NO_ACTION not in episode.actions: - continue - action_idx = list( np.argwhere( - [action != NO_ACTION for action in episode.actions.data] + [not is_no_action(action) for action in episode.actions.data] ).flatten() ) obs_idx = action_idx.copy() @@ -162,7 +204,7 @@ def __call__( lookback = episode.actions.data[: episode.actions.lookback] new_lookback = episode.actions.lookback for action in lookback: - if action == NO_ACTION: + if is_no_action(action): new_lookback -= 1 else: break diff --git a/tests/integration/sim/test_int_dynamics.py b/tests/integration/sim/test_int_dynamics.py index 0f7aa0a3..6a28e9ef 100644 --- a/tests/integration/sim/test_int_dynamics.py +++ b/tests/integration/sim/test_int_dynamics.py @@ -149,7 +149,7 @@ class CollisionSat(sats.Satellite): env.reset() - env.step(dict(Collision1=0, Collision2=0, Collision3=0)) + env.step(dict(Collision1=0, Collision2=0, NoCollision3=0)) sat1 = env.unwrapped.satellites[0] sat2 = env.unwrapped.satellites[1] diff --git a/tests/integration/test_int_full_environments.py b/tests/integration/test_int_full_environments.py index 1cf452cc..44acbe91 100644 --- a/tests/integration/test_int_full_environments.py +++ b/tests/integration/test_int_full_environments.py @@ -71,6 +71,44 @@ class FullFeaturedSatellite(sats.ImagingSatellite): time_limit=5700.0, ) +parallel_meta_env = ConstellationTasking( + satellites=[ + FullFeaturedSatellite( + "Sentinel-2A", + sat_args=FullFeaturedSatellite.default_sat_args( + oe=random_orbit, + imageAttErrorRequirement=0.01, + imageRateErrorRequirement=0.01, + ), + ), + FullFeaturedSatellite( + "Sentinel-2B", + sat_args=FullFeaturedSatellite.default_sat_args( + oe=random_orbit, + imageAttErrorRequirement=0.01, + imageRateErrorRequirement=0.01, + ), + ), + FullFeaturedSatellite( + "Sentinel-2C", + sat_args=FullFeaturedSatellite.default_sat_args( + oe=random_orbit, + imageAttErrorRequirement=0.01, + imageRateErrorRequirement=0.01, + ), + ), + ], + scenario=scene.UniformTargets(n_targets=1000), + rewarder=data.UniqueImageReward(), + sim_rate=0.5, + max_step_duration=1e9, + time_limit=5700.0, + meta_agent_groupings={ + "DoubleSat": ["Sentinel-2A", "Sentinel-2C"], + "SingleSat": ["Sentinel-2B"], + }, +) + @pytest.mark.parametrize("env", [multi_env]) def test_reproducibility(env): @@ -100,11 +138,12 @@ def test_reproducibility(env): @pytest.mark.repeat(5) -def test_parallel_api(): +@pytest.mark.parametrize("env", [parallel_env, parallel_meta_env]) +def test_parallel_api(env): with pytest.warns(UserWarning): # expect an erroneous warning about the info dict due to our additional info try: - parallel_api_test(parallel_env) + parallel_api_test(env) except AssertionError as e: if "agent cannot be revived once dead" in str(e): warn(f"'{e}' is a known issue (#59)") diff --git a/tests/unittest/test_gym_env.py b/tests/unittest/test_gym_env.py index b0db7cc6..7fbc29ae 100644 --- a/tests/unittest/test_gym_env.py +++ b/tests/unittest/test_gym_env.py @@ -503,7 +503,8 @@ def test_action_spaces(self): def test_obs_spaces(self): env = ConstellationTasking( satellites=[ - MagicMock(observation_space=spaces.Discrete(i + 1)) for i in range(3) + MagicMock(observation_space=spaces.Box(low=0, high=i + 1, shape=(1,))) + for i in range(3) ], world_type=MagicMock(), scenario=MagicMock(), @@ -512,9 +513,9 @@ def test_obs_spaces(self): env.unwrapped.simulator = MagicMock() env.reset = MagicMock() assert env.observation_spaces == { - env.unwrapped.satellites[0].name: spaces.Discrete(1), - env.unwrapped.satellites[1].name: spaces.Discrete(2), - env.unwrapped.satellites[2].name: spaces.Discrete(3), + env.unwrapped.satellites[0].name: spaces.Box(low=0, high=1, shape=(1,)), + env.unwrapped.satellites[1].name: spaces.Box(low=0, high=2, shape=(1,)), + env.unwrapped.satellites[2].name: spaces.Box(low=0, high=3, shape=(1,)), } @patch(