You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
def do_train(args, cfg):
"""
Args:
cfg: an object with the following attributes:
model: instantiate to a module
dataloader.{train,test}: instantiate to dataloaders
dataloader.evaluator: instantiate to evaluator for test set
optimizer: instantaite to an optimizer
lr_multiplier: instantiate to a fvcore scheduler
train: other misc config defined in `configs/common/train.py`, including:
output_dir (str)
init_checkpoint (str)
amp.enabled (bool)
max_iter (int)
eval_period, log_period (int)
device (str)
checkpointer (dict)
ddp (dict)
"""
model = instantiate(cfg.model)
logger = logging.getLogger("detectron2")
logger.info("Model:\n{}".format(model))
model.to(cfg.train.device)
# instantiate optimizer
cfg.optimizer.params.model = model
optim = instantiate(cfg.optimizer)
# build training loader
train_loader = instantiate(cfg.dataloader.train)
# create ddp model
model = create_ddp_model(model, **cfg.train.ddp)
# build model ema
ema.may_build_model_ema(cfg, model)
trainer = Trainer(
model=model,
dataloader=train_loader,
optimizer=optim,
amp=cfg.train.amp.enabled,
clip_grad_params=cfg.train.clip_grad.params if cfg.train.clip_grad.enabled else None,
)
checkpointer = DetectionCheckpointer(
model,
cfg.train.output_dir,
trainer=trainer,
# save model ema
**ema.may_get_ema_checkpointer(cfg, model)
)
if comm.is_main_process():
# writers = default_writers(cfg.train.output_dir, cfg.train.max_iter)
output_dir = cfg.train.output_dir
PathManager.mkdirs(output_dir)
writers = [
CommonMetricPrinter(cfg.train.max_iter),
JSONWriter(os.path.join(output_dir, "metrics.json")),
TensorboardXWriter(output_dir),
]
if cfg.train.wandb.enabled:
PathManager.mkdirs(cfg.train.wandb.params.dir)
writers.append(WandbWriter(cfg))
trainer.register_hooks(
[
hooks.IterationTimer(),
ema.EMAHook(cfg, model) if cfg.train.model_ema.enabled else None,
hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)),
hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer)
if comm.is_main_process()
else None,
hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)),
hooks.PeriodicWriter(
writers,
period=cfg.train.log_period,
)
if comm.is_main_process()
else None,
]
)
val_loss = ValidationLoss(cfg)
trainer.register_hooks([val_loss]) ## Register Validation Loss Hook
# swap the order of PeriodicWriter and ValidationLoss
trainer._hooks = trainer._hooks[:-2] + trainer._hooks[-2:][::-1]
checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume)
if args.resume and checkpointer.has_checkpoint():
# The checkpoint stores the training iteration that just finished, thus we start
# at the next iteration
start_iter = trainer.iter + 1
else:
start_iter = 0
trainer.train(start_iter, cfg.train.max_iter)
import copy
from detectron2.engine import HookBase
class ValidationLoss(HookBase):
def __init__(self, cfg):
super().__init__()
self.cfg = copy.deepcopy(cfg)
## In order to get instances from test data
#print("self.cfg.dataloader.test.keys():", self.cfg.dataloader.test.keys())
##self.cfg.dataloader.test.keys(): dict_keys(['dataset', 'mapper', 'num_workers', '_target_'])
self.cfg.dataloader.test.mapper.is_train = True
self.data_loader = instantiate(self.cfg.dataloader.test)
def after_step(self):
if self.trainer.iter % (self.cfg.train.eval_period // 4) == 0: # Evaluate every 25% of eval_period
self._compute_validation_loss()
def _compute_validation_loss(self):
assert self.trainer.model.training, "[Trainer] model was changed to eval mode!"
assert torch.cuda.is_available(), "[Trainer] CUDA is required for AMP training!"
from torch.cuda.amp import autocast
"""
If you want to do something with the data, you can wrap the dataloader.
"""
total_loss = 0
num_batches = 0
loss_sums = {} # Dictionary to accumulate loss for each component
with torch.no_grad():
for data in self.data_loader: #tqdm.tqdm()
##print(data) ## Contains instances when self.cfg.dataloader.test.mapper.is_train : True
"""
If you want to do something with the losses, you can wrap the model.
"""
with autocast(enabled=self.trainer.amp):
loss_dict = self.trainer.model(data)
if isinstance(loss_dict, torch.Tensor):
losses = loss_dict
loss_dict = {"total_loss": loss_dict}
else:
losses = sum(loss_dict.values())
#loss_dict = self.trainer.model(data)
#print("loss_dict:", loss_dict)
# Accumulate each loss separately
for key, loss in loss_dict.items():
loss_value = loss.item()
loss_sums[key] = loss_sums.get(key, 0) + loss_value # Accumulate loss per key
total_loss += losses.item()
num_batches += 1
# Compute dataset-wide average losses
avg_losses = {key: value / num_batches for key, value in loss_sums.items()} if num_batches > 0 else {}
# Store scalar values
for key, avg_loss in avg_losses.items():
self.trainer.storage.put_scalar(f"validation_{key}", avg_loss)
avg_val_loss = total_loss / num_batches if num_batches > 0 else 0
self.trainer.storage.put_scalar("validation_loss", avg_val_loss)
print(f"Validation Loss Breakdown: {avg_losses}") # Optional logging
Please add validation loss evaluation as a feature.
Thank you
The text was updated successfully, but these errors were encountered:
Dear Authors,
Thank you for maintaining an excellent repository and keeping it intuitive for learners.
I wanted to confirm there is a better way to have validation loss during train time.
How do I compute validation loss during training?
#810
Taken inspiration from: facebookresearch/detectron2#810 (comment)
Please add validation loss evaluation as a feature.
Thank you
The text was updated successfully, but these errors were encountered: