Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/optim/early_stopping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
class EarlyStopping:
def __init__(self, patience=5):
self.patience = patience
self.counter = 0
self.best_loss = float('inf')
self.early_stop = False

def __call__(self, val_loss):
if val_loss < self.best_loss:
self.best_loss = val_loss
self.counter = 0
else:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True

def reset(self):
self.counter = 0
self.best_loss = float('inf')
self.early_stop = False
13 changes: 13 additions & 0 deletions src/solver/clas_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@
from ..misc import dist_utils
from ._solver import BaseSolver
from .clas_engine import evaluate, train_one_epoch
from ..optim.early_stopping import EarlyStopping


class ClasSolver(BaseSolver):
def __init__(self, cfg):
super().__init__(cfg)
self.early_stopping = EarlyStopping(patience=cfg.early_stopping_patience)

def fit(
self,
):
Expand All @@ -32,6 +37,7 @@ def fit(

start_time = time.time()
start_epoch = self.last_epoch + 1
self.early_stopping.reset()
for epoch in range(start_epoch, args.epochs):
if dist_utils.is_dist_available_and_initialized():
self.train_dataloader.sampler.set_epoch(epoch)
Expand Down Expand Up @@ -59,6 +65,13 @@ def fit(
module = self.ema.module if self.ema else self.model
test_stats = evaluate(module, self.criterion, self.val_dataloader, self.device)

# Pass validation loss to EarlyStopping
val_loss = test_stats["loss"]
self.early_stopping(val_loss)
if self.early_stopping.early_stop:
print(f"Early stopping at epoch {epoch}")
break

log_stats = {
**{f"train_{k}": v for k, v in train_stats.items()},
**{f"test_{k}": v for k, v in test_stats.items()},
Expand Down
13 changes: 12 additions & 1 deletion src/solver/det_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,4 +238,15 @@ def evaluate(
if "segm" in iou_types:
stats["coco_eval_masks"] = coco_evaluator.coco_eval["segm"].stats.tolist()

return stats, coco_evaluator
#Log the computed metrics
if use_wandb:
wandb.log({
"Precision": metrics["precision"],
"Recall": metrics["recall"],
"mAP@0.5": metrics["mAP@0.5"],
"mAP@0.5-0.95": metrics["mAP@0.5-0.95"],
"epoch": epoch,
})

return stats, coco_evaluator, metrics

13 changes: 13 additions & 0 deletions src/solver/det_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,14 @@
from ..misc import dist_utils, stats
from ._solver import BaseSolver
from .det_engine import evaluate, train_one_epoch
from ..optim.early_stopping import EarlyStopping


class DetSolver(BaseSolver):
def __init__(self, cfg):
super().__init__(cfg)
self.early_stopping = EarlyStopping(patience=cfg.early_stopping_patience)

def fit(self):
self.train()
args = self.cfg
Expand Down Expand Up @@ -59,6 +64,7 @@ def fit(self):
best_stat_print = best_stat.copy()
start_time = time.time()
start_epoch = self.last_epoch + 1
self.early_stopping.reset()
for epoch in range(start_epoch, args.epochs):
self.train_dataloader.set_epoch(epoch)
# self.train_dataloader.dataset.set_epoch(epoch)
Expand Down Expand Up @@ -111,6 +117,13 @@ def fit(self):
self.use_wandb,
)

# Pass validation loss to EarlyStopping
val_loss = test_stats["coco_eval_bbox"][0]
self.early_stopping(val_loss)
if self.early_stopping.early_stop:
print(f"Early stopping at epoch {epoch}")
break

# TODO
for k in test_stats:
if self.writer and dist_utils.is_main_process():
Expand Down
48 changes: 48 additions & 0 deletions src/solver/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,65 @@ def _compute_main_metrics(self, preds):
recall = tps / (tps + fns) if (tps + fns) > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
iou = np.mean(ious).item() if ious else 0

# Compute mAP@0.5 and mAP@0.5-0.95
mAP_50 = self._compute_mAP(preds, iou_threshold=0.5)
mAP_50_95 = self._compute_mAP(preds, iou_thresholds=np.arange(0.5, 1.0, 0.05))

return {
"f1": f1,
"precision": precision,
"recall": recall,
"iou": iou,
"mAP@0.5": mAP_50,
"mAP@0.5-0.95": mAP_50_95,
"TPs": tps,
"FPs": fps,
"FNs": fns,
"extended_metrics": extended_metrics,
}

def _compute_mAP(self, preds, iou_thresholds):
ap_per_class = defaultdict(list)
for pred, gt in zip(preds, self.gt):
pred_boxes = pred["boxes"]
pred_labels = pred["labels"]
gt_boxes = gt["boxes"]
gt_labels = gt["labels"]

for iou_thresh in iou_thresholds:
ious = box_iou(pred_boxes, gt_boxes)
ious_mask = ious >= iou_thresh

pred_indices, gt_indices = torch.nonzero(ious_mask, as_tuple=True)
if not pred_indices.numel():
continue

iou_values = ious[pred_indices, gt_indices]
sorted_indices = torch.argsort(-iou_values)
pred_indices = pred_indices[sorted_indices]
gt_indices = gt_indices[sorted_indices]

matched_preds = set()
matched_gts = set()
for pred_idx, gt_idx in zip(pred_indices, gt_indices):
if gt_idx.item() not in matched_gts and pred_idx.item() not in matched_preds:
matched_preds.add(pred_idx.item())
matched_gts.add(gt_idx.item())
pred_label = pred_labels[pred_idx].item()
ap_per_class[pred_label].append(1)

unmatched_preds = set(range(len(pred_boxes))) - matched_preds
for pred_idx in unmatched_preds:
pred_label = pred_labels[pred_idx].item()
ap_per_class[pred_label].append(0)

aps = []
for label, ap_list in ap_per_class.items():
if ap_list:
aps.append(np.mean(ap_list))
return np.mean(aps) if aps else 0

def _compute_matrix_multi_class(self, preds):
metrics_per_class = defaultdict(lambda: {"TPs": 0, "FPs": 0, "FNs": 0, "IoUs": []})
for pred, gt in zip(preds, self.gt):
Expand Down
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def main(args) -> None:
}
)

cfg = YAMLConfig(args.config, **update_dict)
cfg = YAMLConfig(args.config, **update_dict, early_stopping_patience=args.early_stopping_patience)

if args.resume or args.tuning:
if "HGNetv2" in cfg.yaml_cfg:
Expand Down Expand Up @@ -99,6 +99,7 @@ def main(args) -> None:
action="store_true",
default=False,
)
parser.add_argument("--early_stopping_patience", type=int, default=5, help="patience for early stopping")

# priority 1
parser.add_argument("-u", "--update", nargs="+", help="update yaml config")
Expand Down