diff --git a/tensorflow_addons/metrics/matthews_correlation_coefficient.py b/tensorflow_addons/metrics/matthews_correlation_coefficient.py index 8394c95055..2e1fb25a30 100644 --- a/tensorflow_addons/metrics/matthews_correlation_coefficient.py +++ b/tensorflow_addons/metrics/matthews_correlation_coefficient.py @@ -50,13 +50,13 @@ class MatthewsCorrelationCoefficient(tf.keras.metrics.Metric): Usage: - >>> y_true = np.array([[1.0], [1.0], [1.0], [0.0]], dtype=np.float32) - >>> y_pred = np.array([[1.0], [0.0], [1.0], [1.0]], dtype=np.float32) - >>> metric = tfa.metrics.MatthewsCorrelationCoefficient(num_classes=1) + >>> y_true = np.array([[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0]], dtype=np.float32) + >>> y_pred = np.array([[0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]], dtype=np.float32) + >>> metric = tfa.metrics.MatthewsCorrelationCoefficient(num_classes=2) >>> metric.update_state(y_true, y_pred) >>> result = metric.result() >>> result.numpy() - array([-0.33333334], dtype=float32) + -0.33333334 """ @typechecked @@ -70,28 +70,10 @@ def __init__( """Creates a Matthews Correlation Coefficient instance.""" super().__init__(name=name, dtype=dtype) self.num_classes = num_classes - self.true_positives = self.add_weight( - "true_positives", - shape=[self.num_classes], - initializer="zeros", - dtype=self.dtype, - ) - self.false_positives = self.add_weight( - "false_positives", - shape=[self.num_classes], - initializer="zeros", - dtype=self.dtype, - ) - self.false_negatives = self.add_weight( - "false_negatives", - shape=[self.num_classes], - initializer="zeros", - dtype=self.dtype, - ) - self.true_negatives = self.add_weight( - "true_negatives", - shape=[self.num_classes], - initializer="zeros", + self.conf_mtx = self.add_weight( + "conf_mtx", + shape=(self.num_classes, self.num_classes), + initializer=tf.keras.initializers.zeros, dtype=self.dtype, ) @@ -100,43 +82,35 @@ def update_state(self, y_true, y_pred, sample_weight=None): y_true = tf.cast(y_true, dtype=self.dtype) y_pred = tf.cast(y_pred, dtype=self.dtype) - true_positive = tf.math.count_nonzero(y_true * y_pred, 0) - # true_negative - y_true_negative = tf.math.not_equal(y_true, 1.0) - y_pred_negative = tf.math.not_equal(y_pred, 1.0) - true_negative = tf.math.count_nonzero( - tf.math.logical_and(y_true_negative, y_pred_negative), axis=0 + new_conf_mtx = tf.math.confusion_matrix( + labels=tf.argmax(y_true, 1), + predictions=tf.argmax(y_pred, 1), + num_classes=self.num_classes, + weights=sample_weight, + dtype=self.dtype, ) - # predicted sum - pred_sum = tf.math.count_nonzero(y_pred, 0) - # Ground truth label sum - true_sum = tf.math.count_nonzero(y_true, 0) - false_positive = pred_sum - true_positive - false_negative = true_sum - true_positive - - # true positive state_update - self.true_positives.assign_add(tf.cast(true_positive, self.dtype)) - # false positive state_update - self.false_positives.assign_add(tf.cast(false_positive, self.dtype)) - # false negative state_update - self.false_negatives.assign_add(tf.cast(false_negative, self.dtype)) - # true negative state_update - self.true_negatives.assign_add(tf.cast(true_negative, self.dtype)) + + self.conf_mtx.assign_add(new_conf_mtx) def result(self): - # numerator - numerator1 = self.true_positives * self.true_negatives - numerator2 = self.false_positives * self.false_negatives - numerator = numerator1 - numerator2 - # denominator - denominator1 = self.true_positives + self.false_positives - denominator2 = self.true_positives + self.false_negatives - denominator3 = self.true_negatives + self.false_positives - denominator4 = self.true_negatives + self.false_negatives - denominator = tf.math.sqrt( - denominator1 * denominator2 * denominator3 * denominator4 - ) - mcc = tf.math.divide_no_nan(numerator, denominator) + + true_sum = tf.reduce_sum(self.conf_mtx, axis=1) + pred_sum = tf.reduce_sum(self.conf_mtx, axis=0) + num_correct = tf.linalg.trace(self.conf_mtx) + num_samples = tf.reduce_sum(pred_sum) + + # covariance true-pred + cov_ytyp = num_correct * num_samples - tf.tensordot(true_sum, pred_sum, axes=1) + # covariance pred-pred + cov_ypyp = num_samples ** 2 - tf.tensordot(pred_sum, pred_sum, axes=1) + # covariance true-true + cov_ytyt = num_samples ** 2 - tf.tensordot(true_sum, true_sum, axes=1) + + mcc = cov_ytyp / tf.math.sqrt(cov_ytyt * cov_ypyp) + + if tf.math.is_nan(mcc): + mcc = tf.constant(0, dtype=self.dtype) + return mcc def get_config(self): @@ -150,5 +124,9 @@ def get_config(self): def reset_states(self): """Resets all of the metric state variables.""" - reset_value = np.zeros(self.num_classes, dtype=self.dtype) - K.batch_set_value([(v, reset_value) for v in self.variables]) + + for v in self.variables: + K.set_value( + v, + np.zeros((self.num_classes, self.num_classes), v.dtype.as_numpy_dtype), + ) diff --git a/tensorflow_addons/metrics/tests/matthews_correlation_coefficient_test.py b/tensorflow_addons/metrics/tests/matthews_correlation_coefficient_test.py index 4f570adbcb..1d5ad0f062 100644 --- a/tensorflow_addons/metrics/tests/matthews_correlation_coefficient_test.py +++ b/tensorflow_addons/metrics/tests/matthews_correlation_coefficient_test.py @@ -18,6 +18,7 @@ import numpy as np from tensorflow_addons.metrics import MatthewsCorrelationCoefficient +from sklearn.metrics import matthews_corrcoef as sklearn_matthew def test_config(): @@ -36,30 +37,59 @@ def check_results(obj, value): def test_binary_classes(): - gt_label = tf.constant([[1.0], [1.0], [1.0], [0.0]], dtype=tf.float32) - preds = tf.constant([[1.0], [0.0], [1.0], [1.0]], dtype=tf.float32) + gt_label = tf.constant( + [[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0]], dtype=tf.float32 + ) + preds = tf.constant( + [[0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]], dtype=tf.float32 + ) # Initialize - mcc = MatthewsCorrelationCoefficient(1) + mcc = MatthewsCorrelationCoefficient(2) # Update mcc.update_state(gt_label, preds) # Check results check_results(mcc, [-0.33333334]) +# See issue #2339 def test_multiple_classes(): - gt_label = tf.constant( - [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 1.0], [0.0, 1.0, 1.0]], - dtype=tf.float32, + gt_label = np.array( + [ + [1.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + [1.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + [0.0, 1.0, 0.0], + [0.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + [0.0, 1.0, 0.0], + ] ) - preds = tf.constant( - [[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 0.0]], - dtype=tf.float32, + preds = np.array( + [ + [0.0, 0.0, 1.0], + [1.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + [0.0, 0.0, 1.0], + [0.0, 0.0, 1.0], + [0.0, 0.0, 1.0], + [0.0, 0.0, 1.0], + [1.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + [0.0, 0.0, 1.0], + ] ) + tensor_gt_label = tf.constant(gt_label, dtype=tf.float32) + tensor_preds = tf.constant(preds, dtype=tf.float32) # Initialize mcc = MatthewsCorrelationCoefficient(3) - mcc.update_state(gt_label, preds) - # Check results - check_results(mcc, [-0.33333334, 1.0, 0.57735026]) + # Update + mcc.update_state(tensor_gt_label, tensor_preds) + # Check results by comparing to results of scikit-learn matthew implementation. + sklearn_result = sklearn_matthew(gt_label.argmax(axis=1), preds.argmax(axis=1)) + check_results(mcc, sklearn_result) # Keras model API check @@ -80,9 +110,13 @@ def test_keras_model(): def test_reset_states_graph(): - gt_label = tf.constant([[1.0], [1.0], [1.0], [0.0]], dtype=tf.float32) - preds = tf.constant([[1.0], [0.0], [1.0], [1.0]], dtype=tf.float32) - mcc = MatthewsCorrelationCoefficient(1) + gt_label = tf.constant( + [[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0]], dtype=tf.float32 + ) + preds = tf.constant( + [[0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]], dtype=tf.float32 + ) + mcc = MatthewsCorrelationCoefficient(2) mcc.update_state(gt_label, preds) @tf.function