Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions src/solver/det_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,24 @@
from .det_engine import train_one_epoch, evaluate


def prepare_eval_metric(metric, catId, type):
# type precision: (iou, recall, cls, area range, max dets)
# type recall: (iou, cls, area range, max dets)
if type == "precision":
metric = [metric[i][j][catId][0][-1] for i in range(len(metric)) for j in range(len(metric[i]))]
elif type == "recall":
metric = [metric[i][catId][0][-1] for i in range(len(metric))]

# Filter out values <= -1
filtered_metric = [value for value in metric if value > -1]

# Calculate mean or return NaN if empty
if filtered_metric:
return sum(filtered_metric) / len(filtered_metric)
else:
return float("nan")


class DetSolver(BaseSolver):

def fit(self, ):
Expand Down Expand Up @@ -98,12 +116,28 @@ def fit(self, ):
self.device
)

coco_eval = coco_evaluator.coco_eval["bbox"]
precisions = coco_evaluator.coco_eval["bbox"].eval['precision']
recalls = coco_evaluator.coco_eval["bbox"].eval['recall']

class_results = {}
for category_id in coco_eval.cocoGt.getCatIds():
category_info = coco_eval.cocoGt.loadCats([category_id])[0]
category_name = category_info['name']
ap = prepare_eval_metric(precisions, category_id, "precision")
ar = prepare_eval_metric(recalls, category_id, "recall")
class_results[category_id] = {'name': category_name, 'mAP': ap, 'mAR': ar}

# TODO
for k in test_stats:
if self.writer and dist_utils.is_main_process():
for i, v in enumerate(test_stats[k]):
self.writer.add_scalar(f'Test/{k}_{i}'.format(k), v, epoch)

for class_id in class_results.keys():
self.writer.add_scalar(f'Test/class_{class_results[class_id]["name"]}_mAP', class_results[class_id]["mAP"], epoch)
self.writer.add_scalar(f'Test/class_{class_results[class_id]["name"]}_mAR', class_results[class_id]["mAR"], epoch)

if k in best_stat:
best_stat['epoch'] = epoch if test_stats[k][0] > best_stat[k] else best_stat['epoch']
best_stat[k] = max(best_stat[k], test_stats[k][0])
Expand Down