Skip to content

Commit 61a6797

Browse files
ispirmustafatensorflower-gardener
authored andcommitted
Simplified estimator logic by MonitoredSession.
Removed graph_action usage. Change: 144126485
1 parent 3e59f05 commit 61a6797

File tree

2 files changed

+319
-348
lines changed

2 files changed

+319
-348
lines changed

tensorflow/contrib/factorization/python/ops/gmm.py

+99-3
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,28 @@
2121
from __future__ import division
2222
from __future__ import print_function
2323

24+
import time
25+
2426
import numpy as np
2527

2628
from tensorflow.contrib import framework
2729
from tensorflow.contrib.factorization.python.ops import gmm_ops
2830
from tensorflow.contrib.framework.python.framework import checkpoint_utils
2931
from tensorflow.contrib.framework.python.ops import variables
30-
from tensorflow.contrib.learn.python.learn.estimators import estimator
32+
from tensorflow.contrib.learn.python.learn import graph_actions
33+
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
34+
from tensorflow.contrib.learn.python.learn.estimators import estimator as estimator_lib
35+
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
3136
from tensorflow.contrib.learn.python.learn.estimators._sklearn import TransformerMixin
3237
from tensorflow.contrib.learn.python.learn.learn_io import data_feeder
3338
from tensorflow.python.framework import constant_op
39+
from tensorflow.python.framework import ops
40+
from tensorflow.python.framework import random_seed as random_seed_lib
3441
from tensorflow.python.ops import array_ops
3542
from tensorflow.python.ops import math_ops
3643
from tensorflow.python.ops import state_ops
3744
from tensorflow.python.ops.control_flow_ops import with_dependencies
45+
from tensorflow.python.platform import tf_logging as logging
3846

3947

4048
def _streaming_sum(scalar_tensor):
@@ -44,7 +52,7 @@ def _streaming_sum(scalar_tensor):
4452
return sum_metric, sum_update
4553

4654

47-
class GMM(estimator.Estimator, TransformerMixin):
55+
class GMM(estimator_lib.Estimator, TransformerMixin):
4856
"""GMM clustering."""
4957
SCORES = 'scores'
5058
ASSIGNMENTS = 'assignments'
@@ -116,7 +124,8 @@ def fit(self, x, y=None, monitors=None, logdir=None, steps=None):
116124
self._data_feeder = data_feeder.setup_train_data_feeder(x, None,
117125
self._num_clusters,
118126
self.batch_size)
119-
self._train_model(
127+
_legacy_train_model( # pylint: disable=protected-access
128+
self,
120129
input_fn=self._data_feeder.input_builder,
121130
feed_fn=self._data_feeder.get_feed_dict_fn(),
122131
steps=steps or self.steps,
@@ -218,3 +227,90 @@ def _get_eval_ops(self, features, _, unused_metrics):
218227
self._covariance_type,
219228
self._params)
220229
return {GMM.SCORES: _streaming_sum(math_ops.reduce_sum(losses))}
230+
231+
232+
# TODO(xavigonzalvo): delete this after implementing model-fn based Estimator.
233+
def _legacy_train_model(estimator,
234+
input_fn,
235+
steps,
236+
feed_fn=None,
237+
init_op=None,
238+
init_feed_fn=None,
239+
init_fn=None,
240+
device_fn=None,
241+
monitors=None,
242+
log_every_steps=100,
243+
fail_on_nan_loss=True,
244+
max_steps=None):
245+
"""Legacy train function of Estimator."""
246+
if hasattr(estimator.config, 'execution_mode'):
247+
if estimator.config.execution_mode not in ('all', 'train'):
248+
return
249+
250+
# Stagger startup of worker sessions based on task id.
251+
sleep_secs = min(
252+
estimator.config.training_worker_max_startup_secs,
253+
estimator.config.task_id *
254+
estimator.config.training_worker_session_startup_stagger_secs)
255+
if sleep_secs:
256+
logging.info('Waiting %d secs before starting task %d.', sleep_secs,
257+
estimator.config.task_id)
258+
time.sleep(sleep_secs)
259+
260+
# Device allocation
261+
device_fn = device_fn or estimator._device_fn # pylint: disable=protected-access
262+
263+
with ops.Graph().as_default() as g, g.device(device_fn):
264+
random_seed_lib.set_random_seed(estimator.config.tf_random_seed)
265+
global_step = framework.create_global_step(g)
266+
features, labels = input_fn()
267+
estimator._check_inputs(features, labels) # pylint: disable=protected-access
268+
269+
# The default return type of _get_train_ops is ModelFnOps. But there are
270+
# some subclasses of tf.contrib.learn.Estimator which override this
271+
# method and use the legacy signature, namely _get_train_ops returns a
272+
# (train_op, loss) tuple. The following else-statement code covers these
273+
# cases, but will soon be deleted after the subclasses are updated.
274+
# TODO(b/32664904): Update subclasses and delete the else-statement.
275+
train_ops = estimator._get_train_ops(features, labels) # pylint: disable=protected-access
276+
if isinstance(train_ops, model_fn_lib.ModelFnOps): # Default signature
277+
train_op = train_ops.train_op
278+
loss_op = train_ops.loss
279+
if estimator.config.is_chief:
280+
hooks = train_ops.training_chief_hooks + train_ops.training_hooks
281+
else:
282+
hooks = train_ops.training_hooks
283+
else: # Legacy signature
284+
if len(train_ops) != 2:
285+
raise ValueError('Expected a tuple of train_op and loss, got {}'.format(
286+
train_ops))
287+
train_op = train_ops[0]
288+
loss_op = train_ops[1]
289+
hooks = []
290+
291+
hooks += monitor_lib.replace_monitors_with_hooks(monitors, estimator)
292+
293+
ops.add_to_collection(ops.GraphKeys.LOSSES, loss_op)
294+
return graph_actions._monitored_train( # pylint: disable=protected-access
295+
graph=g,
296+
output_dir=estimator.model_dir,
297+
train_op=train_op,
298+
loss_op=loss_op,
299+
global_step_tensor=global_step,
300+
init_op=init_op,
301+
init_feed_dict=init_feed_fn() if init_feed_fn is not None else None,
302+
init_fn=init_fn,
303+
log_every_steps=log_every_steps,
304+
supervisor_is_chief=estimator.config.is_chief,
305+
supervisor_master=estimator.config.master,
306+
supervisor_save_model_secs=estimator.config.save_checkpoints_secs,
307+
supervisor_save_model_steps=estimator.config.save_checkpoints_steps,
308+
supervisor_save_summaries_steps=estimator.config.save_summary_steps,
309+
keep_checkpoint_max=estimator.config.keep_checkpoint_max,
310+
keep_checkpoint_every_n_hours=(
311+
estimator.config.keep_checkpoint_every_n_hours),
312+
feed_fn=feed_fn,
313+
steps=steps,
314+
fail_on_nan_loss=fail_on_nan_loss,
315+
hooks=hooks,
316+
max_steps=max_steps)

0 commit comments

Comments
 (0)