diff --git a/src/world_model/__init__.py b/src/world_model/__init__.py index 785ff53..2b9f46f 100644 --- a/src/world_model/__init__.py +++ b/src/world_model/__init__.py @@ -1,3 +1,3 @@ -from .world_model import WorldModel -from .mdp_world_model import MDPWorldModel -from .simple_gridworld import SimpleGridworld +# from .world_model import WorldModel +# from .mdp_world_model import MDPWorldModel +# from .simple_gridworld import SimpleGridworld diff --git a/src/world_model/mdp_world_model.py b/src/world_model/mdp_world_model.py index aff687d..a292719 100644 --- a/src/world_model/mdp_world_model.py +++ b/src/world_model/mdp_world_model.py @@ -1,6 +1,6 @@ from typing import Any, Generic, TypeVar, TypeVarTuple -from . import WorldModel +from .world_model import WorldModel ObsType = TypeVar("ObsType") Action = TypeVar("Action") diff --git a/src/world_model/objects.py b/src/world_model/objects.py new file mode 100644 index 0000000..b368137 --- /dev/null +++ b/src/world_model/objects.py @@ -0,0 +1,184 @@ +from enum import IntEnum +from typing import ( + ClassVar, + Generic, + Self, + TypeVar, +) + +from world_model.types import Direction, Location + +ObjState = TypeVar("ObjState", bound=IntEnum) + + +class ObjectType(Generic[ObjState]): + symbol: ClassVar[str] + immobile: ClassVar[bool] + + by_symbol: ClassVar[dict[str, type["ObjectType"]]] = {} + by_id: ClassVar[dict[int, "ObjectType"]] = {} + + def __init_subclass__(cls, symbol: str) -> None: + cls.symbol = symbol + cls.by_symbol[symbol] = cls + + @classmethod + def next_id(cls): + return len(cls.by_id) + 1 + + id: int + state: ObjState | None + """ None is interpreted as "initial state" """ + location: Location + + def __init__(self, location: Location, *, copying: Self | None = None) -> None: + if copying is not None: + self.id = copying.id + self.state = copying.state + self.location = location + else: + self.id = self.next_id() + self.by_id[self.id] = self + self.state = None + self.location = location + + def __str__(self) -> str: + return f"<{self.symbol}-{self.id}>" + + def __hash__(self) -> int: + return hash((self.id, self.state, self.location)) + + def flat_state(self) -> tuple: + return (*self.location, self.state.value if self.state else 0) + + def moved(self: Self, direction: Direction) -> "Self": + return type(self)(self.location + direction.value, copying=self) + + +class Wall(ObjectType, symbol="#"): + pass + + +class EmptySpace(ObjectType, symbol=" "): + def __init__(self, location: Location) -> None: + self.id = 0 + self.state = None + self.location = location + + def __str__(self) -> str: + return "< >" + + +class UnevenGround(ObjectType, symbol="~"): + # Agents/boxes might fall off to any side except to where agent came from, + # with equal probability + pass + + +class Pinnacle(ObjectType, symbol="^"): + # Pinnacle (Climbing on it will result in falling off to any side except + # to where agent came from, with equal probability) + pass + + +class Box(ObjectType, symbol="X"): + # can be pushed around but not pulled, can slide and fall off. + # Heavy, so agent can only push one at a time + pass + + +class Agent(ObjectType, symbol="A"): + collected_delta: float + + def __init__(self, location: Location, delta: float = 0.0, *, copying=None) -> None: + super().__init__(location, copying=copying) + self.collected_delta = delta + + def collect(self: Self, additional_delta: float) -> Self: + return type(self)( + self.location, self.collected_delta + additional_delta, copying=self + ) + + def moved(self: Self, direction: Direction) -> "Self": + res = super().moved(direction) + res.collected_delta = self.collected_delta + return res + + def flat_state(self) -> tuple: + return (*super().flat_state(), self.collected_delta) + + +class SlipperyGround(ObjectType, symbol="-"): + # Slippery ground (Agents and boxes might slide along in a straight line; + # after sliding by one tile, + # a coin is tossed to decide whether we slide another tile, and this is repeated + # until the coin shows heads or we hit an obstacle. + # All this motion takes place within a single time step.) + pass + + +class Goal(ObjectType, symbol="G"): + # Goal or exit door (acting while on it ends the episode) + pass + + +class EmptyToWallState(IntEnum): + WALL = 1 + + +class EmptyToWall(ObjectType[EmptyToWallState], symbol=","): + # Empty tile that turns into a wall after leaving it (so that one cannot go back) + def to_wall(self) -> Self: + n = type(self)(self.location, copying=self) + n.state = EmptyToWallState.WALL + return n + + +class DeltaState(IntEnum): + COLLECTED = 1 + + +class Delta(ObjectType[DeltaState], symbol="Δ"): + # positive or negative, can be collected once, does not end the episode + delta: float + + def __init__(self, location: Location, delta: float, *, copying=None) -> None: + super().__init__(location, copying=copying) + self.delta = delta + + def __str__(self) -> str: + return f"<Δ {int(self.delta)}>" + + def collect(self: Self) -> Self: + res = type(self)(self.location, self.delta, copying=self) + res.state = DeltaState.COLLECTED + return res + + def flat_state(self) -> tuple: + return (*super().flat_state(), self.delta) + + +class GlassState(IntEnum): + BROKEN = 1 + + +class Glass(ObjectType[GlassState], symbol="|"): + # A pane of glass, will break if anything moves into it from left or right, + # and can be pushed up or down + def break_glass(self): + res = self.moved(Direction.STAY) + res.state = GlassState.BROKEN + return res + + +class FragileState(IntEnum): + DESTROYED = 1 + + +class Fragile(ObjectType[FragileState], symbol="F"): + # A fragile object or organism (might move around on its own, + # is destroyed when stepped upon by the agent) + def destroy(self): + res = self.moved(Direction.STAY) + res.state = FragileState.DESTROYED + return res diff --git a/src/world_model/simple_gridworld.py b/src/world_model/simple_gridworld.py index 3022e22..7904adf 100644 --- a/src/world_model/simple_gridworld.py +++ b/src/world_model/simple_gridworld.py @@ -1,812 +1,772 @@ -from functools import cache -import os -from sre_parse import State -from typing import Generic, NamedTuple, Self, TypeVar, overload +from itertools import chain +from enum import IntEnum +from typing import ( + Generic, + Iterable, + Never, + NamedTuple, + Self, + Type, + TypeAlias, + TypeVar, + overload, +) +from logging import getLogger -from satisfia.util import distribution -from . import MDPWorldModel +import numpy as np +# import pygame +# from gymnasium import spaces -# based in large part on https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/ +from satisfia.util.distribution import categorical +from world_model.mdp_world_model import MDPWorldModel + +from .types import * +from .objects import * + + +log = getLogger(__name__) + +T = TypeVar("T") + + +def tup_replace_at(tup: tuple[T, ...], ix: int, val: T) -> tuple[T, ...]: + return tup[:ix] + (val,) + tup[ix + 1 :] -import numpy as np -import pygame -from gymnasium import spaces +class Cell(tuple[ObjectType, ...]): + def get(self, ot: Type[T] | tuple[Type[T], ...]) -> T | None: + for t in self: + if isinstance(t, ot): + return t + return None -unenterable_immobile_cell_types = ['#'] # can't run into walls -unenterable_mobile_object_types = ['A'] # can't run into agents -unsteady_cell_types = ['~', '^', '-'] -what_can_move_into_agent = ['A'] + def __str__(self) -> str: + n = len(self) + if n == 0: + return "< >" + elif n == 1: + return str(self[0]) + return f"<{''.join(str(o.id) for o in self)}>" -immobile_object_types = [',','Δ'] -mobile_constant_object_types = ['X','|','F'] -mobile_variable_object_types = [] + def __sub__(self, obj: ObjectType) -> "Cell": + return Cell(tuple(o for o in self if o is not obj)) -render_as_char_types = unsteady_cell_types + immobile_object_types + ['G'] + def __add__(self, obj: ObjectType) -> "Cell": # type: ignore[override] + return Cell(tuple((*self, obj))) -max_n_object_states = 2 +ObsType = TypeVar("ObsType") +Exact = bool +Probability = float +Action = Direction +StateChanges = dict[ObjectType, ObjectType] +Grid = tuple[tuple[Cell, ...], ...] -def set_entry(iterable, index, value): - if type(iterable) is tuple: - l = list(iterable) - l[index] = value - return tuple(l) - else: - iterable[index] = value - return iterable -def set_loc(locs, index, loc): - return set_entry(set_entry(locs, 2*index+1, loc[1]), 2*index, loc[0]) +def _print_changes(changes: StateChanges): + for bef, aft in changes.items(): + log.debug(f"{bef} ({bef.location}) -> {aft} ({aft.location})") -def get_loc(locs, index): - return (locs[2*index], locs[2*index+1]) -def state_embedding_for_distance(state): - """return an embedding of state where all entries -2 are replaced by -10000""" - return tuple(-10000 if x == -2 else x for x in state) +# based in large part on https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/ -class Location(NamedTuple): - x: int - y: int -class SimpleGWState(NamedTuple): +class SimpleGridworld(Generic[ObsType], MDPWorldModel[ObsType, Action, None]): + size: Location + grid: Grid + agents: list[Agent] t: int - locp: Location - locc: Location - immobiles_s: tuple[Location] - mobiles_s: tuple[Location] + max_episode_length: int + uneven_ground_prob: float + move_probability_F: float -ObsType = TypeVar("ObsType") -State= TypeVar("State") - -Action = int -class SimpleGridworld(Generic[ObsType, State], MDPWorldModel[ObsType, Action, State]): - """A world model of a simple MDP-type Gridworld environment. - - A *state* here is a tuple of integers encoding the following sequence of items, - each one encoded as one or two integers: - - - position 0: the current time step - - positions 1+2: the previous location x,y of the agent - - positions 3+4: the current location x,y of the agent - - positions 5...4+k: for each of k immobile objects with variable state, its state - - positions 5+k..4+k+2*l: for each of l mobile objects without a variable state, its location x,y - - positions 5+k+2*l...4+k+2*l+2*m: for each of m mobile objects with a variable state, its location x,y - - positions 5+k+2*l+2*m...4+k+2*l+3*m: for each of m mobile objects with a variable state, its state - - A *coordinate* in a location is encoded in as an integer, where: - - -2 means the object is not present - - -1 means the object is in the agent's inventory - - >= 0 is a coordinate in the grid, counted from top-left to bottom-right - - Objects are *ordered* by their initial location in the ascii-art grid representation in row-major order. - - The *grid* and the agent's and all objects' *initial locations* are given as a 2d array of characters, - each representing one cell of the grid, with the following character meanings: - - - already implemented: - - '#' (hash): wall - - ' ' (blank): empty space - - '~': Uneven ground (Agents/boxes might fall off to any side except to where agent came from, - with equal probability) - - '^': Pinnacle (Climbing on it will result in falling off to any side except to where agent came from, - with equal probability) - - 'A': agent's initial location - - 'X': Box (can be pushed around but not pulled, can slide and fall off. Heavy, so agent can only push one at a time) - - - not yet implemented, but are planned to be implemented in the future: - - ',': Empty tile that turns into a wall after leaving it (so that one cannot go back) - - '-': Slippery ground (Agents and boxes might slide along in a straight line; after sliding by one tile, - a coin is tossed to decide whether we slide another tile, and this is repeated - until the coin shows heads or we hit an obstacle. All this motion takes place within a single time step.) - - '%': Death trap (Episode ends when agent steps on it) - - '|': A pane of glass, will break if anything moves into it from left or right, and can be pushed up or down - - 'B': Button (can be stepped on) - - 'C': Collaborator (might move around) - - 'D': Door (can only be entered after having collected a key) - - 'E': Enemy (might move around on its own) - - 'F': A fragile object or organism (might move around on its own, is destroyed when stepped upon by the agent) - - 'f': A stationary even more fragile object that is destroyed when *anything* moves onto it - - 'Δ': Delta (positive or negative, can be collected once, does not end the episode) - - 'G': Goal or exit door (acting while on it ends the episode) - - 'I': (Potential) interruption (agent might get stuck in forever) - - 'K': Key (must be collected to be able to pass a door) - - 'O': Ball (when pushed, will move straight until meeting an obstacle) - - 'S': Supervisor (might move around on their own) - - 'T': Teleporter (sends the agent to some destination t) - - 't': Destination of a teleporter (stepping on it does nothing) - (TODO: compare with pycolab asciiart conventions, try to harmonize them, and add things that are missing) - - *Deltas* (rewards) can accrue from the following events: - - Time passing. This is specified by time_delta or a list time_deltas of length max_episode_length. - - The agent stepping onto a certain object. This is specified by a list object_deltas - ordered by the objects' initial locations in the ascii-art grid representation in row-major order. - - The agent currently being in a certain location. This is specified by - - another 2d array of characters, delta_grid, of the same size as the grid, - containing cell_codes with the following character meanings: - - ' ' (space): no Delta - - '': Delta as specified by cell_code2delta[''] - - a dictionary cell_code2delta listing the actual Delta values for each cell_code in that grid - Note that the delta accrues at each time point when the agent is in a cell, - not at the time point it steps onto it! - """ - - - ## parameters: - xygrid = None - """(2d array of characters) The grid as an array of strings, each string representing one row of the grid, - each character representing one cell of the grid""" - delta_xygrid = None - """(2d array of characters) codes for deltas (rewards) for each cell of the grid""" - cell_code2delta = None - """(dictionary) maps cell codes to deltas (rewards)""" - - n_immobile_objects = None - """The number of immobile objects with variable state.""" - n_mobile_constant_objects = None - """The number of mobile objects without a variable state.""" - n_mobile_variable_objects = None - """The number of mobile objects with a variable state.""" - max_episode_length = None - """The maximum number of steps in an episode.""" - initial_agent_location = None - """(pair of ints) The initial location of the agent starting with zero.""" - time_deltas = None - """(list of floats) The deltas (rewards) for each time step.""" - timeout_delta = None - """(float) The delta (reward) for the timeout event.""" - move_probability_F = None - """(float) The probability with which objects of type 'F' move uniformly at random.""" - - # additional attributes: - _state = None - """(singleton list of state) The current state encoded as a tuple of ints.""" - t = None - """The current time step.""" - _agent_location = None - """(pair of ints) The current location of the agent starting with zero.""" - _previous_agent_location = None - """(pair of ints) The previous location of the agent starting with zero.""" - - metadata = {"render_modes": ["human", "rgb_array"]} - - def __init__(self, render_mode = None, - grid = [['A','G']], - delta_grid = None, - cell_code2delta = {'1': 1}, - max_episode_length = 1e10, - time_deltas = [0], - timeout_delta = 0, - uneven_ground_prob = 0.25, - move_probability_F = 0, - fps = 4 - ): - - self.xygrid = xygrid = np.array(grid).T - self.delta_xygrid = delta_xygrid = np.array(delta_grid).T if delta_grid is not None else np.full(xygrid.shape, ' ') - self.cell_code2delta = cell_code2delta - self.max_episode_length = max_episode_length - self.time_deltas = np.array(time_deltas).flatten() - self.timeout_delta = timeout_delta - self.move_probability_F = move_probability_F - self.uneven_ground_prob = uneven_ground_prob - self._fps = fps - - self._window_shape = 800 * np.array(xygrid.shape) / np.max(xygrid.shape) # The size of the PyGame window in pixels - - # The initial agent location is the first occurrence of 'A' in the grid: - wh = np.where(xygrid == 'A') - self.initial_agent_location = (wh[0][0], wh[1][0]) - - self.n_immobile_objects = self.n_mobile_constant_objects = self.n_mobile_variable_objects = 0 # TODO: extract from grid - - # Construct an auxiliary grid that contains a unique index of each immobile object - # (cells of a type in immobile_object_types), or None if there is none. - # Also, get lists of objects and their types and initial locations. - self.immobile_object_types = [] - self.immobile_object_indices = np.full(xygrid.shape, None) - self.immobile_object_locations = [] - self.immobile_object_state0_deltas = [] # delta collected when meeting an immobile object that is in state 0 - self.mobile_constant_object_types = [] - self.mobile_constant_object_initial_locations = [] - self.mobile_constant_object_deltas = [] # delta collected when meeting a mobile constant object - self.mobile_variable_object_types = [] - self.mobile_variable_object_initial_locations = [] - self.mobile_variable_object_state0_deltas = [] # delta collected when meeting a mobile variable object that is in state 0 - for x in range(xygrid.shape[0]): - for y in range(xygrid.shape[1]): - if xygrid[x, y] in immobile_object_types: - self.immobile_object_types.append(xygrid[x, y]) - self.immobile_object_locations += [x, y] - self.immobile_object_indices[x, y] = self.n_immobile_objects - self.immobile_object_state0_deltas.append(cell_code2delta[delta_xygrid[x, y]] if delta_xygrid[x, y] != ' ' else 0) - self.n_immobile_objects += 1 - elif xygrid[x, y] in mobile_constant_object_types: - self.mobile_constant_object_types.append(xygrid[x, y]) - self.mobile_constant_object_initial_locations += [x, y] - self.mobile_constant_object_deltas.append(cell_code2delta[delta_xygrid[x, y]] if delta_xygrid[x, y] != ' ' else 0) - self.n_mobile_constant_objects += 1 - elif xygrid[x, y] in mobile_variable_object_types: - self.mobile_variable_object_types.append(xygrid[x, y]) - self.mobile_variable_object_initial_locations += [x, y] - self.mobile_variable_object_state0_deltas.append(cell_code2delta[delta_xygrid[x, y]] if delta_xygrid[x, y] != ' ' else 0) - self.n_mobile_variable_objects += 1 - - # The observation returned for reinforcement learning equals state, as described above. - # TODO how to specify start range of each dimension for MultiDiscrete? - nx, ny = xygrid.shape[0], xygrid.shape[1] - self.observation_space = spaces.MultiDiscrete( - [max_episode_length+1, # current time step - nx+2, ny+2] # current location - + [max_n_object_states] * self.n_immobile_objects - + [nx+2, ny+2] * self.n_mobile_constant_objects - + [nx+2, ny+2] * self.n_mobile_variable_objects - + [max_n_object_states] * self.n_mobile_variable_objects - , start = - [0, # current time step - -2, -2] # current location - + [0] * self.n_immobile_objects - + [-2, -2] * self.n_mobile_constant_objects - + [-2, -2] * self.n_mobile_variable_objects - + [0] * self.n_mobile_variable_objects - ) - - """ - return (state[0], # time step - (state[3], state[4]), # current location - (state[1], state[2]), # previous location - state[5 - : 5+self.n_immobile_objects], # immobile object states - state[5+self.n_immobile_objects - : 5+self.n_immobile_objects+2*self.n_mobile_constant_objects], # mobile constant object locations - state[5+self.n_immobile_objects+2*self.n_mobile_constant_objects - : 5+self.n_immobile_objects+2*self.n_mobile_constant_objects+2*self.n_mobile_variable_objects], # mobile variable object locations - state[5+self.n_immobile_objects+2*self.n_mobile_constant_objects+2*self.n_mobile_variable_objects - : 5+self.n_immobile_objects+2*self.n_mobile_constant_objects+3*self.n_mobile_variable_objects] # mobile variable object states - """ + TransitionDistribution: TypeAlias = dict[Self, tuple[Probability, Exact]] - # We have 4 actions, corresponding to "right", "up", "left", "down" - self.action_space = spaces.Discrete(5) + @overload + def __init__( + self, + grid: list[list[str]], + delta_grid: list[list[str]] | None = None, + *, + cell_code2delta: dict[str, float] | None = None, + max_episode_length=1e10, + uneven_ground_prob: Probability = 0.25, + move_probability_F: Probability = 0, + # render_mode=None, + # time_deltas=[0], + # timeout_delta=0, + # fps=4, + copying: None = None, + changes: None = None, + ): + ... - """ - The following dictionary maps abstract actions from `self.action_space` to - the direction we will walk in if that action is taken. - """ - self._action_to_direction: dict[Action, tuple[int, int]] = { - 0: (0,-1),# up - 1: (1,0),# right - 2: (0,1),# down - 3: (-1,0),# left - 4: (0,0),# stay in place - } + @overload + def __init__( + self, + grid: None, + delta_grid: None = None, + *, + cell_code2delta: None = None, + max_episode_length=1e10, + uneven_ground_prob: Probability = 0.25, + move_probability_F: Probability = 0, + # render_mode=None, + # time_deltas=[0], + # timeout_delta=0, + # fps=4, + copying: Self, + changes: dict[ObjectType, ObjectType] | None = None, + ): + ... + + def __init__( + self, + grid: list[list[str]] | None, + delta_grid: list[list[str]] | None | Never = None, + *, + cell_code2delta: dict[str, float] | None = None, + max_episode_length=1e10, + uneven_ground_prob: Probability = 0.25, + move_probability_F: Probability = 0, + # render_mode=None, + # time_deltas=[0], + # timeout_delta=0, + # fps=4, + copying: Self | None = None, + changes: dict[ObjectType, ObjectType] | None = None, + ): + if copying is not None: + assert grid is None and changes is not None + self.size = copying.size + self.agents = [changes.get(a, a) for a in copying.agents] + self.t = copying.t + self.max_episode_length = copying.max_episode_length + self.uneven_ground_prob = copying.uneven_ground_prob + self.move_probability_F = copying.move_probability_F + self.grid = copying.grid + # _print_changes(changes) + for old, new in changes.items(): + self._set_cell(new.location, self.at(new.location) + new) + self._set_cell(old.location, self.at(old.location) - old) + else: + assert grid is not None + assert copying is None + self.size = Location(len(grid[0]), len(grid)) + self.agents = [] + self.t = 0 + self.max_episode_length = int(max_episode_length) + self.uneven_ground_prob = uneven_ground_prob + self.move_probability_F = move_probability_F + + _grid: list[tuple[Cell, ...]] = [] + for x in range(self.size.x): + inner: list[Cell] = [] + for y in range(self.size.y): + cls = ObjectType.by_symbol[grid[y][x]] + loc = Location(x, y) + i: ObjectType + if cls is Delta: + assert delta_grid + assert cell_code2delta + i = Delta(loc, cell_code2delta[delta_grid[y][x]]) + elif cls is EmptySpace: + inner.append(Cell(())) + continue + else: + i = cls(loc) + inner.append(Cell((i,))) + if isinstance(i, Agent): + self.agents.append(i) + _grid.append(tuple(inner)) + self.grid = tuple(_grid) + + def __hash__(self) -> int: + return hash(self.grid) + + def flat_state(self) -> tuple[float | int]: + return ( + self.t, + *chain.from_iterable(o.flat_state() for o in ObjectType.by_id.values()), + ) + + def state_embedding(self): + # Discart time + return np.array(self.flat_state()[1:], dtype=np.float32) - assert render_mode is None or render_mode in self.metadata["render_modes"] - self.render_mode = render_mode - if render_mode == "human": - self._init_human_rendering() + def __getitem__(self, x: int) -> tuple[Cell, ...]: + return self.grid[x] + def at(self, loc: tuple[int, int]) -> Cell: + return self.grid[loc[0]][loc[1]] + + def _set_cell(self, loc: Location, cell: Cell) -> None: + self.grid = tup_replace_at( + self.grid, + loc.x, + tup_replace_at(self.grid[loc.x], loc.y, cell), + ) + + def __str__(self) -> str: + res = "" + for y in range(self.size.y): + for x in range(self.size.x): + res += str(self.grid[x][y]) + res += "\n" + return res + + def _can_move( + self, + direction: Direction, + who: ObjectType | None = None, + ): """ - If human-rendering is used, `self.window` will be a reference - to the window that we draw to. `self.clock` will be a clock that is used - to ensure that the environment is rendered at the correct framerate in - human-mode. They will remain `None` until human-mode is used for the - first time. + Return True if the agent or other object (designated by the who parameter) + can move from the given location to the given target_location. """ - self._window = None - self.clock = None - - def get_prolonged_version(self: Self, horizon=None) -> Self: - """Return a copy of this gridworld in which the episode length is prolonged by horizon steps.""" - # get a copy of the original grid, the delta grid, and the delta table: - xygrid = self.xygrid.copy() - delta_xygrid = self.delta_xygrid.copy() - cell_code2delta = self.cell_code2delta.copy() - # replace all 'G' states by 'Δ' states to make them non-terminal: - xygrid[xygrid == 'G'] = 'Δ' - # return a new SimpleGridworld with this data: - return type(self)(render_mode = self.render_mode, - grid = xygrid.T, - delta_grid = delta_xygrid.T, - cell_code2delta = cell_code2delta, - max_episode_length = self.max_episode_length + horizon, - time_deltas = self.time_deltas, - timeout_delta = self.timeout_delta, - uneven_ground_prob = self.uneven_ground_prob, - move_probability_F = self.move_probability_F, - fps = self._fps - ) - - def _get_target_location(self, location: Location, action: Action) -> Location: - """Return the next location of the agent if it takes the given action from the given location.""" - direction = self._action_to_direction[action] - return Location( - location[0] + direction[0], - location[1] + direction[1] - ) + if who is None: + assert len(self.agents) == 1 + who = self.agents[0] - def _can_move(self, from_loc, to_loc, state, who='A'): - """Return True if the agent or other object (designated by the who parameter) - can move from the given location to the given target_location.""" - if not (0 <= to_loc[0] < self.xygrid.shape[0] - and 0 <= to_loc[1] < self.xygrid.shape[1] - and not self.xygrid[to_loc] in unenterable_immobile_cell_types): + to_loc = who.location + direction.value + if not (0 <= to_loc.x < self.size.x and 0 <= to_loc.y < self.size.y): return False - # TODO: add other conditions for not being able to move, e.g. because of other objects - t, agent_loc, imm_states, mc_locs, mv_locs, mv_states = self._extract_state_attributes(state) - if self.xygrid[to_loc] == ',': - # can only move there if it hasn't turned into a wall yet: - if imm_states[self.immobile_object_indices[to_loc]] > 0: - return False - if to_loc == agent_loc and who not in what_can_move_into_agent: - return False - # loop through all mobile objects and see if they hinder the movement: - for i, object_type in enumerate(self.mobile_constant_object_types): - if to_loc == (mc_locs[2*i],mc_locs[2*i+1]): - if object_type in unenterable_mobile_object_types: - return False - if object_type in ['X','|']: # a box - if who != 'A' and (object_type == 'X' or - (object_type == '|' and from_loc[1]!=to_loc[1]) # attempt to push glass pane up or down - ): - return False # only the agent can push a box or glass pane! - # see if it can be pushed: - obj_target_loc = tuple(2*np.array(to_loc) - np.array(from_loc)) - if not self._can_move(to_loc, obj_target_loc, state, who=object_type): - return False - # TODO: implement destroying an 'F' by pushing a 'X' onto it - return True - - def opposite_action(self, action): - """Return the opposite action to the given action.""" - return 4 if action == 4 else (action + 2) % 4 - - def state_embedding(self, state): - res = np.array(state_embedding_for_distance(state), dtype=np.float32)[3:] # make time and previous position irrelevant - return res - @cache - def possible_actions(self, state=None): + def non_obstat(o: ObjectType): + match o: + case Wall(): + return False + case EmptyToWall(): + return o.state != EmptyToWallState.WALL + case Agent() as a: + return a == who + case Box() as b: + return isinstance(who, Agent) and self._can_move(direction, b) + case Glass() as g: + if direction in (Direction.RIGHT, Direction.LEFT): + return True + return isinstance(who, (Agent, Fragile)) and self._can_move( + direction, g + ) + case UnevenGround(), Pinnacle(), SlipperyGround() if isinstance( + o, (Box, Glass) + ): + raise NotImplementedError( + "boxes cannot slide/fall yet" + ) # TODO: let boxes slide/fall like agents! + return True + + target = self.at(to_loc) + return all(non_obstat(o) for o in target) + + def possible_actions(self, a: ObjectType, *, include_stay) -> Iterable[Action]: """Return a list of possible actions from the given state.""" - if state is None: - state = self._state - t, loc, imm_states, mc_locs, mv_locs, mv_states = self._extract_state_attributes(state) - actions = [action for action in range(5) - if self._can_move(loc, self._get_target_location(loc, action), state)] - if len(actions) == 0: - raise ValueError(f"No possible actions from state {state}") # FIXME: raise a more specific exception + actions = [ + action + for action in Direction + if self._can_move(action, a) + if include_stay or Direction.STAY is not action + ] + if not actions: + raise ValueError(f"No possible actions from state {self.flat_state()}") return actions - def default_policy(self, state): + def default_policy(self) -> categorical: """Return a default action, if any""" - return distribution.categorical([4], [1]) # staying in place + return categorical([Direction.STAY], [1.0]) - @overload - def _extract_state_attributes(self, state, gridcontents=False): - pass - @overload - def _extract_state_attributes(self, state, gridcontents=True): - pass - def _extract_state_attributes(self, state, gridcontents=False) -> tuple: - """Return the individual attributes of a state.""" - t, loc, imm_states, mc_locs, mv_locs, mv_states = ( - state[0], # time step - (state[1], state[2]), # current location - state[3 - : 3+self.n_immobile_objects], # immobile object states - state[3+self.n_immobile_objects - : 3+self.n_immobile_objects+2*self.n_mobile_constant_objects], # mobile constant object locations - state[3+self.n_immobile_objects+2*self.n_mobile_constant_objects - : 3+self.n_immobile_objects+2*self.n_mobile_constant_objects+2*self.n_mobile_variable_objects], # mobile variable object locations - state[3+self.n_immobile_objects+2*self.n_mobile_constant_objects+2*self.n_mobile_variable_objects - : 3+self.n_immobile_objects+2*self.n_mobile_constant_objects+3*self.n_mobile_variable_objects] # mobile variable object states - ) - if not gridcontents: - return t, loc, imm_states, mc_locs, mv_locs, mv_states - gc = { get_loc(mc_locs, i): (self.mobile_constant_object_types[i], i) - for i in range(self.n_mobile_constant_objects) } - gc.update( - { get_loc(mv_locs, i): (self.mobile_variable_object_types[i], i) - for i in range(self.n_mobile_variable_objects) } - ) - return t, loc, imm_states, mc_locs, mv_locs, mv_states, gc - - def _set_state(self, state): - """Set the current state to the provided one.""" - self._previous_agent_location = self._agent_location - self._state = state - self.t, loc, imm_states, mc_locs, mv_locs, mv_states = self._extract_state_attributes(state) - self._agent_location = loc - self._immobile_object_states = imm_states - self._mobile_constant_object_locations = mc_locs - self._mobile_variable_object_locations = mv_locs - self._mobile_variable_object_states = mv_states - - def _make_state(self, t = 0, loc = None, - imm_states = None, mc_locs = None, mv_locs = None, mv_states = None): - """Compile the given attributes into a state encoding that can be returned as an observation.""" - if loc is None: - loc = self.initial_agent_location - if mc_locs is None: - mc_locs = self.mobile_constant_object_initial_locations - if mv_locs is None: - mv_locs = self.mobile_variable_object_initial_locations - # default states are 0: - if imm_states is None: - imm_states = np.zeros(self.n_immobile_objects, dtype = int) - if mv_states is None: - mv_states = np.zeros(self.n_mobile_variable_objects, dtype = int) - return (t, - loc[0], loc[1], - *imm_states, - *mc_locs, - *mv_locs, - *mv_states - ) - - @cache - def is_terminal(self, state: State): + def is_terminal(self): """Return True if the given state is a terminal state.""" - t, loc, _, _, _, _ = self._extract_state_attributes(state) - is_at_goal = self.xygrid[loc] == 'G' - return is_at_goal or (t == self.max_episode_length) + self.flat_state + is_at_goal = any( + any(isinstance(o, Goal) for o in self.at(agent.location)) + for agent in self.agents + ) + return is_at_goal or (self.t == self.max_episode_length) + + def tick( + self: Self, + changes_map: dict[ObjectType, ObjectType] | None = None, + *, + increase_time=True, + ) -> Self: + res = type(self)(None, None, copying=self, changes=changes_map or {}) + if increase_time: + res.t += 1 + return res - @cache - def state_distance(self, state1, state2): + def state_distance(self, other: "SimpleGridworld") -> float: """Return the distance between the two given states, disregarding time.""" - return np.sqrt(np.sum(np.power(np.array(state_embedding_for_distance(state1))[1:] - - np.array(state_embedding_for_distance(state2))[1:], 2))) - - @cache - def transition_distribution(self, state, action, n_samples = None) -> dict: - if state is None and action is None: - successor = self._make_state() - return {successor: (1, True)} - - t, loc, imm_states, mc_locs, mv_locs, mv_states = self._extract_state_attributes(state) - cell_type = self.xygrid[loc] - at_goal = cell_type == 'G' - if at_goal: - successor = self._make_state(t + 1, loc, loc, imm_states, mc_locs, mv_locs, mv_states) - return {successor: (1, True)} - - if cell_type == ',': - # turn into a wall: - imm_states = set_entry(imm_states, self.immobile_object_indices[loc], 1) - elif cell_type == 'Δ': - if imm_states[self.immobile_object_indices[loc]] == 0: - # turn state to 1: - imm_states = set_entry(imm_states, self.immobile_object_indices[loc], 1) - - target_loc = self._get_target_location(loc, action) - target_type = self.xygrid[target_loc] - - # loop through all mobile constant objects and see if they are affected by the action: - for i, object_type in enumerate(self.mobile_constant_object_types): - if (mc_locs[2*i],mc_locs[2*i+1]) != target_loc: - continue - if object_type == 'X': # a box - # see if we can push it: - box_target_loc = self._get_target_location(target_loc, action) - if self._can_move(target_loc, box_target_loc, state): - if self.xygrid[box_target_loc] in unsteady_cell_types: - raise NotImplementedError("boxes cannot slide/fall yet") # TODO: let boxes slide/fall like agents! - mc_locs = set_loc(mc_locs, i, box_target_loc) - elif object_type == '|': # a glass pane - if action in [0,2]: - # see if we can push it: - pane_target_loc = self._get_target_location(target_loc, action) - if self._can_move(target_loc, pane_target_loc, state): - if self.xygrid[pane_target_loc] in unsteady_cell_types: - raise NotImplementedError("glass panes cannot slide/fall yet") # TODO: let boxes slide/fall like agents! - mc_locs = set_loc(mc_locs, i, pane_target_loc) - else: - # it will break - mc_locs = set_loc(mc_locs, i, (-2,-2)) - - if target_type in ['^', '~']: - # see what "falling-off" actions are possible: - simulated_actions = [a for a in range(4) - if a != self.opposite_action(action) # won't fall back to where we came from - and self._can_move(target_loc, self._get_target_location(target_loc, a), state)] - if len(simulated_actions) == 0: - return None - p0 = 1 if target_type == '^' else self.uneven_ground_prob # probability of falling off - intermediate_state = self._make_state(t, target_loc, loc, imm_states, mc_locs, mv_locs, mv_states) - trans_dist = {} - # compose the transition distribution recursively: - for simulate_action in simulated_actions: - for (successor, (probability, _)) in self.transition_distribution(intermediate_state, simulate_action, n_samples).items(): - dp = p0 * probability / len(simulated_actions) - if successor in trans_dist: - trans_dist[successor] += dp - else: - trans_dist[successor] = dp - if target_type == '~': - trans_dist[intermediate_state] = 1 - p0 - return { successor: (probability, True) for (successor,probability) in trans_dist.items() } - - # implement all deterministic changes: - # (none yet) - - # initialize a dictionary of possible successor states as keys and their probabilities as values, - # which will subsequently be adjusted: - trans_dist = { self._make_state(t + 1, target_loc, imm_states, mc_locs, mv_locs, mv_states): 1 } # stay in the same state with probability 1 - - # implement all probabilistic changes: - - # again loop through all variable mobile objects encoded in mv_locs and mv_states: - for i, object_type in enumerate(self.mobile_constant_object_types): - object_loc = get_loc(mc_locs, i) - if object_type != 'F': # a non-fragile object - continue - if not(object_loc != (-2,-2) and self.move_probability_F > 0): # object may not move - continue - - # loop through all possible successor states in trans_dist and split them into at most 5 depending on whether F moves and where: - new_trans_dist = {} - for (successor, probability) in trans_dist.items(): - succ_t, succ_loc, succ_imm_states, succ_mc_locs, succ_mv_locs, succ_mv_states, gridcontents = self._extract_state_attributes(successor, gridcontents=True) - if object_loc == target_loc: # object is destroyed - default_successor = self._make_state(succ_t, succ_loc, succ_imm_states, set_loc(succ_mc_locs, i, (-2,-2)), succ_mv_locs, succ_mv_states) - else: # it stays in place - default_successor = successor - direction_locs = tuple((direction, self._get_target_location(object_loc, direction)) - for direction in range(4)) - direction_locs = tuple((direction, loc) for (direction, loc) in direction_locs - if self._can_move(object_loc, loc, successor, who='F')) - n_directions = len(direction_locs) - if n_directions == 0: - new_trans_dist[default_successor] = probability + return ((self.state_embedding() - other.state_embedding()) ** 2).sum().sqrt() + + def _handle_falling( + self, + falling: Pinnacle | UnevenGround, + action: Action, + agent: Agent, + n_samples: int, + ) -> dict[Self, Probability]: + simulated_actions = [ + a + for a in self.possible_actions(agent, include_stay=False) + if a is not action.opposite # won't fall back to where we came from + ] + if not simulated_actions: + return "none" # None + + p0 = ( + 1 if isinstance(falling, Pinnacle) else self.uneven_ground_prob + ) # probability of falling off + trans_dist: dict[Self, Probability] = {} + # intermediate_state = self.tick(changes, increase_time=False) + for simulate_action in simulated_actions: + for successor, ( + probability, + _, + ) in self.transition_distribution( + agent, simulate_action, n_samples, timestep=False + ).items(): + dp = p0 * probability / len(simulated_actions) + if successor in trans_dist: + trans_dist[successor] += dp else: - new_trans_dist[default_successor] = probability * (1 - self.move_probability_F) - p = probability * self.move_probability_F / n_directions - for (direction, obj_target_loc) in direction_locs: - if obj_target_loc == target_loc: # object is destroyed - new_successor = self._make_state(succ_t, succ_loc, succ_imm_states, set_loc(succ_mc_locs, i, (-2,-2)), succ_mv_locs, succ_mv_states) - else: # it moves - new_mc_locs = set_loc(succ_mc_locs, i, obj_target_loc) - # see if there's a glass pane at obj_target_loc: - inhabitant_type, inhabitant_index = gridcontents.get(obj_target_loc, (None, None)) - if inhabitant_type == '|': - # glass pane breaks - new_mc_locs = set_loc(new_mc_locs, inhabitant_index, (-2,-2)) - new_successor = self._make_state(succ_t, succ_loc, succ_imm_states, new_mc_locs, succ_mv_locs, succ_mv_states) - new_trans_dist[new_successor] = p - trans_dist = new_trans_dist + trans_dist[successor] = dp - # TODO: update object states and/or object locations, e.g. if the agent picks up an object or moves an object + if isinstance(falling, UnevenGround): + trans_dist[self] = 1 - p0 - return {successor: (probability, True) for (successor,probability) in trans_dist.items()} + return trans_dist - @cache - def observation_and_reward_distribution(self, state, action, successor, n_samples = None): - """ - Delta for a state accrues when entering the state, so it depends on successor: - """ - if state is None and action is None: - return {(self._make_state(), 0): (1, True)} - t, loc, imm_states, mc_locs, mv_locs, mv_states = self._extract_state_attributes(successor) - delta = self.time_deltas[t % self.time_deltas.size] - if self.delta_xygrid[loc] in self.cell_code2delta: - delta += self.cell_code2delta[self.delta_xygrid[loc]] - # loop through all immobile objects with state 0, see if agent has met it, and if so add the corresponding Delta: - for i in range(self.n_immobile_objects): - if imm_states[i] == 0 and loc == self.immobile_object_locations[i]: - delta += self.immobile_object_state0_deltas[i] - # do the same for all mobile variable objects: - for i in range(self.n_mobile_variable_objects): - if mv_states[i] == 0 and loc == get_loc(mv_locs, i): - delta += self.mobile_variable_object_state0_deltas[i] - # do the same for all mobile constant objects: - for i in range(self.n_mobile_constant_objects): - if loc == get_loc(mc_locs, i): - delta += self.mobile_constant_object_deltas[i] - # add timeout Delta: - if t == self.max_episode_length and self.xygrid[loc] != 'G': - delta += self.timeout_delta - return {(successor, delta): (1, True)} - - # reset() and step() are inherited from MDPWorldModel and use the above transition_distribution(): - - def initial_state(self): - return self._make_state() - - def reset(self, seed = None, options = None): - ret = super().reset(seed = seed, options = options) - if self.render_mode == "human" and self._previous_agent_location is not None: - self._render_frame() - return ret - - def step(self, action): - ret = super().step(action) - if self.render_mode == "human": - self._render_frame() - return ret - - def render(self, additional_data=None): -# if self.render_mode == "rgb_array": - return self._render_frame(additional_data=additional_data) - - def _init_human_rendering(self): - pygame.font.init() # you have to call this at the start, - # if you want to use this module. - self._cell_font = pygame.font.SysFont('Helvetica', 30) - self._delta_font = pygame.font.SysFont('Helvetica', 10) - self._cell_data_font = pygame.font.SysFont('Helvetica', 10) - self._action_data_font = pygame.font.SysFont('Helvetica', 10) - - def _render_frame(self, additional_data=None): - if self._window is None and self.render_mode == "human": - os.environ['SDL_VIDEO_WINDOW_POS'] = "%d,%d" % (900,0) - pygame.init() - pygame.display.init() - self._window = pygame.display.set_mode( - self._window_shape - ) - if self.clock is None and self.render_mode == "human": - self.clock = pygame.time.Clock() - - canvas = pygame.Surface(self._window_shape) - canvas.fill((255, 255, 255)) - pix_square_size = self._window_shape[0] / self.xygrid.shape[0] # The size of a single grid square in pixels - - # Draw grid contents: - for x in range(self.xygrid.shape[0]): - for y in range(self.xygrid.shape[1]): - cell_type = self.xygrid[x, y] - cell_code = self.delta_xygrid[x, y] - if cell_code in self.cell_code2delta: - pygame.draw.rect( - canvas, - (255, 255, 240), - (x * pix_square_size, y * pix_square_size, pix_square_size, pix_square_size), - ) - if cell_type == "#" or (cell_type == "," and self._immobile_object_states[self.immobile_object_indices[x, y]] == 1): - pygame.draw.rect( - canvas, - (64, 64, 64), - (x * pix_square_size, y * pix_square_size, pix_square_size, pix_square_size), - ) - elif cell_type == "G": - pygame.draw.rect( - canvas, - (0, 255, 0), - (x * pix_square_size, y * pix_square_size, pix_square_size, pix_square_size), - ) - elif (cell_type == "," and self._immobile_object_states[self.immobile_object_indices[x, y]] != 1): - pygame.draw.rect( - canvas, - (64, 64, 64), - ((x+.3) * pix_square_size, (y+.8) * pix_square_size, .4*pix_square_size, .1*pix_square_size), - ) - elif cell_type == "Δ": - if self._immobile_object_states[self.immobile_object_indices[x, y]] == 0: - # draw a small triangle: - pygame.draw.polygon( - canvas, - (224, 224, 0), - (((x+.3) * pix_square_size, (y+.7) * pix_square_size), - ((x+.7) * pix_square_size, (y+.7) * pix_square_size), - ((x+.5) * pix_square_size, (y+.3) * pix_square_size)), - ) - elif cell_type in render_as_char_types: - canvas.blit(self._cell_font.render(cell_type, True, (0, 0, 0)), - ((x+.3) * pix_square_size, (y+.3) * pix_square_size)) - if self._window is None and self.render_mode == "human": - canvas.blit(self._delta_font.render( - f"{x},{y}", True, (128, 128, 128)), - ((x+.8) * pix_square_size, (y+.1) * pix_square_size)) - if cell_code in self.cell_code2delta: - canvas.blit(self._delta_font.render( - cell_code + f" {self.cell_code2delta[cell_code]}", True, (0, 0, 0)), - ((x+.1) * pix_square_size, (y+.1) * pix_square_size)) - - # Render all mobile objects: - for i, object_type in enumerate(self.mobile_constant_object_types): - x, y = get_loc(self._mobile_constant_object_locations, i) - if object_type == 'X': # a box - pygame.draw.rect( - canvas, - (128, 128, 128), - ((x+.1) * pix_square_size, (y+.1) * pix_square_size, .8*pix_square_size, .8*pix_square_size), - ) - elif object_type == '|': # a glass pane - pygame.draw.rect( - canvas, - (192, 192, 192), - ((x+.45) * pix_square_size, (y+.1) * pix_square_size, .1*pix_square_size, .8*pix_square_size), - ) - elif object_type == 'F': # a fragile object - pygame.draw.circle( - canvas, - (255, 0, 0), - ((x+.5) * pix_square_size, (y+.5) * pix_square_size), - pix_square_size / 4, - ) - -# for i, object_type in enumerate(self.mobile_variable_object_types): -# x, y = get_loc(self._mobile_variable_object_locations, i) - - # Now we draw the agent and its previous location: - pygame.draw.circle( - canvas, - (0, 0, 255), - (np.array(self._previous_agent_location) + 0.5) * pix_square_size, - pix_square_size / 4, - width = 3, - ) - pygame.draw.circle( - canvas, - (0, 0, 255), - (np.array(self._agent_location) + 0.5) * pix_square_size, - pix_square_size / 3, - ) + def transition_distribution( + self, agent: Agent, action: Action, n_samples=None, timestep=True + ) -> TransitionDistribution: + assert self._can_move(action, agent) - # Optionally print some additional data: - if additional_data is not None: - if 'cell' in additional_data: # draw some list of values onto each cell - for x in range(self.xygrid.shape[0]): - for y in range(self.xygrid.shape[1]): - values = set(additional_data['cell'].get((x,y), [])) - if len(values) > 0: # then it is a list - surf = self._cell_data_font.render( - "|".join([str(v) for v in values]), True, - (0,0,255)) - canvas.blit(surf, - ((x+.5) * pix_square_size - .5 * surf.get_width(), - (y+.35) * pix_square_size - .5 * surf.get_height())) - if 'action' in additional_data: # draw some list of values next to each cell boundary - for x in range(self.xygrid.shape[0]): - for y in range(self.xygrid.shape[1]): - for action in range(4): - values = set(additional_data['action'].get((x,y,action), [])) - if len(values) > 0: # then it is a list - dx,dy = self._action_to_direction[action] if action < 4 else (0,0) - surf = self._action_data_font.render( - "|".join([str(v) for v in values]), - True, (0,0,255)) - canvas.blit(surf, - ((x+.5+dx*.48) * pix_square_size - [.5,1,.5,0,.5][action] * surf.get_width(), - (y+.5+dx*0.04+dy*.48) * pix_square_size - [0,0.5,1,0.5,.5][action] * surf.get_height())) - - # Finally, add some gridlines - for x in range(self.xygrid.shape[0] + 1): - pygame.draw.line( - canvas, - (128, 128, 128), - (pix_square_size * x, 0), - (pix_square_size * x, self._window_shape[1]), - width=3, - ) - for y in range(self.xygrid.shape[1] + 1): - pygame.draw.line( - canvas, - (128, 128, 128), - (0, pix_square_size * y), - (self._window_shape[0], pix_square_size * y), - width=3, - ) - # And print the time left into the top-right cell: - if self.render_mode == "human": - canvas.blit(self._cell_font.render( - f"{self.max_episode_length - self.t}", True, (0, 0, 0)), - ((self.xygrid.shape[0]-1+.3) * pix_square_size, (.3) * pix_square_size)) - - # The following line copies our drawings from `canvas` to the visible window - self._window.blit(canvas, canvas.get_rect()) - pygame.event.pump() - pygame.display.update() - - # We need to ensure that human-rendering occurs at the predefined framerate. - # The following line will automatically add a delay to keep the framerate stable. - self.clock.tick(self._fps) - else: # rgb_array - return np.transpose( - np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2) + # if state is None and action is None: + # successor = self._make_state() + # return {successor: (1, True)} + + if self.is_terminal(): + return {self.tick(): (1.0, True)} + + current: Cell = self.at(agent.location) + new_agent = agent.moved(action) + changes: StateChanges = {} + + # Why don't we do these against `target` at the end of the turn? + if etw := current.get(EmptyToWall): + changes[etw] = etw.to_wall() + + if delta := current.get(Delta): + if delta.state is not DeltaState.COLLECTED: + changes[delta] = delta.collect() + new_agent = new_agent.collect(delta.delta) + + changes[agent] = new_agent + target = self.at(agent.location + action.value) + + if fragile := target.get(Fragile): + assert self._can_move(action, fragile) + changes[fragile] = fragile.destroy() + + if box := target.get(Box): + assert self._can_move(action, delta) + changes[box] = box.moved(action) + + if glass := target.get(Glass): + if action in (Direction.UP, Direction.DOWN): + assert self._can_move(action, glass) + changes[glass] = glass.moved(action) + else: + changes[glass] = glass.break_glass() + + world = self.tick(changes, increase_time=timestep) + + trans_dist: dict[Self, Probability] + if falling := target.get((Pinnacle, UnevenGround)): + trans_dist = world._handle_falling( + falling, action, new_agent, n_samples or 1 ) - - def close(self): - if self._window is not None: - pygame.display.quit() - pygame.quit() + else: + trans_dist = {world: 1.0} + + # TODO: implement all deterministic changes + + # if self.move_probability_F > 0: + # for fr in ObjectType.by_id.values(): + # if not isinstance(fr, Fragile) or fr.state is FragileState.DESTROYED: + # continue + # + # trans_dist = self._handle_fragile(fr, trans_dist, changes) + + # TODO: update object states and/or object locations, e.g. if the agent picks up an object or moves an object + + return { + successor: (probability, True) + for (successor, probability) in trans_dist.items() + } + + def _handle_fragile( + self, + fragile: Fragile, + trans_dist: dict[Self, Probability], + changes: StateChanges, + ) -> dict[Self, Probability]: + new_trans_dist: dict[Self, Probability] = {} + # + # for successor, probability in trans_dist.items(): + # direction_locs = tuple( + # (direction, fragile.location + direction.value) + # for direction in successor.possible_actions(fragile, include_stay=False) + # ) + # if (n_directions := len(direction_locs)) == 0: + # new_trans_dist[successor] = probability + # continue + # + # new_trans_dist[successor] = probability * (1 - self.move_probability_F) + # p = probability * self.move_probability_F / n_directions + # for direction, obj_target_loc in direction_locs: + # hits_an_agent = obj_target_loc in (a.location for a in self.agents) + # if hits_an_agent: + # changes[fragile] = fragile.destroy() + # else: # it moves + # changes[fragile] = fragile.moved(direction) + # if not (glass := successor.at(obj_target_loc).get(Glass)): + # continue + # changes[glass] = glass.break_glass() + # new_successor = successor.tick(changes, increase_time=False) + # new_trans_dist[new_successor] = p + return new_trans_dist + + +# @cache +# def observation_and_reward_distribution( +# self, state, action, successor, n_samples=None +# ): +# """ +# Delta for a state accrues when entering the state, so it depends on successor: +# """ +# if state is None and action is None: +# return {(self._make_state(), 0): (1, True)} +# ( +# t, +# loc, +# imm_states, +# mc_locs, +# mv_locs, +# mv_states, +# ) = self._extract_state_attributes(successor) +# delta = self.time_deltas[t % self.time_deltas.size] +# if self.delta_xygrid[loc] in self.cell_code2delta: +# delta += self.cell_code2delta[self.delta_xygrid[loc]] +# # loop through all immobile objects with state 0, see if agent has met it, and if so add the corresponding Delta: +# for i in range(self.n_immobile_objects): +# if imm_states[i] == 0 and loc == self.immobile_object_locations[i]: +# delta += self.immobile_object_state0_deltas[i] +# # do the same for all mobile variable objects: +# for i in range(self.n_mobile_variable_objects): +# if mv_states[i] == 0 and loc == get_loc(mv_locs, i): +# delta += self.mobile_variable_object_state0_deltas[i] +# # do the same for all mobile constant objects: +# for i in range(self.n_mobile_constant_objects): +# if loc == get_loc(mc_locs, i): +# delta += self.mobile_constant_object_deltas[i] +# # add timeout Delta: +# if t == self.max_episode_length and self.xygrid[loc] != "G": +# delta += self.timeout_delta +# return {(successor, delta): (1, True)} +# +# # reset() and step() are inherited from MDPWorldModel and use the above transition_distribution(): +# +# def initial_state(self): +# return self._make_state() +# +# def reset(self, seed=None, options=None): +# ret = super().reset(seed=seed, options=options) +# if self.render_mode == "human" and self._previous_agent_location is not None: +# self._render_frame() +# return ret +# +# def step(self, action): +# ret = super().step(action) +# if self.render_mode == "human": +# self._render_frame() +# return ret +# +# def render(self, additional_data=None): +# # if self.render_mode == "rgb_array": +# return self._render_frame(additional_data=additional_data) +# +# def _init_human_rendering(self): +# pygame.font.init() # you have to call this at the start, +# # if you want to use this module. +# self._cell_font = pygame.font.SysFont("Helvetica", 30) +# self._delta_font = pygame.font.SysFont("Helvetica", 10) +# self._cell_data_font = pygame.font.SysFont("Helvetica", 10) +# self._action_data_font = pygame.font.SysFont("Helvetica", 10) +# +# def _render_frame(self, additional_data=None): +# if self._window is None and self.render_mode == "human": +# os.environ["SDL_VIDEO_WINDOW_POS"] = "%d,%d" % (900, 0) +# pygame.init() +# pygame.display.init() +# self._window = pygame.display.set_mode(self._window_shape) +# if self.clock is None and self.render_mode == "human": +# self.clock = pygame.time.Clock() +# +# canvas = pygame.Surface(self._window_shape) +# canvas.fill((255, 255, 255)) +# pix_square_size = ( +# self._window_shape[0] / self.xygrid.shape[0] +# ) # The size of a single grid square in pixels +# +# # Draw grid contents: +# for x in range(self.xygrid.shape[0]): +# for y in range(self.xygrid.shape[1]): +# cell_type = self.xygrid[x, y] +# cell_code = self.delta_xygrid[x, y] +# if cell_code in self.cell_code2delta: +# pygame.draw.rect( +# canvas, +# (255, 255, 240), +# ( +# x * pix_square_size, +# y * pix_square_size, +# pix_square_size, +# pix_square_size, +# ), +# ) +# if cell_type == "#" or ( +# cell_type == "," +# and self._immobile_object_states[self.immobile_object_indices[x, y]] +# == 1 +# ): +# pygame.draw.rect( +# canvas, +# (64, 64, 64), +# ( +# x * pix_square_size, +# y * pix_square_size, +# pix_square_size, +# pix_square_size, +# ), +# ) +# elif cell_type == "G": +# pygame.draw.rect( +# canvas, +# (0, 255, 0), +# ( +# x * pix_square_size, +# y * pix_square_size, +# pix_square_size, +# pix_square_size, +# ), +# ) +# elif ( +# cell_type == "," +# and self._immobile_object_states[self.immobile_object_indices[x, y]] +# != 1 +# ): +# pygame.draw.rect( +# canvas, +# (64, 64, 64), +# ( +# (x + 0.3) * pix_square_size, +# (y + 0.8) * pix_square_size, +# 0.4 * pix_square_size, +# 0.1 * pix_square_size, +# ), +# ) +# elif cell_type == "Δ": +# if ( +# self._immobile_object_states[self.immobile_object_indices[x, y]] +# == 0 +# ): +# # draw a small triangle: +# pygame.draw.polygon( +# canvas, +# (224, 224, 0), +# ( +# ( +# (x + 0.3) * pix_square_size, +# (y + 0.7) * pix_square_size, +# ), +# ( +# (x + 0.7) * pix_square_size, +# (y + 0.7) * pix_square_size, +# ), +# ( +# (x + 0.5) * pix_square_size, +# (y + 0.3) * pix_square_size, +# ), +# ), +# ) +# elif cell_type in render_as_char_types: +# canvas.blit( +# self._cell_font.render(cell_type, True, (0, 0, 0)), +# ((x + 0.3) * pix_square_size, (y + 0.3) * pix_square_size), +# ) +# if self._window is None and self.render_mode == "human": +# canvas.blit( +# self._delta_font.render(f"{x},{y}", True, (128, 128, 128)), +# ((x + 0.8) * pix_square_size, (y + 0.1) * pix_square_size), +# ) +# if cell_code in self.cell_code2delta: +# canvas.blit( +# self._delta_font.render( +# cell_code + f" {self.cell_code2delta[cell_code]}", +# True, +# (0, 0, 0), +# ), +# ((x + 0.1) * pix_square_size, (y + 0.1) * pix_square_size), +# ) +# +# # Render all mobile objects: +# for i, object_type in enumerate(self.mobile_constant_object_types): +# x, y = get_loc(self._mobile_constant_object_locations, i) +# if object_type == "X": # a box +# pygame.draw.rect( +# canvas, +# (128, 128, 128), +# ( +# (x + 0.1) * pix_square_size, +# (y + 0.1) * pix_square_size, +# 0.8 * pix_square_size, +# 0.8 * pix_square_size, +# ), +# ) +# elif object_type == "|": # a glass pane +# pygame.draw.rect( +# canvas, +# (192, 192, 192), +# ( +# (x + 0.45) * pix_square_size, +# (y + 0.1) * pix_square_size, +# 0.1 * pix_square_size, +# 0.8 * pix_square_size, +# ), +# ) +# elif object_type == "F": # a fragile object +# pygame.draw.circle( +# canvas, +# (255, 0, 0), +# ((x + 0.5) * pix_square_size, (y + 0.5) * pix_square_size), +# pix_square_size / 4, +# ) +# +# # for i, object_type in enumerate(self.mobile_variable_object_types): +# # x, y = get_loc(self._mobile_variable_object_locations, i) +# +# # Now we draw the agent and its previous location: +# pygame.draw.circle( +# canvas, +# (0, 0, 255), +# (np.array(self._previous_agent_location) + 0.5) * pix_square_size, +# pix_square_size / 4, +# width=3, +# ) +# pygame.draw.circle( +# canvas, +# (0, 0, 255), +# (np.array(self._agent_location) + 0.5) * pix_square_size, +# pix_square_size / 3, +# ) +# +# # Optionally print some additional data: +# if additional_data is not None: +# if "cell" in additional_data: # draw some list of values onto each cell +# for x in range(self.xygrid.shape[0]): +# for y in range(self.xygrid.shape[1]): +# values = set(additional_data["cell"].get((x, y), [])) +# if len(values) > 0: # then it is a list +# surf = self._cell_data_font.render( +# "|".join([str(v) for v in values]), True, (0, 0, 255) +# ) +# canvas.blit( +# surf, +# ( +# (x + 0.5) * pix_square_size +# - 0.5 * surf.get_width(), +# (y + 0.35) * pix_square_size +# - 0.5 * surf.get_height(), +# ), +# ) +# if ( +# "action" in additional_data +# ): # draw some list of values next to each cell boundary +# for x in range(self.xygrid.shape[0]): +# for y in range(self.xygrid.shape[1]): +# for action in range(4): +# values = set( +# additional_data["action"].get((x, y, action), []) +# ) +# if len(values) > 0: # then it is a list +# dx, dy = ( +# self._action_to_direction[action] +# if action < 4 +# else (0, 0) +# ) +# surf = self._action_data_font.render( +# "|".join([str(v) for v in values]), +# True, +# (0, 0, 255), +# ) +# canvas.blit( +# surf, +# ( +# (x + 0.5 + dx * 0.48) * pix_square_size +# - [0.5, 1, 0.5, 0, 0.5][action] +# * surf.get_width(), +# (y + 0.5 + dx * 0.04 + dy * 0.48) +# * pix_square_size +# - [0, 0.5, 1, 0.5, 0.5][action] +# * surf.get_height(), +# ), +# ) +# +# # Finally, add some gridlines +# for x in range(self.xygrid.shape[0] + 1): +# pygame.draw.line( +# canvas, +# (128, 128, 128), +# (pix_square_size * x, 0), +# (pix_square_size * x, self._window_shape[1]), +# width=3, +# ) +# for y in range(self.xygrid.shape[1] + 1): +# pygame.draw.line( +# canvas, +# (128, 128, 128), +# (0, pix_square_size * y), +# (self._window_shape[0], pix_square_size * y), +# width=3, +# ) +# # And print the time left into the top-right cell: +# if self.render_mode == "human": +# canvas.blit( +# self._cell_font.render( +# f"{self.max_episode_length - self.t}", True, (0, 0, 0) +# ), +# ( +# (self.xygrid.shape[0] - 1 + 0.3) * pix_square_size, +# (0.3) * pix_square_size, +# ), +# ) +# +# # The following line copies our drawings from `canvas` to the visible window +# self._window.blit(canvas, canvas.get_rect()) +# pygame.event.pump() +# pygame.display.update() +# +# # We need to ensure that human-rendering occurs at the predefined framerate. +# # The following line will automatically add a delay to keep the framerate stable. +# self.clock.tick(self._fps) +# else: # rgb_array +# return np.transpose( +# np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2) +# ) +# +# def close(self): +# if self._window is not None: +# pygame.display.quit() +# pygame.quit() diff --git a/src/world_model/test_simple_gridworld.py b/src/world_model/test_simple_gridworld.py new file mode 100644 index 0000000..c60c76d --- /dev/null +++ b/src/world_model/test_simple_gridworld.py @@ -0,0 +1,215 @@ +import pytest +from world_model.objects import Agent, Delta, DeltaState, EmptyToWall, EmptyToWallState +from world_model.simple_gridworld import ( + Box, + GlassState, + SimpleGridworld, + Direction, + Glass, +) + + +def print_td(td: dict): + print("==============") + for gw, (prob, exact) in td.items(): + print(f"- prob {prob} ({'non-' if not exact else ''}exact), t={gw.t}") + print(gw) + + +def get_where_agent_at( + td, loc: tuple[int, int] +) -> tuple[SimpleGridworld, float] | None: + for gw, prob in td.items(): + if gw.agents[0].location == loc: + return gw, prob[0] + + +def test_pinnacle_one_way(): + g: SimpleGridworld = SimpleGridworld( + [ + ["X", "A", "^"], + ["~", "X", " "], + ["X", " ", " "], + ], + ) + a = g.agents[0] + assert g._can_move(Direction.DOWN, a) + assert g._can_move(Direction.RIGHT, a) + assert not g._can_move(Direction.LEFT, a) + assert not g._can_move(Direction.UP, a) + td = g.transition_distribution(a, Direction.RIGHT) + assert len(td) == 1 + [(gw, (prob, _))] = td.items() + assert gw.agents[0].location == (2, 1) + assert prob == 1 + + td = g.transition_distribution(a, Direction.RIGHT) + + +def test_pinnacle_threeway(): + g: SimpleGridworld = SimpleGridworld( + [ + [" ", " ", " "], + ["A", "^", " "], + ["X", "X", " "], + ["X", " ", " "], + ], + ) + a = g.agents[0] + td = g.transition_distribution(a, Direction.RIGHT) + print_td(g.transition_distribution(a, Direction.RIGHT)) + + assert len(td) == 3 + for loc in ((1, 0), (2, 1), (1, 2)): + w, prob = get_where_agent_at(td, loc) + assert w.t == g.t + 1 + assert prob == 1.0 / 3 + box_moved = w.at((1, 3)).get(Box) + assert bool(box_moved) == (loc == (1, 2)) + + +def test_unstable(): + g: SimpleGridworld = SimpleGridworld( + [ + [" ", " ", " "], + ["A", "~", " "], + ["X", "X", " "], + ["X", " ", " "], + ], + uneven_ground_prob=0.6, + ) + a = g.agents[0] + td = g.transition_distribution(a, Direction.RIGHT) + print_td(g.transition_distribution(a, Direction.RIGHT)) + + assert len(td) == 4 + for loc in ((1, 0), (2, 1), (1, 2)): + w, prob = get_where_agent_at(td, loc) + assert w.t == g.t + 1 + assert prob == (1.0 - 0.4) / 3 + box_moved = w.at((1, 3)).get(Box) + assert bool(box_moved) == (loc == (1, 2)) + + w, prob = get_where_agent_at(td, (1, 1)) + assert prob == 0.4 + + +def test_glass(): + g: SimpleGridworld = SimpleGridworld( + [ + ["#", " ", "#"], + ["#", "|", "#"], + ["|", "A", "|"], + ["X", "|", " "], + ["X", "X", " "], + ], + uneven_ground_prob=0.6, + ) + a = g.agents[0] + + td = g.transition_distribution(a, Direction.RIGHT) + assert len(td) == 1 + w, prob = get_where_agent_at(td, (2, 2)) + assert w.t == g.t + 1 + assert prob == 1 + glass = w.at((2, 2)).get(Glass) + assert glass.state is GlassState.BROKEN + + td = g.transition_distribution(a, Direction.UP) + assert len(td) == 1 + w, prob = get_where_agent_at(td, (1, 1)) + assert prob == 1 + glass = w.at((1, 0)).get(Glass) + assert glass.state is None + + with pytest.raises(AssertionError): + td = g.transition_distribution(a, Direction.DOWN) + + td = g.transition_distribution(a, Direction.LEFT) + assert len(td) == 1 + w, prob = get_where_agent_at(td, (0, 2)) + assert prob == 1 + glass = w.at((0, 2)).get(Glass) + assert glass.state is GlassState.BROKEN + + +def test_btw(): + g: SimpleGridworld = SimpleGridworld( + [ + ["#", " ", "#"], + ["#", ",", "#"], + ["|", "A", "|"], + ], + ) + a = g.agents[0] + etw = g.at((1, 1)).get(EmptyToWall) + assert etw.state is None + + td = g.transition_distribution(a, Direction.UP) + assert len(td) == 1 + w1, prob = get_where_agent_at(td, (1, 1)) + assert w1.t == g.t + 1 + assert prob == 1 + etw = w1.at((1, 1)).get(EmptyToWall) + assert etw.state is None + + td = w1.transition_distribution(w1.agents[0], Direction.UP) + assert len(td) == 1 + w2, prob = get_where_agent_at(td, (1, 0)) + assert w2.t == w1.t + 1 + assert prob == 1 + etw = w2.at((1, 1)).get(EmptyToWall) + assert etw.state is EmptyToWallState.WALL + + with pytest.raises(AssertionError): + td = w2.transition_distribution(w2.agents[0], Direction.UP) + with pytest.raises(AssertionError): + td = w2.transition_distribution(w2.agents[0], Direction.DOWN) + + +def test_delta(): + g: SimpleGridworld = SimpleGridworld( + [ + ["A", "Δ", " "], + ], + [ + [" ", "D", " "], + ], + cell_code2delta = {"D": 3} + ) + delta = g.at((1, 0)).get(Delta) + assert delta.state is None + assert g.agents[0].collected_delta == 0 + + td = g.transition_distribution(g.agents[0], Direction.RIGHT) + assert len(td) == 1 + w1, prob = get_where_agent_at(td, (1, 0)) + assert w1.t == g.t + 1 + assert prob == 1 + + td = w1.transition_distribution(w1.agents[0], Direction.RIGHT) + assert len(td) == 1 + w2, prob = get_where_agent_at(td, (2, 0)) + assert w2.t == w1.t + 1 + assert prob == 1 + delta = w2.at((1, 0)).get(Delta) + assert delta.state is DeltaState.COLLECTED + assert w2.agents[0].collected_delta == 3 + + td = w2.transition_distribution(w2.agents[0], Direction.LEFT) + assert len(td) == 1 + w3, prob = get_where_agent_at(td, (1, 0)) + print(w3) + print(w3.at((1,0)).get(Agent).collected_delta) + assert w3.t == w2.t + 1 + assert prob == 1 + assert w3.agents[0].collected_delta == 3 + + td = w3.transition_distribution(w3.agents[0], Direction.LEFT) + assert len(td) == 1 + w4, prob = get_where_agent_at(td, (0, 0)) + assert w4.t == w3.t + 1 + assert prob == 1 + delta = w4.at((1, 0)).get(Delta) + assert delta.state is DeltaState.COLLECTED + assert w4.agents[0].collected_delta == 3 diff --git a/src/world_model/types.py b/src/world_model/types.py new file mode 100644 index 0000000..da77284 --- /dev/null +++ b/src/world_model/types.py @@ -0,0 +1,26 @@ +from enum import Enum +from functools import cached_property +from typing import NamedTuple + +class Location(NamedTuple): + x: int + y: int + + def __sub__(self, other: "Location") -> "Location": + return Location(self.x - other.x, self.y - other.y) + + def __add__(self, other: "Location") -> "Location": # type: ignore[override] + return Location(self.x + other.x, self.y + other.y) + +class Direction(Enum): + UP = Location(0, -1) + RIGHT = Location(1, 0) + DOWN = Location(0, 1) + LEFT = Location(-1, 0) + STAY = Location(0, 0) + + @cached_property + def opposite(self) -> "Direction": + return type(self)(Location(-self.value.x, -self.value.y)) + +