Skip to content

Commit c632b8b

Browse files
committed
add all the python files
1 parent c6c35a6 commit c632b8b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+10321
-2
lines changed

.gitignore

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Byte-compiled / optimized / DLL files
22
__pycache__/
3-
3+
*
4+
!*.py
45
*$py.class
56

67
# C extensions
@@ -14,7 +15,6 @@ dist/
1415
downloads/
1516
eggs/
1617
.eggs/
17-
lib/
1818
lib64/
1919
parts/
2020
sdist/

3DGAN/lib/__init__.py

Whitespace-only changes.

3DGAN/lib/config/__init__.py

Whitespace-only changes.

3DGAN/lib/config/config.py

+308
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
# ------------------------------------------------------------------------------
2+
# Copyright (c) Tencent
3+
# Licensed under the GPLv3 License.
4+
# Created by Kai Ma ([email protected])
5+
# ------------------------------------------------------------------------------
6+
7+
from __future__ import print_function
8+
from __future__ import absolute_import
9+
from __future__ import division
10+
11+
from easydict import EasyDict
12+
import os
13+
import numpy as np
14+
15+
__C = EasyDict()
16+
cfg = __C
17+
18+
# Model Path
19+
__C.MODEL_SAVE_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'save_models'))
20+
__C.CT_MIN_MAX = [0, 5800]
21+
__C.XRAY1_MIN_MAX = [0, 1200]
22+
__C.XRAY2_MIN_MAX = [0, 1700]
23+
__C.CT_MEAN_STD = [0., 1.0]
24+
__C.XRAY1_MEAN_STD = [0., 1.0]
25+
__C.XRAY2_MEAN_STD = [0., 1.0]
26+
27+
'''
28+
Network
29+
Generator
30+
'''
31+
__C.NETWORK = EasyDict()
32+
# of input image channels
33+
__C.NETWORK.input_nc_G = 3
34+
# of output image channels
35+
__C.NETWORK.output_nc_G = 3
36+
# of gen filters in first conv layer
37+
__C.NETWORK.ngf = 64
38+
# selects model to use for netG
39+
__C.NETWORK.which_model_netG = 'resnet_generator'
40+
# instance normalization or batch normalization
41+
__C.NETWORK.norm_G = 'instance'
42+
# no dropout for the generator
43+
__C.NETWORK.no_dropout = False
44+
# network initialization [normal|xavier|kaiming|orthogonal]
45+
__C.NETWORK.init_type = 'normal'
46+
# gan, lsgan, wgan, wgan_gp
47+
__C.NETWORK.ganloss = 'lsgan'
48+
# down sampling
49+
__C.NETWORK.n_downsampling = 3
50+
__C.NETWORK.n_blocks = 9
51+
# activation
52+
__C.NETWORK.activation_type = 'relu'
53+
54+
'''
55+
Network
56+
Discriminator
57+
'''
58+
# of input image channels
59+
__C.NETWORK.input_nc_D = 3
60+
# of output image channels
61+
# __C.NETWORK.output_nc_D = 1
62+
# of discrim filters in first conv layer
63+
__C.NETWORK.ndf = 64
64+
# selects model to use for netD
65+
__C.NETWORK.which_model_netD = 'basic'
66+
# only used if which_model_netD==n_layers, dtype = int
67+
__C.NETWORK.n_layers_D = 3
68+
# instance normalization or batch normalization, dtype = str
69+
__C.NETWORK.norm_D = 'instance3d'
70+
# output channels of discriminator network, dtype = int
71+
__C.NETWORK.n_out_ChannelsD = 1
72+
__C.NETWORK.pool_size = 50
73+
__C.NETWORK.if_pool = False
74+
__C.NETWORK.num_D = 3
75+
# add condition to discriminator network
76+
__C.NETWORK.conditional_D = False
77+
78+
# of input image channels
79+
__C.NETWORK.map_input_nc_D = 1
80+
# of discrim filters in first conv layer
81+
__C.NETWORK.map_ndf = 64
82+
# selects model to use for netD
83+
__C.NETWORK.map_which_model_netD = 'basic'
84+
# only used if which_model_netD==n_layers
85+
__C.NETWORK.map_n_layers_D = 3
86+
# instance normalization or batch normalization, dtype = str
87+
__C.NETWORK.map_norm_D = 'instance'
88+
# output channels of discriminator network, dtype = int
89+
__C.NETWORK.map_n_out_ChannelsD = 1
90+
__C.NETWORK.map_pool_size = 50
91+
__C.NETWORK.map_num_D = 3
92+
93+
'''
94+
Train
95+
'''
96+
__C.TRAIN = EasyDict()
97+
# initial learning rate for adam
98+
__C.TRAIN.lr = 0.0002
99+
# momentum term of adam
100+
__C.TRAIN.beta1 = 0.5
101+
__C.TRAIN.beta2 = 0.9
102+
# if true, takes images in order to make batches, otherwise takes them randomly
103+
__C.TRAIN.serial_batches = False
104+
__C.TRAIN.batch_size = 1
105+
# threads for loading data
106+
__C.TRAIN.nThreads = 5
107+
# __C.TRAIN.max_epoch = 10
108+
# learning rate policy: lambda|step|plateau
109+
__C.TRAIN.lr_policy = 'lambda'
110+
# of iter at starting learning rate
111+
__C.TRAIN.niter = 100
112+
# of iter to linearly decay learning rate to zero
113+
__C.TRAIN.niter_decay = 100
114+
# multiply by a gamma every lr_decay_iters iterations
115+
__C.TRAIN.lr_decay_iters = 50
116+
# frequency of showing training results on console
117+
__C.TRAIN.print_freq = 10
118+
# frequency of showing training results on console
119+
__C.TRAIN.print_img_freq = 200
120+
# save model
121+
__C.TRAIN.save_latest_freq = 3000
122+
# save model frequent
123+
__C.TRAIN.save_epoch_freq = 5
124+
__C.TRAIN.begin_save_epoch = 0
125+
126+
__C.TRAIN.weight_decay_if = False
127+
128+
'''
129+
TEST
130+
'''
131+
__C.TEST = EasyDict()
132+
__C.TEST.howmany_in_train = 10
133+
134+
'''
135+
Data
136+
Augmentation
137+
'''
138+
__C.DATA_AUG = EasyDict()
139+
__C.DATA_AUG.select_slice_num = 0
140+
__C.DATA_AUG.fine_size = 256
141+
__C.DATA_AUG.ct_channel = 256
142+
__C.DATA_AUG.xray_channel = 1
143+
__C.DATA_AUG.resize_size = 289
144+
145+
'''
146+
2D GAN define loss
147+
'''
148+
__C.TD_GAN = EasyDict()
149+
# identity loss
150+
__C.TD_GAN.idt_lambda = 10.
151+
__C.TD_GAN.idt_reduction = 'elementwise_mean'
152+
__C.TD_GAN.idt_weight = 0.
153+
__C.TD_GAN.idt_weight_range = [0., 1.]
154+
__C.TD_GAN.restruction_loss = 'l1'
155+
# perceptual loss
156+
__C.TD_GAN.fea_m_lambda = 10.
157+
# output of discriminator
158+
__C.TD_GAN.discriminator_feature = True
159+
# wgan-gp
160+
__C.TD_GAN.wgan_gp_lambda = 10.
161+
# identity loss of map
162+
__C.TD_GAN.map_m_lambda = 0.
163+
# 'l1' or 'mse'
164+
__C.TD_GAN.map_m_type = 'l1'
165+
__C.TD_GAN.fea_m_map_lambda = 10.
166+
# Discriminator train times
167+
__C.TD_GAN.critic_times = 1
168+
169+
'''
170+
3D GD-GAN define structure
171+
'''
172+
__C.D3_GAN = EasyDict()
173+
__C.D3_GAN.noise_len = 1000
174+
__C.D3_GAN.input_shape = [4,4,4]
175+
# __C.D3_GAN.input_shape_nc = 512
176+
__C.D3_GAN.output_shape = [128,128,128]
177+
# __C.D3_GAN.output_shape_nc = 1
178+
__C.D3_GAN.encoder_input_shape = [128, 128]
179+
__C.D3_GAN.encoder_input_nc = 1
180+
__C.D3_GAN.encoder_norm = 'instance'
181+
__C.D3_GAN.encoder_blocks = 4
182+
__C.D3_GAN.multi_view = [1,2,3]
183+
__C.D3_GAN.min_max_norm = False
184+
__C.D3_GAN.skip_number = 1
185+
# DoubleBlockLinearUnit Activation [low high k]
186+
__C.D3_GAN.dblu = [0., 1.0, 1.0]
187+
188+
'''
189+
CT GAN
190+
'''
191+
__C.CTGAN = EasyDict()
192+
# input x-ray direction, 'H'-FrontBehind 'D'-UpDown 'W'-LeftRight
193+
# 'HDW' Means that deepness is 'H' and projection in plane of 'DW'
194+
# relative to CT.
195+
__C.CTGAN.Xray1_Direction = 'HDW'
196+
__C.CTGAN.Xray2_Direction = 'WDH'
197+
# dimension order of input CT is 'DHW'(should add 'NC'-01 to front when training)
198+
__C.CTGAN.CTOrder = [0, 1, 2, 3, 4]
199+
# NCHDW to xray1 and NCWDH to xray2
200+
__C.CTGAN.CTOrder_Xray1 = [0, 1, 3, 2, 4]
201+
__C.CTGAN.CTOrder_Xray2 = [0, 1, 4, 2, 3]
202+
# identity loss'weight
203+
__C.CTGAN.idt_lambda = 1.0
204+
__C.CTGAN.idt_reduction = 'elementwise_mean'
205+
__C.CTGAN.idt_weight = 0.
206+
__C.CTGAN.idt_weight_range = [0., 1.]
207+
# 'l1' or 'mse'
208+
__C.CTGAN.idt_loss = 'l1'
209+
# feature metrics loss
210+
__C.CTGAN.feature_D_lambda = 0.
211+
# projection loss'weight
212+
__C.CTGAN.map_projection_lambda = 0.
213+
# 'l1' or 'mse'
214+
__C.CTGAN.map_projection_loss = 'l1'
215+
# gan loss'weight
216+
__C.CTGAN.gan_lambda = 1.0
217+
# multiView GAN auxiliary loss
218+
__C.CTGAN.auxiliary_lambda = 0.
219+
# 'l1' or 'mse'
220+
__C.CTGAN.auxiliary_loss = 'mse'
221+
# map discriminator
222+
__C.CTGAN.feature_D_map_lambda = 0.
223+
__C.CTGAN.map_gan_lambda = 1.0
224+
225+
def cfg_from_yaml(filename):
226+
'''
227+
Load a config file and merge it into the default options
228+
:param filename:
229+
:return:
230+
'''
231+
import yaml
232+
with open(filename, 'r') as f:
233+
yaml_cfg = EasyDict(yaml.load(f))
234+
_merge_a_into_b(yaml_cfg, __C)
235+
236+
def print_easy_dict(easy_dict):
237+
print('==='*10)
238+
print('====YAML Parameters')
239+
for k,v in easy_dict.__dict__.items():
240+
print('{}: {}'.format(k, v))
241+
print('==='*10)
242+
243+
def merge_dict_and_yaml(in_dict, easy_dict):
244+
if type(easy_dict) is not EasyDict:
245+
return in_dict
246+
easy_list = _easy_dict_squeeze(easy_dict)
247+
for (k, v) in easy_list:
248+
if k in in_dict:
249+
raise KeyError('The same Key appear {}/{}'.format(k,k))
250+
out_dict = EasyDict(dict(easy_list + list(in_dict.items())))
251+
return out_dict
252+
253+
def _easy_dict_squeeze(easy_dict):
254+
if type(easy_dict) is not EasyDict:
255+
print('Not EasyDict!!!')
256+
return []
257+
258+
total_list = []
259+
for k, v in easy_dict.items():
260+
# recursively merge dicts
261+
if type(v) is EasyDict:
262+
try:
263+
total_list += _easy_dict_squeeze(v)
264+
except:
265+
print('Error under config key: {}'.format(k))
266+
raise
267+
else:
268+
total_list.append((k, v))
269+
return total_list
270+
271+
def _merge_a_into_b(a, b):
272+
'''
273+
Merge easyDict a to easyDict b
274+
:param a: from easyDict
275+
:param b: to easyDict
276+
:return:
277+
'''
278+
if type(a) is not EasyDict:
279+
return
280+
281+
for k, v in a.items():
282+
# check k in a or not
283+
if k not in b:
284+
raise KeyError('{} is not a valid config key'.format(k))
285+
286+
old_type = type(b[k])
287+
if old_type is not type(v):
288+
if isinstance(b[k], np.ndarray):
289+
v = np.array(v, dtype=b[k].type)
290+
else:
291+
raise ValueError('Type mismatch ({} vs. {})'
292+
'for config key: {}'.format(type(b[k]), type(v), k))
293+
# recursively merge dicts
294+
if type(v) is EasyDict:
295+
try:
296+
_merge_a_into_b(a[k], b[k])
297+
except:
298+
print('Error under config key: {}'.format(k))
299+
raise
300+
else:
301+
b[k] = v
302+
303+
304+
305+
306+
307+
308+

3DGAN/lib/dataset/__init__.py

Whitespace-only changes.

3DGAN/lib/dataset/alignDataSet.py

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# ------------------------------------------------------------------------------
2+
# Copyright (c) Tencent
3+
# Licensed under the GPLv3 License.
4+
# Created by Kai Ma ([email protected])
5+
# ------------------------------------------------------------------------------
6+
7+
from __future__ import print_function
8+
from __future__ import absolute_import
9+
from __future__ import division
10+
11+
from lib.dataset.baseDataSet import Base_DataSet
12+
from lib.dataset.utils import *
13+
import h5py
14+
import numpy as np
15+
16+
17+
class AlignDataSet(Base_DataSet):
18+
'''
19+
DataSet For unaligned data
20+
'''
21+
def __init__(self, opt):
22+
super(AlignDataSet, self).__init__()
23+
self.opt = opt
24+
self.ext = '.h5'
25+
self.dataset_paths = get_dataset_from_txt_file(self.opt.datasetfile)
26+
self.dataset_paths = sorted(self.dataset_paths)
27+
self.dataset_size = len(self.dataset_paths)
28+
self.dir_root = self.get_data_path
29+
self.data_augmentation = self.opt.data_augmentation(opt)
30+
31+
@property
32+
def name(self):
33+
return 'AlignDataSet'
34+
35+
@property
36+
def get_data_path(self):
37+
path = os.path.join(self.opt.dataroot)
38+
return path
39+
40+
@property
41+
def num_samples(self):
42+
return self.dataset_size
43+
44+
def get_image_path(self, root, index_name):
45+
img_path = os.path.join(root, index_name, 'ct_xray_data'+self.ext)
46+
assert os.path.exists(img_path), 'Path do not exist: {}'.format(img_path)
47+
return img_path
48+
49+
def load_file(self, file_path):
50+
hdf5 = h5py.File(file_path, 'r')
51+
ct_data = np.asarray(hdf5['ct'])
52+
x_ray1 = np.asarray(hdf5['xray1'])
53+
x_ray1 = np.expand_dims(x_ray1, 0)
54+
hdf5.close()
55+
return ct_data, x_ray1
56+
57+
'''
58+
generate batch
59+
'''
60+
def pull_item(self, item):
61+
file_path = self.get_image_path(self.dir_root, self.dataset_paths[item])
62+
ct_data, x_ray1 = self.load_file(file_path)
63+
64+
# Data Augmentation
65+
ct, xray1 = self.data_augmentation([ct_data, x_ray1])
66+
67+
return ct, xray1, file_path
68+
69+
70+
71+
72+
73+

0 commit comments

Comments
 (0)