Skip to content

Commit

Permalink
Update auroc.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Jingtao-Li-CVer authored Jan 8, 2024
1 parent f19875e commit f714f83
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions metrics/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
from iou_metric import SegEvaluator


def compute_auroc(epoch: int, ep_reconst, ep_gt, working_dir: str, image_level=False, save_image=False, compute_iou=True) -> float:
def compute_auroc(epoch: int, ep_amaps, ep_gt, working_dir: str, image_level=False, save_image=False, compute_iou=True) -> float:
"""Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC)
Args:
epoch (int): Current epoch
ep_reconst (NDArray): Reconstructed images in a current epoch
ep_amaps (NDArray): Anomaly maps in a current epoch
ep_gt (NDArray): Ground truth masks in a current epoch
Returns:
Expand All @@ -34,7 +34,7 @@ def compute_auroc(epoch: int, ep_reconst, ep_gt, working_dir: str, image_level=F


y_score, y_true = [], []
for i, (amap, gt) in enumerate(tqdm(zip(ep_reconst, ep_gt))):
for i, (amap, gt) in enumerate(tqdm(zip(ep_amaps, ep_gt))):
anomaly_scores = amap[np.where(gt == 0)]
normal_scores = amap[np.where(gt == 1)]
y_score += anomaly_scores.tolist()
Expand All @@ -52,7 +52,7 @@ def compute_auroc(epoch: int, ep_reconst, ep_gt, working_dir: str, image_level=F
threshold = thresholds[maxindex]
evaluator = SegEvaluator(2)
evaluator.reset()
for i, (amap, gt) in enumerate(tqdm(zip(ep_reconst, ep_gt))):
for i, (amap, gt) in enumerate(tqdm(zip(ep_amaps, ep_gt))):
amap = np.where(amap > threshold, 1, 0)
amap = amap.astype(np.int8)
evaluator.add_batch(gt, amap)
Expand All @@ -70,4 +70,4 @@ def compute_auroc(epoch: int, ep_reconst, ep_gt, working_dir: str, image_level=F
plt.savefig(os.path.join(save_dir,"roc_curve.png"))
plt.close()

return scoreDF
return scoreDF

0 comments on commit f714f83

Please sign in to comment.