-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
109 lines (92 loc) · 3.82 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import sys
import chess
import chess.svg
from PyQt5.QtWidgets import QApplication
from screen2fen.Classifier import NNSquareClassifier, DefaultBoardClassifier
from screen2fen.MainWindow import MainWindow
from screen2fen.MessageBus import MessageBus
from screen2fen.TrainThread import TrainThread
from screen2fen.ScreenGrabThread import ScreenGrabThread
def qt_app():
app = QApplication(sys.argv)
bus = MessageBus()
square_classifier = NNSquareClassifier('models/model.keras')
board_classifier = DefaultBoardClassifier(square_classifier)
# global flags to hold the status of the game that are not part of the board
fen_flags = {
'castle_wk': 'K',
'castle_wq': 'Q',
'castle_bk': 'k',
'castle_bq': 'q',
'enpassant': '-',
'turn': 'w',
'move_number': '1',
'fifty_move': '1'
}
# global to hold the last FEN string, since we may need to update the flags
# without having a new FEN string event from the board
last_fen = None
# handler for when the rectangle is updated, updates the coordinates
# in the screengrab thread and the train thread
def update_rectangle(rect):
screengrab_thread.set_coords(rect[0], rect[1])
train_thread.set_coords(rect[0], rect[1])
main.enable_train_button()
main.hide_transparent_window()
bus.subscribe('update_rectangle', update_rectangle)
# handler to update the FEN flags global
def fen_flags_updated(flags):
nonlocal fen_flags
fen_flags.update(flags)
update_fen()
bus.subscribe('fen_flags', fen_flags_updated)
# handler to update the FEN string when either the flags or the board
# have changed
def update_fen(fen=None):
nonlocal last_fen, fen_flags
if not fen and last_fen:
fen = last_fen
last_fen = fen
fen = format_fen(fen, fen_flags)
board = chess.Board(fen)
svg_string = chess.svg.board(board, size=128)
main.set_board_image(svg_string)
main.set_message_box_text(fen)
# helper function to format the FEN string with the flags
def format_fen(fen, fen_flags):
parts = fen.split(' ')
flags = fen_flags
no_castle = '-' if len(flags['castle_wk'] + flags['castle_wq'] + flags['castle_bk'] + flags['castle_bq']) == 0 else ''
return f"{parts[0]} {flags['turn']} {no_castle}{flags['castle_wk']}{flags['castle_wq']}{flags['castle_bk']}{flags['castle_bq']} {flags['enpassant']} {flags['fifty_move']} {flags['move_number']}"
# handler to start training
def start_training(_):
train_thread.train()
bus.subscribe('start_training', start_training)
# handler to stop training
def stop_training(_):
train_thread.pause()
bus.subscribe('stop_training', stop_training)
# handler to print the training status to the message box
def update_training_progress(progress: str):
main.set_message_box_text(progress)
# handler to update the board classifier when training is done
# and immediately try to classify the board
def done_training(status: int):
nonlocal square_classifier, board_classifier
square_classifier = NNSquareClassifier('models/model.keras')
board_classifier = DefaultBoardClassifier(square_classifier)
screengrab_thread.set_board_classifier(board_classifier)
screengrab_thread.clear()
main.end_training()
train_thread = TrainThread()
train_thread.updated_signal.connect(update_training_progress)
train_thread.done_signal.connect(done_training)
train_thread.start()
screengrab_thread = ScreenGrabThread(board_classifier)
screengrab_thread.updated_signal.connect(update_fen)
screengrab_thread.start()
main = MainWindow(bus)
main.show()
sys.exit(app.exec_())
if __name__ == "__main__":
qt_app()