Skip to content

Commit 0882189

Browse files
committed
Add dqn agent
1 parent 6974fa4 commit 0882189

File tree

7 files changed

+71
-3
lines changed

7 files changed

+71
-3
lines changed

.DS_Store

0 Bytes
Binary file not shown.
2 KB
Binary file not shown.
0 Bytes
Binary file not shown.
Binary file not shown.

Deep Reinforcement Learning/Pytorch Implementations/DQN/dqn_agent.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# coding=utf-8
22
import math
33
import random
4+
import json
5+
import subprocess
46
from collections import namedtuple
57

68
import matplotlib
79
matplotlib.use('TkAgg')
810
import matplotlib.pyplot as plt
911
from matplotlib import animation
10-
# from IPython.display import display
11-
# from JSAnimation.IPython_display import display_animation
1212

1313
import numpy as np
1414
import torch
@@ -19,7 +19,6 @@
1919
from .dqn_agent_network import DqnAgentNetwork
2020
from .replay_memory import ReplayMemory
2121
from .utils import Utils
22-
# from moviepy.editor import ImageSequenceClip
2322

2423

2524
class DqnAgent:
@@ -76,6 +75,17 @@ def __init__(self,
7675

7776
self.utils = Utils()
7877

78+
gui_code = subprocess.call(["python", "gui.py"])
79+
s
80+
81+
with open("./params.json", "r") as infile:
82+
params_dict = json.load(infile)
83+
84+
self.Gamma = params_dict['Gamma']
85+
self.LearningRate = params_dict['Learning Rate']
86+
self.number_of_episodes = int(params_dict['Episodes'])
87+
self.BatchSize = int(params_dict['Batch Size'])
88+
7989
def train(self, rl_environment):
8090
"""
8191
This method trains the agent on the game environment
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from tkinter import *
2+
from tkinter import messagebox
3+
import json
4+
5+
fields = 'Episodes', 'Learning Rate', 'Gamma', 'Batch Size'
6+
7+
def fetch(entries):
8+
for entry in entries:
9+
field = entry[0]
10+
text = entry[1].get()
11+
print('%s: "%s"' % (field, text))
12+
13+
def makeform(root, fields):
14+
entries = []
15+
for field in fields:
16+
row = Frame(root)
17+
lab = Label(row, width=15, text=field, anchor='w')
18+
ent = Entry(row)
19+
row.pack(side=TOP, fill=X, padx=5, pady=5)
20+
lab.pack(side=LEFT)
21+
ent.pack(side=RIGHT, expand=YES, fill=X)
22+
entries.append((field, ent))
23+
return entries
24+
25+
def save_parameters_to_file(entries, root):
26+
params_dict = dict()
27+
28+
for entry in entries:
29+
field = entry[0]
30+
text = entry[1].get()
31+
if field == "Episodes":
32+
if int(text) < 5:
33+
messagebox.showinfo("Error", "Too few episodes. Please enter reasonable number of episodes")
34+
return
35+
elif int(text) > 100000:
36+
messagebox.showinfo("Error", "Too many episodes. Please enter reasonable number of episodes")
37+
return
38+
if text == "":
39+
messagebox.showinfo("Error", "Please enter a value for " + str(field))
40+
return
41+
params_dict[field] = float(text)
42+
43+
with open("./params.json", "w") as outfile:
44+
json.dump(params_dict, outfile)
45+
46+
messagebox.showinfo("Yayyy!", "The parameter values were successfully saved. Beginning training!")
47+
root.destroy()
48+
return
49+
50+
root = Tk()
51+
ents = makeform(root, fields)
52+
root.title("Reinforcement Learning")
53+
root.bind('<Return>', (lambda event, e = ents: fetch(e)))
54+
55+
b1 = Button(root, text = 'SUBMIT', command = (lambda e=ents: save_parameters_to_file(e, root)), bg = "blue")
56+
b1.pack(side = LEFT, padx = 5, pady = 5)
57+
root.mainloop()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"Episodes": 1000.0, "Learning Rate": 1.0, "Gamma": 1.0, "Batch Size": 1.0}

0 commit comments

Comments
 (0)