@@ -170,7 +170,7 @@ def _predictions_matches_labels(
170170 return labels == predicted_label
171171
172172 def _should_keep_prediction (
173- self , predicted_scores : List [OutputScore ], actual_label : str
173+ self , predicted_scores : List [OutputScore ], actual_label : OutputScore
174174 ) -> bool :
175175 # filter by class
176176 if len (self ._config .classes ) != 0 :
@@ -180,13 +180,14 @@ def _should_keep_prediction(
180180 return False
181181
182182 # filter by accuracy
183+ label_name = actual_label .label
183184 if self ._config .prediction == "all" :
184185 pass
185186 elif self ._config .prediction == "correct" :
186- if not self ._predictions_matches_labels (predicted_scores , actual_label ):
187+ if not self ._predictions_matches_labels (predicted_scores , label_name ):
187188 return False
188189 elif self ._config .prediction == "incorrect" :
189- if self ._predictions_matches_labels (predicted_scores , actual_label ):
190+ if self ._predictions_matches_labels (predicted_scores , label_name ):
190191 return False
191192 else :
192193 raise Exception (f"Invalid prediction config: { self ._config .prediction } " )
@@ -238,16 +239,16 @@ def _calculate_vis_output(
238239 predicted = predicted .cpu ().squeeze (0 )
239240
240241 if label is not None and len (label ) > 0 :
241- actual_label = OutputScore (
242- score = 0 , index = label [0 ], label = self .classes [label [0 ]]
242+ actual_label_output = OutputScore (
243+ score = 100 , index = label [0 ], label = self .classes [label [0 ]]
243244 )
244245 else :
245- actual_label = None
246+ actual_label_output = None
246247
247248 predicted_scores = self ._get_labels_from_scores (scores , predicted )
248249
249250 # Filter based on UI configuration
250- if not self ._should_keep_prediction (predicted_scores , actual_label ):
251+ if not self ._should_keep_prediction (predicted_scores , actual_label_output ):
251252 return None
252253
253254 baselines = [tuple (b ) for b in baselines ]
@@ -277,9 +278,9 @@ def _calculate_vis_output(
277278
278279 return VisualizationOutput (
279280 feature_outputs = features_per_input ,
280- actual = actual_label ,
281+ actual = actual_label_output ,
281282 predicted = predicted_scores ,
282- active_index = target if target is not None else actual_label .index ,
283+ active_index = target if target is not None else actual_label_output .index ,
283284 )
284285
285286 def _get_outputs (self ) -> List [VisualizationOutput ]:
0 commit comments