@@ -413,18 +413,24 @@ def log_grads_and_vars(grads_and_vars):
413
413
def train ():
414
414
exception_box = ExceptionBox ()
415
415
416
+ if FLAGS .horovod :
417
+ import horovod .tensorflow as hvd
418
+
416
419
# Create training and validation datasets
420
+ split_dataset = FLAGS .horovod
421
+
417
422
train_set = create_dataset (FLAGS .train_files .split (',' ),
418
423
batch_size = FLAGS .train_batch_size ,
419
424
epochs = FLAGS .epochs ,
420
425
augmentations = Config .augmentations ,
421
426
cache_path = FLAGS .feature_cache ,
422
427
train_phase = True ,
423
428
exception_box = exception_box ,
424
- process_ahead = len ( Config .available_devices ) * FLAGS .train_batch_size * 2 ,
429
+ process_ahead = Config .num_devices * FLAGS .train_batch_size * 2 ,
425
430
reverse = FLAGS .reverse_train ,
426
431
limit = FLAGS .limit_train ,
427
- buffering = FLAGS .read_buffer )
432
+ buffering = FLAGS .read_buffer ,
433
+ split_dataset = split_dataset )
428
434
429
435
iterator = tfv1 .data .Iterator .from_structure (tfv1 .data .get_output_types (train_set ),
430
436
tfv1 .data .get_output_shapes (train_set ),
@@ -439,10 +445,11 @@ def train():
439
445
batch_size = FLAGS .dev_batch_size ,
440
446
train_phase = False ,
441
447
exception_box = exception_box ,
442
- process_ahead = len ( Config .available_devices ) * FLAGS .dev_batch_size * 2 ,
448
+ process_ahead = Config .num_devices * FLAGS .dev_batch_size * 2 ,
443
449
reverse = FLAGS .reverse_dev ,
444
450
limit = FLAGS .limit_dev ,
445
- buffering = FLAGS .read_buffer ) for source in dev_sources ]
451
+ buffering = FLAGS .read_buffer ,
452
+ split_dataset = split_dataset ) for source in dev_sources ]
446
453
dev_init_ops = [iterator .make_initializer (dev_set ) for dev_set in dev_sets ]
447
454
448
455
if FLAGS .metrics_files :
@@ -451,10 +458,11 @@ def train():
451
458
batch_size = FLAGS .dev_batch_size ,
452
459
train_phase = False ,
453
460
exception_box = exception_box ,
454
- process_ahead = len ( Config .available_devices ) * FLAGS .dev_batch_size * 2 ,
461
+ process_ahead = Config .num_devices * FLAGS .dev_batch_size * 2 ,
455
462
reverse = FLAGS .reverse_dev ,
456
463
limit = FLAGS .limit_dev ,
457
- buffering = FLAGS .read_buffer ) for source in metrics_sources ]
464
+ buffering = FLAGS .read_buffer ,
465
+ split_dataset = split_dataset ) for source in metrics_sources ]
458
466
metrics_init_ops = [iterator .make_initializer (metrics_set ) for metrics_set in metrics_sets ]
459
467
460
468
# Dropout
@@ -474,22 +482,38 @@ def train():
474
482
# Building the graph
475
483
learning_rate_var = tfv1 .get_variable ('learning_rate' , initializer = FLAGS .learning_rate , trainable = False )
476
484
reduce_learning_rate_op = learning_rate_var .assign (tf .multiply (learning_rate_var , FLAGS .plateau_reduction ))
477
- optimizer = create_optimizer (learning_rate_var )
485
+ if FLAGS .horovod :
486
+ # Effective batch size in synchronous distributed training is scaled by the number of workers. An increase in learning rate compensates for the increased batch size.
487
+ optimizer = create_optimizer (learning_rate_var * hvd .size ())
488
+ optimizer = hvd .DistributedOptimizer (optimizer )
489
+ else :
490
+ optimizer = create_optimizer (learning_rate_var )
478
491
479
492
# Enable mixed precision training
480
493
if FLAGS .automatic_mixed_precision :
481
494
log_info ('Enabling automatic mixed precision training.' )
482
495
optimizer = tfv1 .train .experimental .enable_mixed_precision_graph_rewrite (optimizer )
483
496
484
- gradients , loss , non_finite_files = get_tower_results (iterator , optimizer , dropout_rates )
497
+ if FLAGS .horovod :
498
+ loss , non_finite_files = calculate_mean_edit_distance_and_loss (iterator , dropout_rates , reuse = False )
499
+ gradients = optimizer .compute_gradients (loss )
500
+
501
+ tfv1 .summary .scalar (name = 'step_loss' , tensor = loss , collections = ['step_summaries' ])
502
+ log_grads_and_vars (gradients )
503
+
504
+ # global_step is automagically incremented by the optimizer
505
+ global_step = tfv1 .train .get_or_create_global_step ()
506
+ apply_gradient_op = optimizer .apply_gradients (gradients , global_step = global_step )
507
+ else :
508
+ gradients , loss , non_finite_files = get_tower_results (iterator , optimizer , dropout_rates )
485
509
486
- # Average tower gradients across GPUs
487
- avg_tower_gradients = average_gradients (gradients )
488
- log_grads_and_vars (avg_tower_gradients )
510
+ # Average tower gradients across GPUs
511
+ avg_tower_gradients = average_gradients (gradients )
512
+ log_grads_and_vars (avg_tower_gradients )
489
513
490
- # global_step is automagically incremented by the optimizer
491
- global_step = tfv1 .train .get_or_create_global_step ()
492
- apply_gradient_op = optimizer .apply_gradients (avg_tower_gradients , global_step = global_step )
514
+ # global_step is automagically incremented by the optimizer
515
+ global_step = tfv1 .train .get_or_create_global_step ()
516
+ apply_gradient_op = optimizer .apply_gradients (avg_tower_gradients , global_step = global_step )
493
517
494
518
# Summaries
495
519
step_summaries_op = tfv1 .summary .merge_all ('step_summaries' )
@@ -506,18 +530,22 @@ def train():
506
530
}
507
531
508
532
# Checkpointing
509
- checkpoint_saver = tfv1 .train .Saver (max_to_keep = FLAGS .max_to_keep )
510
- checkpoint_path = os .path .join (FLAGS .save_checkpoint_dir , 'train' )
533
+ if Config .is_master_process :
534
+ checkpoint_saver = tfv1 .train .Saver (max_to_keep = FLAGS .max_to_keep )
535
+ checkpoint_path = os .path .join (FLAGS .save_checkpoint_dir , 'train' )
511
536
512
- best_dev_saver = tfv1 .train .Saver (max_to_keep = 1 )
513
- best_dev_path = os .path .join (FLAGS .save_checkpoint_dir , 'best_dev' )
537
+ best_dev_saver = tfv1 .train .Saver (max_to_keep = 1 )
538
+ best_dev_path = os .path .join (FLAGS .save_checkpoint_dir , 'best_dev' )
514
539
515
- # Save flags next to checkpoints
516
- if not is_remote_path (FLAGS .save_checkpoint_dir ):
517
- os .makedirs (FLAGS .save_checkpoint_dir , exist_ok = True )
518
- flags_file = os .path .join (FLAGS .save_checkpoint_dir , 'flags.txt' )
519
- with open_remote (flags_file , 'w' ) as fout :
520
- fout .write (FLAGS .flags_into_string ())
540
+ # Save flags next to checkpoints
541
+ if not is_remote_path (FLAGS .save_checkpoint_dir ):
542
+ os .makedirs (FLAGS .save_checkpoint_dir , exist_ok = True )
543
+ flags_file = os .path .join (FLAGS .save_checkpoint_dir , 'flags.txt' )
544
+ with open_remote (flags_file , 'w' ) as fout :
545
+ fout .write (FLAGS .flags_into_string ())
546
+
547
+ if FLAGS .horovod :
548
+ bcast = hvd .broadcast_global_variables (0 )
521
549
522
550
with tfv1 .Session (config = Config .session_config ) as session :
523
551
log_debug ('Session opened.' )
@@ -527,6 +555,8 @@ def train():
527
555
528
556
# Load checkpoint or initialize variables
529
557
load_or_init_graph_for_training (session )
558
+ if FLAGS .horovod :
559
+ bcast .run ()
530
560
531
561
def run_set (set_name , epoch , init_op , dataset = None ):
532
562
is_train = set_name == 'train'
@@ -554,12 +584,13 @@ def __call__(self, progress, data, **kwargs):
554
584
data ['mean_loss' ] = total_loss / step_count if step_count else 0.0
555
585
return progressbar .widgets .FormatLabel .__call__ (self , progress , data , ** kwargs )
556
586
557
- prefix = 'Epoch {} | {:>10}' .format (epoch , human_readable_set_names [set_name ])
558
- widgets = [' | ' , progressbar .widgets .Timer (),
559
- ' | Steps: ' , progressbar .widgets .Counter (),
560
- ' | ' , LossWidget ()]
561
- suffix = ' | Dataset: {}' .format (dataset ) if dataset else None
562
- pbar = create_progressbar (prefix = prefix , widgets = widgets , suffix = suffix ).start ()
587
+ if Config .is_master_process :
588
+ prefix = 'Epoch {} | {:>10}' .format (epoch , human_readable_set_names [set_name ])
589
+ widgets = [' | ' , progressbar .widgets .Timer (),
590
+ ' | Steps: ' , progressbar .widgets .Counter (),
591
+ ' | ' , LossWidget ()]
592
+ suffix = ' | Dataset: {}' .format (dataset ) if dataset else None
593
+ pbar = create_progressbar (prefix = prefix , widgets = widgets , suffix = suffix ).start ()
563
594
564
595
# Initialize iterator to the appropriate dataset
565
596
session .run (init_op )
@@ -583,15 +614,17 @@ def __call__(self, progress, data, **kwargs):
583
614
total_loss += batch_loss
584
615
step_count += 1
585
616
586
- pbar .update (step_count )
617
+ if Config .is_master_process :
618
+ pbar .update (step_count )
587
619
588
- step_summary_writer .add_summary (step_summary , current_step )
620
+ step_summary_writer .add_summary (step_summary , current_step )
589
621
590
- if is_train and FLAGS .checkpoint_secs > 0 and time .time () - checkpoint_time > FLAGS .checkpoint_secs :
591
- checkpoint_saver .save (session , checkpoint_path , global_step = current_step )
592
- checkpoint_time = time .time ()
622
+ if is_train and FLAGS .checkpoint_secs > 0 and time .time () - checkpoint_time > FLAGS .checkpoint_secs :
623
+ checkpoint_saver .save (session , checkpoint_path , global_step = current_step )
624
+ checkpoint_time = time .time ()
593
625
594
- pbar .finish ()
626
+ if Config .is_master_process :
627
+ pbar .finish ()
595
628
mean_loss = total_loss / step_count if step_count > 0 else 0.0
596
629
return mean_loss , step_count
597
630
@@ -603,21 +636,25 @@ def __call__(self, progress, data, **kwargs):
603
636
try :
604
637
for epoch in range (FLAGS .epochs ):
605
638
# Training
606
- log_progress ('Training epoch %d...' % epoch )
639
+ if Config .is_master_process :
640
+ log_progress ('Training epoch %d...' % epoch )
607
641
train_loss , _ = run_set ('train' , epoch , train_init_op )
608
- log_progress ('Finished training epoch %d - loss: %f' % (epoch , train_loss ))
609
- checkpoint_saver .save (session , checkpoint_path , global_step = global_step )
642
+ if Config .is_master_process :
643
+ log_progress ('Finished training epoch %d - loss: %f' % (epoch , train_loss ))
644
+ checkpoint_saver .save (session , checkpoint_path , global_step = global_step )
610
645
611
646
if FLAGS .dev_files :
612
647
# Validation
613
648
dev_loss = 0.0
614
649
total_steps = 0
615
650
for source , init_op in zip (dev_sources , dev_init_ops ):
616
- log_progress ('Validating epoch %d on %s...' % (epoch , source ))
651
+ if Config .is_master_process :
652
+ log_progress ('Validating epoch %d on %s...' % (epoch , source ))
617
653
set_loss , steps = run_set ('dev' , epoch , init_op , dataset = source )
618
654
dev_loss += set_loss * steps
619
655
total_steps += steps
620
- log_progress ('Finished validating epoch %d on %s - loss: %f' % (epoch , source , set_loss ))
656
+ if Config .is_master_process :
657
+ log_progress ('Finished validating epoch %d on %s - loss: %f' % (epoch , source , set_loss ))
621
658
622
659
dev_loss = dev_loss / total_steps
623
660
dev_losses .append (dev_loss )
@@ -629,16 +666,19 @@ def __call__(self, progress, data, **kwargs):
629
666
else :
630
667
epochs_without_improvement = 0
631
668
632
- # Save new best model
633
- if dev_loss < best_dev_loss :
634
- best_dev_loss = dev_loss
635
- save_path = best_dev_saver .save (session , best_dev_path , global_step = global_step , latest_filename = 'best_dev_checkpoint' )
636
- log_info ("Saved new best validating model with loss %f to: %s" % (best_dev_loss , save_path ))
669
+ if Config .is_master_process :
670
+ # Save new best model
671
+ if dev_loss < best_dev_loss :
672
+ best_dev_loss = dev_loss
673
+ save_path = best_dev_saver .save (session , best_dev_path , global_step = global_step ,
674
+ latest_filename = 'best_dev_checkpoint' )
675
+ log_info ("Saved new best validating model with loss %f to: %s" % (best_dev_loss , save_path ))
637
676
638
677
# Early stopping
639
678
if FLAGS .early_stop and epochs_without_improvement == FLAGS .es_epochs :
640
- log_info ('Early stop triggered as the loss did not improve the last {} epochs' .format (
641
- epochs_without_improvement ))
679
+ if Config .is_master_process :
680
+ log_info ('Early stop triggered as the loss did not improve the last {} epochs' .format (
681
+ epochs_without_improvement ))
642
682
break
643
683
644
684
# Reduce learning rate on plateau
@@ -655,26 +695,31 @@ def __call__(self, progress, data, **kwargs):
655
695
# Reduce learning rate
656
696
session .run (reduce_learning_rate_op )
657
697
current_learning_rate = learning_rate_var .eval ()
658
- log_info ('Encountered a plateau, reducing learning rate to {}' .format (
659
- current_learning_rate ))
698
+ if Config .is_master_process :
699
+ log_info ('Encountered a plateau, reducing learning rate to {}' .format (
700
+ current_learning_rate ))
660
701
661
- # Overwrite best checkpoint with new learning rate value
662
- save_path = best_dev_saver .save (session , best_dev_path , global_step = global_step , latest_filename = 'best_dev_checkpoint' )
663
- log_info ("Saved best validating model with reduced learning rate to: %s" % (save_path ))
702
+ # Overwrite best checkpoint with new learning rate value
703
+ save_path = best_dev_saver .save (session , best_dev_path , global_step = global_step ,
704
+ latest_filename = 'best_dev_checkpoint' )
705
+ log_info ("Saved best validating model with reduced learning rate to: %s" % (save_path ))
664
706
665
707
if FLAGS .metrics_files :
666
708
# Read only metrics, not affecting best validation loss tracking
667
709
for source , init_op in zip (metrics_sources , metrics_init_ops ):
668
- log_progress ('Metrics for epoch %d on %s...' % (epoch , source ))
710
+ if Config .is_master_process :
711
+ log_progress ('Metrics for epoch %d on %s...' % (epoch , source ))
669
712
set_loss , _ = run_set ('metrics' , epoch , init_op , dataset = source )
670
- log_progress ('Metrics for epoch %d on %s - loss: %f' % (epoch , source , set_loss ))
713
+ if Config .is_master_process :
714
+ log_progress ('Metrics for epoch %d on %s - loss: %f' % (epoch , source , set_loss ))
671
715
672
716
print ('-' * 80 )
673
717
674
718
675
719
except KeyboardInterrupt :
676
720
pass
677
- log_info ('FINISHED optimization in {}' .format (datetime .utcnow () - train_start_time ))
721
+ if Config .is_master_process :
722
+ log_info ('FINISHED optimization in {}' .format (datetime .utcnow () - train_start_time ))
678
723
log_debug ('Session closed.' )
679
724
680
725
@@ -951,30 +996,32 @@ def main(_):
951
996
if FLAGS .train_files :
952
997
tfv1 .reset_default_graph ()
953
998
tfv1 .set_random_seed (FLAGS .random_seed )
999
+
954
1000
train ()
955
1001
956
- if FLAGS .test_files :
957
- tfv1 .reset_default_graph ()
958
- test ()
1002
+ if Config .is_master_process :
1003
+ if FLAGS .test_files :
1004
+ tfv1 .reset_default_graph ()
1005
+ test ()
959
1006
960
- if FLAGS .export_dir and not FLAGS .export_zip :
961
- tfv1 .reset_default_graph ()
962
- export ()
1007
+ if FLAGS .export_dir and not FLAGS .export_zip :
1008
+ tfv1 .reset_default_graph ()
1009
+ export ()
963
1010
964
- if FLAGS .export_zip :
965
- tfv1 .reset_default_graph ()
966
- FLAGS .export_tflite = True
1011
+ if FLAGS .export_zip :
1012
+ tfv1 .reset_default_graph ()
1013
+ FLAGS .export_tflite = True
967
1014
968
- if listdir_remote (FLAGS .export_dir ):
969
- log_error ('Directory {} is not empty, please fix this.' .format (FLAGS .export_dir ))
970
- sys .exit (1 )
1015
+ if listdir_remote (FLAGS .export_dir ):
1016
+ log_error ('Directory {} is not empty, please fix this.' .format (FLAGS .export_dir ))
1017
+ sys .exit (1 )
971
1018
972
- export ()
973
- package_zip ()
1019
+ export ()
1020
+ package_zip ()
974
1021
975
- if FLAGS .one_shot_infer :
976
- tfv1 .reset_default_graph ()
977
- do_single_file_inference (FLAGS .one_shot_infer )
1022
+ if FLAGS .one_shot_infer :
1023
+ tfv1 .reset_default_graph ()
1024
+ do_single_file_inference (FLAGS .one_shot_infer )
978
1025
979
1026
980
1027
def run_script ():
0 commit comments