Skip to content

Commit bde9a1d

Browse files
authored
Merge pull request #3533 from tud-zih-tools/horovod_pr
Implement distributed training using horovod
2 parents 1d7a554 + 329bf87 commit bde9a1d

File tree

6 files changed

+197
-87
lines changed

6 files changed

+197
-87
lines changed

doc/TRAINING.rst

+21
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,27 @@ python3 DeepSpeech.py --train_files ./train.csv --dev_files ./dev.csv --test_fil
196196

197197
On a Volta generation V100 GPU, automatic mixed precision speeds up DeepSpeech training and evaluation by ~30%-40%.
198198

199+
Distributed training using Horovod
200+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
201+
202+
If you have a capable compute architecture, it is possible to distribute the training using `Horovod <https://github.com/horovod/horovod>`_. A fast network is recommended.
203+
Horovod is capable of using MPI and NVIDIA's NCCL for highly optimized inter-process communication.
204+
It also offers `Gloo <https://github.com/facebookincubator/gloo>`_ as an easy-to-setup communication backend.
205+
206+
For more information about setup or tuning of Horovod please visit `Horovod's documentation <https://horovod.readthedocs.io/en/stable/summary_include.html>`_.
207+
208+
Horovod is expected to run on heterogeneous systems (e.g. different number and model type of GPUs per machine).
209+
However, this can cause unpredictable problems and user interaction in training code is needed.
210+
Therefore, we do only support homogenous systems, which means same hardware and also same software configuration (OS, drivers, MPI, NCCL, TensorFlow, ...) on each machine.
211+
The only exception is different number of GPUs per machine, since this can be controlled by ``horovodrun -H``.
212+
213+
Detailed documentation how to run Horovod is provided `here <https://horovod.readthedocs.io/en/stable/running.html>`_.
214+
The short command to train on 4 machines using 4 GPUs each:
215+
216+
.. code-block:: bash
217+
218+
horovodrun -np 16 -H server1:4,server2:4,server3:4,server4:4 python3 DeepSpeech.py --train_files [...] --horovod
219+
199220
Checkpointing
200221
^^^^^^^^^^^^^
201222

setup.py

+10
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ def main():
7676
'tensorflow == 1.15.4'
7777
]
7878

79+
horovod_pypi_dep = [
80+
'horovod[tensorflow] == 0.21.3'
81+
]
82+
7983
# Due to pip craziness environment variables are the only consistent way to
8084
# get options into this script when doing `pip install`.
8185
tc_decoder_artifacts_root = os.environ.get('DECODER_ARTIFACTS_ROOT', '')
@@ -94,6 +98,12 @@ def main():
9498
else:
9599
install_requires = install_requires + tensorflow_pypi_dep
96100

101+
if os.environ.get('DS_WITH_HOROVOD', ''):
102+
install_requires = install_requires + horovod_pypi_dep
103+
else:
104+
install_requires = install_requires
105+
106+
97107
setup(
98108
name='deepspeech_training',
99109
version=version,

training/deepspeech_training/train.py

+120-73
Original file line numberDiff line numberDiff line change
@@ -413,18 +413,24 @@ def log_grads_and_vars(grads_and_vars):
413413
def train():
414414
exception_box = ExceptionBox()
415415

416+
if FLAGS.horovod:
417+
import horovod.tensorflow as hvd
418+
416419
# Create training and validation datasets
420+
split_dataset = FLAGS.horovod
421+
417422
train_set = create_dataset(FLAGS.train_files.split(','),
418423
batch_size=FLAGS.train_batch_size,
419424
epochs=FLAGS.epochs,
420425
augmentations=Config.augmentations,
421426
cache_path=FLAGS.feature_cache,
422427
train_phase=True,
423428
exception_box=exception_box,
424-
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
429+
process_ahead=Config.num_devices * FLAGS.train_batch_size * 2,
425430
reverse=FLAGS.reverse_train,
426431
limit=FLAGS.limit_train,
427-
buffering=FLAGS.read_buffer)
432+
buffering=FLAGS.read_buffer,
433+
split_dataset=split_dataset)
428434

429435
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
430436
tfv1.data.get_output_shapes(train_set),
@@ -439,10 +445,11 @@ def train():
439445
batch_size=FLAGS.dev_batch_size,
440446
train_phase=False,
441447
exception_box=exception_box,
442-
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
448+
process_ahead=Config.num_devices * FLAGS.dev_batch_size * 2,
443449
reverse=FLAGS.reverse_dev,
444450
limit=FLAGS.limit_dev,
445-
buffering=FLAGS.read_buffer) for source in dev_sources]
451+
buffering=FLAGS.read_buffer,
452+
split_dataset=split_dataset) for source in dev_sources]
446453
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]
447454

448455
if FLAGS.metrics_files:
@@ -451,10 +458,11 @@ def train():
451458
batch_size=FLAGS.dev_batch_size,
452459
train_phase=False,
453460
exception_box=exception_box,
454-
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
461+
process_ahead=Config.num_devices * FLAGS.dev_batch_size * 2,
455462
reverse=FLAGS.reverse_dev,
456463
limit=FLAGS.limit_dev,
457-
buffering=FLAGS.read_buffer) for source in metrics_sources]
464+
buffering=FLAGS.read_buffer,
465+
split_dataset=split_dataset) for source in metrics_sources]
458466
metrics_init_ops = [iterator.make_initializer(metrics_set) for metrics_set in metrics_sets]
459467

460468
# Dropout
@@ -474,22 +482,38 @@ def train():
474482
# Building the graph
475483
learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False)
476484
reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction))
477-
optimizer = create_optimizer(learning_rate_var)
485+
if FLAGS.horovod:
486+
# Effective batch size in synchronous distributed training is scaled by the number of workers. An increase in learning rate compensates for the increased batch size.
487+
optimizer = create_optimizer(learning_rate_var * hvd.size())
488+
optimizer = hvd.DistributedOptimizer(optimizer)
489+
else:
490+
optimizer = create_optimizer(learning_rate_var)
478491

479492
# Enable mixed precision training
480493
if FLAGS.automatic_mixed_precision:
481494
log_info('Enabling automatic mixed precision training.')
482495
optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
483496

484-
gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)
497+
if FLAGS.horovod:
498+
loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=False)
499+
gradients = optimizer.compute_gradients(loss)
500+
501+
tfv1.summary.scalar(name='step_loss', tensor=loss, collections=['step_summaries'])
502+
log_grads_and_vars(gradients)
503+
504+
# global_step is automagically incremented by the optimizer
505+
global_step = tfv1.train.get_or_create_global_step()
506+
apply_gradient_op = optimizer.apply_gradients(gradients, global_step=global_step)
507+
else:
508+
gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates)
485509

486-
# Average tower gradients across GPUs
487-
avg_tower_gradients = average_gradients(gradients)
488-
log_grads_and_vars(avg_tower_gradients)
510+
# Average tower gradients across GPUs
511+
avg_tower_gradients = average_gradients(gradients)
512+
log_grads_and_vars(avg_tower_gradients)
489513

490-
# global_step is automagically incremented by the optimizer
491-
global_step = tfv1.train.get_or_create_global_step()
492-
apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)
514+
# global_step is automagically incremented by the optimizer
515+
global_step = tfv1.train.get_or_create_global_step()
516+
apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)
493517

494518
# Summaries
495519
step_summaries_op = tfv1.summary.merge_all('step_summaries')
@@ -506,18 +530,22 @@ def train():
506530
}
507531

508532
# Checkpointing
509-
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
510-
checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train')
533+
if Config.is_master_process:
534+
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
535+
checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train')
511536

512-
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
513-
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
537+
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
538+
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
514539

515-
# Save flags next to checkpoints
516-
if not is_remote_path(FLAGS.save_checkpoint_dir):
517-
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
518-
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
519-
with open_remote(flags_file, 'w') as fout:
520-
fout.write(FLAGS.flags_into_string())
540+
# Save flags next to checkpoints
541+
if not is_remote_path(FLAGS.save_checkpoint_dir):
542+
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
543+
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
544+
with open_remote(flags_file, 'w') as fout:
545+
fout.write(FLAGS.flags_into_string())
546+
547+
if FLAGS.horovod:
548+
bcast = hvd.broadcast_global_variables(0)
521549

522550
with tfv1.Session(config=Config.session_config) as session:
523551
log_debug('Session opened.')
@@ -527,6 +555,8 @@ def train():
527555

528556
# Load checkpoint or initialize variables
529557
load_or_init_graph_for_training(session)
558+
if FLAGS.horovod:
559+
bcast.run()
530560

531561
def run_set(set_name, epoch, init_op, dataset=None):
532562
is_train = set_name == 'train'
@@ -554,12 +584,13 @@ def __call__(self, progress, data, **kwargs):
554584
data['mean_loss'] = total_loss / step_count if step_count else 0.0
555585
return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs)
556586

557-
prefix = 'Epoch {} | {:>10}'.format(epoch, human_readable_set_names[set_name])
558-
widgets = [' | ', progressbar.widgets.Timer(),
559-
' | Steps: ', progressbar.widgets.Counter(),
560-
' | ', LossWidget()]
561-
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
562-
pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()
587+
if Config.is_master_process:
588+
prefix = 'Epoch {} | {:>10}'.format(epoch, human_readable_set_names[set_name])
589+
widgets = [' | ', progressbar.widgets.Timer(),
590+
' | Steps: ', progressbar.widgets.Counter(),
591+
' | ', LossWidget()]
592+
suffix = ' | Dataset: {}'.format(dataset) if dataset else None
593+
pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start()
563594

564595
# Initialize iterator to the appropriate dataset
565596
session.run(init_op)
@@ -583,15 +614,17 @@ def __call__(self, progress, data, **kwargs):
583614
total_loss += batch_loss
584615
step_count += 1
585616

586-
pbar.update(step_count)
617+
if Config.is_master_process:
618+
pbar.update(step_count)
587619

588-
step_summary_writer.add_summary(step_summary, current_step)
620+
step_summary_writer.add_summary(step_summary, current_step)
589621

590-
if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
591-
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
592-
checkpoint_time = time.time()
622+
if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
623+
checkpoint_saver.save(session, checkpoint_path, global_step=current_step)
624+
checkpoint_time = time.time()
593625

594-
pbar.finish()
626+
if Config.is_master_process:
627+
pbar.finish()
595628
mean_loss = total_loss / step_count if step_count > 0 else 0.0
596629
return mean_loss, step_count
597630

@@ -603,21 +636,25 @@ def __call__(self, progress, data, **kwargs):
603636
try:
604637
for epoch in range(FLAGS.epochs):
605638
# Training
606-
log_progress('Training epoch %d...' % epoch)
639+
if Config.is_master_process:
640+
log_progress('Training epoch %d...' % epoch)
607641
train_loss, _ = run_set('train', epoch, train_init_op)
608-
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
609-
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
642+
if Config.is_master_process:
643+
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
644+
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
610645

611646
if FLAGS.dev_files:
612647
# Validation
613648
dev_loss = 0.0
614649
total_steps = 0
615650
for source, init_op in zip(dev_sources, dev_init_ops):
616-
log_progress('Validating epoch %d on %s...' % (epoch, source))
651+
if Config.is_master_process:
652+
log_progress('Validating epoch %d on %s...' % (epoch, source))
617653
set_loss, steps = run_set('dev', epoch, init_op, dataset=source)
618654
dev_loss += set_loss * steps
619655
total_steps += steps
620-
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss))
656+
if Config.is_master_process:
657+
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss))
621658

622659
dev_loss = dev_loss / total_steps
623660
dev_losses.append(dev_loss)
@@ -629,16 +666,19 @@ def __call__(self, progress, data, **kwargs):
629666
else:
630667
epochs_without_improvement = 0
631668

632-
# Save new best model
633-
if dev_loss < best_dev_loss:
634-
best_dev_loss = dev_loss
635-
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
636-
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
669+
if Config.is_master_process:
670+
# Save new best model
671+
if dev_loss < best_dev_loss:
672+
best_dev_loss = dev_loss
673+
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step,
674+
latest_filename='best_dev_checkpoint')
675+
log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path))
637676

638677
# Early stopping
639678
if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs:
640-
log_info('Early stop triggered as the loss did not improve the last {} epochs'.format(
641-
epochs_without_improvement))
679+
if Config.is_master_process:
680+
log_info('Early stop triggered as the loss did not improve the last {} epochs'.format(
681+
epochs_without_improvement))
642682
break
643683

644684
# Reduce learning rate on plateau
@@ -655,26 +695,31 @@ def __call__(self, progress, data, **kwargs):
655695
# Reduce learning rate
656696
session.run(reduce_learning_rate_op)
657697
current_learning_rate = learning_rate_var.eval()
658-
log_info('Encountered a plateau, reducing learning rate to {}'.format(
659-
current_learning_rate))
698+
if Config.is_master_process:
699+
log_info('Encountered a plateau, reducing learning rate to {}'.format(
700+
current_learning_rate))
660701

661-
# Overwrite best checkpoint with new learning rate value
662-
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint')
663-
log_info("Saved best validating model with reduced learning rate to: %s" % (save_path))
702+
# Overwrite best checkpoint with new learning rate value
703+
save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step,
704+
latest_filename='best_dev_checkpoint')
705+
log_info("Saved best validating model with reduced learning rate to: %s" % (save_path))
664706

665707
if FLAGS.metrics_files:
666708
# Read only metrics, not affecting best validation loss tracking
667709
for source, init_op in zip(metrics_sources, metrics_init_ops):
668-
log_progress('Metrics for epoch %d on %s...' % (epoch, source))
710+
if Config.is_master_process:
711+
log_progress('Metrics for epoch %d on %s...' % (epoch, source))
669712
set_loss, _ = run_set('metrics', epoch, init_op, dataset=source)
670-
log_progress('Metrics for epoch %d on %s - loss: %f' % (epoch, source, set_loss))
713+
if Config.is_master_process:
714+
log_progress('Metrics for epoch %d on %s - loss: %f' % (epoch, source, set_loss))
671715

672716
print('-' * 80)
673717

674718

675719
except KeyboardInterrupt:
676720
pass
677-
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
721+
if Config.is_master_process:
722+
log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time))
678723
log_debug('Session closed.')
679724

680725

@@ -951,30 +996,32 @@ def main(_):
951996
if FLAGS.train_files:
952997
tfv1.reset_default_graph()
953998
tfv1.set_random_seed(FLAGS.random_seed)
999+
9541000
train()
9551001

956-
if FLAGS.test_files:
957-
tfv1.reset_default_graph()
958-
test()
1002+
if Config.is_master_process:
1003+
if FLAGS.test_files:
1004+
tfv1.reset_default_graph()
1005+
test()
9591006

960-
if FLAGS.export_dir and not FLAGS.export_zip:
961-
tfv1.reset_default_graph()
962-
export()
1007+
if FLAGS.export_dir and not FLAGS.export_zip:
1008+
tfv1.reset_default_graph()
1009+
export()
9631010

964-
if FLAGS.export_zip:
965-
tfv1.reset_default_graph()
966-
FLAGS.export_tflite = True
1011+
if FLAGS.export_zip:
1012+
tfv1.reset_default_graph()
1013+
FLAGS.export_tflite = True
9671014

968-
if listdir_remote(FLAGS.export_dir):
969-
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
970-
sys.exit(1)
1015+
if listdir_remote(FLAGS.export_dir):
1016+
log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir))
1017+
sys.exit(1)
9711018

972-
export()
973-
package_zip()
1019+
export()
1020+
package_zip()
9741021

975-
if FLAGS.one_shot_infer:
976-
tfv1.reset_default_graph()
977-
do_single_file_inference(FLAGS.one_shot_infer)
1022+
if FLAGS.one_shot_infer:
1023+
tfv1.reset_default_graph()
1024+
do_single_file_inference(FLAGS.one_shot_infer)
9781025

9791026

9801027
def run_script():

0 commit comments

Comments
 (0)