-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathPitu_test.py
102 lines (81 loc) · 3.23 KB
/
Pitu_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import random
from lib.models.seg_hrnet import HighResolutionNet
from lib.datasets.pituitary import PitDataset
from lib.utils.utils import create_logger
from lib.core.function_v4 import test
from lib.config import update_config
from lib.config import config
import torch.optim
import torch
import numpy as np
import timeit
import pprint
import argparse
seed = 2
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
np.random.seed(seed) # Numpy module.
random.seed(seed) # Python random module.
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def parse_args():
parser = argparse.ArgumentParser(description='Train segmentation and landmark detection network')
# parser.add_argument('--cfg',
# help='experiment configure file name',
# required=True,
# type=str)
parser.add_argument('--cfg',
default=r'./experiments/pituitary/seg_hrnet_w48_train_736x1280_sgd_lr1e-2_bs_6_epoch500_4loss_2stage_v4_fold1.yaml',
help='experiment configure file name',
type=str)
parser.add_argument('--model',
default=r'/workspace/zhmao/data/HRNet_with_DICE_BD_Wing_FL_2stage/pituitary/seg_hrnet_w48_train_736x1280_sgd_lr1e-2_bs_6_epoch350_4loss_2stage_v4_fold5/train_best_mIoU.pth',
help='trained model file',
type=str)
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument('opts',
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER)
args = parser.parse_args()
update_config(config, args)
return args
def main():
args = parse_args()
logger, final_output_dir, _ = create_logger(
config, args.cfg, 'test')
logger.info(pprint.pformat(args))
logger.info(pprint.pformat(config))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# build model
model = HighResolutionNet(config)
model.init_weights(config.MODEL.PRETRAINED)
if args.model:
trained_dict = torch.load(args.model)
logger.info('=> loading trained model {}'.format(args.model))
model_dict = model.state_dict()
trained_dict = {k: v for k, v in trained_dict.items()
if k in model_dict.keys()}
model_dict.update(trained_dict)
model.load_state_dict(model_dict)
else:
print('No trained model is found, you are using pretrained model')
model = model.to(device)
model = torch.nn.DataParallel(model)
# prepare data
test_dataset = PitDataset(config, is_train=False)
testloader = torch.utils.data.DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=config.WORKERS,
pin_memory=True)
start = timeit.default_timer()
test(testloader, model, sv_dir=final_output_dir, device=device)
end = timeit.default_timer()
logger.info('Mins: %d' % np.int32((end-start)/60))
logger.info('Done')
if __name__ == '__main__':
main()