-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
364 lines (306 loc) · 16.2 KB
/
train.py
File metadata and controls
364 lines (306 loc) · 16.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
import argparse
import os
import time
from utee import misc
import torch
import torch.optim as optim
from torch.autograd import Variable
from utee import make_path
from utee import wage_util
from datetime import datetime
from utee import wage_quantizer
from utee import hook
import numpy as np
import csv
from subprocess import call
from modules.quantization_cpu_np_infer import QConv2d, QLinear
import tempfile
# 额外导入梯度裁剪
import torch.nn.utils as utils
# 从models导入数据集和模型
from models import dataset
from models import VGG, DenseNet, ResNet, ResNet18cifar10
def main():
parser = argparse.ArgumentParser(description='PyTorch CIFAR-X Example')
# 新增的参数以适应原始需求
parser.add_argument('--dataset', default='cifar10', help='cifar10|cifar100|imagenet|miniimagenet')
parser.add_argument('--model', default='VGG8', help='VGG8|DenseNet40|ResNet18|resnet18_cifar10|VGG16')
parser.add_argument('--mode', default='WAGE', help='WAGE|FP')
parser.add_argument('--parallelRead', type=int, default=1, help='Set parallelRead option')
parser.add_argument('--batch_size', type=int, default=200, help='input batch size for training (default: 200)')
parser.add_argument('--epochs', type=int, default=125, help='number of epochs to train (default: 125)')
parser.add_argument('--grad_scale', type=float, default=1, help='learning rate for wage delta calculation')
parser.add_argument('--seed', type=int, default=117, help='random seed (default: 117)')
parser.add_argument('--log_interval', type=int, default=100, help='how many batches to wait before logging training status default = 100')
parser.add_argument('--test_interval', type=int, default=1, help='how many epochs to wait before another test (default = 1)')
parser.add_argument('--logdir', default='log/default', help='folder to save to the log')
parser.add_argument('--decreasing_lr', default='50,100', help='decreasing strategy')
parser.add_argument('--wl_weight', type=int, default=2)
parser.add_argument('--wl_grad', type=int, default=8)
parser.add_argument('--wl_activate', type=int, default=8)
parser.add_argument('--wl_error', type=int, default=8)
parser.add_argument('--inference', default=0)
parser.add_argument('--onoffratio', default=100)
parser.add_argument('--cellBit', default=1)
parser.add_argument('--subArray', default=128)
parser.add_argument('--ADCprecision', default=5)
parser.add_argument('--vari', default=0)
parser.add_argument('--t', default=0)
parser.add_argument('--v', default=0)
parser.add_argument('--detect', default=0)
parser.add_argument('--target', default=0)
parser.add_argument('--nonlinearityLTP', default=0.01)
parser.add_argument('--nonlinearityLTD', default=-0.01)
parser.add_argument('--max_level', default=64)
parser.add_argument('--d2dVari', default=0)
parser.add_argument('--c2cVari', default=0)
current_time = datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
args = parser.parse_args()
# Override some arguments
args.wl_weight = 5
args.wl_grad = 5
args.cellBit = 5
args.max_level =32
args.c2cVari = 0.0
args.d2dVari = 0.0
args.nonlinearityLTP = 0.88
args.nonlinearityLTD = 1.04
# 添加GPU内存清理函数
def clear_gpu_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
# 添加检查点保存函数
def save_checkpoint(state, filename='checkpoint.pth'):
torch.save(state, filename)
args.cuda = False
gamma = 0.9
alpha = 0.1
NeuroSim_Out = np.array([["L_forward (s)", "L_activation gradient (s)", "L_weight gradient (s)", "L_weight update (s)",
"E_forward (J)", "E_activation gradient (J)", "E_weight gradient (J)", "E_weight update (J)",
"L_forward_Peak (s)", "L_activation gradient_Peak (s)", "L_weight gradient_Peak (s)", "L_weight update_Peak (s)",
"E_forward_Peak (J)", "E_activation gradient_Peak (J)", "E_weight gradient_Peak (J)", "E_weight update_Peak (J)",
"TOPS/W", "TOPS", "Peak TOPS/W", "Peak TOPS"]])
np.savetxt("NeuroSim_Output.csv", NeuroSim_Out, delimiter=",", fmt='%s')
if not os.path.exists('./NeuroSim_Results_Each_Epoch'):
os.makedirs('./NeuroSim_Results_Each_Epoch')
with open("PythonWrapper_Output.csv", 'ab') as out:
out_firstline = np.array([["epoch", "average loss", "accuracy"]])
np.savetxt(out, out_firstline, delimiter=",", fmt='%s')
with open("delta_dist.csv", 'ab') as delta_distribution:
delta_firstline = np.array([["1_mean", "2_mean", "3_mean", "4_mean", "5_mean", "6_mean", "7_mean", "8_mean",
"1_std", "2_std", "3_std", "4_std", "5_std", "6_std", "7_std", "8_std"]])
np.savetxt(delta_distribution, delta_firstline, delimiter=",", fmt='%s')
with open("weight_dist.csv", 'ab') as weight_distribution:
weight_firstline = np.array([["1_mean", "2_mean", "3_mean", "4_mean", "5_mean", "6_mean", "7_mean", "8_mean",
"1_std", "2_std", "3_std", "4_std", "5_std", "6_std", "7_std", "8_std"]])
np.savetxt(weight_distribution, weight_firstline, delimiter=",", fmt='%s')
args.logdir = os.path.join(os.path.dirname(__file__), args.logdir)
args = make_path.makepath(args, ['log_interval', 'test_interval', 'logdir', 'epochs'])
misc.logger.init(args.logdir, 'train_log_' + current_time)
logger = misc.logger.info
misc.ensure_dir(args.logdir)
logger("=================FLAGS==================")
for k, v in args.__dict__.items():
logger('{}: {}'.format(k, v))
logger("========================================")
torch.manual_seed(args.seed)
assert args.dataset in ['cifar10', 'cifar100', 'imagenet', 'miniimagenet'], f"Unsupported dataset: {args.dataset}"
if args.dataset == 'cifar10':
train_loader, test_loader = dataset.get_cifar10(batch_size=args.batch_size, num_workers=1,pin_memory=True)
elif args.dataset == 'cifar100':
train_loader, test_loader = dataset.get_cifar100(batch_size=args.batch_size, num_workers=1)
elif args.dataset == 'imagenet':
train_loader, test_loader = dataset.get_imagenet(batch_size=args.batch_size, num_workers=1)
else: # miniimagenet
train_loader, test_loader = dataset.get_miniimagenet(batch_size=args.batch_size, num_workers=1)
assert args.model in ['VGG8', 'DenseNet40', 'ResNet18', 'resnet18_cifar10', 'VGG16'], f"Unsupported model: {args.model}"
if args.model == 'VGG8':
model = VGG.vgg8(args=args, logger=logger)
criterion = wage_util.SSE()
elif args.model == 'VGG16':
model = VGG.vgg16(args=args, logger=logger)
criterion = wage_util.SSE()
elif args.model == 'resnet18':
model = ResNet18cifar10.resnet18_cifar10(args=args, logger=logger)
criterion = wage_util.SSE()
elif args.model == 'DenseNet40':
model = DenseNet.densenet40(args=args, logger=logger)
elif args.model == 'resnet18_cifar10':
model = ResNet18cifar10.resnet18_cifar10(args=args, logger=logger)
criterion = wage_util.SSE()
else:
model = ResNet.resnet18(args=args, logger=logger)
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1)
decreasing_lr = list(map(int, args.decreasing_lr.split(',')))
logger('decreasing_lr: ' + str(decreasing_lr))
best_acc, old_file = 0, None
t_begin = time.time()
grad_scale = args.grad_scale
try:
if args.cellBit != args.wl_weight:
print("Warning: Weight precision should be the same as the cell precision!")
paramALTP = {}
paramALTD = {}
k = 0
for layer in list(model.parameters())[::-1]:
d2dVariation = torch.normal(torch.zeros_like(layer), args.d2dVari * torch.ones_like(layer))
NL_LTP = torch.ones_like(layer) * args.nonlinearityLTP + d2dVariation
NL_LTD = torch.ones_like(layer) * args.nonlinearityLTD + d2dVariation
# 在调用 GetParamA() 之前,对 NL_LTP / NL_LTD 做缩放
# NL_LTP_scaled = NL_LTP * 0.5 # 如 0.5
# NL_LTD_scaled = NL_LTD * 0.5
# paramALTP[k] = wage_quantizer.GetParamA(NL_LTP_scaled) * args.max_level
# paramALTD[k] = wage_quantizer.GetParamA(NL_LTD_scaled) * args.max_level
paramALTP[k] = wage_quantizer.GetParamA(NL_LTP) * args.max_level
paramALTD[k] = wage_quantizer.GetParamA(NL_LTD) * args.max_level
k += 1
for epoch in range(args.epochs):
# 每个epoch开始时清理内存
clear_gpu_memory()
model.train()
velocity = {}
for i, layer in enumerate(list(model.parameters())[::-1]):
velocity[i] = torch.zeros_like(layer)
if epoch in decreasing_lr:
grad_scale = grad_scale / 8.0
logger("Training phase")
for batch_idx, (data, target) in enumerate(train_loader):
indx_target = target.clone()
data, target = Variable(data), Variable(target)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# --- 在这里做梯度裁剪,避免爆 ---
utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# 更新(含量化)
for j, (name, param) in enumerate(list(model.named_parameters())[::-1]):
velocity[j] = gamma * velocity[j] + alpha * param.grad.data
param.grad.data = velocity[j]
param.grad.data = wage_quantizer.QG(
param.data,
args.wl_weight,
param.grad.data,
args.wl_grad,
grad_scale,
paramALTP[j],
paramALTD[j],
args.max_level,
args.max_level
)
optimizer.step()
# 对权重再做 W 量化
for name, param in list(model.named_parameters())[::-1]:
param.data = wage_quantizer.W(
param.data,
param.grad.data,
args.wl_weight,
args.c2cVari
)
# 定期清理内存
if batch_idx % 50 == 0:
clear_gpu_memory()
if batch_idx % args.log_interval == 0 and batch_idx > 0:
pred = output.data.max(1)[1]
correct = pred.eq(indx_target).sum().item()
acc = float(correct) / len(data)
logger('Train Epoch: {} [{}/{}] Loss: {:.6f} Acc: {:.4f} lr: {:.2e}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
loss.item(), acc, optimizer.param_groups[0]['lr']
))
elapse_time = time.time() - t_begin
speed_epoch = elapse_time / (epoch + 1)
speed_batch = speed_epoch / len(train_loader)
eta = speed_epoch * args.epochs - elapse_time
logger("Elapsed {:.2f}s, {:.2f} s/epoch, {:.2f} s/batch, eta {:.2f}s".format(
elapse_time, speed_epoch, speed_batch, eta
))
misc.model_save(model, os.path.join(args.logdir, 'latest.pth'))
if not os.path.exists('./layer_record'):
os.makedirs('./layer_record')
if os.path.exists('./layer_record/trace_command.sh'):
os.remove('./layer_record/trace_command.sh')
delta_std = np.array([])
delta_mean = np.array([])
w_std = np.array([])
w_mean = np.array([])
oldWeight = {}
k = 0
for name, param in list(model.named_parameters()):
oldWeight[k] = param.data + param.grad.data
delta_std = np.append(delta_std, torch.std(param.grad.data).cpu().numpy())
delta_mean = np.append(delta_mean, torch.mean(param.grad.data).cpu().numpy())
w_std = np.append(w_std, torch.std(param.data).cpu().numpy())
w_mean = np.append(w_mean, torch.mean(param.data).cpu().numpy())
k += 1
delta_mean = np.append(delta_mean, delta_std, axis=0)
w_mean = np.append(w_mean, w_std, axis=0)
with open("delta_dist.csv", 'ab') as delta_distribution:
np.savetxt(delta_distribution, [delta_mean], delimiter=",", fmt='%f')
with open("weight_dist.csv", 'ab') as weight_distribution:
np.savetxt(weight_distribution, [w_mean], delimiter=",", fmt='%f')
print("Weight distribution")
print(w_mean)
print("Delta distribution")
print(delta_mean)
h = 0
# 将权重保存,对于VGG,model.features里为卷积层,对于DenseNet/ResNet需要适当调节逻辑
# 以下逻辑保持原有训练代码的风格
# 如果模型有features和classifier属性:
if hasattr(model, 'features'):
for layer in model.features.modules():
if isinstance(layer, (QConv2d, QLinear)):
weight_file_name = f'./layer_record/weightOld{h}.csv'
hook.write_matrix_weight(oldWeight[h].cpu().numpy(), weight_file_name)
h += 1
if hasattr(model, 'classifier'):
for layer in model.classifier.modules():
if isinstance(layer, QLinear):
weight_file_name = f'./layer_record/weightOld{h}.csv'
hook.write_matrix_weight(oldWeight[h].cpu().numpy(), weight_file_name)
h += 1
# Testing phase
if epoch % args.test_interval == 0:
clear_gpu_memory() # 测试前清理内存
model.eval()
test_loss = 0
correct = 0
logger("Testing phase")
with open("PythonWrapper_Output.csv", 'ab') as out:
for i, (data, target) in enumerate(test_loader):
if i == 0:
hook_handle_list = hook.hardware_evaluation(model,args.wl_weight,args.wl_activate,args.subArray,args.parallelRead,args.model,args.mode)
indx_target = target.clone()
with torch.no_grad():
data, target = Variable(data), Variable(target)
output = model(data)
test_loss_i = criterion(output, target)
test_loss += test_loss_i.item()
pred = output.data.max(1)[1]
correct += pred.eq(indx_target).sum().item()
if i == 0:
hook.remove_hook_list(hook_handle_list)
test_loss = test_loss / len(test_loader)
acc = 100. * correct / len(test_loader.dataset)
logger('\tEpoch {} Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
epoch, test_loss, correct, len(test_loader.dataset), acc
))
accuracy = acc
np.savetxt(out, [[epoch, test_loss, accuracy]], delimiter=",", fmt='%f')
if acc > best_acc:
new_file = os.path.join(args.logdir, f'best-{epoch}.pth')
misc.model_save(model, new_file, old_file=old_file, verbose=True)
best_acc = acc
old_file = new_file
# 清理当前epoch的内存
clear_gpu_memory()
call(["/bin/bash", "./layer_record/trace_command.sh"])
except Exception as e:
import traceback
traceback.print_exc()
finally:
total_elapse = time.time() - t_begin
logger("Total Elapse: {:.2f}s, Best Result: {:.3f}%".format(total_elapse, best_acc))
if __name__ == '__main__':
main()