21
21
from __future__ import division
22
22
from __future__ import print_function
23
23
24
+ import time
25
+
24
26
import numpy as np
25
27
26
28
from tensorflow .contrib import framework
27
29
from tensorflow .contrib .factorization .python .ops import gmm_ops
28
30
from tensorflow .contrib .framework .python .framework import checkpoint_utils
29
31
from tensorflow .contrib .framework .python .ops import variables
30
- from tensorflow .contrib .learn .python .learn .estimators import estimator
32
+ from tensorflow .contrib .learn .python .learn import graph_actions
33
+ from tensorflow .contrib .learn .python .learn import monitors as monitor_lib
34
+ from tensorflow .contrib .learn .python .learn .estimators import estimator as estimator_lib
35
+ from tensorflow .contrib .learn .python .learn .estimators import model_fn as model_fn_lib
31
36
from tensorflow .contrib .learn .python .learn .estimators ._sklearn import TransformerMixin
32
37
from tensorflow .contrib .learn .python .learn .learn_io import data_feeder
33
38
from tensorflow .python .framework import constant_op
39
+ from tensorflow .python .framework import ops
40
+ from tensorflow .python .framework import random_seed as random_seed_lib
34
41
from tensorflow .python .ops import array_ops
35
42
from tensorflow .python .ops import math_ops
36
43
from tensorflow .python .ops import state_ops
37
44
from tensorflow .python .ops .control_flow_ops import with_dependencies
45
+ from tensorflow .python .platform import tf_logging as logging
38
46
39
47
40
48
def _streaming_sum (scalar_tensor ):
@@ -44,7 +52,7 @@ def _streaming_sum(scalar_tensor):
44
52
return sum_metric , sum_update
45
53
46
54
47
- class GMM (estimator .Estimator , TransformerMixin ):
55
+ class GMM (estimator_lib .Estimator , TransformerMixin ):
48
56
"""GMM clustering."""
49
57
SCORES = 'scores'
50
58
ASSIGNMENTS = 'assignments'
@@ -116,7 +124,8 @@ def fit(self, x, y=None, monitors=None, logdir=None, steps=None):
116
124
self ._data_feeder = data_feeder .setup_train_data_feeder (x , None ,
117
125
self ._num_clusters ,
118
126
self .batch_size )
119
- self ._train_model (
127
+ _legacy_train_model ( # pylint: disable=protected-access
128
+ self ,
120
129
input_fn = self ._data_feeder .input_builder ,
121
130
feed_fn = self ._data_feeder .get_feed_dict_fn (),
122
131
steps = steps or self .steps ,
@@ -218,3 +227,90 @@ def _get_eval_ops(self, features, _, unused_metrics):
218
227
self ._covariance_type ,
219
228
self ._params )
220
229
return {GMM .SCORES : _streaming_sum (math_ops .reduce_sum (losses ))}
230
+
231
+
232
+ # TODO(xavigonzalvo): delete this after implementing model-fn based Estimator.
233
+ def _legacy_train_model (estimator ,
234
+ input_fn ,
235
+ steps ,
236
+ feed_fn = None ,
237
+ init_op = None ,
238
+ init_feed_fn = None ,
239
+ init_fn = None ,
240
+ device_fn = None ,
241
+ monitors = None ,
242
+ log_every_steps = 100 ,
243
+ fail_on_nan_loss = True ,
244
+ max_steps = None ):
245
+ """Legacy train function of Estimator."""
246
+ if hasattr (estimator .config , 'execution_mode' ):
247
+ if estimator .config .execution_mode not in ('all' , 'train' ):
248
+ return
249
+
250
+ # Stagger startup of worker sessions based on task id.
251
+ sleep_secs = min (
252
+ estimator .config .training_worker_max_startup_secs ,
253
+ estimator .config .task_id *
254
+ estimator .config .training_worker_session_startup_stagger_secs )
255
+ if sleep_secs :
256
+ logging .info ('Waiting %d secs before starting task %d.' , sleep_secs ,
257
+ estimator .config .task_id )
258
+ time .sleep (sleep_secs )
259
+
260
+ # Device allocation
261
+ device_fn = device_fn or estimator ._device_fn # pylint: disable=protected-access
262
+
263
+ with ops .Graph ().as_default () as g , g .device (device_fn ):
264
+ random_seed_lib .set_random_seed (estimator .config .tf_random_seed )
265
+ global_step = framework .create_global_step (g )
266
+ features , labels = input_fn ()
267
+ estimator ._check_inputs (features , labels ) # pylint: disable=protected-access
268
+
269
+ # The default return type of _get_train_ops is ModelFnOps. But there are
270
+ # some subclasses of tf.contrib.learn.Estimator which override this
271
+ # method and use the legacy signature, namely _get_train_ops returns a
272
+ # (train_op, loss) tuple. The following else-statement code covers these
273
+ # cases, but will soon be deleted after the subclasses are updated.
274
+ # TODO(b/32664904): Update subclasses and delete the else-statement.
275
+ train_ops = estimator ._get_train_ops (features , labels ) # pylint: disable=protected-access
276
+ if isinstance (train_ops , model_fn_lib .ModelFnOps ): # Default signature
277
+ train_op = train_ops .train_op
278
+ loss_op = train_ops .loss
279
+ if estimator .config .is_chief :
280
+ hooks = train_ops .training_chief_hooks + train_ops .training_hooks
281
+ else :
282
+ hooks = train_ops .training_hooks
283
+ else : # Legacy signature
284
+ if len (train_ops ) != 2 :
285
+ raise ValueError ('Expected a tuple of train_op and loss, got {}' .format (
286
+ train_ops ))
287
+ train_op = train_ops [0 ]
288
+ loss_op = train_ops [1 ]
289
+ hooks = []
290
+
291
+ hooks += monitor_lib .replace_monitors_with_hooks (monitors , estimator )
292
+
293
+ ops .add_to_collection (ops .GraphKeys .LOSSES , loss_op )
294
+ return graph_actions ._monitored_train ( # pylint: disable=protected-access
295
+ graph = g ,
296
+ output_dir = estimator .model_dir ,
297
+ train_op = train_op ,
298
+ loss_op = loss_op ,
299
+ global_step_tensor = global_step ,
300
+ init_op = init_op ,
301
+ init_feed_dict = init_feed_fn () if init_feed_fn is not None else None ,
302
+ init_fn = init_fn ,
303
+ log_every_steps = log_every_steps ,
304
+ supervisor_is_chief = estimator .config .is_chief ,
305
+ supervisor_master = estimator .config .master ,
306
+ supervisor_save_model_secs = estimator .config .save_checkpoints_secs ,
307
+ supervisor_save_model_steps = estimator .config .save_checkpoints_steps ,
308
+ supervisor_save_summaries_steps = estimator .config .save_summary_steps ,
309
+ keep_checkpoint_max = estimator .config .keep_checkpoint_max ,
310
+ keep_checkpoint_every_n_hours = (
311
+ estimator .config .keep_checkpoint_every_n_hours ),
312
+ feed_fn = feed_fn ,
313
+ steps = steps ,
314
+ fail_on_nan_loss = fail_on_nan_loss ,
315
+ hooks = hooks ,
316
+ max_steps = max_steps )
0 commit comments