Skip to content

Commit 83252f9

Browse files
hertschuhcopybara-github
authored andcommitted
Explicitly import estimator from tensorflow as a separate import instead of
accessing it via tf.estimator and depend on the tensorflow estimator target. PiperOrigin-RevId: 436553352
1 parent 31a2eaf commit 83252f9

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

trax/data/tf_inputs.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import numpy as np
3131
import scipy
3232
import tensorflow as tf
33+
from tensorflow import estimator as tf_estimator
3334
import tensorflow_datasets as tfds
3435
import tensorflow_text as tf_text
3536
from trax import data
@@ -380,13 +381,13 @@ def _train_and_eval_dataset_v1(problem_name, data_dir, train_shuffle_files,
380381
hparams = problem.get_hparams()
381382
bair_robot_pushing_hparams(hparams)
382383
train_dataset = problem.dataset(
383-
tf.estimator.ModeKeys.TRAIN,
384+
tf_estimator.ModeKeys.TRAIN,
384385
data_dir,
385386
shuffle_files=train_shuffle_files,
386387
hparams=hparams)
387388
train_dataset = train_dataset.map(_select_features)
388389
eval_dataset = problem.dataset(
389-
tf.estimator.ModeKeys.EVAL,
390+
tf_estimator.ModeKeys.EVAL,
390391
data_dir,
391392
shuffle_files=eval_shuffle_files,
392393
hparams=hparams)

0 commit comments

Comments
 (0)