Skip to content

Commit c324b83

Browse files
committed
Commit of working code
1 parent 284c743 commit c324b83

10 files changed

+488
-0
lines changed

Dockerfile

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
FROM pytorch/pytorch:1.4-cuda10.1-cudnn7-runtime
2+
3+
WORKDIR /src/
4+
COPY requirements.txt .
5+
RUN pip install --no-cache-dir -r requirements.txt
6+
COPY setup.py .
7+
COPY tictactoe/ ./tictactoe/
8+
RUN python setup.py install
9+
COPY pytorch_dqn.pt .
10+
11+
ENTRYPOINT ["python", "-m", "tictactoe"]

pytorch_dqn.pt

277 KB
Binary file not shown.

requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
gym>=0.15.6
2+
click>=7.0

setup.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from setuptools import setup, find_packages
2+
3+
setup(
4+
name="TicTacToe",
5+
version="0.0",
6+
description="Learn to play Tic Tac Toe",
7+
author="Matthew Mahowald",
8+
author_email="",
9+
packages=find_packages(),
10+
install_requires=[]
11+
)

tictactoe/__init__.py

Whitespace-only changes.

tictactoe/__main__.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import click
2+
from .fit import fit
3+
from .play import play
4+
import sys
5+
import logging
6+
7+
logging.basicConfig(level=logging.DEBUG)
8+
9+
@click.command()
10+
@click.option("--mode", default="play", help="fit or play")
11+
def main(mode="play"):
12+
if mode == "fit":
13+
res = fit()
14+
sys.stdout.buffer.write(res)
15+
sys.stdout.flush()
16+
elif mode == "play":
17+
play()
18+
19+
if __name__ == "__main__":
20+
main()

tictactoe/env.py

+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import numpy as np
2+
import gym
3+
from gym import spaces
4+
5+
6+
class TicTacToe(gym.Env):
7+
8+
reward_range = (-np.inf, np.inf)
9+
observation_space = spaces.MultiDiscrete([2 for _ in range(0, 9 * 3)])
10+
action_space = spaces.Discrete(9)
11+
12+
"""
13+
Board looks like:
14+
[0, 1, 2,
15+
3, 4, 5,
16+
6, 7, 8]
17+
"""
18+
winning_streaks = [
19+
[0, 1, 2],
20+
[3, 4, 5],
21+
[6, 7, 8],
22+
[0, 3, 6],
23+
[1, 4, 7],
24+
[2, 5, 8],
25+
[0, 4, 8],
26+
[2, 4, 6],
27+
]
28+
29+
def __init__(self, summary: dict = None):
30+
super().__init__()
31+
if summary is None:
32+
summary = {
33+
"total games": 0,
34+
"ties": 0,
35+
"illegal moves": 0,
36+
"player 0 wins": 0,
37+
"player 1 wins": 0,
38+
}
39+
self.summary = summary
40+
41+
def seed(self, seed=None):
42+
pass
43+
44+
def _one_hot_board(self):
45+
if self.current_player == 0:
46+
return np.eye(3)[self.board].reshape(-1)
47+
if self.current_player == 1:
48+
# permute for symmetry
49+
return np.eye(3)[self.board][:, [0, 2, 1]].reshape(-1)
50+
51+
def reset(self):
52+
self.current_player = 0
53+
self.board = np.zeros(9, dtype="int")
54+
return self._one_hot_board()
55+
56+
def step(self, actions):
57+
exp = {"state": "in progress"}
58+
59+
# get the current player's action
60+
action = actions
61+
62+
reward = 0
63+
done = False
64+
# illegal move
65+
if self.board[action] != 0:
66+
reward = -10 # illegal moves are really bad
67+
exp = {"state": "done", "reason": "Illegal move"}
68+
done = True
69+
self.summary["total games"] += 1
70+
self.summary["illegal moves"] += 1
71+
return self._one_hot_board(), reward, done, exp
72+
73+
self.board[action] = self.current_player + 1
74+
75+
# check if the other player can win on the next turn:
76+
for streak in self.winning_streaks:
77+
if ((self.board[streak] == 2 - self.current_player).sum() >= 2) and (
78+
self.board[streak] == 0
79+
).any():
80+
reward = -2
81+
exp = {
82+
"state": "in progress",
83+
"reason": "Player {} can lose on the next turn".format(
84+
self.current_player
85+
),
86+
}
87+
88+
# check if we won
89+
for streak in self.winning_streaks:
90+
if (self.board[streak] == self.current_player + 1).all():
91+
reward = 1 # player wins!
92+
exp = {
93+
"state": "in progress",
94+
"reason": "Player {} has won".format(self.current_player),
95+
}
96+
self.summary["total games"] += 1
97+
self.summary["player {} wins".format(self.current_player)] += 1
98+
done = True
99+
# check if we tied, which ends the game
100+
if (self.board != 0).all():
101+
reward = 0
102+
exp = {
103+
"state": "in progress",
104+
"reason": "Player {} has tied".format(self.current_player),
105+
}
106+
done = True
107+
self.summary["total games"] += 1
108+
self.summary["ties"] += 1
109+
110+
# move to the next player
111+
self.current_player = 1 - self.current_player
112+
113+
return self._one_hot_board(), reward, done, exp
114+
115+
def render(self, mode: str = "human"):
116+
print("{}|{}|{}\n-----\n{}|{}|{}\n-----\n{}|{}|{}".format(*self.board.tolist()))

tictactoe/fit.py

+206
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
from .env import TicTacToe
2+
from .model import Policy, Transition, ReplayMemory
3+
4+
import torch
5+
import torch.optim as optim
6+
import torch.nn.functional as F
7+
import numpy as np
8+
from typing import Tuple
9+
import random
10+
import logging
11+
import io
12+
13+
def fit(
14+
n_steps: int = 500_000,
15+
batch_size: int = 128,
16+
gamma: float = 0.99,
17+
eps_start: float = 1.0,
18+
eps_end: float = 0.1,
19+
eps_steps: int = 200_000,
20+
) -> bytes:
21+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22+
23+
logging.info("Beginning training on: {}".format(device))
24+
25+
target_update = int((1e-2) * n_steps)
26+
policy = Policy(n_inputs=3 * 9, n_outputs=9).to(device)
27+
target = Policy(n_inputs=3 * 9, n_outputs=9).to(device)
28+
target.load_state_dict(policy.state_dict())
29+
target.eval()
30+
31+
optimizer = optim.Adam(policy.parameters(), lr=1e-3)
32+
memory = ReplayMemory(50_000)
33+
34+
env = TicTacToe()
35+
state = torch.tensor([env.reset()], dtype=torch.float).to(device)
36+
old_summary = {
37+
"total games": 0,
38+
"ties": 0,
39+
"illegal moves": 0,
40+
"player 0 wins": 0,
41+
"player 1 wins": 0,
42+
}
43+
_randoms = 0
44+
summaries = []
45+
46+
for step in range(n_steps):
47+
t = np.clip(step / eps_steps, 0, 1)
48+
eps = (1 - t) * eps_start + t * eps_end
49+
50+
action, was_random = select_model_action(device, policy, state, eps)
51+
if was_random:
52+
_randoms += 1
53+
next_state, reward, done, _ = env.step(action.item())
54+
55+
# player 2 goes
56+
if not done:
57+
next_state, _, done, _ = env.step(select_dummy_action(next_state))
58+
next_state = torch.tensor([next_state], dtype=torch.float).to(device)
59+
if done:
60+
next_state = None
61+
62+
memory.push(state, action, next_state, torch.tensor([reward], device=device))
63+
64+
state = next_state
65+
optimize_model(
66+
device=device,
67+
optimizer=optimizer,
68+
policy=policy,
69+
target=target,
70+
memory=memory,
71+
batch_size=batch_size,
72+
gamma=gamma,
73+
)
74+
if done:
75+
state = torch.tensor([env.reset()], dtype=torch.float).to(device)
76+
if step % target_update == 0:
77+
target.load_state_dict(policy.state_dict())
78+
if step % 5000 == 0:
79+
delta_summary = {k: env.summary[k] - old_summary[k] for k in env.summary}
80+
delta_summary["random actions"] = _randoms
81+
old_summary = {k: env.summary[k] for k in env.summary}
82+
logging.info("{} : {}".format(step, delta_summary))
83+
summaries.append(delta_summary)
84+
_randoms = 0
85+
86+
logging.info("Complete")
87+
88+
res = io.BytesIO()
89+
torch.save(policy.state_dict(), res)
90+
91+
return res.getbuffer()
92+
93+
94+
def optimize_model(
95+
device: torch.device,
96+
optimizer: optim.Optimizer,
97+
policy: Policy,
98+
target: Policy,
99+
memory: ReplayMemory,
100+
batch_size: int,
101+
gamma: float,
102+
):
103+
"""Model optimization step, copied verbatim from the Torch DQN tutorial.
104+
105+
Arguments:
106+
device {torch.device} -- Device
107+
optimizer {torch.optim.Optimizer} -- Optimizer
108+
policy {Policy} -- Policy module
109+
target {Policy} -- Target module
110+
memory {ReplayMemory} -- Replay memory
111+
batch_size {int} -- Number of observations to use per batch step
112+
gamma {float} -- Reward discount factor
113+
"""
114+
if len(memory) < batch_size:
115+
return
116+
transitions = memory.sample(batch_size)
117+
# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
118+
# detailed explanation). This converts batch-array of Transitions
119+
# to Transition of batch-arrays.
120+
batch = Transition(*zip(*transitions))
121+
122+
# Compute a mask of non-final states and concatenate the batch elements
123+
# (a final state would've been the one after which simulation ended)
124+
non_final_mask = torch.tensor(
125+
tuple(map(lambda s: s is not None, batch.next_state)),
126+
device=device,
127+
dtype=torch.bool,
128+
)
129+
non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
130+
state_batch = torch.cat(batch.state)
131+
action_batch = torch.cat(batch.action)
132+
reward_batch = torch.cat(batch.reward)
133+
134+
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
135+
# columns of actions taken. These are the actions which would've been taken
136+
# for each batch state according to policy_net
137+
state_action_values = policy(state_batch).gather(1, action_batch)
138+
139+
# Compute V(s_{t+1}) for all next states.
140+
# Expected values of actions for non_final_next_states are computed based
141+
# on the "older" target_net; selecting their best reward with max(1)[0].
142+
# This is merged based on the mask, such that we'll have either the expected
143+
# state value or 0 in case the state was final.
144+
next_state_values = torch.zeros(batch_size, device=device)
145+
next_state_values[non_final_mask] = target(non_final_next_states).max(1)[0].detach()
146+
# Compute the expected Q values
147+
expected_state_action_values = (next_state_values * gamma) + reward_batch
148+
149+
# Compute Huber loss
150+
loss = F.smooth_l1_loss(
151+
state_action_values, expected_state_action_values.unsqueeze(1)
152+
)
153+
154+
# Optimize the model
155+
optimizer.zero_grad()
156+
loss.backward()
157+
for param in policy.parameters():
158+
param.grad.data.clamp_(-1, 1)
159+
optimizer.step()
160+
161+
162+
def select_dummy_action(state: np.array) -> int:
163+
"""Select a random (valid) move, given a board state.
164+
165+
Arguments:
166+
state {np.array} -- Board state observation
167+
168+
Returns:
169+
int -- Move to make.
170+
"""
171+
state = state.reshape(3, 3, 3)
172+
open_spots = state[:, :, 0].reshape(-1)
173+
p = open_spots / open_spots.sum()
174+
return np.random.choice(np.arange(9), p=p)
175+
176+
177+
def select_model_action(
178+
device: torch.device, model: Policy, state: torch.tensor, eps: float
179+
) -> Tuple[torch.tensor, bool]:
180+
"""Selects an action for the model: either using the policy, or
181+
by choosing a random valid action (as controlled by `eps`)
182+
183+
Arguments:
184+
device {torch.device} -- Device
185+
model {Policy} -- Policy module
186+
state {torch.tensor} -- Current board state, as a torch tensor
187+
eps {float} -- Probability of choosing a random state.
188+
189+
Returns:
190+
Tuple[torch.tensor, bool] -- The action, and a bool indicating whether
191+
the action is random or not.
192+
"""
193+
194+
sample = random.random()
195+
if sample > eps:
196+
return model.act(state), False
197+
else:
198+
return (
199+
torch.tensor(
200+
[[select_dummy_action(state.cpu().numpy())]],
201+
device=device,
202+
dtype=torch.long,
203+
),
204+
True,
205+
)
206+

0 commit comments

Comments
 (0)