Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 104 additions & 7 deletions dygraph/se_resnet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,56 @@
import math
import argparse
import ast
import time
import os
import six

parser = argparse.ArgumentParser("Training for Se-ResNeXt.")
parser.add_argument("-e", "--epoch", default=200, type=int, help="set epoch")
parser.add_argument("-b", "--batch_size", default=64, type=int, help="set batch")
parser.add_argument("--ce", action="store_true", help="run ce")
parser.add_argument(
"--use_data_parallel",
type=ast.literal_eval,
default=False,
help="The flag indicating whether to use data parallel mode to train the model."
)
parser.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=True,
help='default use gpu.')
parser.add_argument(
"--max_iter",
default=0,
type=int,
help="the max iters to train, used in benchmark")


def print_arguments(args):
"""Print argparse's arguments.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
:param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace
"""
print("------------- Configuration Arguments -------------")
for arg, value in sorted(six.iteritems(vars(args))):
print("%25s : %s" % (arg, value))
print("----------------------------------------------------")


args = parser.parse_args()
batch_size = 64
batch_size = args.batch_size

if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
print_arguments(args)


train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
Expand All @@ -59,6 +97,24 @@
l2_decay = 1.2e-4


class TimeCostAverage(object):
def __init__(self):
self.reset()

def reset(self):
self.cnt = 0
self.total_time = 0

def record(self, usetime):
self.cnt += 1
self.total_time += usetime

def get_average(self):
if self.cnt == 0:
return 0
return self.total_time / self.cnt


def optimizer_setting(params, parameter_list):
ls = params["learning_strategy"]
if "total_images" not in params:
Expand Down Expand Up @@ -365,19 +421,27 @@ def train():
batch_size = train_parameters["batch_size"]

trainer_count = fluid.dygraph.parallel.Env().nranks
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else fluid.CUDAPlace(0)
if not args.use_gpu:
place = fluid.CPUPlace()
elif not args.use_data_parallel:
place = fluid.CUDAPlace(0)
else:
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id)

with fluid.dygraph.guard(place):
if args.ce:
print("ce mode")
seed = 90
np.random.seed(seed)
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed

if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()

se_resnext = SeResNeXt()
optimizer = optimizer_setting(train_parameters, se_resnext.parameters())

if args.use_data_parallel:
se_resnext = fluid.dygraph.parallel.DataParallel(se_resnext,
strategy)
Expand All @@ -398,12 +462,31 @@ def train():
test_loader = fluid.io.DataLoader.from_generator(capacity=10)
test_loader.set_sample_list_generator(test_reader, places=place)

#NOTE: used in benchmark
total_batch_num = 0

for epoch_id in range(epoch_num):
epoch_start = time.time()
total_loss = 0.0
total_acc1 = 0.0
total_acc5 = 0.0
total_sample = 0

train_batch_cost_avg = TimeCostAverage()
train_reader_cost_avg = TimeCostAverage()
batch_start = time.time()

localTime = time.localtime(epoch_start)
strTime = time.strftime("%Y-%m-%d %H:%M:%S", localTime)
print("[Epoch %d, start %s]" % (epoch_id, strTime))

for batch_id, data in enumerate(train_loader()):
#NOTE: used in benchmark
if args.max_iter and total_batch_num == args.max_iter:
return

train_reader_cost = time.time() - batch_start

img, label = data
label.stop_gradient = True

Expand Down Expand Up @@ -434,18 +517,32 @@ def train():
total_acc1 += acc_top1.numpy()
total_acc5 += acc_top5.numpy()
total_sample += 1

train_batch_cost = time.time() - batch_start
train_batch_cost_avg.record(train_batch_cost)
train_reader_cost_avg.record(train_reader_cost)

total_batch_num = total_batch_num + 1 #this is for benchmark
if batch_id % 10 == 0:
print( "epoch %d | batch step %d, loss %0.3f acc1 %0.3f acc5 %0.3f lr %0.5f" % \
ips = float(
args.batch_size) / train_batch_cost_avg.get_average()
print( "epoch %d | batch step %d, loss %0.3f acc1 %0.3f acc5 %0.3f lr %0.5f, batch_cost: %.5f sec, reader_cost: %.5f sec, ips: %.5f images/sec" % \
( epoch_id, batch_id, total_loss / total_sample, \
total_acc1 / total_sample, total_acc5 / total_sample, lr))
total_acc1 / total_sample, total_acc5 / total_sample, lr,train_batch_cost_avg.get_average(),
train_reader_cost_avg.get_average(), ips))
train_batch_cost_avg.reset()
train_reader_cost_avg.reset()
batch_start = time.time()

if args.ce:
print("kpis\ttrain_acc1\t%0.3f" % (total_acc1 / total_sample))
print("kpis\ttrain_acc5\t%0.3f" % (total_acc5 / total_sample))
print("kpis\ttrain_loss\t%0.3f" % (total_loss / total_sample))
print("epoch %d | batch step %d, loss %0.3f acc1 %0.3f acc5 %0.3f" % \

train_epoch_cost = time.time() - epoch_start
print("epoch %d | batch step %d, loss %0.3f acc1 %0.3f acc5 %0.3f, epoch_cost: %.5f s" % \
(epoch_id, batch_id, total_loss / total_sample, \
total_acc1 / total_sample, total_acc5 / total_sample))
total_acc1 / total_sample, total_acc5 / total_sample,train_epoch_cost))
se_resnext.eval()
eval(se_resnext, test_loader)
se_resnext.train()
Expand Down