diff --git a/ngclearn/utils/metric_utils.py b/ngclearn/utils/metric_utils.py index f48472f84..0ab3a078e 100755 --- a/ngclearn/utils/metric_utils.py +++ b/ngclearn/utils/metric_utils.py @@ -122,7 +122,7 @@ def analyze_scores(mu, y, extract_label_indx=True): ## examines classifcation st confusion matrix, precision, recall, misses (empty predictions/all-zero rows), accuracy, adjusted-accuracy (counts all misses as incorrect) """ - miss_mask = (jnp.sum(mu, axis=1, keepdims=True) == 0.) * 1. + miss_mask = (jnp.sum(mu, axis=1) == 0.) * 1. misses = jnp.sum(miss_mask) ## how many misses? labels = y if extract_label_indx: @@ -133,7 +133,7 @@ def analyze_scores(mu, y, extract_label_indx=True): ## examines classifcation st recall = recall_score(labels, guesses, average='macro') ## produce accuracy score measurements guess = jnp.argmax(mu, axis=1) ## gather all model/output guesses - equality_mask = jnp.equal(guess, labels) + equality_mask = jnp.equal(guess, labels) * 1. ### compute raw accuracy acc = jnp.sum(equality_mask) / (y.shape[0] * 1.) ### compute hit-masked accuracy (adjusted accuracy