-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLetNet5MINIST.py
135 lines (122 loc) · 5.98 KB
/
LetNet5MINIST.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#引入相关库
import torch
import torch.optim as optim
from torch.autograd import Variable
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 定义超参数
EPOCH = 20 #遍历数据集次数
pre_epoch = 0 # 定义已经遍历数据集的次数
BATCH_SIZE = 64 #批处理尺寸(batch_size)
LR = 0.001 #学习率
class LeNet(nn.Module):#定义网络
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 5, 1, 2), nn.ReLU(),
nn.MaxPool2d(2, 2))
self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
nn.MaxPool2d(2, 2))
self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
nn.BatchNorm1d(120), nn.ReLU())
self.fc2 = nn.Sequential(
nn.Linear(120, 84),
nn.BatchNorm1d(84),
nn.ReLU(),
nn.Linear(84, 10))
# 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size()[0], -1)#展平
x = self.fc1(x)
x = self.fc2(x)
return x
# 启用GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 启用GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#加载数据集
train_loader = torch.utils.data.DataLoader( # 加载训练数据
datasets.MNIST('./data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # 数据集给出的均值和标准差系数,每个数据集都不同的,都数据集提供方给出的
])),
batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader( # 加载训练数据
datasets.MNIST('./data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # 数据集给出的均值和标准差系数,每个数据集都不同的,都数据集提供方给出的
])),
batch_size=BATCH_SIZE, shuffle=True)
model = LeNet() # 实例化一个网络对象
model = model.to(device)
criterion = nn.CrossEntropyLoss() #损失函数为交叉熵,多用于多分类问题
optimizer = optim.Adam(model.parameters(), lr=LR) #优化方式为mini-batch momentum-SGD,并采用L2正则化(权重衰
# 训练
if __name__ == "__main__":
best_acc = 85 #2 初始化best test accuracy
print("Start Training, LetNet5-Minist!") # 定义遍历数据集的次数
with open("LetNet5-Ministacc.txt", "w") as f:
with open("LetNet5-Ministlog.txt", "w")as f2:
for epoch in range(pre_epoch, EPOCH):
print('\nEpoch: %d' % (epoch + 1))
model.train()
sum_loss = 0.0
correct = 0.0
total = 0.0
for i, data in enumerate(train_loader, 0):
# 准备数据
length = len(train_loader)
inputs, labels = data
inputs, labels = Variable(inputs), Variable(labels)
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
# forward + backward
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 每训练1个batch打印一次loss和准确率
sum_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += predicted.eq(labels.data).cpu().sum()
print('[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% '
% (epoch + 1, (i + 1 + epoch * length), sum_loss / (i + 1), 100. * correct / total))
f2.write('%03d %05d |Loss: %.03f | Acc: %.3f%% '
% (epoch + 1, (i + 1 + epoch * length), sum_loss / (i + 1), 100. * correct / total))
f2.write('\n')
f2.flush()
# 每训练完一个epoch测试一下准确率
print("Waiting Test!")
model.eval()
with torch.no_grad():
correct = 0
total = 0
for data in test_loader:
model.eval()
images, labels = data
images, labels = Variable(images), Variable(labels)
images, labels = images.to(device), labels.to(device)
outputs = model(images)
# 取得分最高的那个类 (outputs.data的索引号)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum()
print('测试分类准确率为:%.3f%%' % (100 * correct / total))
acc = 100. * correct / total
# 将每次测试结果实时写入acc.txt文件中
f.write("EPOCH=%03d,Accuracy= %.3f%%" % (epoch + 1, acc))
f.write('\n')
f.flush()
# 记录最佳测试分类准确率并写入best_acc.txt文件中
if acc > best_acc:
f3 = open("LetNet5-Ministbest_acc.txt", "w")
f3.write("EPOCH=%d,best_acc= %.3f%%" % (epoch + 1, acc))
f3.close()
best_acc = acc
print('Saving model......')
torch.save(model, 'LetNet5-Minist_%03d.pth' % (epoch + 1))
print("Training Finished, TotalEPOCH=%d" % EPOCH)