-
Notifications
You must be signed in to change notification settings - Fork 246
Add stop train callback, torch filesystem #87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Hi @Matvezy 👋🏻 Could you show how this feature is intended to be used? It seems that when you start training like this: from rfdetr import RFDETRBase
model = RFDETRBase()
model.train(dataset_dir=<DATASET_PATH>, epochs=10, batch_size=4, grad_accum_steps=4, lr=1e-4, output_dir=<OUTPUT_PATH>) The training will run until the end, and calling |
rfdetr/main.py
Outdated
@@ -151,6 +159,17 @@ def train(self, callbacks: DefaultDict[str, List[Callable]], **kwargs): | |||
print(args) | |||
device = torch.device(args.device) | |||
|
|||
# Initialize early stopping if enabled | |||
if args.early_stopping: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should place initialization of early stopping next to MetricsPlotSink
and MetricsTensorBoardSink
, in rfdetr/detr.py
to keep things consistant.
metrics_plot_sink = MetricsPlotSink(output_dir=config.output_dir)
self.callbacks["on_fit_epoch_end"].append(metrics_plot_sink.update)
self.callbacks["on_train_end"].append(metrics_plot_sink.save)
metrics_tensor_board_sink = MetricsTensorBoardSink(output_dir=config.output_dir)
self.callbacks["on_fit_epoch_end"].append(metrics_tensor_board_sink.update)
self.callbacks["on_train_end"].append(metrics_tensor_board_sink.close)
@Matvezy is there any reason why we couldn't do that?
Got that change in there |
Awesome! Let me test it and we should be good to merge this change |
…/rf-detr into add-stop-train-callback
rfdetr/util/early_stopping.py
Outdated
else: | ||
# No valid mAP metric found, skip early stopping check | ||
if self.verbose: | ||
print("Early stopping: No valid mAP metric found, skipping check") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should probably just raise an exception
rfdetr/util/early_stopping.py
Outdated
self.best_map = current_map | ||
self.counter = 0 | ||
if self.verbose: | ||
print(f"Early stopping: mAP improved to {current_map:.4f}") | ||
print(f"Early stopping: mAP improved to {current_map:.4f} using {metric_source} metric") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prints should be logging.logger
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets not change default behavior
@probicheaux I set the default value of |
Description
Added early stopping and additional filesystem setting for pytorch
Type of change
Please delete options that are not relevant.
How has this change been tested, please provide a testcase or example of how you tested the change?
Tested locally
Any specific deployment considerations
N/A
Docs