Skip to content

Commit

Permalink
added speedometer during evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
davsol committed Oct 7, 2017
1 parent 7b128f0 commit 0888d10
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
4 changes: 3 additions & 1 deletion evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def parse_args():
help='use PASCAL VOC 07 metric')
parser.add_argument('--deploy', dest='deploy_net', help='Load network from model',
action='store_true', default=False)
parser.add_argument('--frequent', dest='frequent', help='frequency of logging',
default=20, type=int)
args = parser.parse_args()
return args

Expand Down Expand Up @@ -87,4 +89,4 @@ def parse_args():
path_imglist=args.list_path, nms_thresh=args.nms_thresh,
force_nms=args.force_nms, ovp_thresh=args.overlap_thresh,
use_difficult=args.use_difficult, class_names=class_names,
voc07_metric=args.use_voc07_metric)
voc07_metric=args.use_voc07_metric, frequent=args.frequent)
9 changes: 7 additions & 2 deletions evaluate/evaluate_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def evaluate_net(net, path_imgrec, num_classes, mean_pixels, data_shape,
model_prefix, epoch, ctx=mx.cpu(), batch_size=1,
path_imglist="", nms_thresh=0.45, force_nms=False,
ovp_thresh=0.5, use_difficult=False, class_names=None,
voc07_metric=False):
voc07_metric=False, frequent=20):
"""
evalute network given validation record file
Expand Down Expand Up @@ -51,6 +51,8 @@ def evaluate_net(net, path_imgrec, num_classes, mean_pixels, data_shape,
class names in string, must correspond to num_classes if set
voc07_metric : boolean
whether to use 11-point evluation as in VOC07 competition
frequent : int
frequency to print out validation status
"""
# set up logger
logging.basicConfig()
Expand Down Expand Up @@ -89,6 +91,9 @@ class names in string, must correspond to num_classes if set
metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names)
else:
metric = MApMetric(ovp_thresh, use_difficult, class_names)
results = mod.score(eval_iter, metric, num_batch=None)
results = mod.score(eval_iter, metric, num_batch=None,
batch_end_callback=mx.callback.Speedometer(batch_size,
frequent=frequent,
auto_reset=False))
for k, v in results:
print("{}: {}".format(k, v))

0 comments on commit 0888d10

Please sign in to comment.