diff --git a/tensorflow_addons/losses/focal_loss.py b/tensorflow_addons/losses/focal_loss.py index 47c785a465..112af1f302 100644 --- a/tensorflow_addons/losses/focal_loss.py +++ b/tensorflow_addons/losses/focal_loss.py @@ -143,4 +143,4 @@ def sigmoid_focal_crossentropy( modulating_factor = tf.pow((1.0 - p_t), gamma) # compute the final loss and return - return tf.reduce_sum(alpha_factor * modulating_factor * ce, axis=-1) + return alpha_factor * modulating_factor * ce