Skip to content

Add stop train callback, torch filesystem #87

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions rfdetr/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ class TrainConfig(BaseModel):
use_ema: bool = True
num_workers: int = 2
weight_decay: float = 1e-4
early_stopping: bool = False
early_stopping_patience: int = 10
early_stopping_min_delta: float = 0.001
early_stopping_use_ema: bool = False
tensorboard: bool = True
wandb: bool = False
project: Optional[str] = None
Expand Down
10 changes: 10 additions & 0 deletions rfdetr/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@ def train_from_config(self, config: TrainConfig, **kwargs):
self.callbacks["on_fit_epoch_end"].append(metrics_wandb_sink.update)
self.callbacks["on_train_end"].append(metrics_wandb_sink.close)

if config.early_stopping:
from rfdetr.util.early_stopping import EarlyStoppingCallback
early_stopping_callback = EarlyStoppingCallback(
model=self.model,
patience=config.early_stopping_patience,
min_delta=config.early_stopping_min_delta,
use_ema=config.early_stopping_use_ema
)
self.callbacks["on_fit_epoch_end"].append(early_stopping_callback.update)

self.model.train(
**all_kwargs,
callbacks=self.callbacks,
Expand Down
32 changes: 31 additions & 1 deletion rfdetr/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
import shutil
from rfdetr.util.files import download_file
import os
if str(os.environ.get("USE_FILE_SYSTEM_SHARING", "False")).lower() in ["true", "1"]:
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')

logger = getLogger(__name__)

Expand Down Expand Up @@ -133,10 +136,15 @@ def __init__(self, **kwargs):
self.model.backbone[0].encoder = get_peft_model(self.model.backbone[0].encoder, lora_config)
self.model = self.model.to(self.device)
self.criterion, self.postprocessors = build_criterion_and_postprocessors(args)
self.stop_early = False

def reinitialize_detection_head(self, num_classes):
self.model.reinitialize_detection_head(num_classes)

def request_early_stop(self):
self.stop_early = True
print("Early stopping requested, will complete current epoch and stop")

def train(self, callbacks: DefaultDict[str, List[Callable]], **kwargs):
currently_supported_callbacks = ["on_fit_epoch_end", "on_train_batch_start", "on_train_end"]
for key in callbacks.keys():
Expand All @@ -150,7 +158,7 @@ def train(self, callbacks: DefaultDict[str, List[Callable]], **kwargs):
print("git:\n {}\n".format(utils.get_sha()))
print(args)
device = torch.device(args.device)

# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
Expand Down Expand Up @@ -394,6 +402,10 @@ def lr_lambda(current_step: int):
for callback in callbacks["on_fit_epoch_end"]:
callback(log_stats)

if self.stop_early:
print(f"Early stopping requested, stopping at epoch {epoch}")
break

best_is_ema = best_map_ema_5095 > best_map_5095
if best_is_ema:
shutil.copy2(output_dir / 'checkpoint_best_ema.pth', output_dir / 'checkpoint_best_total.pth')
Expand Down Expand Up @@ -736,6 +748,15 @@ def get_args_parser():
)
parser.add_argument('--lr_min_factor', default=0.0, type=float,
help='Minimum learning rate factor (as a fraction of initial lr) at the end of cosine annealing')
# Early stopping parameters
parser.add_argument('--early_stopping', action='store_true',
help='Enable early stopping based on mAP improvement')
parser.add_argument('--early_stopping_patience', default=10, type=int,
help='Number of epochs with no improvement after which training will be stopped')
parser.add_argument('--early_stopping_min_delta', default=0.001, type=float,
help='Minimum change in mAP to qualify as an improvement')
parser.add_argument('--early_stopping_use_ema', action='store_true',
help='Use EMA model metrics for early stopping')
# subparsers
subparsers = parser.add_subparsers(title='sub-commands', dest='subcommand',
description='valid subcommands', help='additional help')
Expand Down Expand Up @@ -866,6 +887,11 @@ def populate_args(
warmup_epochs=1,
lr_scheduler='step',
lr_min_factor=0.0,
# Early stopping parameters
early_stopping=True,
early_stopping_patience=10,
early_stopping_min_delta=0.001,
early_stopping_use_ema=False,
gradient_checkpointing=False,
# Additional
subcommand=None,
Expand Down Expand Up @@ -961,6 +987,10 @@ def populate_args(
warmup_epochs=warmup_epochs,
lr_scheduler=lr_scheduler,
lr_min_factor=lr_min_factor,
early_stopping=early_stopping,
early_stopping_patience=early_stopping_patience,
early_stopping_min_delta=early_stopping_min_delta,
early_stopping_use_ema=early_stopping_use_ema,
gradient_checkpointing=gradient_checkpointing,
**extra_kwargs
)
Expand Down
75 changes: 75 additions & 0 deletions rfdetr/util/early_stopping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
Early stopping callback for RF-DETR training
"""

from logging import getLogger

logger = getLogger(__name__)

class EarlyStoppingCallback:
"""
Early stopping callback that monitors mAP and stops training if no improvement
over a threshold is observed for a specified number of epochs.
Args:
patience (int): Number of epochs with no improvement to wait before stopping
min_delta (float): Minimum change in mAP to qualify as improvement
use_ema (bool): Whether to use EMA model metrics for early stopping
verbose (bool): Whether to print early stopping messages
"""

def __init__(self, model, patience=5, min_delta=0.001, use_ema=False, verbose=True):
self.patience = patience
self.min_delta = min_delta
self.use_ema = use_ema
self.verbose = verbose
self.best_map = 0.0
self.counter = 0
self.model = model

def update(self, log_stats):
"""Update early stopping state based on epoch validation metrics"""
regular_map = None
ema_map = None

if 'test_coco_eval_bbox' in log_stats:
regular_map = log_stats['test_coco_eval_bbox'][0]

if 'ema_test_coco_eval_bbox' in log_stats:
ema_map = log_stats['ema_test_coco_eval_bbox'][0]

current_map = None
if regular_map is not None and ema_map is not None:
if self.use_ema:
current_map = ema_map
metric_source = "EMA"
else:
current_map = max(regular_map, ema_map)
metric_source = "max(regular, EMA)"
elif ema_map is not None:
current_map = ema_map
metric_source = "EMA"
elif regular_map is not None:
current_map = regular_map
metric_source = "regular"
else:
if self.verbose:
raise ValueError("No valid mAP metric found!")
return

if self.verbose:
print(f"Early stopping: Current mAP ({metric_source}): {current_map:.4f}, Best: {self.best_map:.4f}, Diff: {current_map - self.best_map:.4f}, Min delta: {self.min_delta}")

if current_map > self.best_map + self.min_delta:
self.best_map = current_map
self.counter = 0
logger.info(f"Early stopping: mAP improved to {current_map:.4f} using {metric_source} metric")
else:
self.counter += 1
if self.verbose:
print(f"Early stopping: No improvement in mAP for {self.counter} epochs (best: {self.best_map:.4f}, current: {current_map:.4f})")

if self.counter >= self.patience:
print(f"Early stopping triggered: No improvement above {self.min_delta} threshold for {self.patience} epochs")
if self.model:
self.model.request_early_stop()
21 changes: 1 addition & 20 deletions rfdetr/util/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,26 +425,7 @@ def save_on_master(obj, f, *args, **kwargs):
Safely save objects, removing any callbacks that can't be pickled
"""
if is_main_process():
try:
if isinstance(obj, dict):
obj_copy = {}
for k, v in obj.items():
if k == 'args' and hasattr(v, '__dict__'):
args_dict = copy.copy(v.__dict__)
if 'callbacks' in args_dict:
del args_dict['callbacks']
obj_copy[k] = argparse.Namespace(**args_dict)
elif k != 'callbacks':
obj_copy[k] = v
obj = obj_copy

torch.save(obj, f, *args, **kwargs)
except Exception as e:
print(f"Error in safe_save_on_master: {e}")
if isinstance(obj, dict) and 'model' in obj:
print("Falling back to saving only model state_dict")
torch.save({'model': obj['model']}, f, *args, **kwargs)

torch.save(obj, f, *args, **kwargs)

def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
Expand Down