1
1
from abc import abstractmethod
2
2
from dataclasses import dataclass
3
- from typing import List , Set , Tuple , Union
3
+ from typing import List , Optional , Set , Tuple , Union
4
4
5
5
from sklearn .metrics import f1_score
6
6
7
7
from nucleus .annotation import AnnotationList , CategoryAnnotation
8
8
from nucleus .metrics .base import Metric , MetricResult , ScalarResult
9
+ from nucleus .metrics .filtering import ListOfAndFilters , ListOfOrAndFilters
9
10
from nucleus .metrics .filters import confidence_filter
10
11
from nucleus .prediction import CategoryPrediction , PredictionList
11
12
@@ -56,12 +57,37 @@ class CategorizationMetric(Metric):
56
57
def __init__ (
57
58
self ,
58
59
confidence_threshold : float = 0.0 ,
60
+ annotation_filters : Optional [
61
+ Union [ListOfOrAndFilters , ListOfAndFilters ]
62
+ ] = None ,
63
+ prediction_filters : Optional [
64
+ Union [ListOfOrAndFilters , ListOfAndFilters ]
65
+ ] = None ,
59
66
):
60
67
"""Initializes CategorizationMetric abstract object.
61
68
62
69
Args:
63
70
confidence_threshold: minimum confidence threshold for predictions to be taken into account for evaluation. Must be in [0, 1]. Default 0.0
71
+ annotation_filters: Filter predicates. Allowed formats are:
72
+ ListOfAndFilters where each Filter forms a chain of AND predicates.
73
+ or
74
+ ListOfOrAndFilters where Filters are expressed in disjunctive normal form (DNF), like
75
+ [[MetadataFilter("short_haired", "==", True), FieldFilter("label", "in", ["cat", "dog"]), ...].
76
+ DNF allows arbitrary boolean logical combinations of single field predicates. The innermost structures
77
+ each describe a single column predicate. The list of inner predicates is interpreted as a conjunction
78
+ (AND), forming a more selective `and` multiple field predicate.
79
+ Finally, the most outer list combines these filters as a disjunction (OR).
80
+ prediction_filters: Filter predicates. Allowed formats are:
81
+ ListOfAndFilters where each Filter forms a chain of AND predicates.
82
+ or
83
+ ListOfOrAndFilters where Filters are expressed in disjunctive normal form (DNF), like
84
+ [[MetadataFilter("short_haired", "==", True), FieldFilter("label", "in", ["cat", "dog"]), ...].
85
+ DNF allows arbitrary boolean logical combinations of single field predicates. The innermost structures
86
+ each describe a single column predicate. The list of inner predicates is interpreted as a conjunction
87
+ (AND), forming a more selective `and` multiple field predicate.
88
+ Finally, the most outer list combines these filters as a disjunction (OR).
64
89
"""
90
+ super ().__init__ (annotation_filters , prediction_filters )
65
91
assert 0 <= confidence_threshold <= 1
66
92
self .confidence_threshold = confidence_threshold
67
93
@@ -83,7 +109,7 @@ def eval(
83
109
def aggregate_score (self , results : List [CategorizationResult ]) -> ScalarResult : # type: ignore[override]
84
110
pass
85
111
86
- def __call__ (
112
+ def call_metric (
87
113
self , annotations : AnnotationList , predictions : PredictionList
88
114
) -> CategorizationResult :
89
115
if self .confidence_threshold > 0 :
@@ -139,7 +165,15 @@ class CategorizationF1(CategorizationMetric):
139
165
"""Evaluation method that matches categories and returns a CategorizationF1Result that aggregates to the F1 score"""
140
166
141
167
def __init__ (
142
- self , confidence_threshold : float = 0.0 , f1_method : str = "macro"
168
+ self ,
169
+ confidence_threshold : float = 0.0 ,
170
+ f1_method : str = "macro" ,
171
+ annotation_filters : Optional [
172
+ Union [ListOfOrAndFilters , ListOfAndFilters ]
173
+ ] = None ,
174
+ prediction_filters : Optional [
175
+ Union [ListOfOrAndFilters , ListOfAndFilters ]
176
+ ] = None ,
143
177
):
144
178
"""
145
179
Args:
@@ -169,8 +203,28 @@ def __init__(
169
203
Calculate metrics for each instance, and find their average (only
170
204
meaningful for multilabel classification where this differs from
171
205
:func:`accuracy_score`).
206
+ annotation_filters: Filter predicates. Allowed formats are:
207
+ ListOfAndFilters where each Filter forms a chain of AND predicates.
208
+ or
209
+ ListOfOrAndFilters where Filters are expressed in disjunctive normal form (DNF), like
210
+ [[MetadataFilter("short_haired", "==", True), FieldFilter("label", "in", ["cat", "dog"]), ...].
211
+ DNF allows arbitrary boolean logical combinations of single field predicates. The innermost structures
212
+ each describe a single column predicate. The list of inner predicates is interpreted as a conjunction
213
+ (AND), forming a more selective `and` multiple field predicate.
214
+ Finally, the most outer list combines these filters as a disjunction (OR).
215
+ prediction_filters: Filter predicates. Allowed formats are:
216
+ ListOfAndFilters where each Filter forms a chain of AND predicates.
217
+ or
218
+ ListOfOrAndFilters where Filters are expressed in disjunctive normal form (DNF), like
219
+ [[MetadataFilter("short_haired", "==", True), FieldFilter("label", "in", ["cat", "dog"]), ...].
220
+ DNF allows arbitrary boolean logical combinations of single field predicates. The innermost structures
221
+ each describe a single column predicate. The list of inner predicates is interpreted as a conjunction
222
+ (AND), forming a more selective `and` multiple field predicate.
223
+ Finally, the most outer list combines these filters as a disjunction (OR).
172
224
"""
173
- super ().__init__ (confidence_threshold )
225
+ super ().__init__ (
226
+ confidence_threshold , annotation_filters , prediction_filters
227
+ )
174
228
assert (
175
229
f1_method in F1_METHODS
176
230
), f"Invalid f1_method { f1_method } , expected one of { F1_METHODS } "
0 commit comments