-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathoptimalTemperature.py
More file actions
112 lines (96 loc) · 3.92 KB
/
optimalTemperature.py
File metadata and controls
112 lines (96 loc) · 3.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright (c) OpenMMLab. All rights reserved.
import os
import cv2
import glob
import torch
import torch.nn as nn
import random
from PIL import Image
import numpy as np
from erfnet import ERFNet
import os.path as osp
from argparse import ArgumentParser
from ood_metrics import fpr_at_95_tpr, calc_metrics, plot_roc, plot_pr,plot_barcode
from sklearn.metrics import roc_auc_score, roc_curve, auc, precision_recall_curve, average_precision_score
from temperature_scaling import ModelWithTemperature
from dataset import cityscapesTemperature
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.transforms import ToTensor, ToPILImage
from torchvision.transforms import Compose, CenterCrop, Normalize, Resize
from dataset import VOC12
from transform import Relabel, ToLabel, Colorize
seed = 42
# general reproducibility
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
NUM_CHANNELS = 3
NUM_CLASSES = 20
# gpu training specific
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
def main():
parser = ArgumentParser()
parser.add_argument(
"--input",
default="/home/shyam/Mask2Former/unk-eval/RoadObsticle21/images/*.webp",
nargs="+",
help="A list of space separated input images; "
"or a single glob pattern such as 'directory/*.jpg'",
)
parser.add_argument('--loadDir',default="../trained_models/")
parser.add_argument('--loadWeights', default="erfnet_pretrained.pth")
parser.add_argument('--loadModel', default="erfnet.py")
parser.add_argument('--subset', default="val") #can be val or train (must have labels)
parser.add_argument('--datadir', default="/home/shyam/ViT-Adapter/segmentation/data/cityscapes/")
parser.add_argument('--num-workers', type=int, default=4)
parser.add_argument('--batch-size', type=int, default=1)
parser.add_argument('--discriminant',default="msp")
parser.add_argument('--cpu', action='store_true')
parser.add_argument('--temperature', default=1)
args = parser.parse_args()
anomaly_score_list = []
ood_gts_list = []
if not os.path.exists('results.txt'):
open('results.txt', 'w').close()
file = open('results.txt', 'a')
modelpath = args.loadDir + args.loadModel
weightspath = args.loadDir + args.loadWeights
print ("Loading model: " + modelpath)
print ("Loading weights: " + weightspath)
model = ERFNet(NUM_CLASSES)
if (not args.cpu):
model = torch.nn.DataParallel(model).cuda()
def load_my_state_dict(model, state_dict): #custom function to load model when not all dict elements
own_state = model.state_dict()
for name, param in state_dict.items():
if name not in own_state:
if name.startswith("module."):
own_state[name.split("module.")[-1]].copy_(param)
else:
print(name, " not loaded")
continue
else:
own_state[name].copy_(param)
return model
model = load_my_state_dict(model, torch.load(weightspath, map_location=lambda storage, loc: storage))
print ("Model and weights LOADED successfully")
model.eval()
model_to_optimize = ModelWithTemperature(model)
print(args.input[0])
input_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
target_transform = transforms.Compose([
transforms.Resize((256, 256), interpolation=Image.NEAREST),
transforms.ToTensor()
])
dataset_val_cityscapes = cityscapesTemperature("../dataset/Cityscapes", input_transform, target_transform , subset='val')
loader = DataLoader(dataset_val_cityscapes, num_workers=4, batch_size=1, shuffle=False)
model_to_optimize.set_temperature(loader)
print("Done!")
if __name__ == '__main__':
main()