-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclassifier.py
More file actions
122 lines (98 loc) · 3.66 KB
/
Copy pathclassifier.py
File metadata and controls
122 lines (98 loc) · 3.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import torch
import pytorch_lightning as pl
from src.data_helper import Data
from src.utils import (
read_yaml,
get_inference_folders,
plot_labeled_data,
subplot_labeled_data,
)
from typing import Optional, Dict, Tuple
from argparse import ArgumentParser, Namespace
import warnings
import logging
from src.mapping_helper import StandardMap
warnings.filterwarnings(
"ignore",
module="pytorch_lightning",
)
logging.getLogger("pytorch_lightning").setLevel(0)
pl.seed_everything(42, workers=True)
def main(args: Namespace) -> None:
version: Optional[int] = args.version or None
directory_path: str = "logs/paper/"
# directory_path: str = "logs/improve_paper/"
folders = get_inference_folders(directory_path, version)
for log_path in folders:
print(f"log_path: {log_path}")
params_path: str = os.path.join(log_path, "hparams.yaml")
params: dict = read_yaml(params_path)
params_update: Dict = {}
params_update.update({"seq_len": 20})
params_update.update({"K": 1.5})
# params_update.update({"K": [3.0, 3.1, 3.2, 3.3, 3.4, 3.5]})
params_update.update({"init_points": 200})
params_update.update({"steps": 100})
map_object = None
# map_object = StandardMap(
# init_points=100, steps=60, K=1.5, seed=42, sampling="random"
# )
data = "testing_data"
# data = "training_data"
predictions, datamodule = inference(
log_path, params, params_update, map_object=map_object, data=data
)
K = params_update.get("K") or params.get("K")
print(
f"rnn_type={params.get('rnn_type')}, {K = }: loss = {predictions['loss']:.2e}, accuracy = {predictions['accuracy']:.2f}, f1 = {predictions['f1']:.2f}, precision = {predictions['precision']:.2f}, recall = {predictions['recall']:.2f}, specificity = {predictions['specificity']:.2f}, balanced_accuracy = {predictions['balanced_accuracy']:.2f}"
)
print()
subplot_labeled_data(
datamodule.thetas_orig,
datamodule.ps_orig,
[datamodule.spectrum, predictions["predicted_labels"]],
["Resnica", f"Napoved (UT = {predictions['balanced_accuracy']:.2f})"],
save_path=f"{log_path}/results",
)
def inference(
log_path: str,
params: Dict,
params_update: Dict = None,
map_object: StandardMap = None,
data: str = "training_data",
) -> Tuple[Dict, Data]:
# NOTE: set these parameters to reduce loaded data
if params_update is not None:
params.update(params_update)
if params.get("rnn_type") == "vanillarnn":
from src.VanillaRNN import Vanilla as Model
elif params.get("rnn_type") == "mgu":
from src.MGU import MGU as Model
elif params.get("rnn_type") == "resrnn":
from src.ResRNN import ResRNN as Model
model_path: str = os.path.join(log_path, f"model.ckpt")
model = Model(**params).load_from_checkpoint(model_path, map_location="cpu")
model.eval()
datamodule = Data(
data_path=data,
map_object=map_object,
K=params.get("K"),
train_size=1.0,
binary=True,
reduce_init_points=True,
params=params,
)
trainer = pl.Trainer(
precision=params.get("precision"),
enable_progress_bar=True,
logger=False,
deterministic=True,
)
predictions: dict = trainer.predict(model=model, dataloaders=datamodule)[0]
return predictions, datamodule
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--version", "-v", nargs="*", type=int, default=None)
args = parser.parse_args()
main(args)