Skip to content

Commit 7ad50dd

Browse files
committed
Update
[ghstack-poisoned]
1 parent bbfc419 commit 7ad50dd

File tree

4 files changed

+139
-55
lines changed

4 files changed

+139
-55
lines changed

test/test_env.py

+25-5
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@
131131
from torchrl.envs.transforms.transforms import (
132132
AutoResetEnv,
133133
AutoResetTransform,
134+
Tokenizer,
134135
Transform,
135136
)
136137
from torchrl.envs.utils import (
@@ -3346,10 +3347,6 @@ def test_batched_dynamic(self, break_when_any_done):
33463347
)
33473348
del env_no_buffers
33483349
gc.collect()
3349-
# print(dummy_rollouts)
3350-
# print(rollout_no_buffers_serial)
3351-
# # for a, b in zip(dummy_rollouts.exclude("action").unbind(0), rollout_no_buffers_serial.exclude("action").unbind(0)):
3352-
# assert_allclose_td(a, b)
33533350
assert_allclose_td(
33543351
dummy_rollouts.exclude("action"),
33553352
rollout_no_buffers_serial.exclude("action"),
@@ -3463,6 +3460,8 @@ def test_env(self, stateful, include_pgn, include_fen, include_hash, include_san
34633460
include_hash=include_hash,
34643461
include_san=include_san,
34653462
)
3463+
# Because we always use mask_actions=True
3464+
assert isinstance(env, TransformedEnv)
34663465
check_env_specs(env)
34673466
if include_hash:
34683467
if include_fen:
@@ -3560,8 +3559,8 @@ def test_reset_white_to_move(self, stateful, include_pgn, include_fen):
35603559
)
35613560
fen = "5k2/4r3/8/8/8/1Q6/2K5/8 w - - 0 1"
35623561
td = env.reset(TensorDict({"fen": fen}))
3563-
assert td["fen"] == fen
35643562
if include_fen:
3563+
assert td["fen"] == fen
35653564
assert env.board.fen() == fen
35663565
assert td["turn"] == env.lib.WHITE
35673566
assert not td["done"]
@@ -3666,6 +3665,27 @@ def test_reward(
36663665
assert td["reward"] == expected_reward
36673666
assert td["turn"] == (not expected_turn)
36683667

3668+
def test_chess_tokenized(self):
3669+
env = ChessEnv(include_fen=True, stateful=True, include_san=True)
3670+
assert isinstance(env.observation_spec["fen"], NonTensor)
3671+
env = env.append_transform(
3672+
Tokenizer(in_keys=["fen"], out_keys=["fen_tokenized"])
3673+
)
3674+
assert isinstance(env.observation_spec["fen"], NonTensor)
3675+
env.transform.transform_output_spec(env.base_env.output_spec)
3676+
env.transform.transform_input_spec(env.base_env.input_spec)
3677+
r = env.rollout(10, return_contiguous=False)
3678+
assert "fen_tokenized" in r
3679+
assert "fen" in r
3680+
assert "fen_tokenized" in r["next"]
3681+
assert "fen" in r["next"]
3682+
ftd = env.fake_tensordict()
3683+
assert "fen_tokenized" in ftd
3684+
assert "fen" in ftd
3685+
assert "fen_tokenized" in ftd["next"]
3686+
assert "fen" in ftd["next"]
3687+
env.check_env_specs()
3688+
36693689

36703690
class TestCustomEnvs:
36713691
def test_tictactoe_env(self):

torchrl/data/tensor_specs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5042,7 +5042,7 @@ def zero(self, shape: torch.Size = None) -> TensorDictBase:
50425042

50435043
def __eq__(self, other):
50445044
return (
5045-
type(self) is type(other)
5045+
type(self) == type(other)
50465046
and self.shape == other.shape
50475047
and self._device == other._device
50485048
and set(self._specs.keys()) == set(other._specs.keys())

torchrl/envs/custom/chess.py

+58-46
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,20 @@
77
import importlib.util
88
import io
99
import pathlib
10-
from typing import Dict, Optional
10+
from typing import Dict
1111

1212
import torch
1313
from PIL import Image
1414
from tensordict import TensorDict, TensorDictBase
15-
from torchrl.data import Bounded, Categorical, Composite, NonTensor, Unbounded
15+
from torchrl.data import Binary, Bounded, Categorical, Composite, NonTensor, Unbounded
1616

1717
from torchrl.envs import EnvBase
1818
from torchrl.envs.common import _EnvPostInit
1919

2020
from torchrl.envs.utils import _classproperty
2121

2222

23-
class _HashMeta(_EnvPostInit):
23+
class _ChessMeta(_EnvPostInit):
2424
def __call__(cls, *args, **kwargs):
2525
instance = super().__call__(*args, **kwargs)
2626
if kwargs.get("include_hash"):
@@ -37,11 +37,15 @@ def __call__(cls, *args, **kwargs):
3737
if instance.include_pgn:
3838
in_keys.append("pgn")
3939
out_keys.append("pgn_hash")
40-
return instance.append_transform(Hash(in_keys, out_keys))
40+
instance = instance.append_transform(Hash(in_keys, out_keys))
41+
if kwargs.get("mask_actions", True):
42+
from torchrl.envs import ActionMask
43+
44+
instance = instance.append_transform(ActionMask())
4145
return instance
4246

4347

44-
class ChessEnv(EnvBase, metaclass=_HashMeta):
48+
class ChessEnv(EnvBase, metaclass=_ChessMeta):
4549
r"""A chess environment that follows the TorchRL API.
4650
4751
This environment simulates a chess game using the `chess` library. It supports various state representations
@@ -63,6 +67,8 @@ class ChessEnv(EnvBase, metaclass=_HashMeta):
6367
include_pgn (bool): Whether to include PGN (Portable Game Notation) in the observations. Default: ``False``.
6468
include_legal_moves (bool): Whether to include legal moves in the observations. Default: ``False``.
6569
include_hash (bool): Whether to include hash transformations in the environment. Default: ``False``.
70+
mask_actions (bool): if ``True``, a :class:`~torchrl.envs.ActionMask` transform will be appended
71+
to the env to make sure that the actions are properly masked. Default: ``True``.
6672
pixels (bool): Whether to include pixel-based observations of the board. Default: ``False``.
6773
6874
.. note:: The action spec is a :class:`~torchrl.data.Categorical` with a number of actions equal to the number of possible SAN moves.
@@ -202,16 +208,15 @@ def _legal_moves_to_index(
202208
) -> torch.Tensor:
203209
if not self.stateful:
204210
if tensordict is None:
205-
raise RuntimeError(
206-
"rand_action requires a tensordict when stateful is False."
207-
)
208-
if self.include_fen:
209-
fen = self._get_fen(tensordict)
211+
# trust the board
212+
pass
213+
elif self.include_fen:
214+
fen = tensordict.get("fen", None)
210215
fen = fen.data
211216
self.board.set_fen(fen)
212217
board = self.board
213218
elif self.include_pgn:
214-
pgn = self._get_pgn(tensordict)
219+
pgn = tensordict.get("pgn")
215220
pgn = pgn.data
216221
board = self._pgn_to_board(pgn, self.board)
217222

@@ -224,15 +229,19 @@ def _legal_moves_to_index(
224229
)
225230

226231
if return_mask:
227-
return torch.zeros(len(self.san_moves), dtype=torch.bool).index_fill_(
228-
0, indices, True
229-
)
232+
return self._move_index_to_mask(indices)
230233
if pad:
231234
indices = torch.nn.functional.pad(
232235
indices, [0, 218 - indices.numel() + 1], value=len(self.san_moves)
233236
)
234237
return indices
235238

239+
@classmethod
240+
def _move_index_to_mask(cls, indices: torch.Tensor) -> torch.Tensor:
241+
return torch.zeros(len(cls.san_moves), dtype=torch.bool).index_fill_(
242+
0, indices, True
243+
)
244+
236245
def __init__(
237246
self,
238247
*,
@@ -242,6 +251,7 @@ def __init__(
242251
include_pgn: bool = False,
243252
include_legal_moves: bool = False,
244253
include_hash: bool = False,
254+
mask_actions: bool = True,
245255
pixels: bool = False,
246256
):
247257
chess = self.lib
@@ -252,6 +262,7 @@ def __init__(
252262
self.include_san = include_san
253263
self.include_fen = include_fen
254264
self.include_pgn = include_pgn
265+
self.mask_actions = mask_actions
255266
self.include_legal_moves = include_legal_moves
256267
if include_legal_moves:
257268
# 218 max possible legal moves per chess board position
@@ -276,8 +287,10 @@ def __init__(
276287

277288
self.stateful = stateful
278289

279-
if not self.stateful:
280-
self.full_state_spec = self.full_observation_spec.clone()
290+
# state_spec is loosely defined as such - it's not really an issue that extra keys
291+
# can go missing but it allows us to reset the env using fen passed to the reset
292+
# method.
293+
self.full_state_spec = self.full_observation_spec.clone()
281294

282295
self.pixels = pixels
283296
if pixels:
@@ -297,16 +310,16 @@ def __init__(
297310
self.full_reward_spec = Composite(
298311
reward=Unbounded(shape=(1,), dtype=torch.float32)
299312
)
313+
if self.mask_actions:
314+
self.full_observation_spec["action_mask"] = Binary(
315+
n=len(self.san_moves), dtype=torch.bool
316+
)
317+
300318
# done spec generated automatically
301319
self.board = chess.Board()
302320
if self.stateful:
303321
self.action_spec.set_provisional_n(len(list(self.board.legal_moves)))
304322

305-
def rand_action(self, tensordict: Optional[TensorDictBase] = None):
306-
mask = self._legal_moves_to_index(tensordict, return_mask=True)
307-
self.action_spec.update_mask(mask)
308-
return super().rand_action(tensordict)
309-
310323
def _is_done(self, board):
311324
return board.is_game_over() | board.is_fifty_moves()
312325

@@ -316,11 +329,11 @@ def _reset(self, tensordict=None):
316329
if tensordict is not None:
317330
dest = tensordict.empty()
318331
if self.include_fen:
319-
fen = self._get_fen(tensordict)
332+
fen = tensordict.get("fen", None)
320333
if fen is not None:
321334
fen = fen.data
322335
elif self.include_pgn:
323-
pgn = self._get_pgn(tensordict)
336+
pgn = tensordict.get("pgn", None)
324337
if pgn is not None:
325338
pgn = pgn.data
326339
else:
@@ -360,13 +373,18 @@ def _reset(self, tensordict=None):
360373
if self.include_legal_moves:
361374
moves_idx = self._legal_moves_to_index(board=self.board, pad=True)
362375
dest.set("legal_moves", moves_idx)
376+
if self.mask_actions:
377+
dest.set("action_mask", self._move_index_to_mask(moves_idx))
378+
elif self.mask_actions:
379+
dest.set(
380+
"action_mask",
381+
self._legal_moves_to_index(
382+
board=self.board, pad=True, return_mask=True
383+
),
384+
)
385+
363386
if self.pixels:
364387
dest.set("pixels", self._get_tensor_image(board=self.board))
365-
366-
if self.stateful:
367-
mask = self._legal_moves_to_index(dest, return_mask=True)
368-
self.action_spec.update_mask(mask)
369-
370388
return dest
371389

372390
_cairosvg_lib = None
@@ -437,16 +455,6 @@ def _board_to_pgn(cls, board: "chess.Board") -> str: # noqa: F821
437455
pgn_string = str(game)
438456
return pgn_string
439457

440-
@classmethod
441-
def _get_fen(cls, tensordict):
442-
fen = tensordict.get("fen", None)
443-
return fen
444-
445-
@classmethod
446-
def _get_pgn(cls, tensordict):
447-
pgn = tensordict.get("pgn", None)
448-
return pgn
449-
450458
def get_legal_moves(self, tensordict=None, uci=False):
451459
"""List the legal moves in a position.
452460
@@ -470,7 +478,7 @@ def get_legal_moves(self, tensordict=None, uci=False):
470478
raise ValueError(
471479
"tensordict must be given since this env is not stateful"
472480
)
473-
fen = self._get_fen(tensordict).data
481+
fen = tensordict.get("fen").data
474482
board.set_fen(fen)
475483
moves = board.legal_moves
476484

@@ -488,10 +496,10 @@ def _step(self, tensordict):
488496
fen = None
489497
if not self.stateful:
490498
if self.include_fen:
491-
fen = self._get_fen(tensordict).data
499+
fen = tensordict.get("fen").data
492500
board.set_fen(fen)
493501
elif self.include_pgn:
494-
pgn = self._get_pgn(tensordict).data
502+
pgn = tensordict.get("pgn").data
495503
board = self._pgn_to_board(pgn, board)
496504
else:
497505
raise RuntimeError(
@@ -521,6 +529,15 @@ def _step(self, tensordict):
521529
if self.include_legal_moves:
522530
moves_idx = self._legal_moves_to_index(board=board, pad=True)
523531
dest.set("legal_moves", moves_idx)
532+
if self.mask_actions:
533+
dest.set("action_mask", self._move_index_to_mask(moves_idx))
534+
elif self.mask_actions:
535+
dest.set(
536+
"action_mask",
537+
self._legal_moves_to_index(
538+
board=self.board, pad=True, return_mask=True
539+
),
540+
)
524541

525542
turn = torch.tensor(board.turn)
526543
done = self._is_done(board)
@@ -540,11 +557,6 @@ def _step(self, tensordict):
540557
dest.set("terminated", [done])
541558
if self.pixels:
542559
dest.set("pixels", self._get_tensor_image(board=self.board))
543-
544-
if self.stateful:
545-
mask = self._legal_moves_to_index(dest, return_mask=True)
546-
self.action_spec.update_mask(mask)
547-
548560
return dest
549561

550562
def _set_seed(self, *args, **kwargs):

0 commit comments

Comments
 (0)