Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 28e0e4e

Browse files
authored
Merge pull request #141 from ReDeiPirati/yellowfin_optimizer
Integrate YellowFin Optimizer(with test) in T2T
2 parents 7566c4d + 7d3c10b commit 28e0e4e

File tree

4 files changed

+805
-1
lines changed

4 files changed

+805
-1
lines changed

tensor2tensor/models/common_hparams.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,4 +202,9 @@ def basic_range1(ranged_hparams):
202202
rhp.set_float("optimizer_adam_beta1", 0.8, 0.9)
203203
rhp.set_float("optimizer_adam_beta2", 0.995, 0.999)
204204
rhp.set_categorical("optimizer",
205-
["Adam", "Adagrad", "Momentum", "RMSProp", "SGD"])
205+
["Adam",
206+
"Adagrad",
207+
"Momentum",
208+
"RMSProp",
209+
"SGD",
210+
"YellowFin"])

tensor2tensor/utils/trainer_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import tensorflow as tf
4444
from tensorflow.contrib.learn.python.learn import learn_runner
4545
from tensorflow.python.ops import init_ops
46+
from tensor2tensor.utils.yellowfin import YellowFinOptimizer
4647

4748
# Number of samples to draw for an image input (in such cases as captioning)
4849
IMAGE_DECODE_LENGTH = 100
@@ -1141,6 +1142,10 @@ def __init__(self, optimizer_name, lr, hparams):
11411142
elif optimizer_name == "Momentum":
11421143
self._opt = tf.train.MomentumOptimizer(
11431144
lr, momentum=hparams.optimizer_momentum_momentum)
1145+
elif optimizer_name == "YellowFin":
1146+
tf.logging.info("Init YellowFin Optimizer.")
1147+
self._opt = YellowFinOptimizer(
1148+
learning_rate=lr, momentum=hparams.optimizer_momentum_momentum)
11441149
else:
11451150
self._opt = tf.contrib.layers.OPTIMIZER_CLS_NAMES[optimizer_name](lr)
11461151

0 commit comments

Comments
 (0)