@@ -241,11 +241,13 @@ class _GraphState(object):
241241 def __init__ (self , saved_model_dir , input_schema , exclude_outputs ,
242242 tf_config ):
243243 self .saved_model_dir = saved_model_dir
244- self .session = tf .Session (graph = tf .Graph (), config = tf_config )
245- with self .session .graph .as_default ():
246- with tf .Session (config = tf_config ):
244+ graph = tf .Graph ()
245+ self .session = tf .Session (graph = graph , config = tf_config )
246+ with graph .as_default ():
247+ with self .session .as_default ():
247248 inputs , outputs = saved_transform_io .partially_apply_saved_transform (
248249 saved_model_dir , {})
250+ self .session .run (tf .global_variables_initializer ())
249251 self .session .run (tf .tables_initializer ())
250252
251253 input_schema_keys = input_schema .column_schemas .keys ()
@@ -342,9 +344,9 @@ def process(self, batch, saved_model_dir):
342344def _assert_tensorflow_version ():
343345 # Fail with a clear error in case we are not using a compatible TF version.
344346 major , minor , _ = tf .__version__ .split ('.' )
345- if int (major ) != 1 or int (minor ) < 4 :
347+ if int (major ) != 1 or int (minor ) < 5 :
346348 raise RuntimeError (
347- 'Tensorflow version >= 1.4 , < 2 is required. Found (%s). Please '
349+ 'TensorFlow version >= 1.5 , < 2 is required. Found (%s). Please '
348350 'install the latest 1.x version from '
349351 'https://github.com/tensorflow/tensorflow. ' % tf .__version__ )
350352
@@ -372,6 +374,8 @@ def _write_saved_transform(graph, inputs, outputs, saved_model_dir):
372374 removed_collections .append ((collection_name ,
373375 graph .get_collection (collection_name )))
374376 graph .clear_collection (collection_name )
377+ # Initialize all variables so they can be saved.
378+ session .run (tf .global_variables_initializer ())
375379 saved_transform_io .write_saved_transform_from_session (
376380 session , inputs , outputs , saved_model_dir )
377381 for collection_name , collection in removed_collections :
@@ -478,6 +482,7 @@ def replace_tensors_with_constant_values(saved_model_dir,
478482 input_tensors , output_tensors = (
479483 saved_transform_io .partially_apply_saved_transform (
480484 saved_model_dir , {}, tensor_replacement_map ))
485+ session .run (tf .global_variables_initializer ())
481486 saved_transform_io .write_saved_transform_from_session (
482487 session , input_tensors , output_tensors , temp_dir )
483488 return temp_dir
@@ -602,6 +607,7 @@ def extract_scalar_constants(tensor_names, saved_model_dir,
602607 tensor_output_map = (
603608 saved_transform_io .fetch_tensor_values (
604609 saved_model_dir , tensor_replacement_map , tensor_names ))
610+ session .run (tf .global_variables_initializer ())
605611 session .run (tf .tables_initializer ())
606612 return session .run (tensor_output_map )
607613
@@ -650,21 +656,43 @@ def expand(self, dataset):
650656
651657 Returns:
652658 A TransformFn containing the deferred transform function.
653- """
654659
660+ Raises:
661+ ValueError: If preprocessing_fn has no outputs.
662+ """
655663 input_values , input_metadata = dataset
656664 input_schema = input_metadata .schema
657665
658666 base_temp_dir = Context .create_base_temp_dir ()
659667
660668 graph = tf .Graph ()
661669 with graph .as_default ():
670+
671+ with tf .name_scope ('inputs' ):
672+ inputs = input_schema .as_batched_placeholders ()
673+ # In order to avoid a bug where import_graph_def fails when the input_map
674+ # and return_elements of an imported graph are the same (b/34288791), we
675+ # avoid using the placeholder of an input column as an output of a graph.
676+ # We do this by applying tf.identity to all inputs of the
677+ # preprocessing_fn. Note this applies at the level of raw tensors.
678+ outputs = self ._preprocessing_fn (impl_helper .copy_tensors (inputs ))
679+
680+ # At this point we check that the preprocessing_fn has at least one
681+ # output. This is because if we allowed the output of preprocessing_fn to
682+ # be empty, we wouldn't be able to determine how many instances to
683+ # "unbatch" the output into.
684+ if not outputs :
685+ raise ValueError ('The preprocessing function returned an empty dict' )
686+
687+ if graph .get_collection (tf .GraphKeys .TRAINABLE_VARIABLES ):
688+ raise ValueError (
689+ 'The preprocessing function contained trainable variables '
690+ '{}' .format (
691+ graph .get_collection_ref (tf .GraphKeys .TRAINABLE_VARIABLES )))
692+
662693 # NOTE: it's important that create_phases is called directly after
663- # run_preprocessing_fn, because we later mutate the graph's
664- # TABLE_INITIALIZERS collection which would break the logic in
665- # create_phases.
666- inputs , outputs = impl_helper .run_preprocessing_fn (
667- self ._preprocessing_fn , input_schema )
694+ # preprocessing_fn, because we later mutate the graph's TABLE_INITIALIZERS
695+ # collection which would break the logic in create_phases.
668696 phases = impl_helper .create_phases ()
669697
670698 # Iterate through levels. tensor_pcoll_mapping is a mapping from tensor
0 commit comments