Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to evaluate validation loss during training. #373

Open
Kunaldargan opened this issue Feb 2, 2025 · 0 comments
Open

How to evaluate validation loss during training. #373

Kunaldargan opened this issue Feb 2, 2025 · 0 comments

Comments

@Kunaldargan
Copy link

Kunaldargan commented Feb 2, 2025

Dear Authors,
Thank you for maintaining an excellent repository and keeping it intuitive for learners.

I wanted to confirm there is a better way to have validation loss during train time.

How do I compute validation loss during training?
#810
Taken inspiration from: facebookresearch/detectron2#810 (comment)


def do_train(args, cfg):
    """
    Args:
        cfg: an object with the following attributes:
            model: instantiate to a module
            dataloader.{train,test}: instantiate to dataloaders
            dataloader.evaluator: instantiate to evaluator for test set
            optimizer: instantaite to an optimizer
            lr_multiplier: instantiate to a fvcore scheduler
            train: other misc config defined in `configs/common/train.py`, including:
                output_dir (str)
                init_checkpoint (str)
                amp.enabled (bool)
                max_iter (int)
                eval_period, log_period (int)
                device (str)
                checkpointer (dict)
                ddp (dict)
    """
    model = instantiate(cfg.model)
    logger = logging.getLogger("detectron2")
    logger.info("Model:\n{}".format(model))
    model.to(cfg.train.device)
    
    # instantiate optimizer
    cfg.optimizer.params.model = model
    optim = instantiate(cfg.optimizer)

    # build training loader
    train_loader = instantiate(cfg.dataloader.train)
    
    # create ddp model
    model = create_ddp_model(model, **cfg.train.ddp)

    # build model ema
    ema.may_build_model_ema(cfg, model)

    trainer = Trainer(
        model=model,
        dataloader=train_loader,
        optimizer=optim,
        amp=cfg.train.amp.enabled,
        clip_grad_params=cfg.train.clip_grad.params if cfg.train.clip_grad.enabled else None,
    )
    
    checkpointer = DetectionCheckpointer(
        model,
        cfg.train.output_dir,
        trainer=trainer,
        # save model ema
        **ema.may_get_ema_checkpointer(cfg, model)
    )

    if comm.is_main_process():
        # writers = default_writers(cfg.train.output_dir, cfg.train.max_iter)
        output_dir = cfg.train.output_dir
        PathManager.mkdirs(output_dir)
        writers = [
            CommonMetricPrinter(cfg.train.max_iter),
            JSONWriter(os.path.join(output_dir, "metrics.json")),
            TensorboardXWriter(output_dir),
        ]
        if cfg.train.wandb.enabled:
            PathManager.mkdirs(cfg.train.wandb.params.dir)
            writers.append(WandbWriter(cfg))

    trainer.register_hooks(
        [
            hooks.IterationTimer(),
            ema.EMAHook(cfg, model) if cfg.train.model_ema.enabled else None,
            hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)),
            hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer)
            if comm.is_main_process()
            else None,
            hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)),
            hooks.PeriodicWriter(
                writers,
                period=cfg.train.log_period,
            )
            if comm.is_main_process()
            else None,
        ]
    )
    

    val_loss = ValidationLoss(cfg)  
    trainer.register_hooks([val_loss]) ## Register Validation Loss Hook
    # swap the order of PeriodicWriter and ValidationLoss
    trainer._hooks = trainer._hooks[:-2] + trainer._hooks[-2:][::-1]
    
    checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume)
    if args.resume and checkpointer.has_checkpoint():
        # The checkpoint stores the training iteration that just finished, thus we start
        # at the next iteration
        start_iter = trainer.iter + 1
    else:
        start_iter = 0
    trainer.train(start_iter, cfg.train.max_iter)
import copy
from detectron2.engine import HookBase
class ValidationLoss(HookBase):
	def __init__(self, cfg):
		super().__init__()
		self.cfg = copy.deepcopy(cfg)
		## In order to get instances from test data
		#print("self.cfg.dataloader.test.keys():", self.cfg.dataloader.test.keys()) 
		##self.cfg.dataloader.test.keys(): dict_keys(['dataset', 'mapper', 'num_workers', '_target_'])

		self.cfg.dataloader.test.mapper.is_train =  True
		self.data_loader = instantiate(self.cfg.dataloader.test)

	def after_step(self):
		if self.trainer.iter % (self.cfg.train.eval_period // 4) == 0:  # Evaluate every 25% of eval_period
	    		self._compute_validation_loss()

	def _compute_validation_loss(self):
		assert self.trainer.model.training, "[Trainer] model was changed to eval mode!"
		assert torch.cuda.is_available(), "[Trainer] CUDA is required for AMP training!"
		from torch.cuda.amp import autocast

		"""
		If you want to do something with the data, you can wrap the dataloader.
		"""
		total_loss = 0
		num_batches = 0
		loss_sums = {}  # Dictionary to accumulate loss for each component

		with torch.no_grad():
			for data in self.data_loader: #tqdm.tqdm()
				##print(data) ## Contains instances when self.cfg.dataloader.test.mapper.is_train : True
				"""
				If you want to do something with the losses, you can wrap the model.
				"""
				with autocast(enabled=self.trainer.amp):
					loss_dict = self.trainer.model(data)
					if isinstance(loss_dict, torch.Tensor):
						losses = loss_dict
						loss_dict = {"total_loss": loss_dict}
					else:
						losses = sum(loss_dict.values())

					#loss_dict = self.trainer.model(data)
					#print("loss_dict:", loss_dict)

					# Accumulate each loss separately
					for key, loss in loss_dict.items():
						loss_value = loss.item()
						loss_sums[key] = loss_sums.get(key, 0) + loss_value  # Accumulate loss per key
        
					total_loss += losses.item()
					num_batches += 1

			# Compute dataset-wide average losses
			avg_losses = {key: value / num_batches for key, value in loss_sums.items()} if num_batches > 0 else {}

			# Store scalar values
			for key, avg_loss in avg_losses.items():
			    self.trainer.storage.put_scalar(f"validation_{key}", avg_loss)

			avg_val_loss = total_loss / num_batches if num_batches > 0 else 0
			self.trainer.storage.put_scalar("validation_loss", avg_val_loss)

			print(f"Validation Loss Breakdown: {avg_losses}")  # Optional logging

Please add validation loss evaluation as a feature.

Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant