Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Python Documentation with Sphinx #1

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 48 additions & 20 deletions deep_rf/deep_rf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,18 @@
import epsilon_method

class ExperienceTuple:
""" ExperienceTuple data structure for DeepRFLearner """
""" ExperienceTuple data structure for DeepRFLearner

A data structure for keeping track of <State, Action, Reward, Next State>
tuples. These tuples are used in learning the Q function.

Args:
state (State): a collection of frames
action (int): an action taken
reward (double): measured reward for taking the selected action
next_state (State): the set of frames after taking the selected action

"""
def __init__(self, state, action, reward, next_state):
self.state = state
self.action = action
Expand Down Expand Up @@ -48,21 +59,21 @@ class DeepRFLearner(object):
""" DeepRFLearner Class

Args:
game:
q_graph:
num_frames:
reward_function:
game (SinglePlayerGame): single player game to learn
q_graph (QGraph): tensorflow Q graph
num_frames (int): number of frames to use as the state
reward_function (func):
A function taking a dictionary of parameters and returning a double.
Dict args include:
'last_score', 'new_score', 'last_state', 'new_state', 'is_game_over'.
file_save_path:
file_save_path (string): path to save location

Methods:
get_next_experience_tuple:
choose_action:
evaluate_q_function:
learn_q_function:
save_tf_weights:
* get_next_experience_tuple
* choose_action
* evaluate_q_function
* learn_q_function
* save_tf_weights

"""

Expand Down Expand Up @@ -108,6 +119,7 @@ def __del__(self):


def save_tf_weights(self):
""" Save the Q-network weights """
if self.file_save_path is not None:
self._saver.save(self._sess, self.file_save_path)

Expand All @@ -123,6 +135,21 @@ def _init_training_loss_operation(self):

def learn_q_function(self, num_iterations=1000, batch_size=50,
num_training_steps=10):
""" Learn deep reinforcement learning Q function.

Train the Q function by repeatedly playing the single player game.

Args:
num_iterations (int): number of training iterations
batch_size (int): number of experience tuples to add in each step
num_training_steps (int):
number of optimization steps to take in each iteration.

Returns:
None: updates the Q function

"""

# For Training Time
# Get next sample -> List of ExperienceTuples
# Get a minibatch -> partially Optimize Q for loss
Expand Down Expand Up @@ -150,10 +177,10 @@ def learn_q_function(self, num_iterations=1000, batch_size=50,
def _get_target_values(self, experience_batch):
"""
Args:
experience_batch: list of ExperienceTuples
experience_batch (list): list of ExperienceTuples

Returns:
y_target: np.ndarray of [batch_size, r + max Q(s')]
y_target (ndarray): ndarray of [batch_size, r + max Q(s')]
"""
rewards = np.array([et.reward for et in experience_batch])
states = [
Expand All @@ -173,8 +200,8 @@ def get_next_experience_tuple(self):

DeepRFLearner chooses an action based on the Q function and random exploration

yields:
experience_tuple (Experience Tuple) - current state, action, reward, new_state
Yields:
experience_tuple (ExperienceTuple): current state, action, reward, new_state
"""
while True:
self._game.reset()
Expand Down Expand Up @@ -210,10 +237,10 @@ def choose_action(self, state):
""" Return the action with the highest q_function value

Args:
state: A State object or list of State objects
state (State): A State object or list of State objects

Return:
actions: the action or list of actions that maximize
actions (Action): the action or list of actions that maximize
the q_function for each state
"""
if isinstance(state, State):
Expand All @@ -235,11 +262,12 @@ def evaluate_q_function(self, state):
""" Return q_values for for given state(s)

Args:
state: A State object or list of State objects
state (State): A State object or list of State objects

Return:
q_values: An ndarray of size(action_list) for a state object
An ndarray of # States by size(action_list) for a list
q_values (ndarray): Either
* An ndarray of size(action_list) for a state object
* An ndarray of # States by size(action_list) for a list

"""
if isinstance(state, State):
Expand Down
25 changes: 19 additions & 6 deletions deep_rf/q_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,19 @@


class QGraph(object):
"""
q_input: (tf.placeholder float [None, board_height, board_width, num_frames]) - tf placeholder for state
q_output: (tf.Tensor of action_values [None, num_actions]) - Q function output to evaluated with tf.run()
""" Data structure for the Q Function

Args:
q_input (tf.placeholder float [None, board_height, board_width, num_frames]):
tf placeholder for frame input
q_output (tf.Tensor of action_values [None, num_actions]):
Q function output to evaluated with tf.run()

Attributes:
q_input (tf.placeholder) : input to Q function
q_output (tf.Tensor): Q function output
graph (tf.Graph): tensorflow graph containing the Q function
var_list (list): list of names for Q function weights

"""
def __init__(self, q_input, q_output):
Expand All @@ -19,14 +29,17 @@ def __init__(self, q_input, q_output):
def default_q_graph(game, num_frames):
""" initialize Q function input & output

Parameters:
game (SinglePlayerGame): a game object
num_frames (int): number of past frames to keep in memory

Returns:
QGraph: a q graph
QGraph (QGraph): a q graph
"""

g = tf.Graph()

with g.as_default():

# input layer
q_input = tf.placeholder(dtype=tf.float32,
shape=[None, game.frame_height,
Expand Down Expand Up @@ -65,4 +78,4 @@ def default_q_graph(game, num_frames):
b_fc2 = _utils.init_fc_bias(length=len(game.action_list))
q_output = tf.matmul(h_fc1, w_fc2) + b_fc2

return QGraph(q_input, q_output)
return QGraph(q_input, q_output)
40 changes: 39 additions & 1 deletion deep_rf/single_player_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,25 @@

"""

class SinglePlayerGame:
""" A virtual class for single player games

This contains the class skeleton for a single player game.
To be used in deep reinforcement learning.

Args:
action_list (list): set of game actions
frame_width (int): non-negative size of game window width
frame_height (int): non-negative size of game window height

Attributes:
action_list (list): set of game actions
action_dict (dict): enumeration of action_list

See Also:
deep_rf.deep_rf_learner.DeepRFLearner
"""

class SinglePlayerGame:
def __init__(self, action_list, frame_height, frame_width):
self.action_list = action_list
self.action_dict = {self.action_list[i]: i for i in
Expand All @@ -16,25 +32,47 @@ def __init__(self, action_list, frame_height, frame_width):

@property
def frame_height(self):
""" frame_height (int): non-negative size of game window height """
return self._frame_height

@property
def frame_width(self):
""" frame_width (int): non-negative size of game window width """
return self._frame_width

@property
def score(self):
""" score (double): current score of game """
raise NotImplementedError('Subclass should define get_score()')

def do_action(self, action):
""" Apply player's selected action to current game.

Args:
action (Action): action to perform

Returns:
None: applies action to game
"""
raise NotImplementedError('Subclass should define do_action()')

def get_frame(self):
""" Return the pixels for the current game

Returns:
frame (ndarray): returns the frame_height by frame_width game window.
"""
raise NotImplementedError('Subclass should define get_frame()')

def is_game_over(self):
""" Return whether the game has ended

Returns:
isGameOver (bool): returns whether the game has ended
"""
raise NotImplementedError('Subclass should define is_game_over()')

def reset(self):
""" Start a new game. """
raise NotImplementedError('Subclass should define reset()')

Loading