Skip to content

Commit

Permalink
feat!: run q-learning on graph
Browse files Browse the repository at this point in the history
  • Loading branch information
NTGNguyen committed Dec 22, 2024
1 parent d2de5d5 commit d7d2096
Show file tree
Hide file tree
Showing 14 changed files with 131 additions and 102 deletions.
1 change: 1 addition & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from src import MainWindow
from src import QLearningNX

root = MainWindow("new")
root.mainloop()
26 changes: 24 additions & 2 deletions src/Graph/button.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg

from .graph import GraphRandomGenerate
from ..Q_Learning import QLearningNX

if TYPE_CHECKING:
from .frame import ButtonFrame, GraphFrame
from .window import MainWindow


class DrawRandomButton(ttk.Button):
"""The Button inherits from ttk.Button"""

def __init__(self, parent: 'ButtonFrame', graph_frame: 'GraphFrame'):
def __init__(self, parent: 'ButtonFrame', graph_frame: 'GraphFrame', window: 'MainWindow'):
"""The button to draw graph
Args:
Expand All @@ -25,13 +27,15 @@ def __init__(self, parent: 'ButtonFrame', graph_frame: 'GraphFrame'):
super().__init__(parent, text="Draw Random Graph", command=self.draw_random_graph)
self.pack(side=LEFT, padx=5, pady=5)
self.graph_frame: GraphFrame = graph_frame
self.window = window

def draw_random_graph(self):
"""Function to draw random graph
"""
for widget in self.graph_frame.winfo_children():
widget.destroy()
self.Gr = GraphRandomGenerate()
self.q_learning = QLearningNX(self.window, self.Gr.G, 9)
fig, ax = plt.subplots(figsize=(5, 4))
nx.draw(self.Gr.G, self.Gr.pos, ax=ax, with_labels=True, node_color=self.Gr.node_colors,
edge_color=self.Gr.edge_colors, node_size=500, font_size=10)
Expand All @@ -44,4 +48,22 @@ def draw_random_graph(self):
self.canvas_widget = self.canvas.get_tk_widget()
self.canvas_widget.pack(fill=BOTH, expand=True)
self.canvas.draw()
self.Gr.change_edge_color(1, 2)


class LearnButton(ttk.Button):
"""The button inherits from ttk.Button"""

def __init__(self, parent: 'ButtonFrame', random_button: DrawRandomButton):
"""The Initialization of class
Args:
parent (ButtonFrame): The Button frame
graph_q_le (QLearningNX): The algorithm
"""
super().__init__(parent, text="Start learn",
command=self.blearn)
self.random_button = random_button
self.pack(side=LEFT, padx=5, pady=5)

def blearn(self):
self.random_button.q_learning.learn(0.5, 0.8, 0.8)
8 changes: 4 additions & 4 deletions src/Graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, default_node_color: str = 'skyblue', default_edge_color: str
self.edge_collection: plt.Collection
self.canvas: Any

def change_node_color(self, node_number: int) -> None:
def change_node_color(self, node_number: int, new_color: str) -> None:
"""Change node color with specific number of node
Args:
Expand All @@ -37,11 +37,11 @@ def change_node_color(self, node_number: int) -> None:
if not self.G:
return
node_id = list(self.G.nodes)[node_number]
self.node_colors[node_number] = "red"
self.node_colors[node_number] = new_color
self.node_collection.set_facecolor(self.node_colors)
self.canvas.draw_idle()

def change_edge_color(self, first_node_number: int, second_node_number: int) -> None:
def change_edge_color(self, first_node_number: int, second_node_number: int, new_color: str) -> None:
"""Change edge color with specific number of two nodes in connection
Args:
Expand All @@ -50,6 +50,6 @@ def change_edge_color(self, first_node_number: int, second_node_number: int) ->
"""
edge_index = list(self.G.edges).index((first_node_number, second_node_number)if (
first_node_number, second_node_number) in self.G.edges else (second_node_number, first_node_number))
self.edge_colors[edge_index] = "red"
self.edge_colors[edge_index] = new_color
self.edge_collection.set_edgecolor(self.edge_colors)
self.canvas.draw_idle()
28 changes: 28 additions & 0 deletions src/Graph/label.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""A label to show epoch in current time"""

from tkinter import LEFT, ttk
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .frame import ButtonFrame
import time


class EpochLabel(ttk.Label):
"""The EpochLabel inherits from ttk.Label"""

def __init__(self, parent: "ButtonFrame"):
"""The Initialization of this class
Args:
parent (ButtonFrame): The parent of this class
"""
super().__init__(parent, text="Epoch:0/50000")
self.pack(side=LEFT, padx=5, pady=5)

def update_epoch(self, new_epoch: int) -> None:
"""Update new epoch
Args:
new_epoch (int): new epoch
"""
self.configure(text=f"Epoch:{new_epoch}/5000")
8 changes: 6 additions & 2 deletions src/Graph/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import tkinter as tk

from ..utils import HEIGHT_WINDOW, WIDTH_WINDOW
from .button import DrawRandomButton
from .button import DrawRandomButton, LearnButton
from .frame import ButtonFrame, GraphFrame, MainFrame
from .label import EpochLabel


class MainWindow(tk.Tk):
Expand All @@ -22,4 +23,7 @@ def __init__(self, title: str):
self.graph_frame: GraphFrame = GraphFrame(self.main_frame)
self.button_frame: ButtonFrame = ButtonFrame(self.main_frame)
self.draw_random_button: DrawRandomButton = DrawRandomButton(
self.button_frame, self.graph_frame)
self.button_frame, self.graph_frame, self)
self.learn_button = LearnButton(
self.button_frame, self.draw_random_button)
self.epoch_label = EpochLabel(self.button_frame)
4 changes: 2 additions & 2 deletions src/Q_Learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Modules"""
from .q_learning_networkx import GraphNetworkX, GraphNetworkXForQLearning
from .q_learning_networkx import QLearningNX

__all__: list[str] = ["GraphNetworkX", "GraphNetworkXForQLearning"]
__all__: list[str] = ["QLearningNX"]
148 changes: 62 additions & 86 deletions src/Q_Learning/q_learning_networkx.py
Original file line number Diff line number Diff line change
@@ -1,123 +1,99 @@
"""Draw Graph with networkx and Find the shortest path behind the scene"""
import matplotlib.pyplot as plt
"""Q-Learning with networkx"""
from typing import TYPE_CHECKING
import networkx as nx
import numpy as np
import random
from numpy import matrix
import time
if TYPE_CHECKING:
from ..Graph import MainWindow

from ..utils import DEFAULT_EDGES_NETWORKX

class QLearningNX:
"""The QLearning with networkx"""

class GraphNetworkX:
"""A class using to draw Graph with NetWorkX"""

def __init__(self, edges: list[tuple[int, int]] = DEFAULT_EDGES_NETWORKX):
"""The Initialization of GraphNetworkX class
def __init__(self, Window: 'MainWindow', G: nx.Graph, target_node_number: int):
"""The Initialization of QLearningNX class
Args:
edges (list[tuple[int, int]], optional): List contains connections between two node. Defaults to DEFAULT_EDGES_NETWORKX.
"""
self.edges = edges
self.G = nx.Graph()
self.G.add_edges_from(self.edges)
self.pos = nx.spring_layout(self.G)

def draw_graph(self):
"""Function to draw graph
"""
nx.draw_networkx_nodes(self.G, self.pos)
nx.draw_networkx_edges(self.G, self.pos)
nx.draw_networkx_labels(self.G, self.pos)
plt.show()


class GraphNetworkXForQLearning:
"""A Class using to run Q-Learning Algorithm in the graph behind the scene"""

def __init__(self, G: GraphNetworkX, number_nodes: int, end_node_number: int):
"""The Initialization of GraphNetworkXForQLearning Class
Args:
G (GraphNetworkX): The Graph which this algorithm run on
number_nodes (int): number of node in graph
end_node_number (int): The target node
Window (MainWindow): The main window
G (nx.Graph): Graph
target_node_number (int): target node(final node)
"""
self.window: 'MainWindow' = Window
self.G = G
self.number_nodes = number_nodes
self.R: matrix = matrix(np.zeros(shape=(number_nodes, number_nodes)))
for x in self.G.G[end_node_number]:
self.R[x, end_node_number] = 100
self.Q: matrix = matrix(np.zeros(shape=(number_nodes, number_nodes)))
self.R: matrix = np.matrix(np.zeros(
shape=(self.G.number_of_nodes(), self.G.number_of_nodes())))
for x in self.G[target_node_number]:
self.R[x, 10] = 100
self.Q: matrix = np.matrix(np.zeros(
shape=(self.G.number_of_nodes(), self.G.number_of_nodes())))
self.Q -= 100
for node in self.G.G.nodes:
for x in self.G.G[node]:
for node in self.G.nodes:
for x in G[node]:
self.Q[node, x] = 0
self.Q[x, node] = 0

def next_number(self, start_node: int, threshold: float) -> int:
"""Find next node to forward
def next_number(self, start_node_number: int, threshold: float) -> int:
"""Find next node number step
Args:
start_node (int): From this node to
threshold (float): Threshold for considering of two options
start_node_number (int): From this node to next node
threshold (float): The threshole
Returns:
int: number of next node
int: the next node number
"""
random_value = np.random.uniform(0, 1)
random_value: float = random.uniform(0, 1)
if random_value < threshold:
sample = self.G.G[start_node]
sample = list(self.G.neighbors(start_node_number))
else:
sample = np.where(self.Q[start_node,] ==
np.max(self.Q[start_node,]))[1]
sample = np.where(self.Q[start_node_number,] == np.max(
self.Q[start_node_number,]))[0]
next_node: int = int(np.random.choice(sample, 1))
return next_node

def updateQ(self, start_node: int, next_node: int, learning_rate: float, discount: float) -> None:
"""Update Q-table of in each step
def update_Q(self, node1: int, node2: int, lr: float, discount: float) -> None:
"""Update Q table
Args:
start_node (int): From this node to
next_node (int): Next node espected to forwar
learning_rate (float): learning rate of learning
discount (float): discount of learning
node1 (int): first node
node2 (int): second node
lr (float): learning rate
discount (float): discount
"""
max_index = np.where(self.Q[next_node,] ==
np.max(self.Q[next_node,]))[1]
max_index = np.where(self.Q[node2,] == np.max(self.Q[node2]))[1]
if max_index.shape[0] > 1:
max_index = int(np.random.choice(max_index, size=1))
max_index: int = int(np.random.choice(max_index, size=1))
else:
max_index = int(max_index)
max_value: int = self.Q[next_node, max_index]
self.Q[start_node, next_node] = int(
(1 - learning_rate)*self.Q[start_node, next_node]+learning_rate*(self.R[start_node, next_node] + discount * max_value))
max_index: int = int(max_index)
max_value = self.Q[node2, max_index]
self.Q[node1, node2] = int(
(1-lr)*self.Q[node1, node2] + lr*(self.R[node1, node2]+discount*max_value))

def learn(self, threshold: float, learning_rate: float, discount: float) -> None:
"""Learning process
def run_epoch(self, threshold: float, lr: float, discount: float) -> None:
"""Action in each epoch
Args:
threshold (float): Threshold for considering of two options
learning_rate (float): learning rate of learning
discount (float): discount of learning
threshold (float): threshold
lr (float): learning rate
discount (float): discount
"""
for _ in range(50000):
start: int = np.random.randint(0, 11)
next_node: int = self.next_number(start, threshold)
self.updateQ(start, next_node, learning_rate, discount)
if self.epoch < 50000:
self.epoch += 1
self.window.epoch_label.update_epoch(self.epoch)
start_node: int = np.random.randint(0, self.G.number_of_nodes())
next_node = self.next_number(start_node, threshold)
self.update_Q(start_node, next_node, lr, discount)
self.window.after(100, self.run_epoch, threshold, lr, discount)

def shortest_path(self, start_node: int, end_node: int) -> list[int]:
"""Find shortest_path after trainning
def learn(self, threshold: float, lr: float, discount: float) -> None:
"""Learn algorithm
Args:
start_node (int): start from
end_node (int): target node
Returns:
list[int]: list contains each node illustating to shortest path
threshold (float): The threshold
lr (float): learning rate
discount (float): discount
"""
path: list[int] = [start_node]
next_node = np.argmax(self.Q[start_node,])
path.append(next_node)
while next_node != end_node:
next_node = np.argmax(self.Q[next_node,])
path.append(next_node)
return path
self.epoch = 0
self.run_epoch(threshold, lr, discount)
Empty file removed src/Q_Learning/q_learning_ttk.py
Empty file.
3 changes: 2 additions & 1 deletion src/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Modules"""
from .Graph import MainWindow
from .Q_Learning import QLearningNX

__all__: list[str] = ["MainWindow"]
__all__: list[str] = ["MainWindow", "QLearningNX"]
4 changes: 2 additions & 2 deletions src/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Modules"""
from .constants import (CHECK_IMPORT_LATER, DEFAULT_EDGES_NETWORKX,
HEIGHT_WINDOW, WIDTH_WINDOW)
from .constants import (
HEIGHT_WINDOW, WIDTH_WINDOW)

__all__: list[str] = ["coordinate_random", "coordinate_list_random",
"WIDTH_WINDOW", "HEIGHT_WINDOW", "DEFAULT_EDGES_NETWORKX", "CHECK_IMPORT_LATER", "handle_edges_list", "random_edges_list"]
3 changes: 0 additions & 3 deletions src/utils/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,2 @@
WIDTH_WINDOW = 1500
HEIGHT_WINDOW = 750
DEFAULT_EDGES_NETWORKX: list[tuple[int, int]] = [(2, 5), (9, 1), (6, 1), (0, 6),
(3, 2), (3, 7), (9, 7), (3, 5), (7, 2), (9, 7)]
CHECK_IMPORT_LATER: bool = True
Empty file removed src/utils/coordinate.py
Empty file.
Empty file removed src/utils/edge.py
Empty file.
File renamed without changes.

0 comments on commit d7d2096

Please sign in to comment.