-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
99 lines (85 loc) · 5.1 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
from options.test_options import TestOptions
from data import create_dataset
from models import create_model
from util.visualizer import save_images
from util import html
import time
import torch
import pandas as pd
from tqdm import tqdm
from util import util
#from util.evaluator import IC15Evaluator
from util.evaluator_vis import IC15Evaluator
import cv2
import json
import copy
import numpy as np
import xml.etree.ElementTree as ET
from xml.etree.ElementTree import Element
import warnings
warnings.filterwarnings("ignore")
if __name__ == '__main__':
opt = TestOptions().parse() # get test options
util.init_distributed_mode(opt)
# hard-code some parameters for test
# opt.num_threads = 0 # test code only supports num_threads = 1
# opt.batch_size = 1 # test code only supports batch_size = 1
# opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
# opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
# evaluator = IC15Evaluator(opt)
test_size = len(dataset)
print('The number of test images = %d. Testset: %s' % (test_size, opt.dataroot))
model = create_model(opt) # create a model given opt.model and other options
model.setup(opt) # regular setup: load and print networks; create schedulers
gen_gt = False # True
MODE = opt.phase # 'test' # val
DATASET_NAME = opt.dataset_name
model.eval()
n_correct_row, n_correct_col = 0, 0
n_total_row, n_total_col = 0, 0
n_correct_precision_row, n_correct_precision_col = 0, 0
n_total_precision_row, n_total_precision_col = 0, 0
n_correct_recall_row, n_correct_recall_col = 0, 0
n_total_recall_row, n_total_recall_col = 0, 0
for data in tqdm(dataset):
torch.cuda.synchronize()
model.set_input(data)
cls_pred_row, cls_pred_col = model.test()
row_label, col_label = data.y_row.detach().cpu().numpy(), data.y_col.detach().cpu().numpy()
cls_pred_row, cls_pred_col = cls_pred_row.detach().cpu().numpy(), cls_pred_col.detach().cpu().numpy()
row_label_3x = row_label * 3
diff_row = row_label_3x - cls_pred_row
col_label_3x = col_label * 3
diff_col = col_label_3x - cls_pred_col
n_correct_row = n_correct_row + (row_label == cls_pred_row).sum()
n_total_row = n_total_row + row_label.shape[0]
n_correct_precision_row = n_correct_precision_row + (diff_row == 2).sum()
n_total_precision_row = n_total_precision_row + (cls_pred_row == 1).sum()
n_correct_recall_row = n_correct_recall_row + (diff_row == 2).sum()
n_total_recall_row = n_total_recall_row + (row_label == 1).sum()
n_correct_col = n_correct_col + (col_label == cls_pred_col).sum()
n_total_col = n_total_col + col_label.shape[0]
n_correct_precision_col = n_correct_precision_col + (diff_col == 2).sum()
n_total_precision_col = n_total_precision_col + (cls_pred_col == 1).sum()
n_correct_recall_col = n_correct_recall_col + (diff_col == 2).sum()
n_total_recall_col = n_total_recall_col + (col_label == 1).sum()
# print(n_correct_col, n_total_col, n_correct_precision_col, n_total_precision_col, n_correct_recall_col, n_total_recall_col)
accuracy_row = n_correct_row / float(n_total_row)
precicion_row = n_correct_precision_row / float(n_total_precision_row) if n_total_precision_row != 0 else 0
recall_row = n_correct_recall_row / float(n_total_recall_row) if n_total_recall_row != 0 else 0
F1_row = 2 * precicion_row * recall_row / (precicion_row + recall_row) if precicion_row != 0 or recall_row != 0 else 0
print(n_correct_row, n_total_row, n_correct_precision_row, n_total_precision_row, n_correct_recall_row, n_total_recall_row)
accuracy_col = n_correct_col / float(n_total_col)
precicion_col = n_correct_precision_col / float(n_total_precision_col) if n_total_precision_col != 0 else 0
recall_col = n_correct_recall_col / float(n_total_recall_col) if n_total_recall_col != 0 else 0
F1_col = 2 * precicion_col * recall_col / (precicion_col + recall_col) if precicion_col != 0 or recall_col != 0 else 0
print(n_correct_col, n_total_col, n_correct_precision_col, n_total_precision_col, n_correct_recall_col, n_total_recall_col)
accuracy = (n_correct_row + n_correct_col) / (n_total_row + n_total_col)
precision = (n_correct_precision_row + n_correct_precision_col) / (n_total_precision_row + n_total_precision_col)
recall = (n_correct_recall_row + n_correct_recall_col) / (n_total_recall_row + n_total_recall_col)
F1 = 2 * precision * recall / (precision + recall)
print('accuray_total_row: %f, precision_row: %f, recall_row: %f, F1_row: %f' % (accuracy_row, precicion_row, recall_row, F1_row))
print('accuray_total_col: %f, precision_col: %f, recall_col: %f, F1_col: %f' % (accuracy_col, precicion_col, recall_col, F1_col))
print('accuray_total_all: %f, precision_all: %f, recall_all: %f, F1_all: %f' % (accuracy, precision, recall, F1))