|
| 1 | +from .env import TicTacToe |
| 2 | +from .model import Policy, Transition, ReplayMemory |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.optim as optim |
| 6 | +import torch.nn.functional as F |
| 7 | +import numpy as np |
| 8 | +from typing import Tuple |
| 9 | +import random |
| 10 | +import logging |
| 11 | +import io |
| 12 | + |
| 13 | +def fit( |
| 14 | + n_steps: int = 500_000, |
| 15 | + batch_size: int = 128, |
| 16 | + gamma: float = 0.99, |
| 17 | + eps_start: float = 1.0, |
| 18 | + eps_end: float = 0.1, |
| 19 | + eps_steps: int = 200_000, |
| 20 | +) -> bytes: |
| 21 | + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 22 | + |
| 23 | + logging.info("Beginning training on: {}".format(device)) |
| 24 | + |
| 25 | + target_update = int((1e-2) * n_steps) |
| 26 | + policy = Policy(n_inputs=3 * 9, n_outputs=9).to(device) |
| 27 | + target = Policy(n_inputs=3 * 9, n_outputs=9).to(device) |
| 28 | + target.load_state_dict(policy.state_dict()) |
| 29 | + target.eval() |
| 30 | + |
| 31 | + optimizer = optim.Adam(policy.parameters(), lr=1e-3) |
| 32 | + memory = ReplayMemory(50_000) |
| 33 | + |
| 34 | + env = TicTacToe() |
| 35 | + state = torch.tensor([env.reset()], dtype=torch.float).to(device) |
| 36 | + old_summary = { |
| 37 | + "total games": 0, |
| 38 | + "ties": 0, |
| 39 | + "illegal moves": 0, |
| 40 | + "player 0 wins": 0, |
| 41 | + "player 1 wins": 0, |
| 42 | + } |
| 43 | + _randoms = 0 |
| 44 | + summaries = [] |
| 45 | + |
| 46 | + for step in range(n_steps): |
| 47 | + t = np.clip(step / eps_steps, 0, 1) |
| 48 | + eps = (1 - t) * eps_start + t * eps_end |
| 49 | + |
| 50 | + action, was_random = select_model_action(device, policy, state, eps) |
| 51 | + if was_random: |
| 52 | + _randoms += 1 |
| 53 | + next_state, reward, done, _ = env.step(action.item()) |
| 54 | + |
| 55 | + # player 2 goes |
| 56 | + if not done: |
| 57 | + next_state, _, done, _ = env.step(select_dummy_action(next_state)) |
| 58 | + next_state = torch.tensor([next_state], dtype=torch.float).to(device) |
| 59 | + if done: |
| 60 | + next_state = None |
| 61 | + |
| 62 | + memory.push(state, action, next_state, torch.tensor([reward], device=device)) |
| 63 | + |
| 64 | + state = next_state |
| 65 | + optimize_model( |
| 66 | + device=device, |
| 67 | + optimizer=optimizer, |
| 68 | + policy=policy, |
| 69 | + target=target, |
| 70 | + memory=memory, |
| 71 | + batch_size=batch_size, |
| 72 | + gamma=gamma, |
| 73 | + ) |
| 74 | + if done: |
| 75 | + state = torch.tensor([env.reset()], dtype=torch.float).to(device) |
| 76 | + if step % target_update == 0: |
| 77 | + target.load_state_dict(policy.state_dict()) |
| 78 | + if step % 5000 == 0: |
| 79 | + delta_summary = {k: env.summary[k] - old_summary[k] for k in env.summary} |
| 80 | + delta_summary["random actions"] = _randoms |
| 81 | + old_summary = {k: env.summary[k] for k in env.summary} |
| 82 | + logging.info("{} : {}".format(step, delta_summary)) |
| 83 | + summaries.append(delta_summary) |
| 84 | + _randoms = 0 |
| 85 | + |
| 86 | + logging.info("Complete") |
| 87 | + |
| 88 | + res = io.BytesIO() |
| 89 | + torch.save(policy.state_dict(), res) |
| 90 | + |
| 91 | + return res.getbuffer() |
| 92 | + |
| 93 | + |
| 94 | +def optimize_model( |
| 95 | + device: torch.device, |
| 96 | + optimizer: optim.Optimizer, |
| 97 | + policy: Policy, |
| 98 | + target: Policy, |
| 99 | + memory: ReplayMemory, |
| 100 | + batch_size: int, |
| 101 | + gamma: float, |
| 102 | +): |
| 103 | + """Model optimization step, copied verbatim from the Torch DQN tutorial. |
| 104 | + |
| 105 | + Arguments: |
| 106 | + device {torch.device} -- Device |
| 107 | + optimizer {torch.optim.Optimizer} -- Optimizer |
| 108 | + policy {Policy} -- Policy module |
| 109 | + target {Policy} -- Target module |
| 110 | + memory {ReplayMemory} -- Replay memory |
| 111 | + batch_size {int} -- Number of observations to use per batch step |
| 112 | + gamma {float} -- Reward discount factor |
| 113 | + """ |
| 114 | + if len(memory) < batch_size: |
| 115 | + return |
| 116 | + transitions = memory.sample(batch_size) |
| 117 | + # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for |
| 118 | + # detailed explanation). This converts batch-array of Transitions |
| 119 | + # to Transition of batch-arrays. |
| 120 | + batch = Transition(*zip(*transitions)) |
| 121 | + |
| 122 | + # Compute a mask of non-final states and concatenate the batch elements |
| 123 | + # (a final state would've been the one after which simulation ended) |
| 124 | + non_final_mask = torch.tensor( |
| 125 | + tuple(map(lambda s: s is not None, batch.next_state)), |
| 126 | + device=device, |
| 127 | + dtype=torch.bool, |
| 128 | + ) |
| 129 | + non_final_next_states = torch.cat([s for s in batch.next_state if s is not None]) |
| 130 | + state_batch = torch.cat(batch.state) |
| 131 | + action_batch = torch.cat(batch.action) |
| 132 | + reward_batch = torch.cat(batch.reward) |
| 133 | + |
| 134 | + # Compute Q(s_t, a) - the model computes Q(s_t), then we select the |
| 135 | + # columns of actions taken. These are the actions which would've been taken |
| 136 | + # for each batch state according to policy_net |
| 137 | + state_action_values = policy(state_batch).gather(1, action_batch) |
| 138 | + |
| 139 | + # Compute V(s_{t+1}) for all next states. |
| 140 | + # Expected values of actions for non_final_next_states are computed based |
| 141 | + # on the "older" target_net; selecting their best reward with max(1)[0]. |
| 142 | + # This is merged based on the mask, such that we'll have either the expected |
| 143 | + # state value or 0 in case the state was final. |
| 144 | + next_state_values = torch.zeros(batch_size, device=device) |
| 145 | + next_state_values[non_final_mask] = target(non_final_next_states).max(1)[0].detach() |
| 146 | + # Compute the expected Q values |
| 147 | + expected_state_action_values = (next_state_values * gamma) + reward_batch |
| 148 | + |
| 149 | + # Compute Huber loss |
| 150 | + loss = F.smooth_l1_loss( |
| 151 | + state_action_values, expected_state_action_values.unsqueeze(1) |
| 152 | + ) |
| 153 | + |
| 154 | + # Optimize the model |
| 155 | + optimizer.zero_grad() |
| 156 | + loss.backward() |
| 157 | + for param in policy.parameters(): |
| 158 | + param.grad.data.clamp_(-1, 1) |
| 159 | + optimizer.step() |
| 160 | + |
| 161 | + |
| 162 | +def select_dummy_action(state: np.array) -> int: |
| 163 | + """Select a random (valid) move, given a board state. |
| 164 | + |
| 165 | + Arguments: |
| 166 | + state {np.array} -- Board state observation |
| 167 | + |
| 168 | + Returns: |
| 169 | + int -- Move to make. |
| 170 | + """ |
| 171 | + state = state.reshape(3, 3, 3) |
| 172 | + open_spots = state[:, :, 0].reshape(-1) |
| 173 | + p = open_spots / open_spots.sum() |
| 174 | + return np.random.choice(np.arange(9), p=p) |
| 175 | + |
| 176 | + |
| 177 | +def select_model_action( |
| 178 | + device: torch.device, model: Policy, state: torch.tensor, eps: float |
| 179 | +) -> Tuple[torch.tensor, bool]: |
| 180 | + """Selects an action for the model: either using the policy, or |
| 181 | + by choosing a random valid action (as controlled by `eps`) |
| 182 | + |
| 183 | + Arguments: |
| 184 | + device {torch.device} -- Device |
| 185 | + model {Policy} -- Policy module |
| 186 | + state {torch.tensor} -- Current board state, as a torch tensor |
| 187 | + eps {float} -- Probability of choosing a random state. |
| 188 | + |
| 189 | + Returns: |
| 190 | + Tuple[torch.tensor, bool] -- The action, and a bool indicating whether |
| 191 | + the action is random or not. |
| 192 | + """ |
| 193 | + |
| 194 | + sample = random.random() |
| 195 | + if sample > eps: |
| 196 | + return model.act(state), False |
| 197 | + else: |
| 198 | + return ( |
| 199 | + torch.tensor( |
| 200 | + [[select_dummy_action(state.cpu().numpy())]], |
| 201 | + device=device, |
| 202 | + dtype=torch.long, |
| 203 | + ), |
| 204 | + True, |
| 205 | + ) |
| 206 | + |
0 commit comments