Skip to content

Commit 575632f

Browse files
abc-125samet-akcayrajeshgangireddy
authored andcommitted
🚀 feat(metric): added PGn, PBn metrics (#2889)
* added pg and pb metrics * fixed typos * Update __init__.py * Update __init__.py - removed duplicate * Update pg_pb.py * Fixed pre-commit checks --------- Co-authored-by: Samet Akcay <[email protected]> Co-authored-by: Rajesh Gangireddy <[email protected]> Signed-off-by: StarPlatinum7 <[email protected]>
1 parent 0753520 commit 575632f

File tree

3 files changed

+292
-0
lines changed

3 files changed

+292
-0
lines changed

src/anomalib/metrics/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
- ``BinaryPrecisionRecallCurve``: Computes precision-recall curves
2626
- ``Evaluator``: Combines multiple metrics for evaluation
2727
- ``MinMax``: Normalizes scores to [0,1] range
28+
- ``PBn``: Presorted bad with n% good samples misclassified
29+
- ``PGn``: Presorted good with n% bad samples missed
2830
- ``PRO``: Per-Region Overlap score
2931
- ``PIMO``: Per-Image Missed Overlap score
3032
@@ -56,6 +58,7 @@
5658
from .evaluator import Evaluator
5759
from .f1_score import F1Max, F1Score
5860
from .min_max import MinMax
61+
from .pg_pb import PBn, PGn
5962
from .pimo import AUPIMO, PIMO
6063
from .precision_recall_curve import BinaryPrecisionRecallCurve
6164
from .pro import PRO
@@ -75,6 +78,8 @@
7578
"F1Score",
7679
"ManualThreshold",
7780
"MinMax",
81+
"PGn",
82+
"PBn",
7883
"PRO",
7984
"PIMO",
8085
"AUPIMO",

src/anomalib/metrics/pg_pb.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
# Copyright (C) 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""PGn and PBn metrics for binary image-level classification tasks.
5+
6+
This module provides two metrics for evaluating binary image-level classification performance
7+
on the assumption that bad (anomalous) samples are considered to be the positive class:
8+
9+
- ``PGn``: Presorted good with n% bad samples missed, can be interpreted as true negative rate
10+
at a fixed false negative rate (TNR@nFNR).
11+
- ``PBn``: Presorted bad with n% good samples misclassified, can be interpreted as true positive rate
12+
at a fixed false positive rate (TPR@nFPR).
13+
14+
These metrics emphasize the practical applications of anomaly detection models by showing their potential
15+
to reduce human operator workload while maintaining an acceptable level of misclassification.
16+
17+
Example:
18+
>>> from anomalib.metrics import PGn, PBn
19+
>>> from anomalib.data import ImageBatch
20+
>>> import torch
21+
>>> # Create sample batch
22+
>>> batch = ImageBatch(
23+
... image=torch.rand(4, 3, 32, 32),
24+
... pred_score=torch.tensor([0.1, 0.4, 0.35, 0.8]),
25+
... gt_label=torch.tensor([0, 0, 1, 1])
26+
... )
27+
>>> pg = PGn(fnr=0.2)
28+
>>> # Print name of the metric
29+
>>> print(pg.name)
30+
PG20
31+
>>> # Compute PGn score
32+
>>> pg.update(batch)
33+
>>> pg.compute()
34+
tensor(1.0)
35+
>>> pb = PBn(fpr=0.2)
36+
>>> # Print name of the metric
37+
>>> print(pb.name)
38+
PB20
39+
>>> # Compute PBn score
40+
>>> pb.update(batch)
41+
>>> pb.compute()
42+
tensor(1.0)
43+
44+
Note:
45+
Scores for both metrics range from 0 to 1, with 1 indicating perfect separation
46+
of the respective class with ``n``% or less of the other class misclassified.
47+
48+
Reference:
49+
Aimira Baitieva, Yacine Bouaouni, Alexandre Briot, Dick Ameln, Souhaiel Khalfaoui,
50+
Samet Akcay; Beyond Academic Benchmarks: Critical Analysis and Best Practices
51+
for Visual Industrial Anomaly Detection; in: Proceedings of the IEEE/CVF Conference
52+
on Computer Vision and Pattern Recognition (CVPR) Workshops, 2025, pp. 4024-4034,
53+
https://arxiv.org/abs/2503.23451
54+
"""
55+
56+
import torch
57+
from torchmetrics import Metric
58+
from torchmetrics.utilities import dim_zero_cat
59+
60+
from anomalib.metrics.base import AnomalibMetric
61+
62+
63+
class _PGn(Metric):
64+
"""Presorted good metric.
65+
66+
This class calculates the Presorted good (PGn) metric, which is the true negative rate
67+
at a fixed false negative rate.
68+
69+
Args:
70+
**kwargs: Additional arguments passed to the parent ``Metric`` class.
71+
72+
Attributes:
73+
fnr (torch.Tensor): Fixed false negative rate (bad parts misclassified).
74+
Defaults to ``0.05``.
75+
76+
Example:
77+
>>> from anomalib.metrics.pg_pb import _PGn
78+
>>> import torch
79+
>>> # Create sample data
80+
>>> preds = torch.tensor([0.1, 0.4, 0.35, 0.8])
81+
>>> target = torch.tensor([0, 0, 1, 1])
82+
>>> # Compute PGn score
83+
>>> pg = _PGn(fnr=0.2)
84+
>>> pg.update(preds, target)
85+
>>> pg.compute()
86+
tensor(1.0)
87+
"""
88+
89+
def __init__(self, fnr: float = 0.05, **kwargs) -> None:
90+
super().__init__(**kwargs)
91+
if fnr < 0 or fnr > 1:
92+
msg = f"False negative rate must be in the range between 0 and 1, got {fnr}."
93+
raise ValueError(msg)
94+
95+
self.fnr = torch.tensor(fnr, dtype=torch.float32)
96+
self.name = "PG" + str(int(fnr * 100))
97+
98+
self.add_state("preds", default=[], dist_reduce_fx="cat")
99+
self.add_state("target", default=[], dist_reduce_fx="cat")
100+
101+
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
102+
"""Update state with new values.
103+
104+
Args:
105+
preds (torch.Tensor): predictions of the model
106+
target (torch.Tensor): ground truth targets
107+
"""
108+
self.target.append(target)
109+
self.preds.append(preds)
110+
111+
def compute(self) -> torch.Tensor:
112+
"""Compute the PGn score at a given false negative rate.
113+
114+
Returns:
115+
torch.Tensor: PGn score value.
116+
117+
Raises:
118+
ValueError: If no negative samples are found.
119+
"""
120+
preds = dim_zero_cat(self.preds)
121+
target = dim_zero_cat(self.target)
122+
123+
pos_scores = preds[target == 1]
124+
thr_accept = torch.quantile(pos_scores, self.fnr)
125+
126+
neg_scores = preds[target == 0]
127+
if neg_scores.numel() == 0:
128+
msg = "No negative samples found. Cannot compute PGn score."
129+
raise ValueError(msg)
130+
pg = neg_scores[neg_scores < thr_accept].numel() / neg_scores.numel()
131+
132+
return torch.tensor(pg, dtype=preds.dtype)
133+
134+
135+
class PGn(AnomalibMetric, _PGn): # type: ignore[misc]
136+
"""Wrapper to add AnomalibMetric functionality to PGn metric.
137+
138+
This class wraps the internal ``_PGn`` metric to make it compatible with
139+
Anomalib's batch processing capabilities.
140+
"""
141+
142+
default_fields = ("pred_score", "gt_label")
143+
144+
145+
class _PBn(Metric):
146+
"""Presorted bad metric.
147+
148+
This class calculates the Presorted bad (PBn) metric, which is the true positive rate
149+
at a fixed false positive rate.
150+
151+
Args:
152+
fpr (float): Fixed false positive rate (good parts misclassified). Defaults to ``0.05``.
153+
**kwargs: Additional arguments passed to the parent ``Metric`` class.
154+
155+
Example:
156+
>>> from anomalib.metrics import _PBn
157+
>>> import torch
158+
>>> preds = torch.tensor([0.1, 0.4, 0.35, 0.8])
159+
>>> target = torch.tensor([0, 0, 1, 1])
160+
>>> pb = _PBn(fpr=0.2)
161+
>>> pb.update(preds, target)
162+
>>> pb.compute()
163+
tensor(1.0)
164+
"""
165+
166+
def __init__(self, fpr: float = 0.05, **kwargs) -> None:
167+
super().__init__(**kwargs)
168+
if fpr < 0 or fpr > 1:
169+
msg = f"False positive rate must be in the range between 0 and 1, got {fpr}."
170+
raise ValueError(msg)
171+
172+
self.fpr = torch.tensor(fpr, dtype=torch.float32)
173+
self.name = "PB" + str(int(fpr * 100))
174+
175+
self.add_state("preds", default=[], dist_reduce_fx="cat")
176+
self.add_state("target", default=[], dist_reduce_fx="cat")
177+
178+
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
179+
"""Update state with new values.
180+
181+
Args:
182+
preds (torch.Tensor): predictions of the model
183+
target (torch.Tensor): ground truth targets
184+
"""
185+
self.target.append(target)
186+
self.preds.append(preds)
187+
188+
def compute(self) -> torch.Tensor:
189+
"""Compute the PBn score at a given false positive rate.
190+
191+
Returns:
192+
torch.Tensor: PBn score value.
193+
194+
Raises:
195+
ValueError: If no positive samples are found.
196+
"""
197+
preds = dim_zero_cat(self.preds)
198+
target = dim_zero_cat(self.target)
199+
200+
neg_scores = preds[target == 0]
201+
thr_accept = torch.quantile(neg_scores, 1 - self.fpr)
202+
203+
pos_scores = preds[target == 1]
204+
if pos_scores.numel() == 0:
205+
msg = "No positive samples found. Cannot compute PBn score."
206+
raise ValueError(msg)
207+
pb = pos_scores[pos_scores > thr_accept].numel() / pos_scores.numel()
208+
209+
return torch.tensor(pb, dtype=preds.dtype)
210+
211+
212+
class PBn(AnomalibMetric, _PBn): # type: ignore[misc]
213+
"""Wrapper to add AnomalibMetric functionality to PBn metric.
214+
215+
This class wraps the internal ``_PBn`` metric to make it compatible with
216+
Anomalib's batch processing capabilities.
217+
"""
218+
219+
default_fields = ("pred_score", "gt_label")

tests/unit/metrics/test_pg_pb.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (C) 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Test PGn and PBn metrics."""
5+
6+
import pytest
7+
import torch
8+
9+
from anomalib.metrics.pg_pb import _PBn as PBn
10+
from anomalib.metrics.pg_pb import _PGn as PGn
11+
12+
13+
def test_pg_basic() -> None:
14+
"""Test PGn metric with simple binary classification."""
15+
metric = PGn(fnr=0.2)
16+
preds = torch.tensor([0.1, 0.4, 0.35, 0.8])
17+
labels = torch.tensor([0, 0, 1, 1])
18+
metric.update(preds, labels)
19+
result = metric.compute()
20+
assert result == torch.tensor(1.0)
21+
assert metric.name == "PG20"
22+
23+
24+
def test_pb_basic() -> None:
25+
"""Test PBn metric with simple binary classification."""
26+
metric = PBn(fpr=0.2)
27+
preds = torch.tensor([0.1, 0.4, 0.35, 0.8])
28+
labels = torch.tensor([0, 0, 1, 1])
29+
metric.update(preds, labels)
30+
result = metric.compute()
31+
assert result == torch.tensor(1.0)
32+
assert metric.name == "PB20"
33+
34+
35+
def test_pg_invalid_fnr() -> None:
36+
"""Test PGn metric raises ValueError for invalid fnr."""
37+
with pytest.raises(ValueError, match="False negative rate must be in the range between 0 and 1"):
38+
PGn(fnr=-0.1)
39+
with pytest.raises(ValueError, match="False negative rate must be in the range between 0 and 1"):
40+
PGn(fnr=1.1)
41+
42+
43+
def test_pb_invalid_fpr() -> None:
44+
"""Test PBn metric raises ValueError for invalid fpr."""
45+
with pytest.raises(ValueError, match="False positive rate must be in the range between 0 and 1"):
46+
PBn(fpr=-0.1)
47+
with pytest.raises(ValueError, match="False positive rate must be in the range between 0 and 1"):
48+
PBn(fpr=1.1)
49+
50+
51+
def test_pg_no_negatives() -> None:
52+
"""Test PGn metric raises ValueError if no negative samples."""
53+
metric = PGn(fnr=0.1)
54+
preds = torch.tensor([0.5, 0.7])
55+
labels = torch.tensor([1, 1])
56+
metric.update(preds, labels)
57+
with pytest.raises(ValueError, match="No negative samples found"):
58+
metric.compute()
59+
60+
61+
def test_pb_no_positives() -> None:
62+
"""Test PBn metric raises ValueError if no positive samples."""
63+
metric = PBn(fpr=0.1)
64+
preds = torch.tensor([0.2, 0.3])
65+
labels = torch.tensor([0, 0])
66+
metric.update(preds, labels)
67+
with pytest.raises(ValueError, match="No positive samples found"):
68+
metric.compute()

0 commit comments

Comments
 (0)