diff --git a/src/optim/early_stopping.py b/src/optim/early_stopping.py new file mode 100644 index 00000000..6508ea88 --- /dev/null +++ b/src/optim/early_stopping.py @@ -0,0 +1,20 @@ +class EarlyStopping: + def __init__(self, patience=5): + self.patience = patience + self.counter = 0 + self.best_loss = float('inf') + self.early_stop = False + + def __call__(self, val_loss): + if val_loss < self.best_loss: + self.best_loss = val_loss + self.counter = 0 + else: + self.counter += 1 + if self.counter >= self.patience: + self.early_stop = True + + def reset(self): + self.counter = 0 + self.best_loss = float('inf') + self.early_stop = False diff --git a/src/solver/clas_solver.py b/src/solver/clas_solver.py index aefb3d0a..d6004962 100644 --- a/src/solver/clas_solver.py +++ b/src/solver/clas_solver.py @@ -14,9 +14,14 @@ from ..misc import dist_utils from ._solver import BaseSolver from .clas_engine import evaluate, train_one_epoch +from ..optim.early_stopping import EarlyStopping class ClasSolver(BaseSolver): + def __init__(self, cfg): + super().__init__(cfg) + self.early_stopping = EarlyStopping(patience=cfg.early_stopping_patience) + def fit( self, ): @@ -32,6 +37,7 @@ def fit( start_time = time.time() start_epoch = self.last_epoch + 1 + self.early_stopping.reset() for epoch in range(start_epoch, args.epochs): if dist_utils.is_dist_available_and_initialized(): self.train_dataloader.sampler.set_epoch(epoch) @@ -59,6 +65,13 @@ def fit( module = self.ema.module if self.ema else self.model test_stats = evaluate(module, self.criterion, self.val_dataloader, self.device) + # Pass validation loss to EarlyStopping + val_loss = test_stats["loss"] + self.early_stopping(val_loss) + if self.early_stopping.early_stop: + print(f"Early stopping at epoch {epoch}") + break + log_stats = { **{f"train_{k}": v for k, v in train_stats.items()}, **{f"test_{k}": v for k, v in test_stats.items()}, diff --git a/src/solver/det_engine.py b/src/solver/det_engine.py index a35b2edf..0c0a6f5f 100644 --- a/src/solver/det_engine.py +++ b/src/solver/det_engine.py @@ -238,4 +238,15 @@ def evaluate( if "segm" in iou_types: stats["coco_eval_masks"] = coco_evaluator.coco_eval["segm"].stats.tolist() - return stats, coco_evaluator + #Log the computed metrics + if use_wandb: + wandb.log({ + "Precision": metrics["precision"], + "Recall": metrics["recall"], + "mAP@0.5": metrics["mAP@0.5"], + "mAP@0.5-0.95": metrics["mAP@0.5-0.95"], + "epoch": epoch, + }) + + return stats, coco_evaluator, metrics + diff --git a/src/solver/det_solver.py b/src/solver/det_solver.py index 08e8c19a..e4de6dc5 100644 --- a/src/solver/det_solver.py +++ b/src/solver/det_solver.py @@ -15,9 +15,14 @@ from ..misc import dist_utils, stats from ._solver import BaseSolver from .det_engine import evaluate, train_one_epoch +from ..optim.early_stopping import EarlyStopping class DetSolver(BaseSolver): + def __init__(self, cfg): + super().__init__(cfg) + self.early_stopping = EarlyStopping(patience=cfg.early_stopping_patience) + def fit(self): self.train() args = self.cfg @@ -59,6 +64,7 @@ def fit(self): best_stat_print = best_stat.copy() start_time = time.time() start_epoch = self.last_epoch + 1 + self.early_stopping.reset() for epoch in range(start_epoch, args.epochs): self.train_dataloader.set_epoch(epoch) # self.train_dataloader.dataset.set_epoch(epoch) @@ -111,6 +117,13 @@ def fit(self): self.use_wandb, ) + # Pass validation loss to EarlyStopping + val_loss = test_stats["coco_eval_bbox"][0] + self.early_stopping(val_loss) + if self.early_stopping.early_stop: + print(f"Early stopping at epoch {epoch}") + break + # TODO for k in test_stats: if self.writer and dist_utils.is_main_process(): diff --git a/src/solver/validator.py b/src/solver/validator.py index e38308ce..5f431fe0 100644 --- a/src/solver/validator.py +++ b/src/solver/validator.py @@ -70,17 +70,65 @@ def _compute_main_metrics(self, preds): recall = tps / (tps + fns) if (tps + fns) > 0 else 0 f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 iou = np.mean(ious).item() if ious else 0 + + # Compute mAP@0.5 and mAP@0.5-0.95 + mAP_50 = self._compute_mAP(preds, iou_threshold=0.5) + mAP_50_95 = self._compute_mAP(preds, iou_thresholds=np.arange(0.5, 1.0, 0.05)) + return { "f1": f1, "precision": precision, "recall": recall, "iou": iou, + "mAP@0.5": mAP_50, + "mAP@0.5-0.95": mAP_50_95, "TPs": tps, "FPs": fps, "FNs": fns, "extended_metrics": extended_metrics, } + def _compute_mAP(self, preds, iou_thresholds): + ap_per_class = defaultdict(list) + for pred, gt in zip(preds, self.gt): + pred_boxes = pred["boxes"] + pred_labels = pred["labels"] + gt_boxes = gt["boxes"] + gt_labels = gt["labels"] + + for iou_thresh in iou_thresholds: + ious = box_iou(pred_boxes, gt_boxes) + ious_mask = ious >= iou_thresh + + pred_indices, gt_indices = torch.nonzero(ious_mask, as_tuple=True) + if not pred_indices.numel(): + continue + + iou_values = ious[pred_indices, gt_indices] + sorted_indices = torch.argsort(-iou_values) + pred_indices = pred_indices[sorted_indices] + gt_indices = gt_indices[sorted_indices] + + matched_preds = set() + matched_gts = set() + for pred_idx, gt_idx in zip(pred_indices, gt_indices): + if gt_idx.item() not in matched_gts and pred_idx.item() not in matched_preds: + matched_preds.add(pred_idx.item()) + matched_gts.add(gt_idx.item()) + pred_label = pred_labels[pred_idx].item() + ap_per_class[pred_label].append(1) + + unmatched_preds = set(range(len(pred_boxes))) - matched_preds + for pred_idx in unmatched_preds: + pred_label = pred_labels[pred_idx].item() + ap_per_class[pred_label].append(0) + + aps = [] + for label, ap_list in ap_per_class.items(): + if ap_list: + aps.append(np.mean(ap_list)) + return np.mean(aps) if aps else 0 + def _compute_matrix_multi_class(self, preds): metrics_per_class = defaultdict(lambda: {"TPs": 0, "FPs": 0, "FNs": 0, "IoUs": []}) for pred, gt in zip(preds, self.gt): diff --git a/train.py b/train.py index d4aa4518..450fc681 100644 --- a/train.py +++ b/train.py @@ -57,7 +57,7 @@ def main(args) -> None: } ) - cfg = YAMLConfig(args.config, **update_dict) + cfg = YAMLConfig(args.config, **update_dict, early_stopping_patience=args.early_stopping_patience) if args.resume or args.tuning: if "HGNetv2" in cfg.yaml_cfg: @@ -99,6 +99,7 @@ def main(args) -> None: action="store_true", default=False, ) + parser.add_argument("--early_stopping_patience", type=int, default=5, help="patience for early stopping") # priority 1 parser.add_argument("-u", "--update", nargs="+", help="update yaml config")