-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
89 lines (78 loc) · 2.97 KB
/
train.py
File metadata and controls
89 lines (78 loc) · 2.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
from torch.utils.data import DataLoader
from torch import optim
from torchvision import transforms
import torch.nn as nn
import os
from network import U_Net
from utils import My_Dataset
from tensorboardX import SummaryWriter
from utils import MulticlassDiceLoss
from utils import loss_weight
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
num_classes = 2
num_channels = 3
batch_size = 4
size=(256,256)
num_epochs=100
root = "data/membrane/train"
# 是否使用cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
weight=loss_weight()
img_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
# mask只需要转换为tensor
mask_transforms = transforms.ToTensor()
def train_model(model, criterion, optimizer, dataload, model_graph=None,num_epochs=10):
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('--' * 10)
dt_size = len(dataload.dataset)
epoch_loss = 0
step = 0
for x, y in dataload:
step += 1
outputs_weight=weight.class_weight(y.numpy())
weight.distance_weight(y.numpy())
outputs_weight=torch.from_numpy(outputs_weight).to(device)
inputs = x.to(device)
labels = y.to(device)
labels=labels.long()
# print(labels)
# zero the parameter gradients
optimizer.zero_grad()
# forward
outputs = model(inputs)
outputs=(outputs.double()).mul(outputs_weight.double())
# print(outputs)
# outputs.float()
# print(outputs.shape,labels.shape)
# outputs=outputs.long()
# labels=labels.long()
loss = criterion(outputs,labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // dataload.batch_size + 1, loss.item()))
print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
model_graph.add_scalar("train",epoch_loss,epoch )
torch.save(model.state_dict(), 'UNet_weights_bilinear_weight.pth')
def train():
model = U_Net(n_channels=num_channels, n_classes=num_classes).to(device)
model_graph = SummaryWriter(comment="UNet")
input_c = torch.rand(1, 3, 256, 256)
# model_graph.add_graph(model, (input_c.to(device),))
model.train()
# criterion = nn.BCELoss()
criterion = nn.NLLLoss2d()
# criterion=MulticlassDiceLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
dataset = My_Dataset(root, num_classes, size, transform=img_transforms, mask_transform=mask_transforms)
data_loaders = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
train_model(model, criterion, optimizer, data_loaders,model_graph=model_graph,num_epochs=num_epochs)
model_graph.close()
if __name__ == "__main__":
train()
# torch.cuda.empty_cache()