diff --git a/.gitignore b/.gitignore
index efaa3648f..9ed20016f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -6,6 +6,10 @@ data/
profile.pstats
catanatron-venv
.DS_Store
+wandb
+videos
+models
+runs
# Byte-compiled / optimized / DLL files
__pycache__/
diff --git a/README.md b/README.md
index d000b754c..0e177e439 100644
--- a/README.md
+++ b/README.md
@@ -5,13 +5,14 @@

[](https://colab.research.google.com/github/bcollazo/catanatron/blob/master/examples/Overview.ipynb)
-Catanatron is a high-performance simulator and strong AI player for Settlers of Catan. You can run thousands of games in the order of seconds. The goal is to find the strongest Settlers of Catan bot possible.
+Catanatron is a high-performance simulator and strong AI player for Settlers of Catan. You can run thousands of games in the order of seconds. The goal is to find the strongest Settlers of Catan bot possible.
Get Started with the Full Documentation: https://docs.catanatron.com
Join our Discord: https://discord.gg/FgFmb75TWd!
## Command Line Interface
+
Catanatron provides a `catanatron-play` CLI tool to run large scale simulations.
@@ -22,26 +23,30 @@ Catanatron provides a `catanatron-play` CLI tool to run large scale simulations.
1. Clone the repository:
- ```bash
- git clone git@github.com:bcollazo/catanatron.git
- cd catanatron/
- ```
-2. Create a virtual environment (requires Python 3.11 or higher)
+ ```bash
+ git clone git@github.com:bcollazo/catanatron.git
+ cd catanatron/
+ ```
+
+2. Create a virtual environment (requires Python 3.11 or higher)
+
+ ```bash
+ python -m venv venv
+ source ./venv/bin/activate
+ # ./venv/Scripts/Activate.ps1 (on windows)
+ ```
- ```bash
- python -m venv venv
- source ./venv/bin/activate
- ```
3. Install dependencies
- ```bash
- pip install -e .
- ```
-4. (Optional) Install developer and advanced dependencies
+ ```bash
+ pip install -e .
+ ```
+
+4. (Optional) Install developer and advanced dependencies
- ```bash
- pip install -e ".[web,gym,dev]"
- ```
+ ```bash
+ pip install -e ".[web,gym,dev]"
+ ```
### Usage
@@ -52,13 +57,13 @@ catanatron-play --players=R,R,R,W --num=100
```
Generate datasets from the games to analyze:
+
```bash
catanatron-play --num 100 --output my-data-path/ --output-format json
```
See more examples at https://docs.catanatron.com.
-
## Graphical User Interface
We provide Docker images so that you can watch, inspect, and play games against Catanatron via a web UI!
@@ -67,15 +72,15 @@ We provide Docker images so that you can watch, inspect, and play games against
-
### Installation
1. Ensure you have Docker installed (https://docs.docker.com/engine/install/)
2. Run the `docker-compose.yaml` in the root folder of the repo:
- ```bash
- docker compose up
- ```
+ ```bash
+ docker compose up
+ ```
+
3. Visit http://localhost:3000 in your browser!
## Python Library
@@ -100,14 +105,17 @@ print(game.play()) # returns winning color
See more at http://docs.catanatron.com
## Gymnasium Interface
+
For Reinforcement Learning, catanatron provides an Open AI / Gymnasium Environment.
Install it with:
+
```bash
pip install -e .[gym]
```
and use it like:
+
```python
import random
import gymnasium
@@ -128,8 +136,8 @@ env.close()
See more at: https://docs.catanatron.com
-
## Documentation
+
Full documentation here: https://docs.catanatron.com
## Contributing
@@ -144,6 +152,5 @@ coverage run --source=catanatron -m pytest tests/ && coverage report
See more at: https://docs.catanatron.com
## Appendix
-See the motivation of the project here: [5 Ways NOT to Build a Catan AI](https://medium.com/@bcollazo2010/5-ways-not-to-build-a-catan-ai-e01bc491af17).
-
+See the motivation of the project here: [5 Ways NOT to Build a Catan AI](https://medium.com/@bcollazo2010/5-ways-not-to-build-a-catan-ai-e01bc491af17).
diff --git a/catanatron/catanatron/cli/play.py b/catanatron/catanatron/cli/play.py
index cf098f95a..7a1d83078 100644
--- a/catanatron/catanatron/cli/play.py
+++ b/catanatron/catanatron/cli/play.py
@@ -7,9 +7,8 @@
from rich.console import Console
from rich.table import Table
from rich.progress import Progress
-from rich.progress import Progress, BarColumn, TimeRemainingColumn
+from rich.progress import BarColumn, TimeRemainingColumn
from rich import box
-from rich.console import Console
from rich.theme import Theme
from rich.text import Text
@@ -207,7 +206,7 @@ class OutputOptions:
class GameConfigOptions:
discard_limit: int = 7
vps_to_win: int = 10
- catan_map: Literal["BASE", "TOURNAMENT", "MINI"] = "BASE"
+ map_type: Literal["BASE", "TOURNAMENT", "MINI"] = "BASE"
COLOR_TO_RICH_STYLE = {
@@ -238,7 +237,7 @@ def play_batch_core(num_games, players, game_config, accumulators=[]):
for _ in range(num_games):
for player in players:
player.reset_state()
- catan_map = build_map(game_config.catan_map)
+ catan_map = build_map(game_config.map_type)
game = Game(
players,
discard_limit=game_config.discard_limit,
@@ -275,7 +274,10 @@ def play_batch(
accumulators.append(
CsvDataAccumulator(
- output_options.output, output_options.include_board_tensor
+ tuple(p.color for p in players),
+ game_config.map_type,
+ output_options.output,
+ output_options.include_board_tensor,
)
)
elif output_options.output_format == "parquet":
@@ -284,7 +286,10 @@ def play_batch(
accumulators.append(
ParquetDataAccumulator(
- output_options.output, output_options.include_board_tensor
+ tuple(p.color for p in players),
+ game_config.map_type,
+ output_options.output,
+ output_options.include_board_tensor,
)
)
elif output_options.output_format == "json":
diff --git a/catanatron/catanatron/features.py b/catanatron/catanatron/features.py
index d74e0f891..a09567c79 100644
--- a/catanatron/catanatron/features.py
+++ b/catanatron/catanatron/features.py
@@ -12,7 +12,7 @@
)
from catanatron.models.board import STATIC_GRAPH, get_edges, get_node_distances
from catanatron.models.map import NUM_TILES, CatanMap, build_map, number_probability
-from catanatron.models.player import Player, Color, SimplePlayer
+from catanatron.models.player import Color, SimplePlayer
from catanatron.models.enums import (
DEVELOPMENT_CARDS,
RESOURCES,
@@ -106,7 +106,7 @@ def resource_hand_features(game: Game, p0_color: Color):
]
for card in DEVELOPMENT_CARDS:
features[f"P0_{card}_IN_HAND"] = player_state[key + f"_{card}_IN_HAND"]
- features[f"P0_HAS_PLAYED_DEVELOPMENT_CARD_IN_TURN"] = player_state[
+ features["P0_HAS_PLAYED_DEVELOPMENT_CARD_IN_TURN"] = player_state[
key + "_HAS_PLAYED_DEVELOPMENT_CARD_IN_TURN"
]
@@ -132,7 +132,7 @@ def map_tile_features(catan_map: CatanMap, robber_coordinate):
for tile_id, tile in catan_map.tiles_by_id.items():
for resource in RESOURCES:
features[f"TILE{tile_id}_IS_{resource}"] = tile.resource == resource
- features[f"TILE{tile_id}_IS_DESERT"] = tile.resource == None
+ features[f"TILE{tile_id}_IS_DESERT"] = tile.resource is None
features[f"TILE{tile_id}_PROBA"] = (
0 if tile.resource is None else number_probability(tile.number)
)
diff --git a/catanatron/catanatron/gym/accumulators.py b/catanatron/catanatron/gym/accumulators.py
index fd9e63212..a7116cff1 100644
--- a/catanatron/catanatron/gym/accumulators.py
+++ b/catanatron/catanatron/gym/accumulators.py
@@ -1,15 +1,16 @@
import os
-from collections import defaultdict
import time
+from collections import defaultdict
+from typing import Tuple, Literal
-from catanatron.utils import format_secs
import numpy as np
import pandas as pd
+from catanatron import Action, Color, Game
from catanatron.features import create_sample
from catanatron.game import GameAccumulator
from catanatron.gym.board_tensor_features import create_board_tensor
-from catanatron.gym.envs.catanatron_env import to_action_space, to_action_type_space
+from catanatron.gym.envs.action_space import to_action_space, to_action_type_space
from catanatron.gym.utils import (
DISCOUNT_FACTOR,
get_tournament_total_return,
@@ -17,11 +18,14 @@
populate_matrices,
simple_total_return,
)
+from catanatron.utils import format_secs
class ReinforcementLearningAccumulator(GameAccumulator):
def __init__(
self,
+ player_colors: Tuple[Color],
+ map_type: Literal["BASE", "TOURNAMENT", "MINI"] = "BASE",
include_board_tensor=True,
total_return_fns={
"RETURN": simple_total_return,
@@ -29,6 +33,8 @@ def __init__(
"VICTORY_POINTS_RETURN": get_victory_points_total_return,
},
):
+ self.player_colors = player_colors
+ self.map_type = map_type
self.include_board_tensor = include_board_tensor
# TODO: Generalize to "rewards_fn" that can yield intermediary rewards
# while still rewarding big on terminal states.
@@ -45,14 +51,17 @@ def before(self, game):
if self.include_board_tensor:
self.data["board_tensors"] = []
- def step(self, game_before_action, action):
+ def step(self, game_before_action: Game, action: Action):
self.data["color_action_indices"][action.color].append(
len(self.data["samples"])
)
self.data["acting_color"].append(action.color)
self.data["samples"].append(create_sample(game_before_action, action.color))
self.data["actions"].append(
- [to_action_space(action), to_action_type_space(action.action_type)]
+ [
+ to_action_space(action, self.player_colors, self.map_type),
+ to_action_type_space(action.action_type),
+ ]
)
if self.include_board_tensor:
@@ -130,8 +139,14 @@ def after(self, game):
class CsvDataAccumulator(ReinforcementLearningAccumulator):
- def __init__(self, output, include_board_tensor=True):
- super().__init__(include_board_tensor)
+ def __init__(
+ self,
+ player_colors: Tuple[Color],
+ map_type: Literal["BASE", "TOURNAMENT", "MINI"],
+ output,
+ include_board_tensor=True,
+ ):
+ super().__init__(player_colors, map_type, include_board_tensor)
self.output = output
def after(self, game):
@@ -164,8 +179,14 @@ def after(self, game):
class ParquetDataAccumulator(ReinforcementLearningAccumulator):
- def __init__(self, output, include_board_tensor=True):
- super().__init__(include_board_tensor)
+ def __init__(
+ self,
+ player_colors: Tuple[Color],
+ map_type: Literal["BASE", "TOURNAMENT", "MINI"],
+ output,
+ include_board_tensor=True,
+ ):
+ super().__init__(player_colors, map_type, include_board_tensor)
self.output = output
def after(self, game):
diff --git a/catanatron/catanatron/gym/envs/action_space.py b/catanatron/catanatron/gym/envs/action_space.py
new file mode 100644
index 000000000..68537c726
--- /dev/null
+++ b/catanatron/catanatron/gym/envs/action_space.py
@@ -0,0 +1,103 @@
+from functools import lru_cache
+from typing import Tuple, Literal
+
+from catanatron.models.actions import Action
+from catanatron.models.board import get_edges
+from catanatron.models.enums import RESOURCES, ActionType
+from catanatron.models.player import Color
+from catanatron.models.map import build_map
+
+
+@lru_cache(maxsize=None)
+def get_action_array(
+ player_colors: Tuple[Color], map_type: Literal["BASE", "TOURNAMENT", "MINI"]
+):
+ catan_map = build_map(map_type)
+ num_nodes = len(catan_map.land_nodes)
+
+ # We sort the actions to ensure a consistent ordering and reproducibility
+ # without sorting, we couldn't get gym usages to be reproducible
+ actions_array = sorted(
+ [
+ (ActionType.ROLL, None),
+ (ActionType.DISCARD, None),
+ *[
+ (ActionType.BUILD_ROAD, tuple(sorted(edge)))
+ for edge in get_edges(catan_map.land_nodes)
+ ],
+ *[(ActionType.BUILD_SETTLEMENT, node_id) for node_id in range(num_nodes)],
+ *[(ActionType.BUILD_CITY, node_id) for node_id in range(num_nodes)],
+ (ActionType.BUY_DEVELOPMENT_CARD, None),
+ (ActionType.PLAY_KNIGHT_CARD, None),
+ *[
+ (ActionType.PLAY_YEAR_OF_PLENTY, (first_card, RESOURCES[j]))
+ for i, first_card in enumerate(RESOURCES)
+ for j in range(i, len(RESOURCES))
+ ],
+ *[
+ (ActionType.PLAY_YEAR_OF_PLENTY, (first_card,))
+ for first_card in RESOURCES
+ ],
+ (ActionType.PLAY_ROAD_BUILDING, None),
+ *[(ActionType.PLAY_MONOPOLY, r) for r in RESOURCES],
+ # Move Robber actions include to every tile and from each opponent
+ *[
+ (ActionType.MOVE_ROBBER, (coordinates, victim_color))
+ for coordinates in catan_map.land_tiles.keys()
+ for victim_color in [None] + list(player_colors)
+ ],
+ # 4:1 with bank
+ *[
+ (ActionType.MARITIME_TRADE, tuple(4 * [i] + [j]))
+ for i in RESOURCES
+ for j in RESOURCES
+ if i != j
+ ],
+ # 3:1 with port
+ *[
+ (ActionType.MARITIME_TRADE, tuple(3 * [i] + [None, j])) # type: ignore
+ for i in RESOURCES
+ for j in RESOURCES
+ if i != j
+ ],
+ # 2:1 with port
+ *[
+ (ActionType.MARITIME_TRADE, tuple(2 * [i] + [None, None, j])) # type: ignore
+ for i in RESOURCES
+ for j in RESOURCES
+ if i != j
+ ],
+ (ActionType.END_TURN, None),
+ ],
+ key=lambda x: str(x),
+ )
+ return actions_array
+
+
+ACTION_TYPES = [i for i in ActionType]
+
+
+def to_action_type_space(action_type: ActionType) -> int:
+ return ACTION_TYPES.index(action_type)
+
+
+def to_action_space(
+ action: Action,
+ player_colors: Tuple[Color],
+ map_type: Literal["BASE", "TOURNAMENT", "MINI"],
+):
+ """maps action to space_action equivalent integer"""
+ actions_array = get_action_array(player_colors, map_type)
+ return actions_array.index((action.action_type, action.value))
+
+
+def from_action_space(
+ action_int,
+ color: Color,
+ player_colors: Tuple[Color],
+ map_type: Literal["BASE", "TOURNAMENT", "MINI"],
+):
+ """maps action_int to catantron.models.actions.Action"""
+ actions_array = get_action_array(player_colors, map_type)
+ (action_type, value) = actions_array[action_int]
+ return Action(color, action_type, value)
diff --git a/catanatron/catanatron/gym/envs/catanatron_env.py b/catanatron/catanatron/gym/envs/catanatron_env.py
index a1f43db93..40d71ced8 100644
--- a/catanatron/catanatron/gym/envs/catanatron_env.py
+++ b/catanatron/catanatron/gym/envs/catanatron_env.py
@@ -1,204 +1,126 @@
-from typing import TypedDict, Union
+from typing import TypedDict, Literal, Callable, List, Any
+import random
import gymnasium as gym
from gymnasium import spaces
import numpy as np
+from catanatron import Action
from catanatron.game import Game, TURNS_LIMIT
from catanatron.models.player import Color, Player, RandomPlayer
-from catanatron.models.map import BASE_MAP_TEMPLATE, NUM_NODES, LandTile, build_map
-from catanatron.models.enums import RESOURCES, Action, ActionType
-from catanatron.models.board import get_edges
+from catanatron.models.map import build_map
from catanatron.features import (
create_sample,
get_feature_ordering,
)
+from catanatron.gym.envs.action_space import (
+ to_action_space,
+ from_action_space,
+ get_action_array,
+)
from catanatron.gym.board_tensor_features import (
create_board_tensor,
get_channels,
is_graph_feature,
)
-
-BASE_TOPOLOGY = BASE_MAP_TEMPLATE.topology
-TILE_COORDINATES = [x for x, y in BASE_TOPOLOGY.items() if y == LandTile]
-ACTIONS_ARRAY = [
- (ActionType.ROLL, None),
- # TODO: One for each tile (and abuse 1v1 setting).
- *[(ActionType.MOVE_ROBBER, tile) for tile in TILE_COORDINATES],
- (ActionType.DISCARD, None),
- *[(ActionType.BUILD_ROAD, tuple(sorted(edge))) for edge in get_edges()],
- *[(ActionType.BUILD_SETTLEMENT, node_id) for node_id in range(NUM_NODES)],
- *[(ActionType.BUILD_CITY, node_id) for node_id in range(NUM_NODES)],
- (ActionType.BUY_DEVELOPMENT_CARD, None),
- (ActionType.PLAY_KNIGHT_CARD, None),
- *[
- (ActionType.PLAY_YEAR_OF_PLENTY, (first_card, RESOURCES[j]))
- for i, first_card in enumerate(RESOURCES)
- for j in range(i, len(RESOURCES))
- ],
- *[(ActionType.PLAY_YEAR_OF_PLENTY, (first_card,)) for first_card in RESOURCES],
- (ActionType.PLAY_ROAD_BUILDING, None),
- *[(ActionType.PLAY_MONOPOLY, r) for r in RESOURCES],
- # 4:1 with bank
- *[
- (ActionType.MARITIME_TRADE, tuple(4 * [i] + [j]))
- for i in RESOURCES
- for j in RESOURCES
- if i != j
- ],
- # 3:1 with port
- *[
- (ActionType.MARITIME_TRADE, tuple(3 * [i] + [None, j])) # type: ignore
- for i in RESOURCES
- for j in RESOURCES
- if i != j
- ],
- # 2:1 with port
- *[
- (ActionType.MARITIME_TRADE, tuple(2 * [i] + [None, None, j])) # type: ignore
- for i in RESOURCES
- for j in RESOURCES
- if i != j
- ],
- (ActionType.END_TURN, None),
-]
-ACTION_SPACE_SIZE = len(ACTIONS_ARRAY)
-ACTION_TYPES = [i for i in ActionType]
-
-
-def to_action_type_space(action_type: ActionType) -> int:
- return ACTION_TYPES.index(action_type)
-
-
-# NOTE: I think I don't need this if we separate action and action_record nicely...
-def normalize_action(action):
- normalized = action
- if normalized.action_type == ActionType.ROLL:
- return Action(action.color, action.action_type, None)
- elif normalized.action_type == ActionType.MOVE_ROBBER:
- return Action(action.color, action.action_type, action.value[0])
- elif normalized.action_type == ActionType.BUILD_ROAD:
- return Action(action.color, action.action_type, tuple(sorted(action.value)))
- elif normalized.action_type == ActionType.BUY_DEVELOPMENT_CARD:
- return Action(action.color, action.action_type, None)
- elif normalized.action_type == ActionType.DISCARD:
- return Action(action.color, action.action_type, None)
- return normalized
-
-
-def to_action_space(action):
- """maps action to space_action equivalent integer"""
- normalized = normalize_action(action)
- return ACTIONS_ARRAY.index((normalized.action_type, normalized.value))
-
-
-def from_action_space(action_int, playable_actions):
- """maps action_int to catantron.models.actions.Action"""
- # Get "catan_action" based on space action.
- # i.e. Take first action in playable that matches ACTIONS_ARRAY blueprint
- (action_type, value) = ACTIONS_ARRAY[action_int]
- catan_action = None
- for action in playable_actions:
- normalized = normalize_action(action)
- if normalized.action_type == action_type and normalized.value == value:
- catan_action = action
- break # return the first one
- assert catan_action is not None
- return catan_action
-
-
-FEATURES = get_feature_ordering(num_players=2)
-NUM_FEATURES = len(FEATURES)
-
# Highest features is NUM_RESOURCES_IN_HAND which in theory is all resource cards
HIGH = 19 * 5
-def simple_reward(game, p0_color):
- winning_color = game.winning_color()
- if p0_color == winning_color:
- return 1
- elif winning_color is None:
- return 0
- else:
- return -1
+class ObservationSpec(TypedDict):
+ encode: Callable[[Game, Color], Any]
+ space: gym.Space
+
+
+class CatanatronEnvConfig(TypedDict):
+ # Game Config
+ map_type: Literal["BASE", "MINI"]
+ vps_to_win: int
+ enemies: List[Player]
+ # Env Config
+ observation_spec: ObservationSpec
+ invalid_action_reward: float
+ reward_function: Callable[[Action, Game, Color], float]
-class MixedObservation(TypedDict):
- board: np.ndarray
- numeric: np.ndarray
+ # Render Config
+ render_mode: Literal["rgb_array", "db"]
+ render_scale: float
class CatanatronEnv(gym.Env):
- metadata = {"render_modes": []}
+ metadata = {"render_modes": ["rgb_array", "db"], "render_fps": 10}
+
+ def __init__(self, config: CatanatronEnvConfig = None):
+ self.dtype = np.float32
- def __init__(self, config=None):
self.config = config or dict()
- self.invalid_action_reward = self.config.get("invalid_action_reward", -1)
- self.reward_function = self.config.get("reward_function", simple_reward)
self.map_type = self.config.get("map_type", "BASE")
self.vps_to_win = self.config.get("vps_to_win", 10)
self.enemies = self.config.get("enemies", [RandomPlayer(Color.RED)])
- self.representation = self.config.get("representation", "vector")
-
+ self.player_colors = tuple([Color.BLUE] + [p.color for p in self.enemies])
assert all(p.color != Color.BLUE for p in self.enemies)
- assert self.representation in ["mixed", "vector"]
self.p0 = Player(Color.BLUE)
self.players = [self.p0] + self.enemies # type: ignore
- self.representation = "mixed" if self.representation == "mixed" else "vector"
- self.features = get_feature_ordering(len(self.players), self.map_type)
+
+ self.observation_spec = self.config.get("observation_spec", None)
+ if self.observation_spec is None:
+ self.observation_spec = build_vector_obs_spec(
+ len(self.player_colors), self.map_type
+ )
+ self.observation_space = self.observation_spec["space"]
+
+ # Build action space depending on map type
+ self.action_array = get_action_array(self.player_colors, self.map_type)
+ self.action_space_size = len(self.action_array)
+ self.action_space = spaces.Discrete(self.action_space_size)
+ self.invalid_action_reward = self.config.get("invalid_action_reward", -1)
+ self.reward_function = self.config.get("reward_function", simple_reward)
self.invalid_actions_count = 0
self.max_invalid_actions = 10
- # TODO: Make self.action_space tighter if possible (per map_type)
- self.action_space = spaces.Discrete(ACTION_SPACE_SIZE)
-
- if self.representation == "mixed":
- channels = get_channels(len(self.players))
- board_tensor_space = spaces.Box(
- low=0, high=1, shape=(channels, 21, 11), dtype=np.float64
- )
- self.numeric_features = [
- f for f in self.features if not is_graph_feature(f)
- ]
- # TODO: This could be tigher (e.g. _ROADS_AVAILABLE <= 15)
- numeric_space = spaces.Box(
- low=0, high=HIGH, shape=(len(self.numeric_features),), dtype=np.float64
- )
- mixed = spaces.Dict(
- {
- "board": board_tensor_space,
- "numeric": numeric_space,
- }
- )
- self.observation_space = mixed
- else:
- # TODO: This could be tigher (e.g. _ROADS_AVAILABLE <= 15)
- self.observation_space = spaces.Box(
- low=0, high=HIGH, shape=(len(self.features),), dtype=np.float64
- )
+ # Render config
+ self.render_mode = self.config.get("render_mode", None)
+ self.render_scale = self.config.get("render_scale", 1.0)
+ self.renderer = None # Lazy init on first render()
self.reset()
def get_valid_actions(self):
"""
Returns:
- List[int]: valid actions
+ List[int]: valid actions (sorted for reproducibility)
+ """
+ return sorted(
+ [
+ to_action_space(a, self.player_colors, self.map_type)
+ for a in self.game.playable_actions
+ ]
+ )
+
+ def action_masks(self) -> list[bool]:
+ """
+ This method is to be compatible with SB3 SubprocVecEnv.
+ See https://sb3-contrib.readthedocs.io/en/master/modules/ppo_mask.html
+
+ Returns:
+ List[bool]: action masks
"""
- return list(map(to_action_space, self.game.playable_actions))
+ mask = np.zeros(self.action_space_size, dtype=bool)
+ mask[self.get_valid_actions()] = True
+ return mask
def step(self, action):
try:
- catan_action = from_action_space(action, self.game.playable_actions)
- except Exception as e:
+ catan_action = from_action_space(
+ action, self.p0.color, self.player_colors, self.map_type
+ )
+ assert catan_action in self.game.playable_actions
+ except AssertionError:
self.invalid_actions_count += 1
observation = self._get_observation()
winning_color = self.game.winning_color()
- done = (
- winning_color is not None
- or self.invalid_actions_count > self.max_invalid_actions
- )
terminated = winning_color is not None
truncated = (
self.invalid_actions_count > self.max_invalid_actions
@@ -216,7 +138,7 @@ def step(self, action):
winning_color = self.game.winning_color()
terminated = winning_color is not None
truncated = self.game.state.num_turns >= TURNS_LIMIT
- reward = self.reward_function(self.game, self.p0.color)
+ reward = self.reward_function(catan_action, self.game, self.p0.color)
return observation, reward, terminated, truncated, info
@@ -227,6 +149,9 @@ def reset(
):
super().reset(seed=seed)
+ if seed is not None:
+ # Ensure map generation uses the same seed as the game.
+ random.seed(seed)
catan_map = build_map(self.map_type)
for player in self.players:
player.reset_state()
@@ -245,16 +170,8 @@ def reset(
return observation, info
- def _get_observation(self) -> Union[np.ndarray, MixedObservation]:
- sample = create_sample(self.game, self.p0.color)
- if self.representation == "mixed":
- board_tensor = create_board_tensor(
- self.game, self.p0.color, channels_first=True
- )
- numeric = np.array([float(sample[i]) for i in self.numeric_features])
- return {"board": board_tensor, "numeric": numeric}
-
- return np.array([float(sample[i]) for i in self.features])
+ def _get_observation(self) -> Any:
+ return self.observation_spec["encode"](self.game, self.p0.color)
def _advance_until_p0_decision(self):
while (
@@ -263,14 +180,108 @@ def _advance_until_p0_decision(self):
):
self.game.play_tick() # will play bot
+ def render(self):
+ """Render the game state.
+
+ Returns:
+ np.ndarray: RGB array (height, width, 3) if render_mode is "rgb_array", None otherwise
+ """
+ if self.render_mode == "rgb_array":
+ if self.renderer is None:
+ from catanatron.gym.envs.pygame_renderer import PygameRenderer
+
+ self.renderer = PygameRenderer(render_scale=self.render_scale)
+ return self.renderer.render(self.game)
+ if self.render_mode == "db":
+ from catanatron.web.utils import ensure_link
+
+ link = ensure_link(self.game, get_replay_link=True)
+ if self._is_done():
+ print(f"Replay link: {link}")
+ return None
+ return None
+
+ def close(self):
+ """Clean up resources."""
+ if self.renderer is not None:
+ self.renderer.close()
+ self.renderer = None
+
+ def _is_done(self) -> bool:
+ return (
+ self.game.winning_color() is not None
+ or self.game.state.num_turns >= TURNS_LIMIT
+ or self.invalid_actions_count > self.max_invalid_actions
+ )
+
+
+# Default Reward and Feature Encoders
+def simple_reward(action, game, p0_color):
+ winning_color = game.winning_color()
+ if p0_color == winning_color:
+ return 1
+ elif winning_color is None:
+ return 0
+ else:
+ return -1
+
+
+def build_vector_obs_spec(
+ num_players: int, map_type: Literal["BASE", "MINI"], dtype: np.dtype = np.float32
+) -> ObservationSpec:
+ features = get_feature_ordering(num_players, map_type)
+
+ def vector_encode(game, p0_color):
+ sample = create_sample(game, p0_color)
+ return np.array([sample[i] for i in features], dtype=dtype)
+
+ return ObservationSpec(
+ encode=vector_encode,
+ space=spaces.Box(low=0, high=HIGH, shape=(len(features),), dtype=dtype),
+ )
+
+
+def build_mixed_obs_spec(
+ num_players: int, map_type: Literal["BASE", "MINI"], dtype: np.dtype = np.float32
+) -> ObservationSpec:
+ features = get_feature_ordering(num_players, map_type)
+ numeric_features = [f for f in features if not is_graph_feature(f)]
+
+ channels = get_channels(num_players)
+ board_tensor_space = spaces.Box(
+ low=0, high=1, shape=(channels, 21, 11), dtype=dtype
+ )
+ # TODO: This could be tigher (e.g. _ROADS_AVAILABLE <= 15)
+ numeric_space = spaces.Box(
+ low=0, high=HIGH, shape=(len(numeric_features),), dtype=dtype
+ )
+
+ def mixed_encode(game, p0_color):
+ sample = create_sample(game, p0_color)
+ board_tensor = create_board_tensor(game, p0_color, channels_first=True)
+ numeric = np.array([sample[i] for i in numeric_features], dtype=dtype)
+ return {"board": board_tensor, "numeric": numeric}
+
+ return ObservationSpec(
+ encode=mixed_encode,
+ space=spaces.Dict(
+ {
+ "board": board_tensor_space,
+ "numeric": numeric_space,
+ }
+ ),
+ )
+
+
+CatanatronEnv.__doc__ = """
+Configurable Catan Gym Environment.
-CatanatronEnv.__doc__ = f"""
-1v1 environment against a random player
+By default, it is a 1v1 environment against a random player in the BASE map.
Attributes:
reward_range: -1 if player lost, 1 if player won, 0 otherwise.
- action_space: Integers from the [0, 289] interval.
- See Action Space table below.
+ action_space: Integers from the [0, 327] interval (in 1v1 BASE).
+ It is smaller if the MINI map is used. See Action Space table below.
observation_space: Numeric Feature Vector. See Observation Space table
below for quantities. They appear in vector in alphabetical order,
from the perspective of "current" player (hiding/showing information
@@ -295,7 +306,7 @@ def _advance_until_p0_decision(self):
* - Integer
- Catanatron Action
"""
-for i, v in enumerate(ACTIONS_ARRAY):
+for i, v in enumerate(get_action_array((Color.BLUE, Color.RED), "BASE")):
CatanatronEnv.__doc__ += f" * - {i}\n - {v}\n"
CatanatronEnv.__doc__ += """
diff --git a/catanatron/catanatron/gym/envs/pygame_renderer.py b/catanatron/catanatron/gym/envs/pygame_renderer.py
new file mode 100644
index 000000000..b91faffd6
--- /dev/null
+++ b/catanatron/catanatron/gym/envs/pygame_renderer.py
@@ -0,0 +1,459 @@
+"""
+Pygame renderer for Catanatron environment.
+
+Renders the Catan board using hexagonal tiles in a minimalist style.
+"""
+
+import math
+from typing import Tuple
+import numpy as np
+import pygame
+
+from catanatron.models.enums import WOOD, BRICK, SHEEP, WHEAT, ORE, SETTLEMENT, CITY
+from catanatron.models.player import Color
+from catanatron.game import Game
+
+
+# Base constants (scale with render_scale)
+HEX_SIZE = 70 # pixels (balanced for both MINI and BASE maps)
+SCREEN_WIDTH = 1000
+SCREEN_HEIGHT = 800
+ROAD_WIDTH = 8 # Scaled proportionally
+SETTLEMENT_RADIUS = 12 # Scaled proportionally
+CITY_RADIUS = 18 # Scaled proportionally
+
+# Colors (RGB)
+COLORS = {
+ WOOD: (0, 86, 35), # forest green
+ BRICK: (125, 45, 0), # brick red
+ SHEEP: (46, 159, 53), # light green
+ WHEAT: (250, 131, 11), # gold
+ ORE: (93, 93, 93), # gray
+ None: (246, 195, 104), # tan (desert)
+}
+
+PLAYER_COLORS = {
+ Color.BLUE: (13, 78, 211),
+ Color.RED: (250, 0, 0),
+ Color.ORANGE: (255, 165, 0),
+ Color.WHITE: (250, 250, 250),
+}
+
+BACKGROUND_COLOR = (90, 180, 215) # light blue (water)
+OUTLINE_COLOR = (0, 0, 0) # black
+BORDER_COLOR = (0, 0, 0) # black
+ROBBER_COLOR = (10, 10, 10) # black
+TEXT_COLOR = (5, 5, 5) # black
+NUMBER_TOKEN_COLOR = (245, 222, 179) # beige/wheat color like physical game
+RED_NUMBER_COLOR = (200, 0, 0) # red for 6 and 8
+
+
+class PygameRenderer:
+ """
+ Renders the Catan board using pygame.
+
+ Returns numpy arrays (rgb_array) for compatibility with gymnasium's RecordVideo.
+ """
+
+ def __init__(self, render_scale: float = 1.0):
+ pygame.init()
+ pygame.font.init()
+
+ self.render_scale = max(1.0, float(render_scale))
+ self.base_width = SCREEN_WIDTH
+ self.base_height = SCREEN_HEIGHT
+ self.render_width = int(self.base_width * self.render_scale)
+ self.render_height = int(self.base_height * self.render_scale)
+
+ # Scaled sizing for higher-res rendering
+ self.hex_size = int(HEX_SIZE * self.render_scale)
+ self.road_width = max(1, int(ROAD_WIDTH * self.render_scale))
+ self.settlement_radius = int(SETTLEMENT_RADIUS * self.render_scale)
+ self.city_radius = int(CITY_RADIUS * self.render_scale)
+ self.outline_width = max(1, int(2 * self.render_scale))
+
+ # Create surface for rendering (headless)
+ self.surface = pygame.Surface((self.render_width, self.render_height))
+ self.font = pygame.font.Font(None, int(32 * self.render_scale))
+
+ # Center offset to position the map in the middle of the screen
+ self.center_x = self.render_width // 2
+ self.center_y = self.render_height // 2
+
+ def cube_to_pixel(self, coord: Tuple[int, int, int]) -> Tuple[float, float]:
+ """Convert cube coordinate to pixel position (flat-top hexagon).
+
+ Args:
+ coord: Cube coordinate (x, y, z) where x + y + z = 0
+
+ Returns:
+ Pixel position (px, py)
+ """
+ x, y, z = coord
+ # Convert cube to axial
+ q = x
+ r = z
+ # Axial to pixel (flat-top hexagon)
+ px = self.hex_size * (math.sqrt(3) * q + math.sqrt(3) / 2 * r)
+ py = self.hex_size * (3 / 2 * r)
+
+ # Apply center offset
+ px += self.center_x
+ py += self.center_y
+
+ return (px, py)
+
+ def hexagon_corners(self, center: Tuple[float, float], size: float) -> list:
+ """Get the 6 corner points of a pointy-top hexagon.
+
+ Args:
+ center: Center position (x, y)
+ size: Radius of the hexagon
+
+ Returns:
+ List of 6 corner points
+ """
+ cx, cy = center
+ corners = []
+ for i in range(6):
+ # Pointy-top: vertices at 30°, 90°, 150°, 210°, 270°, 330°
+ angle = math.pi / 6 + math.pi / 3 * i
+ x = cx + size * math.cos(angle)
+ y = cy + size * math.sin(angle)
+ corners.append((x, y))
+ return corners
+
+ def draw_hexagon(
+ self,
+ center: Tuple[float, float],
+ size: float,
+ fill_color: Tuple[int, int, int],
+ outline: bool = True,
+ ):
+ """Draw a hexagon on the surface.
+
+ Args:
+ center: Center position (x, y)
+ size: Radius of the hexagon
+ fill_color: RGB color for fill
+ outline: Whether to draw black outline
+ """
+ corners = self.hexagon_corners(center, size)
+ pygame.draw.polygon(self.surface, fill_color, corners)
+ if outline:
+ pygame.draw.polygon(
+ self.surface, OUTLINE_COLOR, corners, self.outline_width
+ )
+
+ def get_number_pips(self, number: int) -> int:
+ """Get the number of probability pips for a dice number.
+
+ Args:
+ number: Dice number (2-12)
+
+ Returns:
+ Number of pips to display
+ """
+ pips = {
+ 2: 1,
+ 12: 1,
+ 3: 2,
+ 11: 2,
+ 4: 3,
+ 10: 3,
+ 5: 4,
+ 9: 4,
+ 6: 5,
+ 8: 5,
+ }
+ return pips.get(number, 0)
+
+ def draw_number_token(self, center: Tuple[float, float], number: int):
+ """Draw a number token like in the physical Catan game.
+
+ Args:
+ center: Center position (x, y)
+ number: The dice number (2-12)
+ """
+ # Token dimensions (scaled proportionally with HEX_SIZE)
+ token_radius = int(30 * self.render_scale)
+
+ # Determine if this is a red number (6 or 8)
+ is_red = number in [6, 8]
+ number_color = RED_NUMBER_COLOR if is_red else TEXT_COLOR
+
+ # Draw token circle (beige background with black border)
+ pygame.draw.circle(
+ self.surface,
+ NUMBER_TOKEN_COLOR,
+ (int(center[0]), int(center[1])),
+ token_radius,
+ )
+ pygame.draw.circle(
+ self.surface,
+ OUTLINE_COLOR,
+ (int(center[0]), int(center[1])),
+ token_radius,
+ self.outline_width,
+ )
+
+ # Draw the number
+ number_font = pygame.font.Font(None, int(38 * self.render_scale))
+ text = number_font.render(str(number), True, number_color)
+ text_rect = text.get_rect(
+ center=(int(center[0]), int(center[1] - int(6 * self.render_scale)))
+ )
+ self.surface.blit(text, text_rect)
+
+ # Draw pips below the number
+ num_pips = self.get_number_pips(number)
+ if num_pips > 0:
+ pip_size = max(1, int(3 * self.render_scale))
+ pip_spacing = int(5 * self.render_scale)
+ total_width = num_pips * pip_size + (num_pips - 1) * pip_spacing
+ start_x = center[0] - total_width / 2 + pip_size / 2
+ pip_y = center[1] + int(10 * self.render_scale)
+
+ for i in range(num_pips):
+ pip_x = start_x + i * (pip_size + pip_spacing)
+ pygame.draw.circle(
+ self.surface, number_color, (int(pip_x), int(pip_y)), pip_size
+ )
+
+ def draw_tile(
+ self, coord: Tuple[int, int, int], tile, robber_coord: Tuple[int, int, int]
+ ):
+ """Draw a single land tile.
+
+ Args:
+ coord: Tile coordinate (x, y, z)
+ tile: LandTile object
+ robber_coord: Coordinate of the robber
+ """
+ center = self.cube_to_pixel(coord)
+
+ # Draw hexagon with resource color
+ resource_color = COLORS.get(tile.resource, COLORS[None])
+ self.draw_hexagon(center, self.hex_size, resource_color, outline=True)
+
+ # Draw number token if not desert
+ if tile.number is not None:
+ self.draw_number_token(center, tile.number)
+
+ # Draw robber if present
+ if coord == robber_coord:
+ self.draw_robber(center)
+
+ def draw_robber(self, center: Tuple[float, float]):
+ """Draw the robber as a black circle.
+
+ Args:
+ center: Center position (x, y)
+ """
+ robber_radius = int(18 * self.render_scale)
+ offset_x = int(self.hex_size * 0.50)
+ robber_center = (int(center[0] + offset_x), int(center[1]))
+ pygame.draw.circle(self.surface, ROBBER_COLOR, robber_center, robber_radius)
+ # White outline to make it visible on dark tiles
+ pygame.draw.circle(
+ self.surface,
+ BORDER_COLOR,
+ robber_center,
+ robber_radius,
+ self.outline_width,
+ )
+ r_font = pygame.font.Font(None, int(24 * self.render_scale))
+ r_text = r_font.render("R", True, (255, 255, 255))
+ r_rect = r_text.get_rect(center=robber_center)
+ self.surface.blit(r_text, r_rect)
+
+ def get_node_delta(self, direction: str, size: float) -> Tuple[float, float]:
+ """Get the offset from tile center to node based on direction.
+
+ Matches the frontend getNodeDelta function.
+
+ Args:
+ direction: NodeRef direction (NORTH, NORTHEAST, etc.)
+ size: Hexagon size
+
+ Returns:
+ (delta_x, delta_y) offset from tile center
+ """
+ w = math.sqrt(3) * size # SQRT3 * size
+ h = 2 * size
+
+ deltas = {
+ "NORTH": (0, -h / 2),
+ "NORTHEAST": (w / 2, -h / 4),
+ "SOUTHEAST": (w / 2, h / 4),
+ "SOUTH": (0, h / 2),
+ "SOUTHWEST": (-w / 2, h / 4),
+ "NORTHWEST": (-w / 2, -h / 4),
+ }
+ return deltas.get(direction, (0, 0))
+
+ def get_node_pixel_position(self, node_id: int, game: Game) -> Tuple[float, float]:
+ """Get pixel position for a node.
+
+ Args:
+ node_id: Node ID
+ game: Game object
+
+ Returns:
+ Pixel position (x, y)
+ """
+ board = game.state.board
+
+ # Find a tile that contains this node and get its direction
+ for coord, tile in board.map.land_tiles.items():
+ for node_ref, nid in tile.nodes.items():
+ if nid == node_id:
+ # Found the tile and direction
+ tile_center = self.cube_to_pixel(coord)
+ delta = self.get_node_delta(node_ref.value, self.hex_size)
+ return (tile_center[0] + delta[0], tile_center[1] + delta[1])
+
+ return (0, 0)
+
+ def draw_node(self, node_id: int, color: Color, building_type: str, game: Game):
+ """Draw a settlement or city at a node.
+
+ Args:
+ node_id: Node ID
+ color: Player color
+ building_type: "SETTLEMENT" or "CITY"
+ game: Game object
+ """
+ pos = self.get_node_pixel_position(node_id, game)
+ player_color = PLAYER_COLORS.get(color, (128, 128, 128))
+
+ if building_type == SETTLEMENT:
+ # Draw settlement as a small square
+ size = self.settlement_radius * 2
+ rect = pygame.Rect(
+ int(pos[0] - size / 2), int(pos[1] - size / 2), size, size
+ )
+ pygame.draw.rect(self.surface, player_color, rect)
+ pygame.draw.rect(self.surface, BORDER_COLOR, rect, self.outline_width)
+ elif building_type == CITY:
+ # Draw city as a larger square
+ size = self.city_radius * 2
+ rect = pygame.Rect(
+ int(pos[0] - size / 2), int(pos[1] - size / 2), size, size
+ )
+ pygame.draw.rect(self.surface, player_color, rect)
+ pygame.draw.rect(self.surface, BORDER_COLOR, rect, self.outline_width)
+
+ def draw_edge(self, edge_id: Tuple[int, int], color: Color, game: Game):
+ """Draw a road along an edge.
+
+ Args:
+ edge_id: Edge ID (node_id_1, node_id_2)
+ color: Player color
+ game: Game object
+ """
+ # Get positions of both nodes that define this edge
+ node1_pos = self.get_node_pixel_position(edge_id[0], game)
+ node2_pos = self.get_node_pixel_position(edge_id[1], game)
+
+ # Check if either position is invalid
+ if node1_pos == (0, 0) or node2_pos == (0, 0):
+ return
+
+ # Shorten road so borders do not overlap at node positions
+ dx = node2_pos[0] - node1_pos[0]
+ dy = node2_pos[1] - node1_pos[1]
+ length = math.hypot(dx, dy)
+ if length == 0:
+ return
+ pad = max(self.settlement_radius * 0.6, self.road_width * 0.9)
+ trim = min(pad, length * 0.4)
+ ux = dx / length
+ uy = dy / length
+ start = (node1_pos[0] + ux * trim, node1_pos[1] + uy * trim)
+ end = (node2_pos[0] - ux * trim, node2_pos[1] - uy * trim)
+
+ player_color = PLAYER_COLORS.get(color, (128, 128, 128))
+ border_thickness = max(2, int(2 * self.render_scale))
+ half_width = self.road_width / 2
+ half_border_width = (self.road_width + 2 * border_thickness) / 2
+ border_start = (
+ start[0] - ux * border_thickness,
+ start[1] - uy * border_thickness,
+ )
+ border_end = (end[0] + ux * border_thickness, end[1] + uy * border_thickness)
+
+ # Perpendicular unit vector for rectangle width
+ nx = -uy
+ ny = ux
+
+ def rect_corners(p0, p1, half_w):
+ return [
+ (p0[0] + nx * half_w, p0[1] + ny * half_w),
+ (p0[0] - nx * half_w, p0[1] - ny * half_w),
+ (p1[0] - nx * half_w, p1[1] - ny * half_w),
+ (p1[0] + nx * half_w, p1[1] + ny * half_w),
+ ]
+
+ border_rect = rect_corners(border_start, border_end, half_border_width)
+ road_rect = rect_corners(start, end, half_width)
+
+ pygame.draw.polygon(
+ self.surface, BORDER_COLOR, [(int(x), int(y)) for x, y in border_rect]
+ )
+ pygame.draw.polygon(
+ self.surface, player_color, [(int(x), int(y)) for x, y in road_rect]
+ )
+
+ def render(self, game: Game) -> np.ndarray:
+ """Render the game state and return as numpy array.
+
+ Args:
+ game: Game object
+
+ Returns:
+ RGB array (height, width, 3) for gymnasium RecordVideo
+ """
+ # Clear surface
+ self.surface.fill(BACKGROUND_COLOR)
+
+ # Get board state
+ board = game.state.board
+
+ # Draw all land tiles
+ for coord, tile in board.map.land_tiles.items():
+ self.draw_tile(coord, tile, board.robber_coordinate)
+
+ # Draw all roads
+ for edge_id, color in board.roads.items():
+ self.draw_edge(edge_id, color, game)
+
+ # Draw all buildings (settlements and cities)
+ for node_id, (color, building_type) in board.buildings.items():
+ self.draw_node(node_id, color, building_type, game)
+
+ # Watermark
+ watermark = self.font.render("CATANATRON", True, TEXT_COLOR)
+ margin = int(8 * self.render_scale)
+ wm_rect = watermark.get_rect(
+ bottomright=(self.render_width - margin, self.render_height - margin)
+ )
+ self.surface.blit(watermark, wm_rect)
+
+ # Convert surface to numpy array
+ # pygame.surfarray.array3d returns (width, height, 3), we need (height, width, 3)
+ if self.render_scale > 1.0:
+ output_surface = pygame.transform.smoothscale(
+ self.surface, (self.base_width, self.base_height)
+ )
+ else:
+ output_surface = self.surface
+
+ array = pygame.surfarray.array3d(output_surface)
+ array = np.transpose(array, (1, 0, 2)) # Transpose to (height, width, channels)
+
+ return array
+
+ def close(self):
+ """Clean up pygame resources."""
+ pygame.quit()
diff --git a/catanatron/catanatron/models/enums.py b/catanatron/catanatron/models/enums.py
index 49c0debb2..a129cd458 100644
--- a/catanatron/catanatron/models/enums.py
+++ b/catanatron/catanatron/models/enums.py
@@ -74,7 +74,9 @@ class ActionType(Enum):
ROLL = "ROLL" # value is None
MOVE_ROBBER = "MOVE_ROBBER" # value is (coordinate, Color|None).
- DISCARD = "DISCARD" # value is None|Resource[]. TODO: Should always be Resource[].
+
+ # TODO: None for now to avoid complexity, but should be Resource[].
+ DISCARD = "DISCARD" # value is None
# Building/Buying
BUILD_ROAD = "BUILD_ROAD" # value is edge_id
diff --git a/catanatron/catanatron/models/map.py b/catanatron/catanatron/models/map.py
index 2219bac8c..6d8c9a3fc 100644
--- a/catanatron/catanatron/models/map.py
+++ b/catanatron/catanatron/models/map.py
@@ -367,7 +367,7 @@ def initialize_tiles(
port_autoinc += 1
elif tile_type == LandTile:
resource = shuffled_tile_resources.pop()
- if resource != None:
+ if resource is not None:
number = shuffled_numbers.pop()
tile = LandTile(tile_autoinc, resource, number, nodes, edges)
else:
diff --git a/catanatron_experimental/catanatron_experimental/machine_learning/players/reinforcement.py b/catanatron_experimental/catanatron_experimental/machine_learning/players/reinforcement.py
index 705d52ac2..5b8d9fde4 100644
--- a/catanatron_experimental/catanatron_experimental/machine_learning/players/reinforcement.py
+++ b/catanatron_experimental/catanatron_experimental/machine_learning/players/reinforcement.py
@@ -4,19 +4,24 @@
import tensorflow as tf
from tensorflow import keras
-from catanatron.models.player import Player
-from catanatron.models.enums import Action, ActionType
+from catanatron import Color, Player
from catanatron.features import (
create_sample,
create_sample_vector,
get_feature_ordering,
)
-from catanatron.gym.envs.catanatron_env import ACTIONS_ARRAY, ACTION_SPACE_SIZE
+from catanatron.gym.envs.action_space import (
+ get_action_array,
+)
from catanatron.gym.board_tensor_features import (
NUMERIC_FEATURES,
create_board_tensor,
)
+
+ACTIONS_ARRAY = get_action_array((Color.BLUE, Color.RED), "BASE")
+ACTION_SPACE_SIZE = len(ACTIONS_ARRAY)
+
# from catanatron_experimental.rep_b_model import build_model
# Taken from correlation analysis
@@ -94,6 +99,8 @@ def p_model_path(version):
def get_v_model(model_path):
global V_MODEL
if V_MODEL is None:
+ import autokeras as ak
+
custom_objects = None if model_path[:2] != "ak" else ak.CUSTOM_OBJECTS
V_MODEL = keras.models.load_model(model_path, custom_objects=custom_objects)
return V_MODEL
@@ -108,29 +115,12 @@ def get_t_model(model_path):
def hot_one_encode_action(action):
- normalized = normalize_action(action)
- index = ACTIONS_ARRAY.index((normalized.action_type, normalized.value))
+ index = ACTIONS_ARRAY.index((action.action_type, action.value))
vector = np.zeros(ACTION_SPACE_SIZE, dtype=int)
vector[index] = 1
return vector
-def normalize_action(action):
- normalized = action
- if normalized.action_type == ActionType.ROLL:
- return Action(action.color, action.action_type, None)
- elif normalized.action_type == ActionType.MOVE_ROBBER:
- return Action(action.color, action.action_type, action.value[0])
- elif normalized.action_type == ActionType.BUILD_ROAD:
- return Action(action.color, action.action_type, tuple(sorted(action.value)))
- elif normalized.action_type == ActionType.BUY_DEVELOPMENT_CARD:
- return Action(action.color, action.action_type, None)
- elif normalized.action_type == ActionType.DISCARD:
- return Action(action.color, action.action_type, None)
-
- return normalized
-
-
class PRLPlayer(Player):
def __init__(self, color, model_path):
super(PRLPlayer, self).__init__(color)
@@ -147,8 +137,7 @@ def decide(self, game, playable_actions):
# return playable_actions[index]
# Create array like [0,0,1,0,0,0,1,...] representing possible actions
- normalized_playable = [normalize_action(a) for a in playable_actions]
- possibilities = [(a.action_type, a.value) for a in normalized_playable]
+ possibilities = [(a.action_type, a.value) for a in playable_actions]
possible_indices = [ACTIONS_ARRAY.index(x) for x in possibilities]
mask = np.zeros(ACTION_SPACE_SIZE, dtype=np.int)
mask[possible_indices] = 1
diff --git a/examples/ppo/.gitignore b/examples/ppo/.gitignore
new file mode 100644
index 000000000..2fdc2308f
--- /dev/null
+++ b/examples/ppo/.gitignore
@@ -0,0 +1,6 @@
+tensorboard_logs
+checkpoints
+wandb
+videos
+models
+runs
diff --git a/examples/ppo/evaluate.py b/examples/ppo/evaluate.py
new file mode 100644
index 000000000..99f0f8a19
--- /dev/null
+++ b/examples/ppo/evaluate.py
@@ -0,0 +1,261 @@
+"""
+Evaluate a trained PPO agent against ValueFunctionPlayer using the simulator.
+
+This script uses the play_batch method to run games in the simulator
+(not the gym environment), which is much faster for evaluation.
+
+Usage:
+ # Evaluate final model:
+ python evaluate.py --model-path checkpoints/final_model.zip
+
+ # Evaluate with custom number of games:
+ python evaluate.py --model-path checkpoints/final_model.zip --num-games 1000
+
+ # Evaluate from a specific checkpoint:
+ python evaluate.py --model-path checkpoints/rl_model_50000_steps.zip
+"""
+
+import argparse
+import os
+
+import numpy as np
+from sb3_contrib.ppo_mask import MaskablePPO
+from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
+
+from catanatron.features import create_sample, get_feature_ordering
+from catanatron.models.player import Player, Color
+from catanatron.players.value import ValueFunctionPlayer
+from catanatron.cli.play import play_batch, GameConfigOptions
+from catanatron.gym.envs.action_space import to_action_space, from_action_space
+
+import catanatron.gym
+from ppo_utils import autodetect_vecnormalize_path, make_catan_env
+
+
+class PPOPlayer(Player):
+ """Player that uses a trained PPO model to make decisions."""
+
+ def __init__(
+ self,
+ color,
+ model,
+ env,
+ map_type="MINI",
+ debug=False,
+ ):
+ """
+ Initialize PPO player.
+
+ Args:
+ color: Player color
+ model: Trained MaskablePPO model
+ env: Environment (DummyVecEnv or VecNormalize-wrapped)
+ map_type: Map type for action space conversion
+ player_colors: Tuple of player colors for action space conversion
+ debug: Whether to print debug information
+ """
+ super().__init__(color, is_bot=True)
+ self.model = model
+ self.env = env
+ self.map_type = map_type
+ self.debug = debug
+ self.decision_count = 0
+
+ def decide(self, game, playable_actions):
+ """
+ Use PPO model to choose an action.
+
+ Args:
+ game: Current game state
+ playable_actions: List of valid actions
+
+ Returns:
+ Selected action from playable_actions
+ """
+ # Get observation from the environment
+ features = get_feature_ordering(len(game.state.colors), self.map_type)
+ sample = create_sample(game, self.color)
+ obs = np.array([sample[i] for i in features], dtype=np.float32)
+
+ # Normalize observation if using VecNormalize
+ # VecNormalize expects batched observations, so we need to add batch dimension
+ if isinstance(self.env, VecNormalize):
+ obs = np.expand_dims(obs, axis=0) # Add batch dimension
+ obs = self.env.normalize_obs(obs)
+ obs = obs[0] # Remove batch dimension
+
+ # Create action mask
+ valid_action_indices = sorted(
+ [
+ to_action_space(a, game.state.colors, self.map_type)
+ for a in playable_actions
+ ]
+ )
+ action_mask = np.zeros(self.env.action_space.n, dtype=bool)
+ action_mask[valid_action_indices] = True
+
+ # Predict action using the model
+ # Note: predict expects a single observation (not batched for single prediction)
+ action_idx, _ = self.model.predict(
+ obs, action_masks=np.array([action_mask]), deterministic=True
+ )
+
+ # Convert action index back to catan action
+ if isinstance(action_idx, np.ndarray):
+ action_idx_int = (
+ int(action_idx.item()) if action_idx.ndim == 0 else int(action_idx[0])
+ )
+ else:
+ action_idx_int = int(action_idx)
+ selected_action = from_action_space(
+ action_idx_int, self.color, game.state.colors, self.map_type
+ )
+
+ # Find and return the matching action from playable_actions
+ for action in playable_actions:
+ if (
+ action.action_type == selected_action.action_type
+ and action.value == selected_action.value
+ ):
+ if self.debug and self.decision_count < 10:
+ print(
+ f"Decision {self.decision_count}: Selected {action.action_type} (index {action_idx_int})"
+ )
+ self.decision_count += 1
+ return action
+
+ # Fallback: if model predicted invalid action, return first playable action
+ print(
+ f"Warning: Model predicted invalid action {selected_action.action_type}, using fallback"
+ )
+ if self.debug:
+ print(
+ f" Valid actions were: {[a.action_type for a in playable_actions[:5]]}"
+ )
+ self.decision_count += 1
+ return playable_actions[0]
+
+
+def evaluate_model(
+ model_path,
+ vecnorm_path=None,
+ num_games=100,
+ map_type="MINI",
+ vps_to_win=6,
+):
+ """
+ Evaluate a trained PPO model against ValueFunctionPlayer.
+
+ Args:
+ model_path: Path to trained model (.zip)
+ vecnorm_path: Path to VecNormalize stats (.pkl)
+ num_games: Number of games to play
+ map_type: Map type
+ vps_to_win: Victory points to win
+
+ Returns:
+ Dictionary with evaluation results
+ """
+ print(f"\nEvaluating PPO model: {model_path}")
+ print(f"Playing {num_games} games against ValueFunctionPlayer...")
+
+ # Load model
+ config = {
+ "map_type": map_type,
+ "vps_to_win": vps_to_win,
+ "use_shaped_reward": True,
+ }
+ temp_env = DummyVecEnv([lambda: make_catan_env(config)])
+
+ if vecnorm_path and os.path.exists(vecnorm_path):
+ print(f"Loading VecNormalize stats from: {vecnorm_path}")
+ temp_env = VecNormalize.load(vecnorm_path, temp_env)
+ temp_env.training = False
+ temp_env.norm_reward = False
+
+ print(f"Loading model from: {model_path}")
+ model = MaskablePPO.load(model_path, env=temp_env)
+
+ # Create players
+ ppo_player = PPOPlayer(
+ Color.BLUE,
+ model,
+ temp_env,
+ map_type=map_type,
+ debug=True, # Enable debug output
+ )
+ value_player = ValueFunctionPlayer(Color.RED)
+ players = [ppo_player, value_player]
+
+ game_config = GameConfigOptions(
+ map_type=map_type,
+ vps_to_win=vps_to_win,
+ )
+
+ play_batch(
+ num_games=num_games,
+ players=players,
+ game_config=game_config,
+ quiet=False,
+ )
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Evaluate a trained PPO agent against ValueFunctionPlayer"
+ )
+ parser.add_argument(
+ "--model-path",
+ type=str,
+ required=True,
+ help="Path to trained model file (.zip)",
+ )
+ parser.add_argument(
+ "--vecnorm-path",
+ type=str,
+ help="Path to VecNormalize stats (.pkl). If not provided, will auto-detect.",
+ )
+ parser.add_argument(
+ "--num-games",
+ type=int,
+ default=100,
+ help="Number of games to play (default: 100)",
+ )
+ parser.add_argument(
+ "--map-type",
+ type=str,
+ default="MINI",
+ help="Map type (default: MINI)",
+ )
+ parser.add_argument(
+ "--vps-to-win",
+ type=int,
+ default=6,
+ help="Victory points to win (default: 6)",
+ )
+
+ args = parser.parse_args()
+
+ # Validate model path exists
+ if not os.path.exists(args.model_path):
+ parser.error(f"Model file not found: {args.model_path}")
+
+ # Auto-detect VecNormalize stats if not provided
+ vecnorm_path, auto_detected = autodetect_vecnormalize_path(
+ args.model_path, args.vecnorm_path
+ )
+ if auto_detected:
+ print(f"Auto-detected VecNormalize stats: {vecnorm_path}")
+
+ # Run evaluation
+ evaluate_model(
+ model_path=args.model_path,
+ vecnorm_path=vecnorm_path,
+ num_games=args.num_games,
+ map_type=args.map_type,
+ vps_to_win=args.vps_to_win,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/ppo/ppo_utils.py b/examples/ppo/ppo_utils.py
new file mode 100644
index 000000000..f89006634
--- /dev/null
+++ b/examples/ppo/ppo_utils.py
@@ -0,0 +1,66 @@
+"""
+Shared utilities for creating Catanatron environments.
+"""
+
+from pathlib import Path
+
+import gymnasium
+from sb3_contrib.common.wrappers import ActionMasker
+
+import catanatron.gym
+from catanatron import Color
+from catanatron.players.value import ValueFunctionPlayer
+from catanatron.gym.envs.catanatron_env import simple_reward
+from shaped_reward import ShapedRewardFunction
+
+
+def autodetect_vecnormalize_path(model_path, vecnorm_path=None):
+ """
+ Resolve VecNormalize stats path from args or a matching file next to the model.
+
+ Returns:
+ Tuple of (vecnorm_path or None, auto_detected_bool).
+ """
+ if vecnorm_path:
+ return vecnorm_path, False
+
+ model_path = Path(model_path)
+ potential_vecnorm = model_path.parent / f"{model_path.stem}_vecnormalize.pkl"
+ if potential_vecnorm.exists():
+ return str(potential_vecnorm), True
+
+ return None, False
+
+
+def make_catan_env(config):
+ """
+ Factory function to create a Catan environment for vectorization.
+
+ Args:
+ config: Dictionary with environment configuration:
+ - map_type: Map type for Catan (BASE, MINI, etc.)
+ - vps_to_win: Victory points needed to win
+ - use_shaped_reward: Whether to use shaped reward function
+ - render_mode: Render mode (optional, defaults to "rgb_array")
+
+ Returns:
+ Wrapped Catan environment
+ """
+ reward_fn = (
+ ShapedRewardFunction()
+ if config.get("use_shaped_reward", True)
+ else simple_reward
+ )
+
+ env = gymnasium.make(
+ "catanatron/Catanatron-v0",
+ config={
+ "map_type": config.get("map_type", "MINI"),
+ "vps_to_win": config.get("vps_to_win", 6),
+ "enemies": [ValueFunctionPlayer(Color.RED)],
+ "reward_function": reward_fn,
+ "render_mode": config.get("render_mode", "rgb_array"),
+ },
+ )
+ env = ActionMasker(env, lambda env: env.unwrapped.action_masks())
+ return env
diff --git a/examples/ppo/record_video.py b/examples/ppo/record_video.py
new file mode 100644
index 000000000..a636589e4
--- /dev/null
+++ b/examples/ppo/record_video.py
@@ -0,0 +1,154 @@
+"""
+Record videos of a trained PPO agent playing Catanatron.
+
+Usage:
+ # Record from a local model:
+ python record_video.py --model-path checkpoints/final_model.zip
+
+ # Auto-detect VecNormalize stats (looks for matching .pkl file):
+ python record_video.py --model-path checkpoints/rl_model_50000_steps.zip
+
+ # Custom number of episodes:
+ python record_video.py --model-path checkpoints/final_model.zip --num-episodes 10
+
+ # Custom output directory:
+ python record_video.py --model-path checkpoints/final_model.zip --output-dir my_videos/
+"""
+
+import argparse
+import os
+from pathlib import Path
+
+import numpy as np
+from sb3_contrib.ppo_mask import MaskablePPO
+from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize, VecVideoRecorder
+
+import catanatron.gym
+from ppo_utils import autodetect_vecnormalize_path, make_catan_env
+
+
+def record_videos(
+ model_path,
+ vecnorm_path=None,
+ config=None,
+ output_dir="videos",
+ num_episodes=3,
+):
+ """Record videos of the agent playing and save to local filesystem."""
+ # Default config if not provided
+ if config is None:
+ config = {
+ "map_type": "MINI",
+ "vps_to_win": 6,
+ "use_shaped_reward": True,
+ }
+
+ print(f"\nRecording {num_episodes} episodes...")
+ print(f"Model: {model_path}")
+ if vecnorm_path:
+ print(f"VecNormalize stats: {vecnorm_path}")
+
+ # Create environment
+ env = DummyVecEnv([lambda: make_catan_env(config)])
+
+ # Load VecNormalize stats if available
+ if vecnorm_path and os.path.exists(vecnorm_path):
+ env = VecNormalize.load(vecnorm_path, env)
+ env.training = False # Don't update normalization stats
+ env.norm_reward = False # Don't normalize rewards during evaluation
+
+ # Wrap with video recorder
+ os.makedirs(output_dir, exist_ok=True)
+ env = VecVideoRecorder(
+ env,
+ output_dir,
+ record_video_trigger=lambda x: x % 1 == 0, # Record every episode
+ video_length=500, # Max steps per video
+ name_prefix="catanatron",
+ )
+
+ # Load model
+ print("\nLoading model...")
+ model = MaskablePPO.load(model_path, env=env)
+
+ # Run episodes
+ obs = env.reset()
+ episode_count = 0
+ total_reward = 0
+ episode_rewards = []
+
+ while episode_count < num_episodes:
+ action_masks = np.array([env.envs[0].unwrapped.action_masks()])
+ action, _ = model.predict(obs, action_masks=action_masks, deterministic=True)
+ obs, reward, done, info = env.step(action)
+ total_reward += reward[0]
+
+ if done[0]:
+ episode_count += 1
+ episode_rewards.append(total_reward)
+ print(f" Episode {episode_count}: Reward = {total_reward:.2f}")
+ total_reward = 0
+
+ env.close()
+
+ # Find generated videos
+ video_dir = Path(output_dir)
+ video_files = sorted(video_dir.glob("*.mp4"))
+
+ print(f"\nRecorded {len(video_files)} videos to {output_dir}/")
+ print(
+ f"Average reward: {np.mean(episode_rewards):.2f} ± {np.std(episode_rewards):.2f}"
+ )
+
+ return video_files
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Record videos of a trained PPO agent playing Catanatron"
+ )
+ parser.add_argument(
+ "--model-path",
+ type=str,
+ required=True,
+ help="Path to local model file (.zip)",
+ )
+ parser.add_argument(
+ "--vecnorm-path",
+ type=str,
+ help="Path to VecNormalize stats (.pkl). If not provided, will auto-detect.",
+ )
+ parser.add_argument(
+ "--output-dir", type=str, default="videos", help="Directory to save videos"
+ )
+ parser.add_argument(
+ "--num-episodes", type=int, default=3, help="Number of episodes to record"
+ )
+
+ args = parser.parse_args()
+
+ # Validate model path exists
+ if not os.path.exists(args.model_path):
+ parser.error(f"Model file not found: {args.model_path}")
+
+ # Auto-detect VecNormalize stats if not provided
+ vecnorm_path, auto_detected = autodetect_vecnormalize_path(
+ args.model_path, args.vecnorm_path
+ )
+ if auto_detected:
+ print(f"Auto-detected VecNormalize stats: {vecnorm_path}")
+ elif not vecnorm_path:
+ print("No VecNormalize stats found (will train without normalization)")
+
+ # Record videos
+ record_videos(
+ model_path=args.model_path,
+ vecnorm_path=vecnorm_path,
+ config=None, # Use defaults
+ output_dir=args.output_dir,
+ num_episodes=args.num_episodes,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/ppo/shaped_reward.py b/examples/ppo/shaped_reward.py
new file mode 100644
index 000000000..2aea336d6
--- /dev/null
+++ b/examples/ppo/shaped_reward.py
@@ -0,0 +1,94 @@
+"""
+Shaped reward function for Catan that provides incremental rewards.
+
+Instead of only rewarding at the end (+1 win, -1 loss), this gives
+partial credit for progress during the game.
+"""
+
+from catanatron.state_functions import (
+ get_actual_victory_points,
+ get_longest_road_color,
+ get_largest_army,
+)
+
+
+class ShapedRewardFunction:
+ """
+ Reward function that gives incremental rewards for game progress.
+
+ Rewards:
+ - Victory point gain: +1.0 per VP
+ - Winning: +10.0 bonus
+ - Losing: -10.0 penalty
+ - Longest road acquired: +0.5
+ - Largest army acquired: +0.5
+ """
+
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ """Reset tracked state for a new game."""
+ self.prev_vp = 0
+ self.prev_has_longest_road = False
+ self.prev_has_largest_army = False
+
+ def __call__(self, action, game, p0_color):
+ """
+ Compute reward for the current step.
+
+ Args:
+ action: The action taken
+ game: The game object
+ p0_color: Player 0's color (BLUE)
+
+ Returns:
+ float: The reward for this step
+ """
+ state = game.state
+ reward = 0.0
+
+ # Get current metrics
+ current_vp = get_actual_victory_points(state, p0_color)
+ longest_road_color = get_longest_road_color(state)
+ largest_army_color, _ = get_largest_army(state)
+
+ has_longest_road = longest_road_color == p0_color
+ has_largest_army = largest_army_color == p0_color
+
+ # Reward for VP gain
+ vp_gain = current_vp - self.prev_vp
+ reward += vp_gain * 1.0
+
+ # Reward for acquiring longest road
+ if has_longest_road and not self.prev_has_longest_road:
+ reward += 0.5
+ elif not has_longest_road and self.prev_has_longest_road:
+ reward -= 0.5 # Lost longest road
+
+ # Reward for acquiring largest army
+ if has_largest_army and not self.prev_has_largest_army:
+ reward += 0.5
+ elif not has_largest_army and self.prev_has_largest_army:
+ reward -= 0.5 # Lost largest army
+
+ # Check for game end
+ winning_color = game.winning_color()
+ if winning_color is not None:
+ if p0_color == winning_color:
+ reward += 10.0 # Big bonus for winning
+ else:
+ reward -= 10.0 # Penalty for losing
+ # Reset for next game
+ self.reset()
+ else:
+ # Update tracked state for next step
+ self.prev_vp = current_vp
+ self.prev_has_longest_road = has_longest_road
+ self.prev_has_largest_army = has_largest_army
+
+ return reward
+
+
+# Create a singleton instance to use
+shaped_reward = ShapedRewardFunction()
diff --git a/examples/ppo/train.py b/examples/ppo/train.py
new file mode 100644
index 000000000..12c5e6c31
--- /dev/null
+++ b/examples/ppo/train.py
@@ -0,0 +1,315 @@
+"""
+Restartable Stable Baselines3 training example with Weights & Biases logging.
+
+Features:
+- Vectorized environments (parallel games for speedup)
+- VecNormalize (observation and reward normalization)
+- Shaped reward function (incremental rewards for progress)
+- Linear learning rate schedule (decreases over time)
+- GPU support (automatic detection)
+- Checkpoint saving/loading for resumable training
+- Weights & Biases integration for monitoring
+
+Usage:
+ # Start new training:
+ python train.py
+
+ # Resume from checkpoint:
+ python train.py --resume
+
+ # Run a wandb sweep:
+ python train.py --sweep
+
+ # Train for custom timesteps:
+ python train.py --timesteps 500000
+"""
+
+import random
+import argparse
+import os
+from pathlib import Path
+
+import numpy as np
+import torch
+import wandb
+from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
+from sb3_contrib.ppo_mask import MaskablePPO
+from stable_baselines3.common.callbacks import CheckpointCallback
+from stable_baselines3.common.env_util import make_vec_env
+from stable_baselines3.common.vec_env import SubprocVecEnv, VecNormalize
+from wandb.integration.sb3 import WandbCallback
+
+import catanatron.gym # noqa: F401
+from ppo_utils import make_catan_env
+
+
+# Configuration object (compatible with wandb)
+DEFAULT_CONFIG = {
+ # Environment parameters
+ "map_type": "MINI", # Map type for Catan (BASE, MINI, etc.)
+ "vps_to_win": 6, # Victory points needed to win
+ "use_shaped_reward": True, # Use shaped vs simple reward function
+ # PPO hyperparameters
+ "n_envs": 8, # Number of parallel environments
+ "n_steps": 1024, # Number of steps to collect before update
+ "batch_size": 128, # Batch size for training
+ "n_epochs": 10, # Number of epochs for PPO update
+ "gamma": 0.99, # Discount factor for future rewards
+ "initial_lr": 0.01, # Initial learning rate
+ "lr_decay_orders": 1, # Orders of magnitude to decay to (final = initial / 10^orders)
+ "ent_coef": 0.01, # Entropy coefficient for exploration
+ # Network architecture
+ "num_layers": 3, # Number of hidden layers
+ "neurons_per_layer": 256, # Neurons in each layer
+ # Training parameters
+ "seed": 42, # Random seed
+ "checkpoint_freq": 10_000, # Save checkpoint every N steps
+}
+
+# Wandb sweep configuration
+SWEEP_CONFIG = {
+ "method": "bayes",
+ "metric": {"name": "rollout/ep_rew_mean", "goal": "maximize"},
+ "parameters": {
+ "batch_size": {"values": [128, 256, 512]},
+ "n_steps": {"values": [128, 256, 512, 1024, 2048]},
+ "n_epochs": {"values": [5, 10]},
+ "gamma": {"values": [0.95, 0.99, 0.999]},
+ "initial_lr": {"values": [0.01, 0.001, 0.0001, 0.00001]},
+ "lr_decay_orders": {"values": [3, 4, 5]},
+ "ent_coef": {"values": [0.0, 0.005, 0.01, 0.02]},
+ "num_layers": {"values": [1, 3, 5, 10, 20]},
+ "neurons_per_layer": {"values": [64, 128, 256, 512, 1024]},
+ },
+}
+
+# Directories
+CHECKPOINT_DIR = os.path.join(os.path.dirname(__file__), "checkpoints")
+
+
+def train_model(run, args, cfg):
+ """Execute the training loop."""
+ # Validate configuration
+ assert (cfg["n_envs"] * cfg["n_steps"]) % cfg["batch_size"] == 0, (
+ "BATCH_SIZE must divide N_ENVS * N_STEPS"
+ )
+
+ # Set random seeds for reproducibility
+ random.seed(cfg["seed"])
+ np.random.seed(cfg["seed"])
+ torch.manual_seed(cfg["seed"])
+ torch.cuda.manual_seed_all(cfg["seed"])
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+ # Set device (GPU if available)
+ device = (
+ "cuda"
+ if torch.cuda.is_available()
+ else ("mps" if torch.backends.mps.is_available() else "cpu")
+ )
+ print(f"Using device: {device}")
+
+ # Create vectorized environments
+ print(
+ f"Using reward function: {'shaped (incremental)' if cfg['use_shaped_reward'] else 'simple (sparse)'}"
+ )
+ print(f"Using map type: {cfg['map_type']}, VPs to win: {cfg['vps_to_win']}")
+ print(f"Creating {cfg['n_envs']} parallel environments...")
+ env = make_vec_env(
+ lambda: make_catan_env(cfg),
+ n_envs=cfg["n_envs"],
+ seed=cfg["seed"],
+ vec_env_cls=SubprocVecEnv, # Use subprocesses for CPU-heavy environments
+ )
+
+ # Wrap with VecNormalize for observation and reward normalization
+ print("Wrapping environments with VecNormalize (obs + reward normalization)")
+ env = VecNormalize(
+ env,
+ norm_obs=True, # Normalize observations
+ norm_reward=False, # Normalize rewards
+ clip_obs=10.0, # Clip observations to [-10, 10] after normalization
+ clip_reward=10.0, # Clip rewards to [-10, 10] after normalization
+ )
+
+ # Load or create model
+ # Create directories
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
+ if args.resume:
+ checkpoint = get_latest_checkpoint()
+ if checkpoint:
+ print(f"Loading checkpoint: {checkpoint}")
+
+ # Load VecNormalize stats if available
+ vec_normalize_path = checkpoint.replace(".zip", "_vecnormalize.pkl")
+ if os.path.exists(vec_normalize_path):
+ print(f"Loading VecNormalize stats from: {vec_normalize_path}")
+ env = VecNormalize.load(vec_normalize_path, env)
+
+ model = MaskablePPO.load(
+ checkpoint,
+ env=env,
+ device=device,
+ tensorboard_log=f"runs/{run.id}",
+ )
+ else:
+ print("No checkpoint found, starting fresh")
+ args.resume = False
+
+ if not args.resume:
+ # Configure network architecture
+ net_arch = [cfg["neurons_per_layer"]] * cfg["num_layers"]
+ policy_kwargs = dict(net_arch=net_arch)
+
+ # Create learning rate schedule
+ final_lr = compute_final_lr(cfg["initial_lr"], cfg["lr_decay_orders"])
+ lr_schedule = linear_schedule(cfg["initial_lr"], final_lr)
+
+ print(f"Creating new model with architecture: {net_arch}")
+ print(
+ f"PPO config: n_steps={cfg['n_steps']}, batch_size={cfg['batch_size']}, n_epochs={cfg['n_epochs']}, gamma={cfg['gamma']}, ent_coef={cfg['ent_coef']}"
+ )
+ print(f"Learning rate schedule: {cfg['initial_lr']:.2e} → {final_lr:.2e}")
+ model = MaskablePPO(
+ MaskableActorCriticPolicy,
+ env,
+ learning_rate=lr_schedule,
+ n_steps=cfg["n_steps"],
+ batch_size=cfg["batch_size"],
+ gamma=cfg["gamma"],
+ n_epochs=cfg["n_epochs"],
+ ent_coef=cfg["ent_coef"],
+ verbose=1,
+ tensorboard_log=f"runs/{run.id}",
+ seed=cfg["seed"],
+ policy_kwargs=policy_kwargs,
+ device=device,
+ )
+
+ # Setup checkpoint callback (also saves VecNormalize stats)
+ checkpoint_callback = VecNormalizeCheckpointCallback(
+ save_freq=cfg["checkpoint_freq"],
+ save_path=CHECKPOINT_DIR,
+ name_prefix="rl_model",
+ )
+
+ # Setup wandb callback
+ wandb_callback = WandbCallback(
+ gradient_save_freq=100,
+ model_save_path=f"models/{run.id}",
+ verbose=2,
+ )
+
+ # Train
+ print(f"\nTraining for {args.timesteps:,} timesteps")
+ print(f"With {cfg['n_envs']} parallel environments (~{cfg['n_envs']}x speedup)")
+ print(f"Wandb run: {run.url}\n")
+
+ model.learn(
+ total_timesteps=args.timesteps,
+ callback=[checkpoint_callback, wandb_callback],
+ reset_num_timesteps=not args.resume,
+ )
+
+ # Save final model and VecNormalize stats
+ final_path = os.path.join(CHECKPOINT_DIR, "final_model.zip")
+ model.save(final_path)
+
+ vec_normalize_path = os.path.join(CHECKPOINT_DIR, "final_model_vecnormalize.pkl")
+ env.save(vec_normalize_path)
+
+ print(f"\nDone! Final model: {final_path}")
+ print(f"VecNormalize stats: {vec_normalize_path}")
+
+ # Clean up
+ env.close()
+
+
+def main():
+ """Main entry point."""
+ # Parse arguments
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--resume", action="store_true", help="Resume from checkpoint")
+ parser.add_argument(
+ "--sweep",
+ action="store_true",
+ help="Run a wandb sweep over the sweep configuration",
+ )
+ parser.add_argument(
+ "--timesteps",
+ type=int,
+ default=100_000,
+ help=f"Number of timesteps to train for (default: {100_000:,})",
+ )
+ args = parser.parse_args()
+
+ # Login to wandb
+ wandb.login()
+
+ def run_training():
+ with wandb.init(
+ project="catan-ppo",
+ config=DEFAULT_CONFIG,
+ sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
+ save_code=True, # save the code
+ ) as run:
+ train_model(run, args, dict(wandb.config))
+
+ if args.sweep:
+ sweep_id = wandb.sweep(SWEEP_CONFIG, project="catan-ppo")
+ wandb.agent(sweep_id, function=run_training)
+ else:
+ run_training()
+
+
+def get_latest_checkpoint():
+ """Find the most recent checkpoint."""
+ checkpoint_path = Path(CHECKPOINT_DIR)
+ if not checkpoint_path.exists():
+ return None
+
+ checkpoints = list(checkpoint_path.glob("rl_model_*_steps.zip"))
+ if not checkpoints:
+ return None
+
+ # Get latest by timestep number
+ latest = max(checkpoints, key=lambda p: int(p.stem.split("_")[2]))
+ return str(latest)
+
+
+def linear_schedule(initial_value, final_value):
+ def schedule(progress_remaining):
+ # progress_remaining goes from 1.0 (start) to 0.0 (end)
+ return final_value + progress_remaining * (initial_value - final_value)
+
+ return schedule
+
+
+def compute_final_lr(initial_lr, lr_decay_orders):
+ return initial_lr / (10**lr_decay_orders)
+
+
+class VecNormalizeCheckpointCallback(CheckpointCallback):
+ """Custom checkpoint callback that also saves VecNormalize statistics."""
+
+ def _on_step(self) -> bool:
+ # Save model checkpoint (parent class behavior)
+ result = super()._on_step()
+
+ # Also save VecNormalize stats if the model was just saved
+ if result and isinstance(self.model.get_env(), VecNormalize):
+ # Get the path of the most recently saved model
+ checkpoint_path = os.path.join(
+ self.save_path, f"{self.name_prefix}_{self.num_timesteps}_steps.zip"
+ )
+ vec_normalize_path = checkpoint_path.replace(".zip", "_vecnormalize.pkl")
+
+ # Save VecNormalize statistics
+ self.model.get_env().save(vec_normalize_path)
+
+ return result
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/render_example.py b/examples/render_example.py
new file mode 100644
index 000000000..ad4ece848
--- /dev/null
+++ b/examples/render_example.py
@@ -0,0 +1,120 @@
+"""
+Example script that plays a game and renders it with video recording.
+
+Uses gymnasium's RecordVideo wrapper to automatically record gameplay videos.
+
+Usage:
+ # Record a game to video
+ python render_example.py --render-style video
+
+ # Record to custom folder
+ python render_example.py --render-style video --video-folder ./my_videos
+
+ # Save a replay to DB (prints link on completion)
+ python render_example.py --render-style db
+"""
+
+import random
+import argparse
+import gymnasium
+from gymnasium.wrappers import RecordVideo
+
+import catanatron.gym # Register the environment
+from catanatron import Color
+from catanatron.players.value import ValueFunctionPlayer
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Play and record Catanatron games")
+ parser.add_argument(
+ "--render-style",
+ type=str,
+ default="video",
+ choices=["video", "db"],
+ help="Render style to use: video (rgb_array) or db (replay link).",
+ )
+ parser.add_argument(
+ "--video-folder",
+ type=str,
+ default="./videos",
+ help="Folder to save videos (default: ./videos)",
+ )
+ parser.add_argument(
+ "--map-type",
+ type=str,
+ default="BASE",
+ choices=["BASE", "MINI"],
+ help="Map type to use (default: MINI)",
+ )
+ parser.add_argument(
+ "--vps-to-win",
+ type=int,
+ default=10,
+ help="Victory points needed to win (default: 6)",
+ )
+ args = parser.parse_args()
+
+ render_mode = "rgb_array" if args.render_style == "video" else "db"
+
+ # Create env with rendering enabled
+ env = gymnasium.make(
+ "catanatron/Catanatron-v0",
+ config={
+ "render_mode": render_mode,
+ "render_scale": 2.0,
+ "map_type": args.map_type,
+ "vps_to_win": args.vps_to_win,
+ "enemies": [
+ ValueFunctionPlayer(Color.RED),
+ ValueFunctionPlayer(Color.ORANGE),
+ ValueFunctionPlayer(Color.WHITE),
+ ],
+ },
+ )
+
+ # Wrap with RecordVideo when rendering video
+ if args.render_style == "video":
+ env = RecordVideo(
+ env,
+ video_folder=args.video_folder,
+ name_prefix="catan-game",
+ episode_trigger=lambda x: True, # Record every episode
+ )
+ print(f"Recording video to: {args.video_folder}")
+ else:
+ print("Saving replay to DB (render() will write steps).")
+
+ observation, info = env.reset()
+ done = False
+ step = 0
+
+ print(f"Starting game with {args.map_type} map, {args.vps_to_win} VPs to win...")
+
+ while not done:
+ # Get valid actions
+ valid_actions = info["valid_actions"]
+
+ # Take first valid action (random)
+ action = random.choice(valid_actions)
+
+ # Step environment
+ observation, reward, terminated, truncated, info = env.step(action)
+ done = terminated or truncated
+
+ if args.render_style == "db":
+ env.render()
+
+ step += 1
+ if step % 10 == 0:
+ print(f"Step {step}...")
+
+ print(f"Game finished after {step} steps")
+
+ if args.render_style == "video":
+ print(f"Video saved to {args.video_folder}")
+
+ env.close()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/pyproject.toml b/pyproject.toml
index ba06d1be2..c01d5b6e1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -19,7 +19,7 @@ classifiers = [
dependencies = ["networkx", "click", "rich"]
[project.optional-dependencies]
-gym = ["gymnasium<=0.29.1", "numpy", "pandas", "fastparquet"]
+gym = ["gymnasium<=0.29.1", "numpy", "pandas", "fastparquet", "pygame"]
web = [
"gunicorn",
"flask",
diff --git a/tests/integration_tests/test_play.py b/tests/integration_tests/test_play.py
index e360d5506..84c85348f 100644
--- a/tests/integration_tests/test_play.py
+++ b/tests/integration_tests/test_play.py
@@ -62,3 +62,32 @@ def test_csv_play():
assert len(board_tensors_df) == num_samples
assert len(main_df) == num_samples
assert len(rewards_df) == num_samples
+
+
+def test_parquet_output():
+ runner = CliRunner()
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ result = runner.invoke(
+ simulate,
+ [
+ "--num=1",
+ "--players=F,F",
+ "--output",
+ tmpdirname,
+ "--output-format",
+ "parquet",
+ ],
+ )
+ assert result.exit_code == 0
+
+ # Assert 1 parquet file is created in tmpdirname
+ files = os.listdir(tmpdirname)
+ assert len(files) == 1
+
+ file = files[0]
+ assert file.endswith(".parquet")
+ df = pd.read_parquet(os.path.join(tmpdirname, file))
+
+ assert "F_BANK_BRICK" in df.columns
+ assert "RETURN" in df.columns
+ assert "ACTION" in df.columns
diff --git a/tests/integration_tests/test_replay.py b/tests/integration_tests/test_replay.py
index 0fc43a1b1..e358ea1ad 100644
--- a/tests/integration_tests/test_replay.py
+++ b/tests/integration_tests/test_replay.py
@@ -73,3 +73,22 @@ def test_execute_action_on_copies_doesnt_conflict():
game_copy.execute(action)
game.execute(action)
+
+
+def test_seed_reproducibility():
+ # Play 10 games with the same seed, assert the action logs look the same
+ players = [
+ RandomPlayer(Color.RED),
+ RandomPlayer(Color.BLUE),
+ RandomPlayer(Color.WHITE),
+ RandomPlayer(Color.ORANGE),
+ ]
+ game = Game(players, seed=123)
+ game.play()
+ game_json = json.dumps(game, cls=GameEncoder)
+
+ for i in range(10):
+ game = Game(players, seed=123)
+ game.play()
+ game_json_copy = json.dumps(game, cls=GameEncoder)
+ assert game_json == game_json_copy
diff --git a/tests/test_gym.py b/tests/test_gym.py
index f7b5dfbe6..77f23095e 100644
--- a/tests/test_gym.py
+++ b/tests/test_gym.py
@@ -1,13 +1,21 @@
import random
+import json
import gymnasium
from gymnasium.utils.env_checker import check_env
import numpy as np
from catanatron.features import get_feature_ordering
+from catanatron.json import GameEncoder
from catanatron.models.player import Color, RandomPlayer
+from catanatron.models.enums import Action, ActionType, WHEAT, SHEEP, ORE
from catanatron.players.value import ValueFunctionPlayer
from catanatron.gym.envs.catanatron_env import CatanatronEnv
+from catanatron.gym.envs.action_space import (
+ get_action_array,
+ to_action_space,
+ from_action_space,
+)
features = get_feature_ordering(2)
@@ -84,7 +92,7 @@ def test_invalid_action_reward():
def test_custom_reward():
- def custom_reward(game, p0_color):
+ def custom_reward(action, game, p0_color):
return 123
env = gymnasium.make(
@@ -139,3 +147,138 @@ def test_mixed_rep():
observation, info = env.reset()
assert "board" in observation
assert "numeric" in observation
+
+
+def test_render_rgb_array():
+ env = gymnasium.make(
+ "catanatron/Catanatron-v0", config={"render_mode": "rgb_array"}
+ )
+ env.reset()
+ frame = env.render()
+ assert isinstance(frame, np.ndarray)
+ assert frame.shape == (800, 1000, 3)
+ assert frame.dtype == np.uint8
+ env.close()
+
+
+def test_move_robber_action_in_base_action_array():
+ """Test that a specific MOVE_ROBBER action is in the BASE action array for 2 players."""
+ player_colors = (Color.BLUE, Color.RED)
+ action_array = get_action_array(player_colors, "BASE")
+ target_action = (ActionType.MOVE_ROBBER, ((-1, 0, 1), Color.BLUE))
+ assert target_action in action_array, (
+ f"Action {target_action} not found in BASE action array for 2 players"
+ )
+
+ target_action = (ActionType.MOVE_ROBBER, ((-1, 0, 1), None))
+ assert target_action in action_array, (
+ f"Action {target_action} not found in BASE action array for 2 players"
+ )
+
+
+def test_there_are_54_build_nodes_in_base():
+ player_colors = (Color.BLUE, Color.RED)
+ action_array = get_action_array(player_colors, "BASE")
+ num_build_nodes = len(
+ [action for action in action_array if action[0] == ActionType.BUILD_SETTLEMENT]
+ )
+ assert num_build_nodes == 54
+
+
+def test_there_are_less_build_nodes_in_mini():
+ player_colors = (Color.BLUE, Color.RED)
+ action_array = get_action_array(player_colors, "MINI")
+ num_build_nodes = len(
+ [action for action in action_array if action[0] == ActionType.BUILD_SETTLEMENT]
+ )
+ assert num_build_nodes == 24
+
+
+def test_outside_tiles_not_in_mini():
+ player_colors = (Color.BLUE, Color.RED)
+ action_array = get_action_array(player_colors, "MINI")
+ target_action = (ActionType.MOVE_ROBBER, ((0, 2, -2), Color.BLUE))
+ assert target_action not in action_array, (
+ f"Action {target_action} found in MINI action array for 2 players"
+ )
+
+
+def test_action_space_conversion_roundtrip():
+ """Test converting actions to action space integers and back."""
+ player_colors = (Color.BLUE, Color.RED)
+ map_type = "BASE"
+
+ # Create various test actions
+ # Note: For PLAY_YEAR_OF_PLENTY with 2 resources, they must be in RESOURCES order
+ # RESOURCES = ['WOOD', 'BRICK', 'SHEEP', 'WHEAT', 'ORE']
+ test_actions = [
+ Action(Color.BLUE, ActionType.ROLL, None),
+ Action(Color.BLUE, ActionType.DISCARD, None),
+ Action(Color.BLUE, ActionType.BUILD_SETTLEMENT, 10),
+ Action(Color.BLUE, ActionType.BUILD_CITY, 5),
+ Action(Color.BLUE, ActionType.BUILD_ROAD, (0, 1)),
+ Action(Color.BLUE, ActionType.BUY_DEVELOPMENT_CARD, None),
+ Action(Color.BLUE, ActionType.PLAY_KNIGHT_CARD, None),
+ Action(Color.BLUE, ActionType.PLAY_YEAR_OF_PLENTY, (SHEEP, WHEAT)),
+ Action(Color.BLUE, ActionType.PLAY_YEAR_OF_PLENTY, (WHEAT, ORE)),
+ Action(Color.BLUE, ActionType.PLAY_YEAR_OF_PLENTY, (ORE,)),
+ Action(Color.BLUE, ActionType.PLAY_MONOPOLY, WHEAT),
+ Action(Color.BLUE, ActionType.PLAY_ROAD_BUILDING, None),
+ Action(Color.BLUE, ActionType.MOVE_ROBBER, ((-1, 0, 1), Color.RED)),
+ Action(Color.BLUE, ActionType.MOVE_ROBBER, ((0, -1, 1), None)),
+ Action(
+ Color.BLUE, ActionType.MARITIME_TRADE, (WHEAT, WHEAT, WHEAT, WHEAT, ORE)
+ ),
+ Action(
+ Color.BLUE, ActionType.MARITIME_TRADE, (SHEEP, SHEEP, SHEEP, None, WHEAT)
+ ),
+ Action(Color.BLUE, ActionType.MARITIME_TRADE, (ORE, ORE, None, None, WHEAT)),
+ Action(Color.BLUE, ActionType.END_TURN, None),
+ ]
+
+ for action in test_actions:
+ # Convert to action space integer
+ action_int = to_action_space(action, player_colors, map_type)
+
+ # Convert back from action space
+ recovered_action = from_action_space(
+ action_int, action.color, player_colors, map_type
+ )
+
+ # Assert they are the same
+ assert recovered_action == action, (
+ f"Action conversion failed: {action} -> {action_int} -> {recovered_action}"
+ )
+
+
+def test_gym_reproducibility():
+ # Play a game with the same seed, and ensure the game is the same
+ env = gymnasium.make(
+ "catanatron/Catanatron-v0",
+ config={
+ "enemies": [
+ ValueFunctionPlayer(Color.RED),
+ ]
+ },
+ )
+ observation, info = env.reset(seed=123)
+ env.action_space.seed(123)
+ game = env.unwrapped.game
+ center_tile = game.state.board.map.land_tiles[(0, 0, 0)]
+ assert center_tile.resource == ORE
+ assert center_tile.number == 8
+
+ done = False
+ reward = 0
+ while not done:
+ action_mask = env.action_masks()
+ valid_indices = np.flatnonzero(action_mask)
+ action = random.choice(valid_indices)
+
+ observation, reward, terminated, truncated, info = env.step(action)
+ done = terminated or truncated
+ game = env.unwrapped.game
+ game_json = json.loads(json.dumps(game, cls=GameEncoder))
+ env.close()
+
+ assert game_json["state_index"] == 126