-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
122 lines (105 loc) · 4.63 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import time
from glob import glob
import cv2
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 统计loss和top1Acc
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
if self.count > 0:
self.avg = self.sum / self.count
def accumulate(self, val, n=1):
self.sum += val
self.count += n
if self.count > 0:
self.avg = self.sum / self.count
# 获取训练集train_loader
def get_train_loader(train_path, batch_size=256, num_workers=4):
train_dataset = torchvision.datasets.CIFAR100(root=train_path,
train=True,
download=False,
transform=transforms.Compose([
transforms.Pad(4),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010))
]))
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers)
return train_loader
# 获取测试集test_loader
def get_test_loader(test_path, batch_size=256, num_workers=4):
test_dataset = torchvision.datasets.CIFAR100(root=test_path,
train=False,
download=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010))
]))
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers)
return test_loader
# 测试网络
def eval_training(model_name,net,test_loader,device):
net.eval() #作用之一:停止dropout层生效
test_loss = 0.0
correct = 0.0
# 定义交叉熵损失函数
lossCE = nn.CrossEntropyLoss()
infer_time = 0.0
for (images, labels) in test_loader:
images, labels = images.to(device), labels.to(device)
with torch.no_grad(): #停止梯度计算
start = time.time()
outputs = net(images)
infer_time += time.time() - start
loss = lossCE(outputs, labels).to(device)
test_loss += loss.item()
_, preds = outputs.max(1)
correct += preds.eq(labels).sum()
loss = test_loss / len(test_loader.dataset)
acc = (100. * correct.float()) / len(test_loader.dataset)
print('test loss:{}, top1_acc:{}'.format(loss,acc))
print('test time:{}s'.format(infer_time))
# 保留三位有效数字
acc = round(acc.item(),3)
test_time = round(infer_time,3)
file = "result/" + model_name + "_" + str(acc) + "_"+ str(test_time) +"s.pt"
torch.save(net, file)
class MyData(torch.utils.data.Dataset):
def __init__(self,transform = None):
self.data = glob("./dataset/test_data/*.png")
self.transform = transform
def __getitem__(self, index):
imageName = self.data[index]
data = cv2.imread(imageName)
if self.transform is not None:
data = self.transform(data)
label = np.zeros(10, dtype=np.float32)
index = int(imageName[-5])
label[index] = 1
return data, torch.from_numpy(label)
def __len__(self):
return len(self.data)