Skip to content

Commit

Permalink
config to use cpu for TriviaQA
Browse files Browse the repository at this point in the history
  • Loading branch information
ibeltagy committed Apr 29, 2020
1 parent 74523f7 commit 89e3980
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions scripts/triviaqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,8 @@ def add_model_specific_args(parser, root_dir):
parser.add_argument("--train_dataset", type=str, required=False, help="Path to the training squad-format")
parser.add_argument("--dev_dataset", type=str, required=True, help="Path to the dev squad-format")
parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
parser.add_argument("--gpus", type=str, default='0', help="Comma separated list of gpus")
parser.add_argument("--gpus", type=str, default='0',
help="Comma separated list of gpus. Default is gpu 0. To use CPU, use --gpus "" ")
parser.add_argument("--warmup", type=int, default=200, help="Number of warmup steps")
parser.add_argument("--lr", type=float, default=0.0001, help="Maximum learning rate")
parser.add_argument("--val_every", type=float, default=0.2, help="Number of training steps between validations")
Expand Down Expand Up @@ -673,14 +674,14 @@ def main(args):
prefix=''
)

args.gpus = [int(x) for x in args.gpus.split(',')]
args.gpus = [int(x) for x in args.gpus.split(',')] if args.gpus is not "" else None # use CPU if no gpu provided
print(args)
train_set_size = 110648 # hardcode dataset size. Needed to compute number of steps for the lr scheduler
num_devices = 1 or len(args.gpus)
args.steps = args.epochs * train_set_size / (args.batch_size * num_devices)
print(f'>>>>>>> #steps: {args.steps}, #epochs: {args.epochs}, batch_size: {args.batch_size * num_devices} <<<<<<<')

args.steps = args.epochs * train_set_size / (args.batch_size * len(args.gpus))
print(f'>>>>>>> #steps: {args.steps}, #epochs: {args.epochs}, batch_size: {args.batch_size * len(args.gpus)} <<<<<<<')

trainer = pl.Trainer(gpus=args.gpus, distributed_backend='ddp' if len(args.gpus) > 1 else None,
trainer = pl.Trainer(gpus=args.gpus, distributed_backend='ddp' if args.gpus and (len(args.gpus) > 1) else None,
track_grad_norm=-1, max_nb_epochs=args.epochs, early_stop_callback=None,
accumulate_grad_batches=args.batch_size,
val_check_interval=args.val_every,
Expand Down

0 comments on commit 89e3980

Please sign in to comment.