-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
131 additions
and
102 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from src import MainWindow | ||
from src import QLearningNX | ||
|
||
root = MainWindow("new") | ||
root.mainloop() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
"""A label to show epoch in current time""" | ||
|
||
from tkinter import LEFT, ttk | ||
from typing import TYPE_CHECKING | ||
if TYPE_CHECKING: | ||
from .frame import ButtonFrame | ||
import time | ||
|
||
|
||
class EpochLabel(ttk.Label): | ||
"""The EpochLabel inherits from ttk.Label""" | ||
|
||
def __init__(self, parent: "ButtonFrame"): | ||
"""The Initialization of this class | ||
Args: | ||
parent (ButtonFrame): The parent of this class | ||
""" | ||
super().__init__(parent, text="Epoch:0/50000") | ||
self.pack(side=LEFT, padx=5, pady=5) | ||
|
||
def update_epoch(self, new_epoch: int) -> None: | ||
"""Update new epoch | ||
Args: | ||
new_epoch (int): new epoch | ||
""" | ||
self.configure(text=f"Epoch:{new_epoch}/5000") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
"""Modules""" | ||
from .q_learning_networkx import GraphNetworkX, GraphNetworkXForQLearning | ||
from .q_learning_networkx import QLearningNX | ||
|
||
__all__: list[str] = ["GraphNetworkX", "GraphNetworkXForQLearning"] | ||
__all__: list[str] = ["QLearningNX"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,123 +1,99 @@ | ||
"""Draw Graph with networkx and Find the shortest path behind the scene""" | ||
import matplotlib.pyplot as plt | ||
"""Q-Learning with networkx""" | ||
from typing import TYPE_CHECKING | ||
import networkx as nx | ||
import numpy as np | ||
import random | ||
from numpy import matrix | ||
import time | ||
if TYPE_CHECKING: | ||
from ..Graph import MainWindow | ||
|
||
from ..utils import DEFAULT_EDGES_NETWORKX | ||
|
||
class QLearningNX: | ||
"""The QLearning with networkx""" | ||
|
||
class GraphNetworkX: | ||
"""A class using to draw Graph with NetWorkX""" | ||
|
||
def __init__(self, edges: list[tuple[int, int]] = DEFAULT_EDGES_NETWORKX): | ||
"""The Initialization of GraphNetworkX class | ||
def __init__(self, Window: 'MainWindow', G: nx.Graph, target_node_number: int): | ||
"""The Initialization of QLearningNX class | ||
Args: | ||
edges (list[tuple[int, int]], optional): List contains connections between two node. Defaults to DEFAULT_EDGES_NETWORKX. | ||
""" | ||
self.edges = edges | ||
self.G = nx.Graph() | ||
self.G.add_edges_from(self.edges) | ||
self.pos = nx.spring_layout(self.G) | ||
|
||
def draw_graph(self): | ||
"""Function to draw graph | ||
""" | ||
nx.draw_networkx_nodes(self.G, self.pos) | ||
nx.draw_networkx_edges(self.G, self.pos) | ||
nx.draw_networkx_labels(self.G, self.pos) | ||
plt.show() | ||
|
||
|
||
class GraphNetworkXForQLearning: | ||
"""A Class using to run Q-Learning Algorithm in the graph behind the scene""" | ||
|
||
def __init__(self, G: GraphNetworkX, number_nodes: int, end_node_number: int): | ||
"""The Initialization of GraphNetworkXForQLearning Class | ||
Args: | ||
G (GraphNetworkX): The Graph which this algorithm run on | ||
number_nodes (int): number of node in graph | ||
end_node_number (int): The target node | ||
Window (MainWindow): The main window | ||
G (nx.Graph): Graph | ||
target_node_number (int): target node(final node) | ||
""" | ||
self.window: 'MainWindow' = Window | ||
self.G = G | ||
self.number_nodes = number_nodes | ||
self.R: matrix = matrix(np.zeros(shape=(number_nodes, number_nodes))) | ||
for x in self.G.G[end_node_number]: | ||
self.R[x, end_node_number] = 100 | ||
self.Q: matrix = matrix(np.zeros(shape=(number_nodes, number_nodes))) | ||
self.R: matrix = np.matrix(np.zeros( | ||
shape=(self.G.number_of_nodes(), self.G.number_of_nodes()))) | ||
for x in self.G[target_node_number]: | ||
self.R[x, 10] = 100 | ||
self.Q: matrix = np.matrix(np.zeros( | ||
shape=(self.G.number_of_nodes(), self.G.number_of_nodes()))) | ||
self.Q -= 100 | ||
for node in self.G.G.nodes: | ||
for x in self.G.G[node]: | ||
for node in self.G.nodes: | ||
for x in G[node]: | ||
self.Q[node, x] = 0 | ||
self.Q[x, node] = 0 | ||
|
||
def next_number(self, start_node: int, threshold: float) -> int: | ||
"""Find next node to forward | ||
def next_number(self, start_node_number: int, threshold: float) -> int: | ||
"""Find next node number step | ||
Args: | ||
start_node (int): From this node to | ||
threshold (float): Threshold for considering of two options | ||
start_node_number (int): From this node to next node | ||
threshold (float): The threshole | ||
Returns: | ||
int: number of next node | ||
int: the next node number | ||
""" | ||
random_value = np.random.uniform(0, 1) | ||
random_value: float = random.uniform(0, 1) | ||
if random_value < threshold: | ||
sample = self.G.G[start_node] | ||
sample = list(self.G.neighbors(start_node_number)) | ||
else: | ||
sample = np.where(self.Q[start_node,] == | ||
np.max(self.Q[start_node,]))[1] | ||
sample = np.where(self.Q[start_node_number,] == np.max( | ||
self.Q[start_node_number,]))[0] | ||
next_node: int = int(np.random.choice(sample, 1)) | ||
return next_node | ||
|
||
def updateQ(self, start_node: int, next_node: int, learning_rate: float, discount: float) -> None: | ||
"""Update Q-table of in each step | ||
def update_Q(self, node1: int, node2: int, lr: float, discount: float) -> None: | ||
"""Update Q table | ||
Args: | ||
start_node (int): From this node to | ||
next_node (int): Next node espected to forwar | ||
learning_rate (float): learning rate of learning | ||
discount (float): discount of learning | ||
node1 (int): first node | ||
node2 (int): second node | ||
lr (float): learning rate | ||
discount (float): discount | ||
""" | ||
max_index = np.where(self.Q[next_node,] == | ||
np.max(self.Q[next_node,]))[1] | ||
max_index = np.where(self.Q[node2,] == np.max(self.Q[node2]))[1] | ||
if max_index.shape[0] > 1: | ||
max_index = int(np.random.choice(max_index, size=1)) | ||
max_index: int = int(np.random.choice(max_index, size=1)) | ||
else: | ||
max_index = int(max_index) | ||
max_value: int = self.Q[next_node, max_index] | ||
self.Q[start_node, next_node] = int( | ||
(1 - learning_rate)*self.Q[start_node, next_node]+learning_rate*(self.R[start_node, next_node] + discount * max_value)) | ||
max_index: int = int(max_index) | ||
max_value = self.Q[node2, max_index] | ||
self.Q[node1, node2] = int( | ||
(1-lr)*self.Q[node1, node2] + lr*(self.R[node1, node2]+discount*max_value)) | ||
|
||
def learn(self, threshold: float, learning_rate: float, discount: float) -> None: | ||
"""Learning process | ||
def run_epoch(self, threshold: float, lr: float, discount: float) -> None: | ||
"""Action in each epoch | ||
Args: | ||
threshold (float): Threshold for considering of two options | ||
learning_rate (float): learning rate of learning | ||
discount (float): discount of learning | ||
threshold (float): threshold | ||
lr (float): learning rate | ||
discount (float): discount | ||
""" | ||
for _ in range(50000): | ||
start: int = np.random.randint(0, 11) | ||
next_node: int = self.next_number(start, threshold) | ||
self.updateQ(start, next_node, learning_rate, discount) | ||
if self.epoch < 50000: | ||
self.epoch += 1 | ||
self.window.epoch_label.update_epoch(self.epoch) | ||
start_node: int = np.random.randint(0, self.G.number_of_nodes()) | ||
next_node = self.next_number(start_node, threshold) | ||
self.update_Q(start_node, next_node, lr, discount) | ||
self.window.after(100, self.run_epoch, threshold, lr, discount) | ||
|
||
def shortest_path(self, start_node: int, end_node: int) -> list[int]: | ||
"""Find shortest_path after trainning | ||
def learn(self, threshold: float, lr: float, discount: float) -> None: | ||
"""Learn algorithm | ||
Args: | ||
start_node (int): start from | ||
end_node (int): target node | ||
Returns: | ||
list[int]: list contains each node illustating to shortest path | ||
threshold (float): The threshold | ||
lr (float): learning rate | ||
discount (float): discount | ||
""" | ||
path: list[int] = [start_node] | ||
next_node = np.argmax(self.Q[start_node,]) | ||
path.append(next_node) | ||
while next_node != end_node: | ||
next_node = np.argmax(self.Q[next_node,]) | ||
path.append(next_node) | ||
return path | ||
self.epoch = 0 | ||
self.run_epoch(threshold, lr, discount) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
"""Modules""" | ||
from .Graph import MainWindow | ||
from .Q_Learning import QLearningNX | ||
|
||
__all__: list[str] = ["MainWindow"] | ||
__all__: list[str] = ["MainWindow", "QLearningNX"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
"""Modules""" | ||
from .constants import (CHECK_IMPORT_LATER, DEFAULT_EDGES_NETWORKX, | ||
HEIGHT_WINDOW, WIDTH_WINDOW) | ||
from .constants import ( | ||
HEIGHT_WINDOW, WIDTH_WINDOW) | ||
|
||
__all__: list[str] = ["coordinate_random", "coordinate_list_random", | ||
"WIDTH_WINDOW", "HEIGHT_WINDOW", "DEFAULT_EDGES_NETWORKX", "CHECK_IMPORT_LATER", "handle_edges_list", "random_edges_list"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,2 @@ | ||
WIDTH_WINDOW = 1500 | ||
HEIGHT_WINDOW = 750 | ||
DEFAULT_EDGES_NETWORKX: list[tuple[int, int]] = [(2, 5), (9, 1), (6, 1), (0, 6), | ||
(3, 2), (3, 7), (9, 7), (3, 5), (7, 2), (9, 7)] | ||
CHECK_IMPORT_LATER: bool = True |
Empty file.
Empty file.
File renamed without changes.