Skip to content

Commit

Permalink
[Feature] Add EnvBase.all_actions and impl for ChessEnv
Browse files Browse the repository at this point in the history
ghstack-source-id: 9ee21835f35437e856b0726019114eb81a1115bc
Pull Request resolved: #2780
  • Loading branch information
kurtamohler committed Feb 12, 2025
1 parent f5445a4 commit 29b4dbc
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 1 deletion.
52 changes: 52 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4032,6 +4032,58 @@ def test_chess_tokenized(self):
assert "fen" in ftd["next"]
env.check_env_specs()

@pytest.mark.parametrize("include_fen", [False, True])
@pytest.mark.parametrize("include_pgn", [False, True])
@pytest.mark.parametrize("stateful", [False, True])
@pytest.mark.parametrize("mask_actions", [False, True])
def test_all_actions(self, include_fen, include_pgn, stateful, mask_actions):
if not stateful and not include_fen and not include_pgn:
pytest.skip("fen or pgn must be included if not stateful")

env = ChessEnv(
include_fen=include_fen,
include_pgn=include_pgn,
stateful=stateful,
mask_actions=mask_actions,
)
td = env.reset()

# Choose random actions from the output of `all_actions`
for _ in range(100):
if stateful:
all_actions = env.all_actions()
else:
# Reset the the initial state first, just to make sure
# `all_actions` knows how to get the board state from the input.
env.reset()
all_actions = env.all_actions(td.clone())

# Choose some random actions and make sure they match exactly one of
# the actions from `all_actions`. This part is not tested when
# `mask_actions == False`, because `rand_action` can pick illegal
# actions in that case.
if mask_actions:
# TODO: Something is wrong in `ChessEnv.rand_action` which makes
# it fail to work properly for stateless mode. It doesn't know
# how to correctly reset the board state to what is given in the
# tensordict before picking an action. When this is fixed, we
# can get rid of the two `reset`s below
if not stateful:
env.reset(td.clone())
td_act = td.clone()
for _ in range(10):
rand_action = env.rand_action(td_act)
assert (rand_action["action"] == all_actions["action"]).sum() == 1
if not stateful:
env.reset()

action_idx = torch.randint(0, all_actions.shape[0], ()).item()
chosen_action = all_actions[action_idx]
td = env.step(td.update(chosen_action))["next"]

if td["done"]:
td = env.reset()


class TestCustomEnvs:
def test_tictactoe_env(self):
Expand Down
21 changes: 21 additions & 0 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2764,6 +2764,27 @@ def _assert_tensordict_shape(self, tensordict: TensorDictBase) -> None:
f"got {tensordict.batch_size} and {self.batch_size}"
)

def all_actions(
self, tensordict: Optional[TensorDictBase] = None
) -> TensorDictBase:
"""Generates all possible actions.
Not all environments implement this function. Furthermore, it can only
be implemented for environments that have fully discrete actions specs.
Args:
tensordict (TensorDictBase, optional): tensordict where the
resulting actions should be written. This input can also be used
to pass arguments to the reset function, in which case the
actions will be generated for the state after reset.
Returns:
a tensordict object with the "action" entry updated with a batch of
all possible actions. The actions are stacked together in the
leading dimension.
"""
raise NotImplementedError

def rand_action(self, tensordict: Optional[TensorDictBase] = None):
"""Performs a random action given the action_spec attribute.
Expand Down
31 changes: 30 additions & 1 deletion torchrl/envs/custom/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import importlib.util
import io
import pathlib
from typing import Dict
from typing import Dict, Optional

import torch
from tensordict import TensorDict, TensorDictBase
Expand Down Expand Up @@ -526,6 +526,35 @@ def get_legal_moves(self, tensordict=None, uci=False):
else:
return [board.san(move) for move in moves]

def all_actions(self, tensordict: Optional[TensorDictBase] = None):
board = self.board

if not self.stateful:
if self.include_fen:
fen = tensordict.get("fen").data
board.set_fen(fen)
elif self.include_pgn:
pgn = tensordict.get("pgn").data
board = self._pgn_to_board(pgn, board)
else:
raise RuntimeError(
"Not enough information to deduce the board. If stateful=False, include_pgn or include_fen must be True."
)

if tensordict is None:
dest = TensorDict()
else:
dest = tensordict.empty()

moves_idx = self._legal_moves_to_index(
board=board,
pad=False,
return_mask=False,
)
dest.batch_size = torch.Size([moves_idx.shape[0]])
dest.set("action", moves_idx)
return dest

def _step(self, tensordict):
# action
action = tensordict.get("action")
Expand Down
5 changes: 5 additions & 0 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,11 @@ def rand_action(self, tensordict: Optional[TensorDictBase] = None) -> TensorDict
return self.base_env.rand_action(tensordict)
return super().rand_action(tensordict)

def all_actions(
self, tensordict: Optional[TensorDictBase] = None
) -> TensorDictBase:
return self.base_env.all_actions(tensordict)

def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
# No need to clone here because inv does it already
# tensordict = tensordict.clone(False)
Expand Down

0 comments on commit 29b4dbc

Please sign in to comment.