Skip to content

Commit 46b6567

Browse files
authored
Merge pull request #23 from HiLab-git/dev
Dev
2 parents 5cbbca7 + 733d93c commit 46b6567

24 files changed

+330
-101
lines changed

README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ Run the following command to install the current released version of PyMIC:
3535
pip install PYMIC
3636
```
3737

38-
Alternatively, you can download the source code and add the path of pymic to the `PYTHONPATH` environment variable.
38+
Alternatively, you can download the source code for the latest version. Run the following command to compile and install:
39+
40+
```bash
41+
python setup.py install
42+
```
3943

4044
## Examples
4145
[PyMIC_examples][examples] provides some examples of starting to use PyMIC. For beginners, you only need to simply change the configuration files to select different datasets, networks and training methods for running the code. For advanced users, you can develop your own modules based on this package. You can find both types of examples

pymic/io/image_read_write.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,10 @@ def save_array_as_nifty_volume(data, image_name, reference_name = None):
7979
img = sitk.GetImageFromArray(data)
8080
if(reference_name is not None):
8181
img_ref = sitk.ReadImage(reference_name)
82-
img.CopyInformation(img_ref)
82+
#img.CopyInformation(img_ref)
83+
img.SetSpacing(img_ref.GetSpacing())
84+
img.SetOrigin(img_ref.GetOrigin())
85+
img.SetDirection(img_ref.GetDirection())
8386
sitk.WriteImage(img, image_name)
8487

8588
def save_array_as_rgb_image(data, image_name):

pymic/loss/cls/ce.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ def __init__(self, params):
3737
def forward(self, loss_input_dict):
3838
predict = loss_input_dict['prediction']
3939
labels = loss_input_dict['ground_truth']
40-
predict = nn.Sigmoid()(predict)
40+
# for numeric stability
41+
predict = nn.Sigmoid()(predict) * 0.999 + 5e-4
4142
loss = - labels * torch.log(predict) - (1 - labels) * torch.log( 1 - predict)
4243
loss = loss.mean()
4344
return loss

pymic/loss/loss_dict_seg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
import torch.nn as nn
44
from pymic.loss.seg.ce import CrossEntropyLoss, GeneralizedCrossEntropyLoss
55
from pymic.loss.seg.dice import DiceLoss, MultiScaleDiceLoss
6-
from pymic.loss.seg.dice import DiceWithCrossEntropyLoss, NoiseRobustDiceLoss
6+
from pymic.loss.seg.dice import FocalDiceLoss, NoiseRobustDiceLoss
77
from pymic.loss.seg.exp_log import ExpLogLoss
88
from pymic.loss.seg.mse import MSELoss, MAELoss
99

1010
SegLossDict = {'CrossEntropyLoss': CrossEntropyLoss,
1111
'GeneralizedCrossEntropyLoss': GeneralizedCrossEntropyLoss,
1212
'DiceLoss': DiceLoss,
1313
'MultiScaleDiceLoss': MultiScaleDiceLoss,
14-
'DiceWithCrossEntropyLoss': DiceWithCrossEntropyLoss,
14+
'FocalDiceLoss': FocalDiceLoss,
1515
'NoiseRobustDiceLoss': NoiseRobustDiceLoss,
1616
'ExpLogLoss': ExpLogLoss,
1717
'MSELoss': MSELoss,

pymic/loss/seg/ce.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
class CrossEntropyLoss(nn.Module):
99
def __init__(self, params):
1010
super(CrossEntropyLoss, self).__init__()
11-
self.enable_pix_weight = params['CrossEntropyLoss_Enable_Pixel_Weight'.lower()]
12-
self.enable_cls_weight = params['CrossEntropyLoss_Enable_Class_Weight'.lower()]
11+
self.enable_pix_weight = params.get('CrossEntropyLoss_Enable_Pixel_Weight'.lower(), False)
12+
self.enable_cls_weight = params.get('CrossEntropyLoss_Enable_Class_Weight'.lower(), False)
1313

1414
def forward(self, loss_input_dict):
1515
predict = loss_input_dict['prediction']
@@ -25,6 +25,8 @@ def forward(self, loss_input_dict):
2525
predict = reshape_tensor_to_2D(predict)
2626
soft_y = reshape_tensor_to_2D(soft_y)
2727

28+
# for numeric stability
29+
predict = predict * 0.999 + 5e-4
2830
ce = - soft_y* torch.log(predict)
2931
if(self.enable_cls_weight):
3032
if(cls_w is None):

pymic/loss/seg/combined.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,5 @@ def __init__(self, params, loss_dict):
2222
def forward(self, loss_input_dict):
2323
loss_value = 0.0
2424
for i in range(len(self.loss_list)):
25-
loss_value = self.loss_weight[i] + self.loss_list[i](loss_input_dict)
26-
return loss_value
25+
loss_value += self.loss_weight[i]*self.loss_list[i](loss_input_dict)
26+
return loss_value

pymic/loss/seg/dice.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,24 +40,46 @@ def forward(self, loss_input_dict):
4040
dice_loss = 1.0 - average_dice
4141
return dice_loss
4242

43-
class DiceWithCrossEntropyLoss(nn.Module):
44-
def __init__(self, params):
45-
super(DiceWithCrossEntropyLoss, self).__init__()
46-
self.enable_pix_weight = params['DiceWithCrossEntropyLoss_Enable_Pixel_Weight'.lower()]
47-
self.enable_cls_weight = params['DiceWithCrossEntropyLoss_Enable_Class_Weight'.lower()]
48-
self.ce_weight = params['DiceWithCrossEntropyLoss_CE_Weight'.lower()]
49-
dice_params = {'DiceLoss_Enable_Pixel_Weight'.lower(): self.enable_pix_weight,
50-
'DiceLoss_Enable_Class_Weight'.lower(): self.enable_cls_weight}
51-
ce_params = {'CrossEntropyLoss_Enable_Pixel_Weight'.lower(): self.enable_pix_weight,
52-
'CrossEntropyLoss_Enable_Class_Weight'.lower(): self.enable_cls_weight}
53-
self.dice_loss = DiceLoss(dice_params)
54-
self.ce_loss = CrossEntropyLoss(ce_params)
43+
class FocalDiceLoss(nn.Module):
44+
"""
45+
focal Dice according to the following paper:
46+
Pei Wang and Albert C. S. Chung, Focal Dice Loss and Image Dilation for
47+
Brain Tumor Segmentation, 2018
48+
"""
49+
def __init__(self, params = None):
50+
super(FocalDiceLoss, self).__init__()
51+
self.beta = params['FocalDiceLoss_beta'.lower()] #beta should be >=1.0
5552

5653
def forward(self, loss_input_dict):
57-
loss1 = self.dice_loss(loss_input_dict)
58-
loss2 = self.ce_loss(loss_input_dict)
59-
loss = loss1 + self.ce_weight * loss2
60-
return loss
54+
predict = loss_input_dict['prediction']
55+
soft_y = loss_input_dict['ground_truth']
56+
img_w = loss_input_dict['image_weight']
57+
pix_w = loss_input_dict['pixel_weight']
58+
cls_w = loss_input_dict['class_weight']
59+
softmax = loss_input_dict['softmax']
60+
61+
if(isinstance(predict, (list, tuple))):
62+
predict = predict[0]
63+
tensor_dim = len(predict.size())
64+
if(softmax):
65+
predict = nn.Softmax(dim = 1)(predict)
66+
predict = reshape_tensor_to_2D(predict)
67+
soft_y = reshape_tensor_to_2D(soft_y)
68+
69+
# combien pixel weight and image weight
70+
if(tensor_dim == 5):
71+
img_w = img_w[:, None, None, None, None]
72+
else:
73+
img_w = img_w[:, None, None, None]
74+
pix_w = pix_w * img_w
75+
pix_w = reshape_tensor_to_2D(pix_w)
76+
dice_score = get_classwise_dice(predict, soft_y, pix_w)
77+
78+
dice_score = torch.pow(dice_score, 1.0 / self.beta)
79+
weighted_dice = dice_score * cls_w
80+
average_dice = weighted_dice.sum() / cls_w.sum()
81+
dice_loss = 1.0 - average_dice
82+
return dice_loss
6183

6284
class MultiScaleDiceLoss(nn.Module):
6385
def __init__(self, params):

pymic/net/cls/torch_pretrained_net.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,30 @@ def get_parameters_to_update(self):
7373
else:
7474
raise(ValueError("update_layers can only be 0 (all layers) " +
7575
"or -1 (the last layer)"))
76+
77+
class MobileNetV2(nn.Module):
78+
def __init__(self, params):
79+
super(MobileNetV2, self).__init__()
80+
self.params = params
81+
net_name = params['net_type']
82+
cls_num = params['class_num']
83+
in_chns = params['input_chns']
84+
self.pretrain = params['pretrain']
85+
self.update_layers = params.get('update_layers', 0)
86+
self.net = models.mobilenet_v2(pretrained = self.pretrain)
87+
88+
# replace the last layer
89+
num_ftrs = self.net.last_channel
90+
self.net.classifier[-1] = nn.Linear(num_ftrs, cls_num)
91+
92+
def forward(self, x):
93+
return self.net(x)
94+
95+
def get_parameters_to_update(self):
96+
if(self.pretrain == False or self.update_layers == 0):
97+
return self.net.parameters()
98+
elif(self.update_layers == -1):
99+
return self.net.classifier[-1].parameters()
100+
else:
101+
raise(ValueError("update_layers can only be 0 (all layers) " +
102+
"or -1 (the last layer)"))

pymic/net/net_dict_cls.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@
44

55
TorchClsNetDict = {
66
'resnet18': ResNet18,
7-
'vgg16': VGG16
7+
'vgg16': VGG16,
8+
'mobilenetv2':MobileNetV2
89
}

pymic/net_run/agent_cls.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def infer(self):
278278
checkpoint = torch.load(checkpoint_name, map_location = device)
279279
self.net.load_state_dict(checkpoint['model_state_dict'])
280280

281-
if(self.config['testing']['evaluation_mode'] == True):
281+
if(self.config['testing'].get('evaluation_mode', True)):
282282
self.net.eval()
283283

284284
output_csv = self.config['testing']['output_csv']

0 commit comments

Comments
 (0)