diff --git a/requirements.txt b/requirements.txt index f5678e05..2d29b61c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ keras_applications>=1.0.7,<=1.0.8 image-classifiers==1.0.0 -efficientnet==1.0.0 +efficientnet==1.1.0 diff --git a/segmentation_models/backbones/inception_resnet_v2.py b/segmentation_models/backbones/inception_resnet_v2.py index 5f6cebb6..499c7279 100644 --- a/segmentation_models/backbones/inception_resnet_v2.py +++ b/segmentation_models/backbones/inception_resnet_v2.py @@ -44,6 +44,7 @@ def conv2d_bn(x, strides=1, padding='same', activation='relu', + activation_dtype=None, use_bias=False, name=None): """Utility function to apply conv + BN. @@ -54,6 +55,8 @@ def conv2d_bn(x, strides: strides in `Conv2D`. padding: padding mode in `Conv2D`. activation: activation in `Conv2D`. + activation_dtype: Optional type parameter to force activations + to be treated in certain type. Used when mixed_precision is enabled. use_bias: whether to use a bias in `Conv2D`. name: name of the ops; will become `name + '_ac'` for the activation and `name + '_bn'` for the batch norm layer. @@ -74,11 +77,14 @@ def conv2d_bn(x, name=bn_name)(x) if activation is not None: ac_name = None if name is None else name + '_ac' - x = layers.Activation(activation, name=ac_name)(x) + if activation_dtype is None: + x = layers.Activation(activation, name=ac_name)(x) + else: + x = layers.Activation(activation, name=ac_name, dtype=activation_dtype)(x) return x -def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'): +def inception_resnet_block(x, scale, block_type, block_idx, activation='relu', activation_dtype=None): """Adds a Inception-ResNet block. This function builds 3 types of Inception-ResNet blocks mentioned in the paper, controlled by the `block_type` argument (which is the @@ -108,6 +114,8 @@ def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'): (see [activations](../activations.md)). When `activation=None`, no activation is applied (i.e., "linear" activation: `a(x) = x`). + activation_dtype: Optional type parameter to force activations + to be treated in certain type. Used when mixed_precision is enabled. # Returns Output tensor for the block. # Raises @@ -115,24 +123,24 @@ def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'): `'block17'` or `'block8'`. """ if block_type == 'block35': - branch_0 = conv2d_bn(x, 32, 1) - branch_1 = conv2d_bn(x, 32, 1) - branch_1 = conv2d_bn(branch_1, 32, 3) - branch_2 = conv2d_bn(x, 32, 1) - branch_2 = conv2d_bn(branch_2, 48, 3) - branch_2 = conv2d_bn(branch_2, 64, 3) + branch_0 = conv2d_bn(x, 32, 1, activation_dtype=activation_dtype) + branch_1 = conv2d_bn(x, 32, 1, activation_dtype=activation_dtype) + branch_1 = conv2d_bn(branch_1, 32, 3, activation_dtype=activation_dtype) + branch_2 = conv2d_bn(x, 32, 1, activation_dtype=activation_dtype) + branch_2 = conv2d_bn(branch_2, 48, 3, activation_dtype=activation_dtype) + branch_2 = conv2d_bn(branch_2, 64, 3, activation_dtype=activation_dtype) branches = [branch_0, branch_1, branch_2] elif block_type == 'block17': - branch_0 = conv2d_bn(x, 192, 1) - branch_1 = conv2d_bn(x, 128, 1) - branch_1 = conv2d_bn(branch_1, 160, [1, 7]) - branch_1 = conv2d_bn(branch_1, 192, [7, 1]) + branch_0 = conv2d_bn(x, 192, 1, activation_dtype=activation_dtype) + branch_1 = conv2d_bn(x, 128, 1, activation_dtype=activation_dtype) + branch_1 = conv2d_bn(branch_1, 160, [1, 7], activation_dtype=activation_dtype) + branch_1 = conv2d_bn(branch_1, 192, [7, 1], activation_dtype=activation_dtype) branches = [branch_0, branch_1] elif block_type == 'block8': - branch_0 = conv2d_bn(x, 192, 1) - branch_1 = conv2d_bn(x, 192, 1) - branch_1 = conv2d_bn(branch_1, 224, [1, 3]) - branch_1 = conv2d_bn(branch_1, 256, [3, 1]) + branch_0 = conv2d_bn(x, 192, 1, activation_dtype=activation_dtype) + branch_1 = conv2d_bn(x, 192, 1, activation_dtype=activation_dtype) + branch_1 = conv2d_bn(branch_1, 224, [1, 3], activation_dtype=activation_dtype) + branch_1 = conv2d_bn(branch_1, 256, [3, 1], activation_dtype=activation_dtype) branches = [branch_0, branch_1] else: raise ValueError('Unknown Inception-ResNet block type. ' @@ -148,14 +156,18 @@ def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'): 1, activation=None, use_bias=True, - name=block_name + '_conv') + name=block_name + '_conv', + activation_dtype=activation_dtype) x = layers.Lambda(lambda inputs, scale: inputs[0] + inputs[1] * scale, output_shape=backend.int_shape(x)[1:], arguments={'scale': scale}, name=block_name)([x, up]) if activation is not None: - x = layers.Activation(activation, name=block_name + '_ac')(x) + if activation_dtype is None: + x = layers.Activation(activation, name=block_name + '_ac')(x) + else: + x = layers.Activation(activation, name=block_name + '_ac', dtype=activation_dtype)(x) return x @@ -165,6 +177,7 @@ def InceptionResNetV2(include_top=True, input_shape=None, pooling=None, classes=1000, + activation_dtype=None, **kwargs): """Instantiates the Inception-ResNet v2 architecture. Optionally loads weights pre-trained on ImageNet. @@ -234,23 +247,23 @@ def InceptionResNetV2(include_top=True, img_input = input_tensor # Stem block: 35 x 35 x 192 - x = conv2d_bn(img_input, 32, 3, strides=2, padding='same') - x = conv2d_bn(x, 32, 3, padding='same') - x = conv2d_bn(x, 64, 3, padding='same') - x = layers.MaxPooling2D(3, strides=2, padding='same')(x) - x = conv2d_bn(x, 80, 1, padding='same') - x = conv2d_bn(x, 192, 3, padding='same') - x = layers.MaxPooling2D(3, strides=2, padding='same')(x) + x = conv2d_bn(img_input, 32, 3, strides=2, padding='same', activation_dtype=activation_dtype) + x = conv2d_bn(x, 32, 3, padding='same', activation_dtype=activation_dtype) + x = conv2d_bn(x, 64, 3, padding='same', activation_dtype=activation_dtype) + x = layers.MaxPooling2D(3, strides=2, padding='same', activation_dtype=activation_dtype)(x) + x = conv2d_bn(x, 80, 1, padding='same', activation_dtype=activation_dtype) + x = conv2d_bn(x, 192, 3, padding='same', activation_dtype=activation_dtype) + x = layers.MaxPooling2D(3, strides=2, padding='same', activation_dtype=activation_dtype)(x) # Mixed 5b (Inception-A block): 35 x 35 x 320 - branch_0 = conv2d_bn(x, 96, 1, padding='same') - branch_1 = conv2d_bn(x, 48, 1, padding='same') - branch_1 = conv2d_bn(branch_1, 64, 5, padding='same') - branch_2 = conv2d_bn(x, 64, 1, padding='same') - branch_2 = conv2d_bn(branch_2, 96, 3, padding='same') - branch_2 = conv2d_bn(branch_2, 96, 3, padding='same') - branch_pool = layers.AveragePooling2D(3, strides=1, padding='same')(x) - branch_pool = conv2d_bn(branch_pool, 64, 1, padding='same') + branch_0 = conv2d_bn(x, 96, 1, padding='same', activation_dtype=activation_dtype) + branch_1 = conv2d_bn(x, 48, 1, padding='same', activation_dtype=activation_dtype) + branch_1 = conv2d_bn(branch_1, 64, 5, padding='same', activation_dtype=activation_dtype) + branch_2 = conv2d_bn(x, 64, 1, padding='same', activation_dtype=activation_dtype) + branch_2 = conv2d_bn(branch_2, 96, 3, padding='same', activation_dtype=activation_dtype) + branch_2 = conv2d_bn(branch_2, 96, 3, padding='same', activation_dtype=activation_dtype) + branch_pool = layers.AveragePooling2D(3, strides=1, padding='same', activation_dtype=activation_dtype)(x) + branch_pool = conv2d_bn(branch_pool, 64, 1, padding='same', activation_dtype=activation_dtype) branches = [branch_0, branch_1, branch_2, branch_pool] channel_axis = 1 if backend.image_data_format() == 'channels_first' else 3 x = layers.Concatenate(axis=channel_axis, name='mixed_5b')(branches) @@ -263,11 +276,11 @@ def InceptionResNetV2(include_top=True, block_idx=block_idx) # Mixed 6a (Reduction-A block): 17 x 17 x 1088 - branch_0 = conv2d_bn(x, 384, 3, strides=2, padding='same') - branch_1 = conv2d_bn(x, 256, 1, padding='same') - branch_1 = conv2d_bn(branch_1, 256, 3, padding='same') - branch_1 = conv2d_bn(branch_1, 384, 3, strides=2, padding='same') - branch_pool = layers.MaxPooling2D(3, strides=2, padding='same')(x) + branch_0 = conv2d_bn(x, 384, 3, strides=2, padding='same', activation_dtype=activation_dtype) + branch_1 = conv2d_bn(x, 256, 1, padding='same', activation_dtype=activation_dtype) + branch_1 = conv2d_bn(branch_1, 256, 3, padding='same', activation_dtype=activation_dtype) + branch_1 = conv2d_bn(branch_1, 384, 3, strides=2, padding='same', activation_dtype=activation_dtype) + branch_pool = layers.MaxPooling2D(3, strides=2, padding='same', activation_dtype=activation_dtype)(x) branches = [branch_0, branch_1, branch_pool] x = layers.Concatenate(axis=channel_axis, name='mixed_6a')(branches) @@ -279,14 +292,14 @@ def InceptionResNetV2(include_top=True, block_idx=block_idx) # Mixed 7a (Reduction-B block): 8 x 8 x 2080 - branch_0 = conv2d_bn(x, 256, 1, padding='same') - branch_0 = conv2d_bn(branch_0, 384, 3, strides=2, padding='same') - branch_1 = conv2d_bn(x, 256, 1, padding='same') - branch_1 = conv2d_bn(branch_1, 288, 3, strides=2, padding='same') - branch_2 = conv2d_bn(x, 256, 1, padding='same') - branch_2 = conv2d_bn(branch_2, 288, 3, padding='same') - branch_2 = conv2d_bn(branch_2, 320, 3, strides=2, padding='same') - branch_pool = layers.MaxPooling2D(3, strides=2, padding='same')(x) + branch_0 = conv2d_bn(x, 256, 1, padding='same', activation_dtype=activation_dtype) + branch_0 = conv2d_bn(branch_0, 384, 3, strides=2, padding='same', activation_dtype=activation_dtype) + branch_1 = conv2d_bn(x, 256, 1, padding='same', activation_dtype=activation_dtype) + branch_1 = conv2d_bn(branch_1, 288, 3, strides=2, padding='same', activation_dtype=activation_dtype) + branch_2 = conv2d_bn(x, 256, 1, padding='same', activation_dtype=activation_dtype) + branch_2 = conv2d_bn(branch_2, 288, 3, padding='same', activation_dtype=activation_dtype) + branch_2 = conv2d_bn(branch_2, 320, 3, strides=2, padding='same', activation_dtype=activation_dtype) + branch_pool = layers.MaxPooling2D(3, strides=2, padding='same', activation_dtype=activation_dtype)(x) branches = [branch_0, branch_1, branch_2, branch_pool] x = layers.Concatenate(axis=channel_axis, name='mixed_7a')(branches) @@ -295,20 +308,26 @@ def InceptionResNetV2(include_top=True, x = inception_resnet_block(x, scale=0.2, block_type='block8', - block_idx=block_idx) + block_idx=block_idx, + activation_dtype=activation_dtype) x = inception_resnet_block(x, scale=1., activation=None, block_type='block8', - block_idx=10) + block_idx=10, + activation_dtype=activation_dtype) # Final convolution block: 8 x 8 x 1536 - x = conv2d_bn(x, 1536, 1, name='conv_7b') + x = conv2d_bn(x, 1536, 1, name='conv_7b', activation_dtype=activation_dtype) if include_top: # Classification block x = layers.GlobalAveragePooling2D(name='avg_pool')(x) - x = layers.Dense(classes, activation='softmax', name='predictions')(x) + if activation_dtype is None: + x = layers.Dense(classes, activation='softmax', name='predictions')(x) + else: + x = layers.Dense(classes, name='dense_logists')(x) + x = layers.Activation('softmax', dtype=activation_dtype, name='predictions')(x) else: if pooling == 'avg': x = layers.GlobalAveragePooling2D()(x) diff --git a/segmentation_models/backbones/inception_v3.py b/segmentation_models/backbones/inception_v3.py index db8b567f..543d91ac 100644 --- a/segmentation_models/backbones/inception_v3.py +++ b/segmentation_models/backbones/inception_v3.py @@ -77,6 +77,7 @@ def InceptionV3(include_top=True, input_tensor=None, input_shape=None, pooling=None, + activation_dtype=None, classes=1000, **kwargs): """Instantiates the Inception v3 architecture. @@ -109,6 +110,8 @@ def InceptionV3(include_top=True, the output of the model will be a 2D tensor. - `max` means that global max pooling will be applied. + activation_dtype: Optional type parameter to force activations + to be treated in certain type. Used when mixed_precision is enabled. classes: optional number of classes to classify images into, only to be specified if `include_top` is True, and if no `weights` argument is specified. @@ -345,7 +348,13 @@ def InceptionV3(include_top=True, if include_top: # Classification block x = layers.GlobalAveragePooling2D(name='avg_pool')(x) - x = layers.Dense(classes, activation='softmax', name='predictions')(x) + if activation_dtype is None: + x = layers.Dense(classes, activation='softmax', name='predictions')(x) + else: + #only softmax activation must be cast when using mixed precision + x = layers.Dense(classes, name='dense_logits')(x) + x = layers.Activation('softmax', dtype=activation_dtype, name='predictions')(x) + else: if pooling == 'avg': x = layers.GlobalAveragePooling2D()(x) diff --git a/segmentation_models/models/_common_blocks.py b/segmentation_models/models/_common_blocks.py index 221d83bd..1a2eb3e9 100644 --- a/segmentation_models/models/_common_blocks.py +++ b/segmentation_models/models/_common_blocks.py @@ -9,6 +9,7 @@ def Conv2dBn( data_format=None, dilation_rate=(1, 1), activation=None, + activation_dtype=None, kernel_initializer='glorot_uniform', bias_initializer='zeros', kernel_regularizer=None, @@ -62,7 +63,10 @@ def wrapper(input_tensor): x = layers.BatchNormalization(axis=bn_axis, name=bn_name)(x) if activation: - x = layers.Activation(activation, name=act_name)(x) + if activation_dtype is None: + x = layers.Activation(activation, name=act_name)(x) + else: + x = layers.Activation(activation, name=act_name, dtype=activation_dtype)(x) return x diff --git a/segmentation_models/models/fpn.py b/segmentation_models/models/fpn.py index deab7f54..3f74efdd 100644 --- a/segmentation_models/models/fpn.py +++ b/segmentation_models/models/fpn.py @@ -52,8 +52,16 @@ def DoubleConv3x3BnReLU(filters, use_batchnorm, name=None): name2 = name + 'b' def wrapper(input_tensor): - x = Conv3x3BnReLU(filters, use_batchnorm, name=name1)(input_tensor) - x = Conv3x3BnReLU(filters, use_batchnorm, name=name2)(x) + x = Conv3x3BnReLU( + filters, + use_batchnorm, + name=name1 + )(input_tensor) + x = Conv3x3BnReLU( + filters, + use_batchnorm, + name=name2 + )(x) return x return wrapper @@ -106,6 +114,7 @@ def build_fpn( segmentation_filters=128, classes=1, activation='sigmoid', + activation_dtype=None, use_batchnorm=True, aggregation='sum', dropout=None, @@ -124,22 +133,30 @@ def build_fpn( p2 = FPNBlock(pyramid_filters, stage=2)(p3, skips[3]) # add segmentation head to each - s5 = DoubleConv3x3BnReLU(segmentation_filters, use_batchnorm, name='segm_stage5')(p5) - s4 = DoubleConv3x3BnReLU(segmentation_filters, use_batchnorm, name='segm_stage4')(p4) - s3 = DoubleConv3x3BnReLU(segmentation_filters, use_batchnorm, name='segm_stage3')(p3) - s2 = DoubleConv3x3BnReLU(segmentation_filters, use_batchnorm, name='segm_stage2')(p2) + s5 = DoubleConv3x3BnReLU(segmentation_filters, + use_batchnorm, name='segm_stage5')(p5) + s4 = DoubleConv3x3BnReLU(segmentation_filters, + use_batchnorm, name='segm_stage4')(p4) + s3 = DoubleConv3x3BnReLU(segmentation_filters, + use_batchnorm, name='segm_stage3')(p3) + s2 = DoubleConv3x3BnReLU(segmentation_filters, + use_batchnorm, name='segm_stage2')(p2) # upsampling to same resolution - s5 = layers.UpSampling2D((8, 8), interpolation='nearest', name='upsampling_stage5')(s5) - s4 = layers.UpSampling2D((4, 4), interpolation='nearest', name='upsampling_stage4')(s4) - s3 = layers.UpSampling2D((2, 2), interpolation='nearest', name='upsampling_stage3')(s3) + s5 = layers.UpSampling2D( + (8, 8), interpolation='nearest', name='upsampling_stage5')(s5) + s4 = layers.UpSampling2D( + (4, 4), interpolation='nearest', name='upsampling_stage4')(s4) + s3 = layers.UpSampling2D( + (2, 2), interpolation='nearest', name='upsampling_stage3')(s3) # aggregating results if aggregation == 'sum': x = layers.Add(name='aggregation_sum')([s2, s3, s4, s5]) elif aggregation == 'concat': concat_axis = 3 if backend.image_data_format() == 'channels_last' else 1 - x = layers.Concatenate(axis=concat_axis, name='aggregation_concat')([s2, s3, s4, s5]) + x = layers.Concatenate( + axis=concat_axis, name='aggregation_concat')([s2, s3, s4, s5]) else: raise ValueError('Aggregation parameter should be in ("sum", "concat"), ' 'got {}'.format(aggregation)) @@ -148,8 +165,10 @@ def build_fpn( x = layers.SpatialDropout2D(dropout, name='pyramid_dropout')(x) # final stage - x = Conv3x3BnReLU(segmentation_filters, use_batchnorm, name='final_stage')(x) - x = layers.UpSampling2D(size=(2, 2), interpolation='bilinear', name='final_upsampling')(x) + x = Conv3x3BnReLU(segmentation_filters, use_batchnorm, + name='final_stage')(x) + x = layers.UpSampling2D( + size=(2, 2), interpolation='bilinear', name='final_upsampling')(x) # model head (define number of output classes) x = layers.Conv2D( @@ -160,7 +179,11 @@ def build_fpn( kernel_initializer='glorot_uniform', name='head_conv', )(x) - x = layers.Activation(activation, name=activation)(x) + if activation_dtype is None: + x = layers.Activation(activation, name=activation)(x) + else: + x = layers.Activation(activation, name=activation, + dtype=activation_dtype)(x) # create keras model instance model = models.Model(input_, x) @@ -177,6 +200,7 @@ def FPN( input_shape=(None, None, 3), classes=21, activation='softmax', + activation_dtype=None, weights=None, encoder_weights='imagenet', encoder_freeze=False, @@ -198,6 +222,8 @@ def FPN( classes: a number of classes for output (output shape - ``(h, w, classes)``). weights: optional, path to model weights. activation: name of one of ``keras.activations`` for last model layer (e.g. ``sigmoid``, ``softmax``, ``linear``). + activation_dtype: Optional type parameter to force activations + to be treated in certain type. Used when mixed_precision is enabled. encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet). encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable. encoder_features: a list of layer numbers or names starting from top of the model. @@ -218,7 +244,8 @@ def FPN( """ global backend, layers, models, keras_utils submodule_args = filter_keras_submodules(kwargs) - backend, layers, models, keras_utils = get_submodules_from_kwargs(submodule_args) + backend, layers, models, keras_utils = get_submodules_from_kwargs( + submodule_args) backbone = Backbones.get_backbone( backbone_name, @@ -239,6 +266,7 @@ def FPN( use_batchnorm=pyramid_use_batchnorm, dropout=pyramid_dropout, activation=activation, + activation_dtype=activation_dtype, classes=classes, aggregation=pyramid_aggregation, ) diff --git a/segmentation_models/models/linknet.py b/segmentation_models/models/linknet.py index 74c533c9..faba695d 100644 --- a/segmentation_models/models/linknet.py +++ b/segmentation_models/models/linknet.py @@ -63,7 +63,7 @@ def wrapper(input_tensor): return wrapper -def DecoderUpsamplingX2Block(filters, stage, use_batchnorm): +def DecoderUpsamplingX2Block(filters, stage, use_batchnorm, activation_dtype=None): conv_block1_name = 'decoder_stage{}a'.format(stage) conv_block2_name = 'decoder_stage{}b'.format(stage) conv_block3_name = 'decoder_stage{}c'.format(stage) @@ -74,7 +74,8 @@ def DecoderUpsamplingX2Block(filters, stage, use_batchnorm): def wrapper(input_tensor, skip=None): input_filters = backend.int_shape(input_tensor)[channels_axis] - output_filters = backend.int_shape(skip)[channels_axis] if skip is not None else filters + output_filters = backend.int_shape( + skip)[channels_axis] if skip is not None else filters x = Conv1x1BnReLU(input_filters // 4, use_batchnorm, name=conv_block1_name)(input_tensor) x = layers.UpSampling2D((2, 2), name=up_name)(x) @@ -88,7 +89,7 @@ def wrapper(input_tensor, skip=None): return wrapper -def DecoderTransposeX2Block(filters, stage, use_batchnorm): +def DecoderTransposeX2Block(filters, stage, use_batchnorm, activation_dtype=None): conv_block1_name = 'decoder_stage{}a'.format(stage) transpose_name = 'decoder_stage{}b_transpose'.format(stage) bn_name = 'decoder_stage{}b_bn'.format(stage) @@ -100,9 +101,11 @@ def DecoderTransposeX2Block(filters, stage, use_batchnorm): def wrapper(input_tensor, skip=None): input_filters = backend.int_shape(input_tensor)[channels_axis] - output_filters = backend.int_shape(skip)[channels_axis] if skip is not None else filters + output_filters = backend.int_shape( + skip)[channels_axis] if skip is not None else filters - x = Conv1x1BnReLU(input_filters // 4, use_batchnorm, name=conv_block1_name)(input_tensor) + x = Conv1x1BnReLU(input_filters // 4, use_batchnorm, + name=conv_block1_name)(input_tensor) x = layers.Conv2DTranspose( filters=input_filters // 4, kernel_size=(4, 4), @@ -114,9 +117,9 @@ def wrapper(input_tensor, skip=None): if use_batchnorm: x = layers.BatchNormalization(axis=bn_axis, name=bn_name)(x) - x = layers.Activation('relu', name=relu_name)(x) - x = Conv1x1BnReLU(output_filters, use_batchnorm, name=conv_block3_name)(x) + x = Conv1x1BnReLU(output_filters, use_batchnorm, + name=conv_block3_name)(x) if skip is not None: x = layers.Add(name=add_name)([x, skip]) @@ -138,6 +141,7 @@ def build_linknet( n_upsample_blocks=5, classes=1, activation='sigmoid', + activation_dtype=None, use_batchnorm=True, ): input_ = backbone.input @@ -160,7 +164,8 @@ def build_linknet( else: skip = None - x = decoder_block(decoder_filters[i], stage=i, use_batchnorm=use_batchnorm)(x, skip) + x = decoder_block( + decoder_filters[i], stage=i, use_batchnorm=use_batchnorm)(x, skip) # model head (define number of output classes) x = layers.Conv2D( @@ -170,7 +175,10 @@ def build_linknet( use_bias=True, kernel_initializer='glorot_uniform' )(x) - x = layers.Activation(activation, name=activation)(x) + if activation_dtype is None: + x = layers.Activation(activation, name=activation)(x) + else: + x = layers.Activation(activation, name=activation, dtype=activation_dtype)(x) # create keras model instance model = models.Model(input_, x) @@ -187,6 +195,7 @@ def Linknet( input_shape=(None, None, 3), classes=1, activation='sigmoid', + activation_dtype=None, weights=None, encoder_weights='imagenet', encoder_freeze=False, @@ -210,6 +219,8 @@ def Linknet( classes: a number of classes for output (output shape - ``(h, w, classes)``). activation: name of one of ``keras.activations`` for last model layer (e.g. ``sigmoid``, ``softmax``, ``linear``). + activation_dtype: Optional type parameter to force activations + to be treated in certain type. Used when mixed_precision is enabled. weights: optional, path to model weights. encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet). encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable. @@ -234,7 +245,8 @@ def Linknet( global backend, layers, models, keras_utils submodule_args = filter_keras_submodules(kwargs) - backend, layers, models, keras_utils = get_submodules_from_kwargs(submodule_args) + backend, layers, models, keras_utils = get_submodules_from_kwargs( + submodule_args) if decoder_block_type == 'upsampling': decoder_block = DecoderUpsamplingX2Block @@ -262,6 +274,7 @@ def Linknet( decoder_filters=decoder_filters, classes=classes, activation=activation, + activation_dtype=activation_dtype, n_upsample_blocks=len(decoder_filters), use_batchnorm=decoder_use_batchnorm, ) diff --git a/segmentation_models/models/pspnet.py b/segmentation_models/models/pspnet.py index 001b28c9..d4d89895 100644 --- a/segmentation_models/models/pspnet.py +++ b/segmentation_models/models/pspnet.py @@ -25,14 +25,16 @@ def get_submodules(): def check_input_shape(input_shape, factor): if input_shape is None: - raise ValueError("Input shape should be a tuple of 3 integers, not None!") + raise ValueError( + "Input shape should be a tuple of 3 integers, not None!") - h, w = input_shape[:2] if backend.image_data_format() == 'channels_last' else input_shape[1:] + h, w = input_shape[:2] if backend.image_data_format( + ) == 'channels_last' else input_shape[1:] min_size = factor * 6 is_wrong_shape = ( - h % min_size != 0 or w % min_size != 0 or - h < min_size or w < min_size + h % min_size != 0 or w % min_size != 0 or + h < min_size or w < min_size ) if is_wrong_shape: @@ -66,7 +68,7 @@ def SpatialContextBlock( level, conv_filters=512, pooling_type='avg', - use_batchnorm=True, + use_batchnorm=True ): if pooling_type not in ('max', 'avg'): raise ValueError('Unsupported pooling type - `{}`.'.format(pooling_type) + @@ -81,17 +83,21 @@ def SpatialContextBlock( def wrapper(input_tensor): # extract input feature maps size (h, and w dimensions) input_shape = backend.int_shape(input_tensor) - spatial_size = input_shape[1:3] if backend.image_data_format() == 'channels_last' else input_shape[2:] + spatial_size = input_shape[1:3] if backend.image_data_format( + ) == 'channels_last' else input_shape[2:] # Compute the kernel and stride sizes according to how large the final feature map will be # When the kernel factor and strides are equal, then we can compute the final feature map factor # by simply dividing the current factor by the kernel or stride factor # The final feature map sizes are 1x1, 2x2, 3x3, and 6x6. - pool_size = up_size = [spatial_size[0] // level, spatial_size[1] // level] + pool_size = up_size = [spatial_size[0] // + level, spatial_size[1] // level] - x = Pooling2D(pool_size, strides=pool_size, padding='same', name=pooling_name)(input_tensor) + x = Pooling2D(pool_size, strides=pool_size, + padding='same', name=pooling_name)(input_tensor) x = Conv1x1BnReLU(conv_filters, use_batchnorm, name=conv_block_name)(x) - x = layers.UpSampling2D(up_size, interpolation='bilinear', name=upsampling_name)(x) + x = layers.UpSampling2D( + up_size, interpolation='bilinear', name=upsampling_name)(x) return x return wrapper @@ -110,6 +116,7 @@ def build_psp( final_upsampling_factor=8, classes=21, activation='softmax', + activation_dtype=None, dropout=None, ): input_ = backbone.input @@ -124,7 +131,8 @@ def build_psp( # aggregate spatial pyramid concat_axis = 3 if backend.image_data_format() == 'channels_last' else 1 - x = layers.Concatenate(axis=concat_axis, name='psp_concat')([x, x1, x2, x3, x6]) + x = layers.Concatenate(axis=concat_axis, name='psp_concat')( + [x, x1, x2, x3, x6]) x = Conv1x1BnReLU(conv_filters, use_batchnorm, name='aggregation')(x) # model regularization @@ -140,8 +148,13 @@ def build_psp( name='final_conv', )(x) - x = layers.UpSampling2D(final_upsampling_factor, name='final_upsampling', interpolation='bilinear')(x) - x = layers.Activation(activation, name=activation)(x) + x = layers.UpSampling2D(final_upsampling_factor, + name='final_upsampling', interpolation='bilinear')(x) + if activation_dtype is None: + x = layers.Activation(activation, name=activation)(x) + else: + x = layers.Activation(activation, name=activation, + dtype=activation_dtype)(x) model = models.Model(input_, x) @@ -157,6 +170,7 @@ def PSPNet( input_shape=(384, 384, 3), classes=21, activation='softmax', + activation_dtype=None, weights=None, encoder_weights='imagenet', encoder_freeze=False, @@ -177,6 +191,8 @@ def PSPNet( classes: a number of classes for output (output shape - ``(h, w, classes)``). activation: name of one of ``keras.activations`` for last model layer (e.g. ``sigmoid``, ``softmax``, ``linear``). + activation_dtype: Optional type parameter to force activations + to be treated in certain type. Used when mixed_precision is enabled. weights: optional, path to model weights. encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet). encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable. @@ -198,7 +214,8 @@ def PSPNet( global backend, layers, models, keras_utils submodule_args = filter_keras_submodules(kwargs) - backend, layers, models, keras_utils = get_submodules_from_kwargs(submodule_args) + backend, layers, models, keras_utils = get_submodules_from_kwargs( + submodule_args) # control image input shape check_input_shape(input_shape, downsample_factor) @@ -220,7 +237,8 @@ def PSPNet( elif downsample_factor == 4: psp_layer_idx = feature_layers[2] else: - raise ValueError('Unsupported factor - `{}`, Use 4, 8 or 16.'.format(downsample_factor)) + raise ValueError( + 'Unsupported factor - `{}`, Use 4, 8 or 16.'.format(downsample_factor)) model = build_psp( backbone, @@ -231,6 +249,7 @@ def PSPNet( final_upsampling_factor=downsample_factor, classes=classes, activation=activation, + activation_dtype=activation_dtype, dropout=psp_dropout, ) diff --git a/segmentation_models/models/unet.py b/segmentation_models/models/unet.py index 7da2b391..f524c4a8 100644 --- a/segmentation_models/models/unet.py +++ b/segmentation_models/models/unet.py @@ -57,7 +57,8 @@ def wrapper(input_tensor, skip=None): x = layers.UpSampling2D(size=2, name=up_name)(input_tensor) if skip is not None: - x = layers.Concatenate(axis=concat_axis, name=concat_name)([x, skip]) + x = layers.Concatenate( + axis=concat_axis, name=concat_name)([x, skip]) x = Conv3x3BnReLU(filters, use_batchnorm, name=conv1_name)(x) x = Conv3x3BnReLU(filters, use_batchnorm, name=conv2_name)(x) @@ -91,9 +92,10 @@ def layer(input_tensor, skip=None): x = layers.BatchNormalization(axis=bn_axis, name=bn_name)(x) x = layers.Activation('relu', name=relu_name)(x) - + if skip is not None: - x = layers.Concatenate(axis=concat_axis, name=concat_name)([x, skip]) + x = layers.Concatenate( + axis=concat_axis, name=concat_name)([x, skip]) x = Conv3x3BnReLU(filters, use_batchnorm, name=conv_block_name)(x) @@ -114,6 +116,7 @@ def build_unet( n_upsample_blocks=5, classes=1, activation='sigmoid', + activation_dtype=None, use_batchnorm=True, ): input_ = backbone.input @@ -136,7 +139,8 @@ def build_unet( else: skip = None - x = decoder_block(decoder_filters[i], stage=i, use_batchnorm=use_batchnorm)(x, skip) + x = decoder_block( + decoder_filters[i], stage=i, use_batchnorm=use_batchnorm)(x, skip) # model head (define number of output classes) x = layers.Conv2D( @@ -147,7 +151,11 @@ def build_unet( kernel_initializer='glorot_uniform', name='final_conv', )(x) - x = layers.Activation(activation, name=activation)(x) + if activation_dtype is None: + x = layers.Activation(activation, name=activation)(x) + else: + x = layers.Activation(activation, name=activation, + dtype=activation_dtype)(x) # create keras model instance model = models.Model(input_, x) @@ -164,6 +172,7 @@ def Unet( input_shape=(None, None, 3), classes=1, activation='sigmoid', + activation_dtype=None, weights=None, encoder_weights='imagenet', encoder_freeze=False, @@ -184,6 +193,8 @@ def Unet( classes: a number of classes for output (output shape - ``(h, w, classes)``). activation: name of one of ``keras.activations`` for last model layer (e.g. ``sigmoid``, ``softmax``, ``linear``). + activation_dtype: Optional type parameter to force activations + to be treated in certain type. Used when mixed_precision is enabled. weights: optional, path to model weights. encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet). encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable. @@ -209,7 +220,8 @@ def Unet( global backend, layers, models, keras_utils submodule_args = filter_keras_submodules(kwargs) - backend, layers, models, keras_utils = get_submodules_from_kwargs(submodule_args) + backend, layers, models, keras_utils = get_submodules_from_kwargs( + submodule_args) if decoder_block_type == 'upsampling': decoder_block = DecoderUpsamplingX2Block @@ -237,6 +249,7 @@ def Unet( decoder_filters=decoder_filters, classes=classes, activation=activation, + activation_dtype=activation_dtype, n_upsample_blocks=len(decoder_filters), use_batchnorm=decoder_use_batchnorm, )