diff --git a/main.py b/main.py index 52f45e0..2febe3a 100644 --- a/main.py +++ b/main.py @@ -2,6 +2,7 @@ import numpy as np import torch import os +import pathlib import re import json import argparse @@ -266,7 +267,10 @@ def compute_metrics_rougel(eval_preds): ) if args.evaluate_dir is None: - trainer.train() + if list(pathlib.Path(save_dir).glob("checkpoint-*")) : + trainer.train(resume_from_checkpoint = True) + else : + trainer.train() trainer.save_model(save_dir) metrics = trainer.evaluate(eval_dataset = test_set, max_length=args.output_len)