Skip to content

Commit

Permalink
use new clubs typing + mypy fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Mar 14, 2022
1 parent eacc2c4 commit a04d007
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 50 deletions.
15 changes: 9 additions & 6 deletions clubs_gym/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__version__ = "0.1.3"
__version__ = "0.1.4"
__author__ = "Ferdinand Schlatt"
__license__ = "GPL-3.0"
__copyright__ = f"Copyright (c) 2021, {__author__}."
__copyright__ = f"Copyright (c) 2022, {__author__}."
__homepage__ = "https://github.com/fschlatt/clubs_gym"
__docs__ = (
"clubs is an open ai gym environment for running arbitrary poker configurations."
Expand All @@ -11,24 +11,27 @@
# This variable is injected in the __builtins__ by the build
# process. It used to enable importing subpackages of skimage when
# the binaries are not built
__CLUBS_GYM_SETUP__ # type: ignore
__CLUBS_GYM_SETUP__
__CLUBS_GYM_SETUP__: bool = True
except NameError:
__CLUBS_GYM_SETUP__ = False

if __CLUBS_GYM_SETUP__: # type: ignore
if __CLUBS_GYM_SETUP__:
pass
else:
from . import agent, envs

from typing import Dict

import clubs

__all__ = ["agent", "envs"]
ENVS = []


def __register():
def __register() -> None:
try:
env_configs = {}
env_configs: Dict[str, clubs.configs.PokerConfig] = {}
for name, config in clubs.configs.__dict__.items():
if not name.endswith("_PLAYER"):
continue
Expand Down
7 changes: 5 additions & 2 deletions clubs_gym/agent/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import clubs


class BaseAgent:
def __init__(self):
def __init__(self) -> None:
pass

def act(self, obs):
def act(self, obs: clubs.poker.engine.ObservationDict) -> int:
raise NotImplementedError()
50 changes: 30 additions & 20 deletions clubs_gym/agent/kuhn.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,80 @@
import random

import clubs

from . import base


class NashKuhnAgent(base.BaseAgent):
def __init__(self, alpha):
def __init__(self, alpha: float) -> None:
super().__init__()
if alpha < 0 or alpha > 1 / 3:
raise ValueError(
f"invalid alpha value, expected 0 <= alpha <= 1/3, got {alpha}"
)
self.alpha = alpha

def player_1_check(self, obs):
if obs["hole_cards"][0].rank == "Q":
def player_1_check(self, obs: clubs.poker.engine.ObservationDict) -> int:
rank = obs["hole_cards"][0].rank
if rank == "Q":
if random.random() < self.alpha:
return 1
return 0
if obs["hole_cards"][0].rank == "K":
if rank == "K":
return 0
if obs["hole_cards"][0].rank == "A":
if rank == "A":
if random.random() < 3 * self.alpha:
return 1
return 0
raise ValueError("got invalid card rank, expected one of [Q, K, A] got {f.}")

def player_1_bet(self, obs):
if obs["hole_cards"][0].rank == "Q":
def player_1_bet(self, obs: clubs.poker.engine.ObservationDict) -> int:
rank = obs["hole_cards"][0].rank
if rank == "Q":
return 0
if obs["hole_cards"][0].rank == "K":
if rank == "K":
if random.random() < 1 / 3 + self.alpha:
return 1
return 0
if obs["hole_cards"][0].rank == "A":
if rank == "A":
return 1
raise ValueError("got invalid card rank, expected one of [Q, K, A] got {f.}")

def _player_1(self, obs):
def _player_1(self, obs: clubs.poker.engine.ObservationDict) -> int:
if obs["pot"] == 2:
return self.player_1_check(obs)
return self.player_1_bet(obs)

def _player_2_check(self, obs):
if obs["hole_cards"][0].rank == "Q":
def _player_2_check(self, obs: clubs.poker.engine.ObservationDict) -> int:
rank = obs["hole_cards"][0].rank
if rank == "Q":
if random.random() < 1 / 3:
return 1
return 0
if obs["hole_cards"][0].rank == "K":
if rank == "K":
return 0
if obs["hole_cards"][0].rank == "A":
if rank == "A":
return 1
raise ValueError("got invalid card rank, expected one of [Q, K, A] got {f.}")

def _player_2_bet(self, obs):
if obs["hole_cards"][0].rank == "Q":
def _player_2_bet(self, obs: clubs.poker.engine.ObservationDict) -> int:
rank = obs["hole_cards"][0].rank
if rank == "Q":
return 0
if obs["hole_cards"][0].rank == "K":
if rank == "K":
if random.random() < 1 / 3:
return 1
return 0
if obs["hole_cards"][0].rank == "A":
if rank == "A":
return 1
raise ValueError("got invalid card rank, expected one of [Q, K, A] got {f.}")

def _player_2(self, obs):
def _player_2(self, obs: clubs.poker.engine.ObservationDict) -> int:
if obs["pot"] == 2:
return self._player_2_check(obs)
return self._player_2_bet(obs)

def act(self, obs):
def act(self, obs: clubs.poker.engine.ObservationDict) -> int:
if obs["action"] == 0:
return self._player_1(obs)
return self._player_2(obs)
47 changes: 28 additions & 19 deletions clubs_gym/envs/env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from typing import Dict, List, Optional, Tuple, Union
import sys
from typing import Any, Dict, List, Optional, Tuple, Union

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal


import clubs
import gym
Expand All @@ -7,7 +14,7 @@
from .. import agent, error


class ClubsEnv(gym.Env):
class ClubsEnv(gym.Env): # type: ignore
"""Runs a range of different of poker games dependent on the
given configuration. Supports limit, no limit and pot limit
bet sizing, arbitrary deck sizes, arbitrary hole and community
Expand Down Expand Up @@ -106,8 +113,10 @@ def __init__(
num_streets: int,
blinds: Union[int, List[int]],
antes: Union[int, List[int]],
raise_sizes: Union[float, str, List[Union[float, str]]],
num_raises: Union[float, List[float]],
raise_sizes: Union[
int, Literal["pot", "inf"], List[Union[int, Literal["pot", "inf"]]]
],
num_raises: Union[int, Literal["inf"], List[Union[int, Literal["inf"]]]],
num_suits: int,
num_ranks: int,
num_hole_cards: int,
Expand Down Expand Up @@ -166,12 +175,12 @@ def __init__(
)

self.agents: Optional[Dict[int, agent.BaseAgent]] = None
self.prev_obs: Optional[Dict] = None
self.prev_obs: Optional[clubs.poker.engine.ObservationDict] = None

def __del__(self):
def __del__(self) -> None:
self.close()

def act(self, obs: dict) -> int:
def act(self, obs: clubs.poker.engine.ObservationDict) -> int:
if self.agents is None:
raise error.NoRegisteredAgentsError(
"register agents using env.register_agents(...) before"
Expand All @@ -185,32 +194,32 @@ def act(self, obs: dict) -> int:
bet = self.agents[action].act(obs)
return bet

@staticmethod
def _parse_obs(obs):
obs["hole_cards"] = obs["hole_cards"][obs["action"]]
return obs

def step(self, bet: int) -> Tuple[Dict, List[int], List[int], None]:
def step(
self, bet: int
) -> Tuple[clubs.poker.engine.ObservationDict, List[int], List[bool], None]:
obs, rewards, done = self.dealer.step(bet)
obs = self._parse_obs(obs)
if self.agents is not None:
self.prev_obs = obs
return obs, rewards, done, None

def reset(self, reset_button: bool = False, reset_stacks: bool = False) -> Dict:
def reset(
self, reset_button: bool = False, reset_stacks: bool = False
) -> clubs.poker.engine.ObservationDict:
obs = self.dealer.reset(reset_button, reset_stacks)
if self.agents is not None:
self.prev_obs = obs
return obs

def render(self, mode="human", **kwargs) -> None:
def render(self, mode: str = "human", **kwargs: Any) -> None:
self.dealer.render(mode=mode, **kwargs)

def close(self):
def close(self) -> None:
if isinstance(self.dealer.viewer, clubs.render.GraphicViewer):
self.dealer.viewer.close()

def register_agents(self, agents: Union[List, Dict]) -> None:
def register_agents(
self, agents: Union[List[agent.BaseAgent], Dict[int, agent.BaseAgent]]
) -> None:
error_msg = "invalid agent configuration, got {}, expected {}"
if not isinstance(agents, (dict, list)):
raise error.InvalidAgentConfigurationError(
Expand Down Expand Up @@ -244,7 +253,7 @@ def register_agents(self, agents: Union[List, Dict]) -> None:
self.agents = dict(zip(agent_keys, agents))


def register(configs: Dict) -> None:
def register(configs: Dict[str, clubs.configs.PokerConfig]) -> None:
"""Registers dict of clubs configs as gym environments
Parameters
Expand Down
2 changes: 1 addition & 1 deletion run_local_tests.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ isort .
black .
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=88 --statistics
mypy test clubs --config-file=mypy.ini --strict
mypy test clubs_gym --config-file=mypy.ini --strict
pytest --cov clubs_gym/ --cov-report html test/
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
try:
import builtins
except ImportError:
import __builtin__ as builtins # type: ignore
import __builtin__ as builtins

PATH_ROOT = os.path.dirname(__file__)
builtins.__CLUBS__SETUP__ = True # type: ignore
builtins.__CLUBS_GYM__SETUP__: bool = True

import clubs_gym # noqa

Expand Down

0 comments on commit a04d007

Please sign in to comment.