Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature,Refactor] Chess improvements: fen, pgn, pixels, san #2702

Merged
merged 10 commits into from
Jan 26, 2025
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 210 additions & 28 deletions torchrl/envs/custom/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,51 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import importlib.util
import io
from typing import Dict, Optional

import torch
from PIL import Image
from tensordict import TensorDict, TensorDictBase
from torchrl.data import Categorical, Composite, NonTensor, Unbounded

from torchrl.envs import EnvBase
from torchrl.envs.common import _EnvPostInit

from torchrl.envs.utils import _classproperty


class ChessEnv(EnvBase):
class _HashMeta(_EnvPostInit):
def __call__(cls, *args, **kwargs):
instance = super().__call__(*args, **kwargs)
if kwargs.get("include_hash"):
from torchrl.envs import Hash

in_keys = []
out_keys = []
if instance.include_san:
in_keys.append("san")
out_keys.append("san_hash")
if instance.include_fen:
in_keys.append("fen")
out_keys.append("fen_hash")
if instance.include_pgn:
in_keys.append("pgn")
out_keys.append("pgn_hash")
return instance.append_transform(Hash(in_keys, out_keys))
return instance


class ChessEnv(EnvBase, metaclass=_HashMeta):
"""A chess environment that follows the TorchRL API.

Requires: the `chess` library. More info `here <https://python-chess.readthedocs.io/en/latest/>`__.

Args:
stateful (bool): Whether to keep track of the internal state of the board.
If False, the state will be stored in the observation and passed back
to the environment on each call. Default: ``False``.
to the environment on each call. Default: ``True``.

.. note:: the action spec is a :class:`~torchrl.data.Categorical` spec with a ``-1`` shape.
Unless :meth:`~torchrl.data.Categorical.set_provisional_n` is called with the cardinality of the legal moves,
Expand Down Expand Up @@ -90,28 +115,76 @@ class ChessEnv(EnvBase):
"""

_hash_table: Dict[int, str] = {}
_PNG_RESTART = """[Event "?"]
vmoens marked this conversation as resolved.
Show resolved Hide resolved
[Site "?"]
[Date "????.??.??"]
[Round "?"]
[White "?"]
[Black "?"]
[Result "*"]

*"""

@_classproperty
def lib(cls):
try:
import chess
import chess.pgn
except ImportError:
raise ImportError(
"The `chess` library could not be found. Make sure you installed it through `pip install chess`."
)
return chess

def __init__(self, stateful: bool = False):
def __init__(
self,
*,
stateful: bool = True,
include_san: bool = False,
include_fen: bool = False,
include_pgn: bool = False,
include_hash: bool = False,
pixels: bool = False,
):
chess = self.lib
super().__init__()
self.full_observation_spec = Composite(
hashing=Unbounded(shape=(), dtype=torch.int64),
fen=NonTensor(shape=()),
turn=Categorical(n=2, dtype=torch.bool, shape=()),
)
self.include_san = include_san
self.include_fen = include_fen
self.include_pgn = include_pgn
if include_san:
self.full_observation_spec["san"] = NonTensor(shape=(), example_data="Nc6")
if include_pgn:
self.full_observation_spec["pgn"] = NonTensor(
shape=(), example_data=self._PNG_RESTART
)
if include_fen:
self.full_observation_spec["fen"] = NonTensor(shape=(), example_data="any")
if not stateful and not (include_pgn or include_fen):
raise RuntimeError(
"At least one state representation (pgn or fen) must be enabled when stateful "
f"is {stateful}."
)

self.stateful = stateful

if not self.stateful:
self.full_state_spec = self.full_observation_spec.clone()

self.pixels = pixels
if pixels:
if importlib.util.find_spec("cairosvg") is None:
raise ImportError(
"Please install cairosvg to use this environment with pixel rendering."
)
if importlib.util.find_spec("torchvision") is None:
raise ImportError(
"Please install torchvision to use this environment with pixel rendering."
)
self.full_observation_spec["pixels"] = Unbounded(shape=())

self.full_action_spec = Composite(
action=Categorical(n=-1, shape=(), dtype=torch.int64)
)
Expand All @@ -132,41 +205,126 @@ def _is_done(self, board):

def _reset(self, tensordict=None):
fen = None
pgn = None
if tensordict is not None:
fen = self._get_fen(tensordict).data
dest = tensordict.empty()
if self.include_fen:
fen = self._get_fen(tensordict).data
dest = tensordict.empty()
if self.include_pgn:
fen = self._get_pgn(tensordict).data
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
fen = self._get_pgn(tensordict).data
pgn = self._get_pgn(tensordict).data

Also _get_pgn doesn't exist, but I figure you realize that

dest = tensordict.empty()
else:
dest = TensorDict()

if fen is None:
if fen is None and pgn is None:
self.board.reset()
fen = self.board.fen()
if self.include_fen and fen is None:
fen = self.board.fen()
if self.include_pgn and pgn is None:
pgn = self._PNG_RESTART
else:
self.board.set_fen(fen)
if self._is_done(self.board):
raise ValueError(
"Cannot reset to a fen that is a gameover state." f" fen: {fen}"
)

hashing = hash(fen)
if fen is not None:
self.board.set_fen(fen)
if self._is_done(self.board):
raise ValueError(
"Cannot reset to a fen that is a gameover state." f" fen: {fen}"
)
elif pgn is not None:
self.board = self._pgn_to_board(pgn)

self._set_action_space()
turn = self.board.turn
return dest.set("fen", fen).set("hashing", hashing).set("turn", turn)
if self.include_san:
dest.set("san", "[SAN][START]")
if self.include_fen:
if fen is None:
fen = self.board.fen()
dest.set("fen", fen)
if self.include_pgn:
if pgn is None:
pgn = self._board_to_pgn(self.board)
dest.set("pgn", pgn)
dest.set("turn", turn)
if self.pixels:
dest.set("pixels", self._get_tensor_image(board=self.board))
return dest

_cairosvg_lib = None

@_classproperty
def _cairosvg(cls):
csvg = cls._cairosvg_lib
if csvg is None:
import cairosvg

csvg = cls._cairosvg_lib = cairosvg
return csvg

_torchvision_lib = None

@_classproperty
def _torchvision(cls):
tv = cls._torchvision_lib
if tv is None:
import torchvision

tv = cls._torchvision_lib = torchvision
return tv

@classmethod
def _get_tensor_image(cls, board):
try:
svg = board._repr_svg_()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could have another option that has a more minimal representation as well, if we want to train without vision. For instance, OpenSpiel's chess env has a representation where each square of the board is given a size 20 array of ones and zeros, which is somehow used to represent which piece is in that square and which color it is

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I was using this because it's builtin - but a more minimal representation could also be cool!

# Convert SVG to PNG using cairosvg
png_data = io.BytesIO()
cls._cairosvg.svg2png(bytestring=svg.encode("utf-8"), write_to=png_data)
png_data.seek(0)
# Open the PNG image using Pillow
img = Image.open(png_data)
img = cls._torchvision.transforms.functional.pil_to_tensor(img)
except ImportError:
raise ImportError(
"Chess rendering requires cairosvg and torchvision to be installed."
)
return img

def _set_action_space(self, tensordict: TensorDict | None = None):
if not self.stateful and tensordict is not None:
fen = self._get_fen(tensordict).data
self.board.set_fen(fen)
self.action_spec.set_provisional_n(self.board.legal_moves.count())

@classmethod
def _pgn_to_board(
cls, pgn_string: str, board: "chess.Board" | None = None
) -> "chess.Board":
pgn_io = io.StringIO(pgn_string)
game = cls.lib.pgn.read_game(pgn_io)
if board is None:
board = cls.Board()
else:
board.reset()
for move in game.mainline_moves():
board.push(move)
return board

@classmethod
def _board_to_pgn(cls, board: "chess.Board") -> str:
# Create a new Game object
game = cls.lib.pgn.Game()

# Add the moves to the game
node = game
for move in board.move_stack:
node = node.add_variation(move)

# Generate the PGN string
pgn_string = str(game)
return pgn_string

@classmethod
def _get_fen(cls, tensordict):
Copy link
Collaborator

@kurtamohler kurtamohler Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this function can be removed, in favor of a direct TensorDict.get call?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking that one day we might let people choose the fen / pgn key for themselves but it's true that as of now it's not incredibly useful

fen = tensordict.get("fen", None)
if fen is None:
hashing = tensordict.get("hashing", None)
if hashing is not None:
fen = cls._hash_table.get(hashing.item())
return fen

def get_legal_moves(self, tensordict=None, uci=False):
Expand Down Expand Up @@ -205,19 +363,40 @@ def _step(self, tensordict):
# action
action = tensordict.get("action")
board = self.board

if not self.stateful:
fen = self._get_fen(tensordict).data
board.set_fen(fen)
if self.include_fen:
fen = self._get_fen(tensordict).data
board.set_fen(fen)
elif self.include_pgn:
pgn = self._get_pgn(tensordict).data
self._pgn_to_board(pgn, board)
else:
raise RuntimeError(
"Not enough information to deduce the board. If stateful=False, include_pgn or include_fen must be True."
)

action = list(board.legal_moves)[action]
san = None
if self.include_san:
san = board.san(action)
board.push(action)

self._set_action_space()

# Collect data
fen = self.board.fen()
dest = tensordict.empty()
hashing = hash(fen)
dest.set("fen", fen)
dest.set("hashing", hashing)

# Collect data
if self.include_fen:
fen = board.fen()
dest.set("fen", fen)

if self.include_pgn:
pgn = self._board_to_pgn(board)
dest.set("pgn", pgn)

if san is not None:
dest.set("san", san)

turn = torch.tensor(board.turn)
if board.is_checkmate():
Expand All @@ -226,12 +405,15 @@ def _step(self, tensordict):
reward_val = 1 if winner == self.lib.WHITE else -1
else:
reward_val = 0

reward = torch.tensor([reward_val], dtype=torch.int32)
done = self._is_done(board)
dest.set("reward", reward)
dest.set("turn", turn)
dest.set("done", [done])
dest.set("terminated", [done])
if self.pixels:
dest.set("pixels", self._get_tensor_image(board=self.board))
return dest

def _set_seed(self, *args, **kwargs):
Expand Down
Loading