7
7
import importlib .util
8
8
import io
9
9
import pathlib
10
- from typing import Dict , Optional
10
+ from typing import Dict
11
11
12
12
import torch
13
13
from PIL import Image
14
14
from tensordict import TensorDict , TensorDictBase
15
- from torchrl .data import Bounded , Categorical , Composite , NonTensor , Unbounded
15
+ from torchrl .data import Binary , Bounded , Categorical , Composite , NonTensor , Unbounded
16
16
17
17
from torchrl .envs import EnvBase
18
18
from torchrl .envs .common import _EnvPostInit
19
19
20
20
from torchrl .envs .utils import _classproperty
21
21
22
22
23
- class _HashMeta (_EnvPostInit ):
23
+ class _ChessMeta (_EnvPostInit ):
24
24
def __call__ (cls , * args , ** kwargs ):
25
25
instance = super ().__call__ (* args , ** kwargs )
26
26
if kwargs .get ("include_hash" ):
@@ -37,11 +37,15 @@ def __call__(cls, *args, **kwargs):
37
37
if instance .include_pgn :
38
38
in_keys .append ("pgn" )
39
39
out_keys .append ("pgn_hash" )
40
- return instance .append_transform (Hash (in_keys , out_keys ))
40
+ instance = instance .append_transform (Hash (in_keys , out_keys ))
41
+ if kwargs .get ("mask_actions" , True ):
42
+ from torchrl .envs import ActionMask
43
+
44
+ instance = instance .append_transform (ActionMask ())
41
45
return instance
42
46
43
47
44
- class ChessEnv (EnvBase , metaclass = _HashMeta ):
48
+ class ChessEnv (EnvBase , metaclass = _ChessMeta ):
45
49
r"""A chess environment that follows the TorchRL API.
46
50
47
51
This environment simulates a chess game using the `chess` library. It supports various state representations
@@ -63,6 +67,8 @@ class ChessEnv(EnvBase, metaclass=_HashMeta):
63
67
include_pgn (bool): Whether to include PGN (Portable Game Notation) in the observations. Default: ``False``.
64
68
include_legal_moves (bool): Whether to include legal moves in the observations. Default: ``False``.
65
69
include_hash (bool): Whether to include hash transformations in the environment. Default: ``False``.
70
+ mask_actions (bool): if ``True``, a :class:`~torchrl.envs.ActionMask` transform will be appended
71
+ to the env to make sure that the actions are properly masked. Default: ``True``.
66
72
pixels (bool): Whether to include pixel-based observations of the board. Default: ``False``.
67
73
68
74
.. note:: The action spec is a :class:`~torchrl.data.Categorical` with a number of actions equal to the number of possible SAN moves.
@@ -202,16 +208,15 @@ def _legal_moves_to_index(
202
208
) -> torch .Tensor :
203
209
if not self .stateful :
204
210
if tensordict is None :
205
- raise RuntimeError (
206
- "rand_action requires a tensordict when stateful is False."
207
- )
208
- if self .include_fen :
209
- fen = self ._get_fen (tensordict )
211
+ # trust the board
212
+ pass
213
+ elif self .include_fen :
214
+ fen = tensordict .get ("fen" , None )
210
215
fen = fen .data
211
216
self .board .set_fen (fen )
212
217
board = self .board
213
218
elif self .include_pgn :
214
- pgn = self . _get_pgn ( tensordict )
219
+ pgn = tensordict . get ( "pgn" )
215
220
pgn = pgn .data
216
221
board = self ._pgn_to_board (pgn , self .board )
217
222
@@ -224,15 +229,19 @@ def _legal_moves_to_index(
224
229
)
225
230
226
231
if return_mask :
227
- return torch .zeros (len (self .san_moves ), dtype = torch .bool ).index_fill_ (
228
- 0 , indices , True
229
- )
232
+ return self ._move_index_to_mask (indices )
230
233
if pad :
231
234
indices = torch .nn .functional .pad (
232
235
indices , [0 , 218 - indices .numel () + 1 ], value = len (self .san_moves )
233
236
)
234
237
return indices
235
238
239
+ @classmethod
240
+ def _move_index_to_mask (cls , indices : torch .Tensor ) -> torch .Tensor :
241
+ return torch .zeros (len (cls .san_moves ), dtype = torch .bool ).index_fill_ (
242
+ 0 , indices , True
243
+ )
244
+
236
245
def __init__ (
237
246
self ,
238
247
* ,
@@ -242,6 +251,7 @@ def __init__(
242
251
include_pgn : bool = False ,
243
252
include_legal_moves : bool = False ,
244
253
include_hash : bool = False ,
254
+ mask_actions : bool = True ,
245
255
pixels : bool = False ,
246
256
):
247
257
chess = self .lib
@@ -252,6 +262,7 @@ def __init__(
252
262
self .include_san = include_san
253
263
self .include_fen = include_fen
254
264
self .include_pgn = include_pgn
265
+ self .mask_actions = mask_actions
255
266
self .include_legal_moves = include_legal_moves
256
267
if include_legal_moves :
257
268
# 218 max possible legal moves per chess board position
@@ -276,8 +287,10 @@ def __init__(
276
287
277
288
self .stateful = stateful
278
289
279
- if not self .stateful :
280
- self .full_state_spec = self .full_observation_spec .clone ()
290
+ # state_spec is loosely defined as such - it's not really an issue that extra keys
291
+ # can go missing but it allows us to reset the env using fen passed to the reset
292
+ # method.
293
+ self .full_state_spec = self .full_observation_spec .clone ()
281
294
282
295
self .pixels = pixels
283
296
if pixels :
@@ -297,16 +310,16 @@ def __init__(
297
310
self .full_reward_spec = Composite (
298
311
reward = Unbounded (shape = (1 ,), dtype = torch .float32 )
299
312
)
313
+ if self .mask_actions :
314
+ self .full_observation_spec ["action_mask" ] = Binary (
315
+ n = len (self .san_moves ), dtype = torch .bool
316
+ )
317
+
300
318
# done spec generated automatically
301
319
self .board = chess .Board ()
302
320
if self .stateful :
303
321
self .action_spec .set_provisional_n (len (list (self .board .legal_moves )))
304
322
305
- def rand_action (self , tensordict : Optional [TensorDictBase ] = None ):
306
- mask = self ._legal_moves_to_index (tensordict , return_mask = True )
307
- self .action_spec .update_mask (mask )
308
- return super ().rand_action (tensordict )
309
-
310
323
def _is_done (self , board ):
311
324
return board .is_game_over () | board .is_fifty_moves ()
312
325
@@ -316,11 +329,11 @@ def _reset(self, tensordict=None):
316
329
if tensordict is not None :
317
330
dest = tensordict .empty ()
318
331
if self .include_fen :
319
- fen = self . _get_fen ( tensordict )
332
+ fen = tensordict . get ( "fen" , None )
320
333
if fen is not None :
321
334
fen = fen .data
322
335
elif self .include_pgn :
323
- pgn = self . _get_pgn ( tensordict )
336
+ pgn = tensordict . get ( "pgn" , None )
324
337
if pgn is not None :
325
338
pgn = pgn .data
326
339
else :
@@ -360,13 +373,18 @@ def _reset(self, tensordict=None):
360
373
if self .include_legal_moves :
361
374
moves_idx = self ._legal_moves_to_index (board = self .board , pad = True )
362
375
dest .set ("legal_moves" , moves_idx )
376
+ if self .mask_actions :
377
+ dest .set ("action_mask" , self ._move_index_to_mask (moves_idx ))
378
+ elif self .mask_actions :
379
+ dest .set (
380
+ "action_mask" ,
381
+ self ._legal_moves_to_index (
382
+ board = self .board , pad = True , return_mask = True
383
+ ),
384
+ )
385
+
363
386
if self .pixels :
364
387
dest .set ("pixels" , self ._get_tensor_image (board = self .board ))
365
-
366
- if self .stateful :
367
- mask = self ._legal_moves_to_index (dest , return_mask = True )
368
- self .action_spec .update_mask (mask )
369
-
370
388
return dest
371
389
372
390
_cairosvg_lib = None
@@ -437,16 +455,6 @@ def _board_to_pgn(cls, board: "chess.Board") -> str: # noqa: F821
437
455
pgn_string = str (game )
438
456
return pgn_string
439
457
440
- @classmethod
441
- def _get_fen (cls , tensordict ):
442
- fen = tensordict .get ("fen" , None )
443
- return fen
444
-
445
- @classmethod
446
- def _get_pgn (cls , tensordict ):
447
- pgn = tensordict .get ("pgn" , None )
448
- return pgn
449
-
450
458
def get_legal_moves (self , tensordict = None , uci = False ):
451
459
"""List the legal moves in a position.
452
460
@@ -470,7 +478,7 @@ def get_legal_moves(self, tensordict=None, uci=False):
470
478
raise ValueError (
471
479
"tensordict must be given since this env is not stateful"
472
480
)
473
- fen = self . _get_fen ( tensordict ).data
481
+ fen = tensordict . get ( "fen" ).data
474
482
board .set_fen (fen )
475
483
moves = board .legal_moves
476
484
@@ -488,10 +496,10 @@ def _step(self, tensordict):
488
496
fen = None
489
497
if not self .stateful :
490
498
if self .include_fen :
491
- fen = self . _get_fen ( tensordict ).data
499
+ fen = tensordict . get ( "fen" ).data
492
500
board .set_fen (fen )
493
501
elif self .include_pgn :
494
- pgn = self . _get_pgn ( tensordict ).data
502
+ pgn = tensordict . get ( "pgn" ).data
495
503
board = self ._pgn_to_board (pgn , board )
496
504
else :
497
505
raise RuntimeError (
@@ -521,6 +529,15 @@ def _step(self, tensordict):
521
529
if self .include_legal_moves :
522
530
moves_idx = self ._legal_moves_to_index (board = board , pad = True )
523
531
dest .set ("legal_moves" , moves_idx )
532
+ if self .mask_actions :
533
+ dest .set ("action_mask" , self ._move_index_to_mask (moves_idx ))
534
+ elif self .mask_actions :
535
+ dest .set (
536
+ "action_mask" ,
537
+ self ._legal_moves_to_index (
538
+ board = self .board , pad = True , return_mask = True
539
+ ),
540
+ )
524
541
525
542
turn = torch .tensor (board .turn )
526
543
done = self ._is_done (board )
@@ -540,11 +557,6 @@ def _step(self, tensordict):
540
557
dest .set ("terminated" , [done ])
541
558
if self .pixels :
542
559
dest .set ("pixels" , self ._get_tensor_image (board = self .board ))
543
-
544
- if self .stateful :
545
- mask = self ._legal_moves_to_index (dest , return_mask = True )
546
- self .action_spec .update_mask (mask )
547
-
548
560
return dest
549
561
550
562
def _set_seed (self , * args , ** kwargs ):
0 commit comments