diff --git a/src/solver/det_solver.py b/src/solver/det_solver.py index b52b9224..656f09a3 100644 --- a/src/solver/det_solver.py +++ b/src/solver/det_solver.py @@ -18,6 +18,24 @@ from .det_engine import train_one_epoch, evaluate +def prepare_eval_metric(metric, catId, type): + # type precision: (iou, recall, cls, area range, max dets) + # type recall: (iou, cls, area range, max dets) + if type == "precision": + metric = [metric[i][j][catId][0][-1] for i in range(len(metric)) for j in range(len(metric[i]))] + elif type == "recall": + metric = [metric[i][catId][0][-1] for i in range(len(metric))] + + # Filter out values <= -1 + filtered_metric = [value for value in metric if value > -1] + + # Calculate mean or return NaN if empty + if filtered_metric: + return sum(filtered_metric) / len(filtered_metric) + else: + return float("nan") + + class DetSolver(BaseSolver): def fit(self, ): @@ -98,12 +116,28 @@ def fit(self, ): self.device ) + coco_eval = coco_evaluator.coco_eval["bbox"] + precisions = coco_evaluator.coco_eval["bbox"].eval['precision'] + recalls = coco_evaluator.coco_eval["bbox"].eval['recall'] + + class_results = {} + for category_id in coco_eval.cocoGt.getCatIds(): + category_info = coco_eval.cocoGt.loadCats([category_id])[0] + category_name = category_info['name'] + ap = prepare_eval_metric(precisions, category_id, "precision") + ar = prepare_eval_metric(recalls, category_id, "recall") + class_results[category_id] = {'name': category_name, 'mAP': ap, 'mAR': ar} + # TODO for k in test_stats: if self.writer and dist_utils.is_main_process(): for i, v in enumerate(test_stats[k]): self.writer.add_scalar(f'Test/{k}_{i}'.format(k), v, epoch) + for class_id in class_results.keys(): + self.writer.add_scalar(f'Test/class_{class_results[class_id]["name"]}_mAP', class_results[class_id]["mAP"], epoch) + self.writer.add_scalar(f'Test/class_{class_results[class_id]["name"]}_mAR', class_results[class_id]["mAR"], epoch) + if k in best_stat: best_stat['epoch'] = epoch if test_stats[k][0] > best_stat[k] else best_stat['epoch'] best_stat[k] = max(best_stat[k], test_stats[k][0])