|
| 1 | +# ------------------------------------------------------------------------------ |
| 2 | +# Copyright (c) Tencent |
| 3 | +# Licensed under the GPLv3 License. |
| 4 | +# Created by Kai Ma ([email protected]) |
| 5 | +# ------------------------------------------------------------------------------ |
| 6 | + |
| 7 | +from __future__ import print_function |
| 8 | +from __future__ import absolute_import |
| 9 | +from __future__ import division |
| 10 | + |
| 11 | +from easydict import EasyDict |
| 12 | +import os |
| 13 | +import numpy as np |
| 14 | + |
| 15 | +__C = EasyDict() |
| 16 | +cfg = __C |
| 17 | + |
| 18 | +# Model Path |
| 19 | +__C.MODEL_SAVE_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'save_models')) |
| 20 | +__C.CT_MIN_MAX = [0, 5800] |
| 21 | +__C.XRAY1_MIN_MAX = [0, 1200] |
| 22 | +__C.XRAY2_MIN_MAX = [0, 1700] |
| 23 | +__C.CT_MEAN_STD = [0., 1.0] |
| 24 | +__C.XRAY1_MEAN_STD = [0., 1.0] |
| 25 | +__C.XRAY2_MEAN_STD = [0., 1.0] |
| 26 | + |
| 27 | +''' |
| 28 | +Network |
| 29 | + Generator |
| 30 | +''' |
| 31 | +__C.NETWORK = EasyDict() |
| 32 | +# of input image channels |
| 33 | +__C.NETWORK.input_nc_G = 3 |
| 34 | +# of output image channels |
| 35 | +__C.NETWORK.output_nc_G = 3 |
| 36 | +# of gen filters in first conv layer |
| 37 | +__C.NETWORK.ngf = 64 |
| 38 | +# selects model to use for netG |
| 39 | +__C.NETWORK.which_model_netG = 'resnet_generator' |
| 40 | +# instance normalization or batch normalization |
| 41 | +__C.NETWORK.norm_G = 'instance' |
| 42 | +# no dropout for the generator |
| 43 | +__C.NETWORK.no_dropout = False |
| 44 | +# network initialization [normal|xavier|kaiming|orthogonal] |
| 45 | +__C.NETWORK.init_type = 'normal' |
| 46 | +# gan, lsgan, wgan, wgan_gp |
| 47 | +__C.NETWORK.ganloss = 'lsgan' |
| 48 | +# down sampling |
| 49 | +__C.NETWORK.n_downsampling = 3 |
| 50 | +__C.NETWORK.n_blocks = 9 |
| 51 | +# activation |
| 52 | +__C.NETWORK.activation_type = 'relu' |
| 53 | + |
| 54 | +''' |
| 55 | +Network |
| 56 | + Discriminator |
| 57 | +''' |
| 58 | +# of input image channels |
| 59 | +__C.NETWORK.input_nc_D = 3 |
| 60 | +# of output image channels |
| 61 | +# __C.NETWORK.output_nc_D = 1 |
| 62 | +# of discrim filters in first conv layer |
| 63 | +__C.NETWORK.ndf = 64 |
| 64 | +# selects model to use for netD |
| 65 | +__C.NETWORK.which_model_netD = 'basic' |
| 66 | +# only used if which_model_netD==n_layers, dtype = int |
| 67 | +__C.NETWORK.n_layers_D = 3 |
| 68 | +# instance normalization or batch normalization, dtype = str |
| 69 | +__C.NETWORK.norm_D = 'instance3d' |
| 70 | +# output channels of discriminator network, dtype = int |
| 71 | +__C.NETWORK.n_out_ChannelsD = 1 |
| 72 | +__C.NETWORK.pool_size = 50 |
| 73 | +__C.NETWORK.if_pool = False |
| 74 | +__C.NETWORK.num_D = 3 |
| 75 | +# add condition to discriminator network |
| 76 | +__C.NETWORK.conditional_D = False |
| 77 | + |
| 78 | +# of input image channels |
| 79 | +__C.NETWORK.map_input_nc_D = 1 |
| 80 | +# of discrim filters in first conv layer |
| 81 | +__C.NETWORK.map_ndf = 64 |
| 82 | +# selects model to use for netD |
| 83 | +__C.NETWORK.map_which_model_netD = 'basic' |
| 84 | +# only used if which_model_netD==n_layers |
| 85 | +__C.NETWORK.map_n_layers_D = 3 |
| 86 | +# instance normalization or batch normalization, dtype = str |
| 87 | +__C.NETWORK.map_norm_D = 'instance' |
| 88 | +# output channels of discriminator network, dtype = int |
| 89 | +__C.NETWORK.map_n_out_ChannelsD = 1 |
| 90 | +__C.NETWORK.map_pool_size = 50 |
| 91 | +__C.NETWORK.map_num_D = 3 |
| 92 | + |
| 93 | +''' |
| 94 | +Train |
| 95 | +''' |
| 96 | +__C.TRAIN = EasyDict() |
| 97 | +# initial learning rate for adam |
| 98 | +__C.TRAIN.lr = 0.0002 |
| 99 | +# momentum term of adam |
| 100 | +__C.TRAIN.beta1 = 0.5 |
| 101 | +__C.TRAIN.beta2 = 0.9 |
| 102 | +# if true, takes images in order to make batches, otherwise takes them randomly |
| 103 | +__C.TRAIN.serial_batches = False |
| 104 | +__C.TRAIN.batch_size = 1 |
| 105 | +# threads for loading data |
| 106 | +__C.TRAIN.nThreads = 5 |
| 107 | +# __C.TRAIN.max_epoch = 10 |
| 108 | +# learning rate policy: lambda|step|plateau |
| 109 | +__C.TRAIN.lr_policy = 'lambda' |
| 110 | +# of iter at starting learning rate |
| 111 | +__C.TRAIN.niter = 100 |
| 112 | +# of iter to linearly decay learning rate to zero |
| 113 | +__C.TRAIN.niter_decay = 100 |
| 114 | +# multiply by a gamma every lr_decay_iters iterations |
| 115 | +__C.TRAIN.lr_decay_iters = 50 |
| 116 | +# frequency of showing training results on console |
| 117 | +__C.TRAIN.print_freq = 10 |
| 118 | +# frequency of showing training results on console |
| 119 | +__C.TRAIN.print_img_freq = 200 |
| 120 | +# save model |
| 121 | +__C.TRAIN.save_latest_freq = 3000 |
| 122 | +# save model frequent |
| 123 | +__C.TRAIN.save_epoch_freq = 5 |
| 124 | +__C.TRAIN.begin_save_epoch = 0 |
| 125 | + |
| 126 | +__C.TRAIN.weight_decay_if = False |
| 127 | + |
| 128 | +''' |
| 129 | +TEST |
| 130 | +''' |
| 131 | +__C.TEST = EasyDict() |
| 132 | +__C.TEST.howmany_in_train = 10 |
| 133 | + |
| 134 | +''' |
| 135 | +Data |
| 136 | +Augmentation |
| 137 | +''' |
| 138 | +__C.DATA_AUG = EasyDict() |
| 139 | +__C.DATA_AUG.select_slice_num = 0 |
| 140 | +__C.DATA_AUG.fine_size = 256 |
| 141 | +__C.DATA_AUG.ct_channel = 256 |
| 142 | +__C.DATA_AUG.xray_channel = 1 |
| 143 | +__C.DATA_AUG.resize_size = 289 |
| 144 | + |
| 145 | +''' |
| 146 | +2D GAN define loss |
| 147 | +''' |
| 148 | +__C.TD_GAN = EasyDict() |
| 149 | +# identity loss |
| 150 | +__C.TD_GAN.idt_lambda = 10. |
| 151 | +__C.TD_GAN.idt_reduction = 'elementwise_mean' |
| 152 | +__C.TD_GAN.idt_weight = 0. |
| 153 | +__C.TD_GAN.idt_weight_range = [0., 1.] |
| 154 | +__C.TD_GAN.restruction_loss = 'l1' |
| 155 | +# perceptual loss |
| 156 | +__C.TD_GAN.fea_m_lambda = 10. |
| 157 | +# output of discriminator |
| 158 | +__C.TD_GAN.discriminator_feature = True |
| 159 | +# wgan-gp |
| 160 | +__C.TD_GAN.wgan_gp_lambda = 10. |
| 161 | +# identity loss of map |
| 162 | +__C.TD_GAN.map_m_lambda = 0. |
| 163 | +# 'l1' or 'mse' |
| 164 | +__C.TD_GAN.map_m_type = 'l1' |
| 165 | +__C.TD_GAN.fea_m_map_lambda = 10. |
| 166 | +# Discriminator train times |
| 167 | +__C.TD_GAN.critic_times = 1 |
| 168 | + |
| 169 | +''' |
| 170 | +3D GD-GAN define structure |
| 171 | +''' |
| 172 | +__C.D3_GAN = EasyDict() |
| 173 | +__C.D3_GAN.noise_len = 1000 |
| 174 | +__C.D3_GAN.input_shape = [4,4,4] |
| 175 | +# __C.D3_GAN.input_shape_nc = 512 |
| 176 | +__C.D3_GAN.output_shape = [128,128,128] |
| 177 | +# __C.D3_GAN.output_shape_nc = 1 |
| 178 | +__C.D3_GAN.encoder_input_shape = [128, 128] |
| 179 | +__C.D3_GAN.encoder_input_nc = 1 |
| 180 | +__C.D3_GAN.encoder_norm = 'instance' |
| 181 | +__C.D3_GAN.encoder_blocks = 4 |
| 182 | +__C.D3_GAN.multi_view = [1,2,3] |
| 183 | +__C.D3_GAN.min_max_norm = False |
| 184 | +__C.D3_GAN.skip_number = 1 |
| 185 | +# DoubleBlockLinearUnit Activation [low high k] |
| 186 | +__C.D3_GAN.dblu = [0., 1.0, 1.0] |
| 187 | + |
| 188 | +''' |
| 189 | +CT GAN |
| 190 | +''' |
| 191 | +__C.CTGAN = EasyDict() |
| 192 | +# input x-ray direction, 'H'-FrontBehind 'D'-UpDown 'W'-LeftRight |
| 193 | +# 'HDW' Means that deepness is 'H' and projection in plane of 'DW' |
| 194 | +# relative to CT. |
| 195 | +__C.CTGAN.Xray1_Direction = 'HDW' |
| 196 | +__C.CTGAN.Xray2_Direction = 'WDH' |
| 197 | +# dimension order of input CT is 'DHW'(should add 'NC'-01 to front when training) |
| 198 | +__C.CTGAN.CTOrder = [0, 1, 2, 3, 4] |
| 199 | +# NCHDW to xray1 and NCWDH to xray2 |
| 200 | +__C.CTGAN.CTOrder_Xray1 = [0, 1, 3, 2, 4] |
| 201 | +__C.CTGAN.CTOrder_Xray2 = [0, 1, 4, 2, 3] |
| 202 | +# identity loss'weight |
| 203 | +__C.CTGAN.idt_lambda = 1.0 |
| 204 | +__C.CTGAN.idt_reduction = 'elementwise_mean' |
| 205 | +__C.CTGAN.idt_weight = 0. |
| 206 | +__C.CTGAN.idt_weight_range = [0., 1.] |
| 207 | +# 'l1' or 'mse' |
| 208 | +__C.CTGAN.idt_loss = 'l1' |
| 209 | +# feature metrics loss |
| 210 | +__C.CTGAN.feature_D_lambda = 0. |
| 211 | +# projection loss'weight |
| 212 | +__C.CTGAN.map_projection_lambda = 0. |
| 213 | +# 'l1' or 'mse' |
| 214 | +__C.CTGAN.map_projection_loss = 'l1' |
| 215 | +# gan loss'weight |
| 216 | +__C.CTGAN.gan_lambda = 1.0 |
| 217 | +# multiView GAN auxiliary loss |
| 218 | +__C.CTGAN.auxiliary_lambda = 0. |
| 219 | +# 'l1' or 'mse' |
| 220 | +__C.CTGAN.auxiliary_loss = 'mse' |
| 221 | +# map discriminator |
| 222 | +__C.CTGAN.feature_D_map_lambda = 0. |
| 223 | +__C.CTGAN.map_gan_lambda = 1.0 |
| 224 | + |
| 225 | +def cfg_from_yaml(filename): |
| 226 | + ''' |
| 227 | + Load a config file and merge it into the default options |
| 228 | + :param filename: |
| 229 | + :return: |
| 230 | + ''' |
| 231 | + import yaml |
| 232 | + with open(filename, 'r') as f: |
| 233 | + yaml_cfg = EasyDict(yaml.load(f)) |
| 234 | + _merge_a_into_b(yaml_cfg, __C) |
| 235 | + |
| 236 | +def print_easy_dict(easy_dict): |
| 237 | + print('==='*10) |
| 238 | + print('====YAML Parameters') |
| 239 | + for k,v in easy_dict.__dict__.items(): |
| 240 | + print('{}: {}'.format(k, v)) |
| 241 | + print('==='*10) |
| 242 | + |
| 243 | +def merge_dict_and_yaml(in_dict, easy_dict): |
| 244 | + if type(easy_dict) is not EasyDict: |
| 245 | + return in_dict |
| 246 | + easy_list = _easy_dict_squeeze(easy_dict) |
| 247 | + for (k, v) in easy_list: |
| 248 | + if k in in_dict: |
| 249 | + raise KeyError('The same Key appear {}/{}'.format(k,k)) |
| 250 | + out_dict = EasyDict(dict(easy_list + list(in_dict.items()))) |
| 251 | + return out_dict |
| 252 | + |
| 253 | +def _easy_dict_squeeze(easy_dict): |
| 254 | + if type(easy_dict) is not EasyDict: |
| 255 | + print('Not EasyDict!!!') |
| 256 | + return [] |
| 257 | + |
| 258 | + total_list = [] |
| 259 | + for k, v in easy_dict.items(): |
| 260 | + # recursively merge dicts |
| 261 | + if type(v) is EasyDict: |
| 262 | + try: |
| 263 | + total_list += _easy_dict_squeeze(v) |
| 264 | + except: |
| 265 | + print('Error under config key: {}'.format(k)) |
| 266 | + raise |
| 267 | + else: |
| 268 | + total_list.append((k, v)) |
| 269 | + return total_list |
| 270 | + |
| 271 | +def _merge_a_into_b(a, b): |
| 272 | + ''' |
| 273 | + Merge easyDict a to easyDict b |
| 274 | + :param a: from easyDict |
| 275 | + :param b: to easyDict |
| 276 | + :return: |
| 277 | + ''' |
| 278 | + if type(a) is not EasyDict: |
| 279 | + return |
| 280 | + |
| 281 | + for k, v in a.items(): |
| 282 | + # check k in a or not |
| 283 | + if k not in b: |
| 284 | + raise KeyError('{} is not a valid config key'.format(k)) |
| 285 | + |
| 286 | + old_type = type(b[k]) |
| 287 | + if old_type is not type(v): |
| 288 | + if isinstance(b[k], np.ndarray): |
| 289 | + v = np.array(v, dtype=b[k].type) |
| 290 | + else: |
| 291 | + raise ValueError('Type mismatch ({} vs. {})' |
| 292 | + 'for config key: {}'.format(type(b[k]), type(v), k)) |
| 293 | + # recursively merge dicts |
| 294 | + if type(v) is EasyDict: |
| 295 | + try: |
| 296 | + _merge_a_into_b(a[k], b[k]) |
| 297 | + except: |
| 298 | + print('Error under config key: {}'.format(k)) |
| 299 | + raise |
| 300 | + else: |
| 301 | + b[k] = v |
| 302 | + |
| 303 | + |
| 304 | + |
| 305 | + |
| 306 | + |
| 307 | + |
| 308 | + |
0 commit comments