diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..8f61a8e --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# SCM syntax highlighting +pixi.lock linguist-language=YAML linguist-generated=true diff --git a/.gitignore b/.gitignore index c4b8a47..4620d00 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +pixi.lock # Created by https://www.gitignore.io/api/linux,python,windows,pycharm+all,visualstudiocode # Edit at https://www.gitignore.io/?templates=linux,python,windows,pycharm+all,visualstudiocode @@ -245,4 +246,7 @@ $RECYCLE.BIN/ # Windows shortcuts *.lnk -# End of https://www.gitignore.io/api/linux,python,windows,pycharm+all,visualstudiocode \ No newline at end of file +# End of https://www.gitignore.io/api/linux,python,windows,pycharm+all,visualstudiocode +# pixi environments +.pixi +*.egg-info diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..7ecf543 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,48 @@ +[project] +name = "rware" +version = "2.0.0" +description = "Multi-Robot Warehouse environment for reinforcement learning" +readme = { content-type = "text/markdown", file = "README.md" } +maintainers = [{ name = "Filippos Christianos" }] +classifiers = [ + "Intended Audience :: Developers", + "Programming Language :: Python :: 3.7", +] +requires-python = ">=3.7" +urls = { github = "https://github.com/semitable/robotic-warehouse" } +dependencies = ["numpy", "gymnasium", "pyglet<2", "networkx"] + +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[project.optional-dependencies] +test = ["pytest"] +pettingzoo = ["pettingzoo"] + +[tool.setuptools.packages.find] +exclude = ["contrib", "docs", "tests"] + +# pixi +[tool.pixi.workspace] +channels = ["conda-forge"] +platforms = ["linux-64"] +preview = ["pixi-build"] + +[tool.pixi.environments] +default = { solve-group = "default" } +test = { features = ["test", "pettingzoo"], solve-group = "default" } + +[tool.pixi.pypi-dependencies] +rware = { path = ".", editable = true } + +[tool.pixi.package] +name = "rware" +version = "2.0.1" + +[tool.pixi.build-system] +build-backend = { name = "pixi-build-python", version = "*" } +channels = ["pixi-build-backends", "conda-forge"] + +[tool.pixi.feature.test.tasks] +test = "pytest" diff --git a/rware/pettingzoo.py b/rware/pettingzoo.py new file mode 100644 index 0000000..45be8d5 --- /dev/null +++ b/rware/pettingzoo.py @@ -0,0 +1,98 @@ +from typing import Dict, Tuple, List, Optional +import warnings + +import gymnasium as gym +import numpy as np +from pettingzoo import ParallelEnv + +from .warehouse import Warehouse + +# ID are str(integers), which represent the agent.id (agent idx+1) in env.agents. +# Set to str for compatability with TorchRL. +AgentID = str +# TODO: Refactor Action object to include the message bits. +ActionType = object +ObsType = np.ndarray + + +def to_agentid_dict(data: List): + return {str(i + 1): x for i, x in enumerate(data)} + + +class PettingZooWrapper(ParallelEnv): + """Wraps a Warehouse Env object to be compatible with the PettingZoo ParallelEnv API.""" + + def __init__(self, env: Warehouse): + super().__init__() + self._env = env + self.agents = self.possible_agents = [] + self.observation_spaces = self.action_spaces = {} + + def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None): + obs, info = self._env.reset(seed, options) + obs = to_agentid_dict(obs) + info = {str(i + 1): {} for i in range(self._env.n_agents)} + # Reset agents and spaces + self.agents = [str(agent.id) for agent in self._env.agents] + self.possible_agents = self.agents + self.observation_spaces = { + agent_id: self.observation_space(agent_id) + for agent_id in [str(i + 1) for i in range(self._env.n_agents)] + } + self.action_spaces = { + agent_id: self.action_space(agent_id) + for agent_id in [str(i + 1) for i in range(self._env.n_agents)] + } + return obs, info + + def step(self, actions: dict[AgentID, ActionType]) -> Tuple[ + dict[AgentID, ObsType], + dict[AgentID, float], + dict[AgentID, bool], + dict[AgentID, bool], + dict[AgentID, dict], + ]: + # Unwrap to list of actions + actions_unwrapped = [(int(id_) - 1, action) for id_, action in actions.items()] + actions_unwrapped.sort(key=lambda x: x[0]) + actions_unwrapped = [x[1] for x in actions_unwrapped] + assert ( + len(actions_unwrapped) == self._env.n_agents + ), f"Incorrect number of actions provided. Expected {self._env.n_agents} but got {len(actions_unwrapped)}" + + # Step inner environment + obs, rewards, terminated, truncated, info = self._env.step(actions_unwrapped) + + # Transform to PettingZoo output + obs = to_agentid_dict(obs) + rewards = to_agentid_dict(rewards) + if terminated or truncated: + self.agents = [] # PettingZoo requires agents to be removed + terminated = to_agentid_dict([terminated for _ in range(self._env.n_agents)]) + truncated = to_agentid_dict([truncated for _ in range(self._env.n_agents)]) + if len(info) != 0: + warnings.warn( + "Error: expected info dict to be empty. PettingZooWrapper is likely out of date." + ) + info = {str(i + 1): {} for i in range(self._env.n_agents)} + + return obs, rewards, terminated, truncated, info + + def render(self): + return self._env.render() + + def close(self) -> None: + self._env.close() + + def state(self): + return self._env.get_global_image() + + def observation_space(self, agent: AgentID) -> gym.spaces.Space: + space = self._env.observation_space + assert isinstance(space, gym.spaces.Tuple) + return space[int(agent) - 1] + + def action_space(self, agent: AgentID) -> gym.spaces.Space: + space = self._env.action_space + assert isinstance(space, gym.spaces.Tuple) + return space[int(agent) - 1] diff --git a/setup.py b/setup.py index 3458de6..be91643 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,6 @@ "pyglet<2", "networkx", ], - extras_require={"test": ["pytest"]}, + extras_require={"test": ["pytest"], "pettingzoo": ["pettingzoo"]}, include_package_data=True, ) diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 0000000..53517e8 --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,53 @@ +from typing import Optional +import importlib +import pytest + +from rware.warehouse import Warehouse, RewardType, ObservationType + +_has_pettingzoo = importlib.util.find_spec("pettingzoo") is not None +if _has_pettingzoo: + from pettingzoo.test import parallel_api_test + from rware.pettingzoo import PettingZooWrapper + + +@pytest.mark.parametrize("n_agents", [1, 3]) +@pytest.mark.parametrize("msg_bits", [0, 1]) +@pytest.mark.parametrize("sensor_range", [1, 3]) +@pytest.mark.parametrize("max_inactivity_steps", [None, 10]) +@pytest.mark.parametrize("reward_type", [RewardType.GLOBAL, RewardType.INDIVIDUAL]) +@pytest.mark.parametrize( + "observation_type", + [ + ObservationType.DICT, + ObservationType.IMAGE, + ObservationType.IMAGE_DICT, + ObservationType.FLATTENED, + ], +) +def test_pettingzoo_wrapper( + n_agents: int, + msg_bits: int, + sensor_range: int, + max_inactivity_steps: Optional[int], + reward_type: RewardType, + observation_type: ObservationType, +): + if not _has_pettingzoo: + pytest.skip("PettingZoo not available.") + return + + env = Warehouse( + shelf_columns=1, + column_height=5, + shelf_rows=3, + n_agents=n_agents, + msg_bits=msg_bits, + sensor_range=sensor_range, + request_queue_size=5, + max_inactivity_steps=max_inactivity_steps, + max_steps=None, + reward_type=reward_type, + observation_type=observation_type, + ) + env = PettingZooWrapper(env) + parallel_api_test(env)