diff --git a/models.py b/models.py index 717a5e8..006de43 100644 --- a/models.py +++ b/models.py @@ -47,6 +47,7 @@ def build( self, x ): super( MyNormLayer, self ).build(x) def call( self, x ): + eps = 1e-12 x1, x2 = x bs, H, W, _ = [tf.shape(x1)[i] for i in range(4)] _, h, w, _ = [tf.shape(x2)[i] for i in range(4)] @@ -55,8 +56,8 @@ def call( self, x ): concat = tf.concat([x1, x2], axis=1) x_mean = K.mean( concat, axis=1, keepdims=True ) x_std = K.std( concat, axis=1, keepdims = True ) - x1 = (x1 - x_mean) / x_std - x2 = (x2 - x_mean) / x_std + x1 = (x1 - x_mean) / (x_std + eps) + x2 = (x2 - x_mean) / (x_std + eps) x1 = tf.reshape(x1, ( bs, H, W, -1 ) ) x2 = tf.reshape(x2, ( bs, h, w, -1 ) ) return [x1, x2] \ No newline at end of file