Skip to content

Commit

Permalink
add kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
lufficc committed Jun 24, 2019
1 parent 711f3e3 commit 436331b
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 10 deletions.
4 changes: 2 additions & 2 deletions ssd/data/datasets/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .voc import voc_evaluation


def evaluate(dataset, predictions, output_dir):
def evaluate(dataset, predictions, output_dir, **kwargs):
"""evaluate dataset using different methods based on dataset type.
Args:
dataset: Dataset object
Expand All @@ -14,7 +14,7 @@ def evaluate(dataset, predictions, output_dir):
evaluation result
"""
args = dict(
dataset=dataset, predictions=predictions, output_dir=output_dir
dataset=dataset, predictions=predictions, output_dir=output_dir, **kwargs,
)
if isinstance(dataset, VOCDataset):
return voc_evaluation(**args)
Expand Down
13 changes: 12 additions & 1 deletion ssd/data/datasets/evaluation/coco/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
import logging
import os
from datetime import datetime


def coco_evaluation(dataset, predictions, output_dir):
def coco_evaluation(dataset, predictions, output_dir, iteration=None):
coco_results = []
for i, prediction in enumerate(predictions):
img_info = dataset.get_img_info(i)
Expand Down Expand Up @@ -43,9 +44,19 @@ def coco_evaluation(dataset, predictions, output_dir):
coco_eval.accumulate()
coco_eval.summarize()

result_strings = []
keys = ["AP", "AP50", "AP75", "APs", "APm", "APl"]
metrics = {}
for i, key in enumerate(keys):
metrics[key] = coco_eval.stats[i]
logger.info('{:<10}: {}'.format(key, round(coco_eval.stats[i], 3)))
result_strings.append('{:<10}: {}'.format(key, round(coco_eval.stats[i], 3)))

if iteration is not None:
result_path = os.path.join(output_dir, 'result_{:07d}.txt'.format(iteration))
else:
result_path = os.path.join(output_dir, 'result_{}.txt'.format(datetime.now().strftime('%Y-%m-%d_%H-%M-%S')))
with open(result_path, "w") as f:
f.write('\n'.join(result_strings))

return dict(metrics=metrics)
9 changes: 7 additions & 2 deletions ssd/data/datasets/evaluation/voc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .eval_detection_voc import eval_detection_voc


def voc_evaluation(dataset, predictions, output_dir):
def voc_evaluation(dataset, predictions, output_dir, iteration=None):
class_names = dataset.class_names

pred_boxes_list = []
Expand Down Expand Up @@ -49,7 +49,12 @@ def voc_evaluation(dataset, predictions, output_dir):
metrics[class_names[i]] = ap
result_str += "{:<16}: {:.4f}\n".format(class_names[i], ap)
logger.info(result_str)
result_path = os.path.join(output_dir, "result_{}.txt".format(datetime.now().strftime('%Y-%m-%d_%H-%M-%S')))

if iteration is not None:
result_path = os.path.join(output_dir, 'result_{:07d}.txt'.format(iteration))
else:
result_path = os.path.join(output_dir, 'result_{}.txt'.format(datetime.now().strftime('%Y-%m-%d_%H-%M-%S')))
with open(result_path, "w") as f:
f.write(result_str)

return dict(metrics=metrics)
8 changes: 4 additions & 4 deletions ssd/engine/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def compute_on_dataset(model, data_loader, device):
return results_dict


def inference(model, data_loader, dataset_name, device, output_folder=None, use_cached=False):
def inference(model, data_loader, dataset_name, device, output_folder=None, use_cached=False, **kwargs):
dataset = data_loader.dataset
logger = logging.getLogger("SSD.inference")
logger.info("Evaluating {} dataset({} images):".format(dataset_name, len(dataset)))
Expand All @@ -64,11 +64,11 @@ def inference(model, data_loader, dataset_name, device, output_folder=None, use_
return
if output_folder:
torch.save(predictions, predictions_path)
return evaluate(dataset=dataset, predictions=predictions, output_dir=output_folder)
return evaluate(dataset=dataset, predictions=predictions, output_dir=output_folder, **kwargs)


@torch.no_grad()
def do_evaluation(cfg, model, distributed):
def do_evaluation(cfg, model, distributed, **kwargs):
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
model = model.module
model.eval()
Expand All @@ -79,6 +79,6 @@ def do_evaluation(cfg, model, distributed):
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name)
if not os.path.exists(output_folder):
mkdir(output_folder)
eval_result = inference(model, data_loader, dataset_name, device, output_folder)
eval_result = inference(model, data_loader, dataset_name, device, output_folder, **kwargs)
eval_results.append(eval_result)
return eval_results
2 changes: 1 addition & 1 deletion ssd/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def do_train(cfg, model,
checkpointer.save("model_{:06d}".format(iteration), **arguments)

if args.eval_step > 0 and iteration % args.eval_step == 0 and not iteration == max_iter:
eval_results = do_evaluation(cfg, model, distributed=args.distributed)
eval_results = do_evaluation(cfg, model, distributed=args.distributed, iteration=iteration)
if dist_util.get_rank() == 0 and summary_writer:
for eval_result, dataset in zip(eval_results, cfg.DATASETS.TEST):
write_metric(eval_result['metrics'], 'metrics/' + dataset, summary_writer, iteration)
Expand Down

0 comments on commit 436331b

Please sign in to comment.