Skip to content

Commit

Permalink
Merge pull request #55 from NACLab/dev
Browse files Browse the repository at this point in the history
fixed adj-acc bug in analyze_scores in metrics
  • Loading branch information
ago109 authored Jun 27, 2024
2 parents 8e2e66d + f00c63e commit 76091d9
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions ngclearn/utils/metric_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def analyze_scores(mu, y, extract_label_indx=True): ## examines classifcation st
confusion matrix, precision, recall, misses (empty predictions/all-zero rows),
accuracy, adjusted-accuracy (counts all misses as incorrect)
"""
miss_mask = (jnp.sum(mu, axis=1, keepdims=True) == 0.) * 1.
miss_mask = (jnp.sum(mu, axis=1) == 0.) * 1.
misses = jnp.sum(miss_mask) ## how many misses?
labels = y
if extract_label_indx:
Expand All @@ -133,7 +133,7 @@ def analyze_scores(mu, y, extract_label_indx=True): ## examines classifcation st
recall = recall_score(labels, guesses, average='macro')
## produce accuracy score measurements
guess = jnp.argmax(mu, axis=1) ## gather all model/output guesses
equality_mask = jnp.equal(guess, labels)
equality_mask = jnp.equal(guess, labels) * 1.
### compute raw accuracy
acc = jnp.sum(equality_mask) / (y.shape[0] * 1.)
### compute hit-masked accuracy (adjusted accuracy
Expand Down

0 comments on commit 76091d9

Please sign in to comment.