@@ -408,8 +408,6 @@ def _write_saved_transform(graph, inputs, outputs, saved_model_dir):
408408 # warnings.
409409 # pylint: disable=protected-access
410410 collections_blacklist = [
411- tft_api ._TF_METADATA_TENSORS_COLLECTION ,
412- tft_api ._TF_METADATA_SCHEMA_OVERRIDES_COLLECTION ,
413411 tft_api .FUNCTION_APPLICATION_COLLECTION ,
414412 tft_analyzers .ANALYZER_COLLECTION
415413 ]
@@ -629,112 +627,60 @@ def expand(self, inputs):
629627 return result
630628
631629
632- class _ComputeDeferredMetadata ( beam . PTransform ):
633- """Extracts values of tensors from a transform function .
630+ def _augment_metadata ( saved_model_dir , metadata ):
631+ """Augments the metadata with min/max values stored in the SavedModel .
634632
635- This transform takes the path to a SavedModel in its constructor, and in its
636- expand() method accepts a mapping from tensors to PCollections. When run, it
637- replaces the tensors corresponding to the keys of this mapping, with the
638- values wrapped in the PCollections. It then extracts the values of some
639- tensors in the new graph. This allows us to compute values that depend on
640- values in the tensor-PCollection mapping in arbitrary ways, where the values
641- are represented by tensors in the graph that depend on the tensor-PCollection
642- mapping (but not on the inputs to the graph).
633+ Takes the min/max values of tensors stored in the SavedModel, and uses these
634+ to augment the metadata. For each feature in the metadata, the min/max of
635+ the corresponding `Tensor` are used to augment the schema. For a feature
636+ represented by a `SparseTensor` we use the min/max for the `values` field of
637+ the `SparseTensor`.
643638
644639 Args:
640+ saved_model_dir: Location of a SavedModel
645641 metadata: A `DatasetMetadata`
646- column_schema_overrides: A dict from column names to `api._SchemaOverride`s
647- saved_model_dir: The model to extract the constants from.
648- pipeline: The beam Pipeline.
649- """
650-
651- def __init__ (self , metadata , column_schema_overrides , saved_model_dir ,
652- pipeline ):
653- self ._metadata = metadata
654- self ._column_schema_overrides = column_schema_overrides
655- self ._saved_model_dir = saved_model_dir
656- # Generally the pipeline is inferred from its inputs, however we need
657- # to know the pipeline for beam.Create.
658- self .pipeline = pipeline
659642
660- def expand (self , tensor_pcoll_mapping ):
661- """Converts a dict of statistics to a transform function.
662-
663- Args:
664- tensor_pcoll_mapping: A dictionary mapping `Tensor`s to a singleton
665- PCollection containing a _TensorValue.
666-
667- Returns:
668- A dict from tensor names to singleton `PCollection`s.
669- """
670- # Convert tensor_value_mapping into a DictPCollectionView so it can be
671- # passed as a side input to the beam Map below.
672- tensor_value_pairs = []
673- for name , pcoll in six .iteritems (tensor_pcoll_mapping ):
674- tensor_value_pairs .append (
675- pcoll
676- | 'AddName[%s]' % name >> beam .Map (lambda x , name = name : (name , x )))
677- tensor_value_mapping = beam .pvalue .AsDict (
678- tensor_value_pairs
679- | 'MergeTensorValuePairs' >> beam .Flatten (pipeline = self .pipeline ))
680-
681- def compute_deferred_metadata (metadata , column_schema_overrides ,
682- saved_model_dir , tensor_value_mapping ):
683- """Extracts constant values from graph."""
684- tensor_names = {
685- tensor_name
686- for override in six .itervalues (column_schema_overrides )
687- for tensor_name in [override .min_value , override .max_value ]}
688-
689- graph = tf .Graph ()
690- with graph .as_default ():
691- tensor_replacement_map = {}
692- for orig_tensor_name , (value ,
693- is_asset ) in six .iteritems (tensor_value_mapping ):
694- new_tensor = tf .constant (value )
695- if is_asset :
696- # Any newly frozen constant tensors containing filenames must be
697- # added to the ASSET_FILENAMES collection.
698- graph .add_to_collection (tf .GraphKeys .ASSET_FILEPATHS , new_tensor )
699- tensor_replacement_map [orig_tensor_name ] = new_tensor
700-
701- with tf .Session (graph = graph ) as session :
702- tensors_by_name = (
703- saved_transform_io .fetch_tensor_values (
704- saved_model_dir , tensor_replacement_map , tensor_names ))
705- session .run (tf .global_variables_initializer ())
706- session .run (tf .tables_initializer ())
707- tensor_values_by_name = session .run (tensors_by_name )
708-
709- new_column_schemas = {}
710- for key , column_schema in six .iteritems (metadata .schema .column_schemas ):
711- if key in column_schema_overrides :
712- override = column_schema_overrides [key ]
713- min_value = tensor_values_by_name [override .min_value ]
714- max_value = tensor_values_by_name [override .max_value ]
715- assert column_schema .domain .dtype == tf .int64
716- assert isinstance (column_schema .domain , dataset_schema .IntDomain )
717- # Create a new column schema. An override always results in a
718- # categorical column.
719- new_column_schemas [key ] = dataset_schema .ColumnSchema (
720- dataset_schema .IntDomain (tf .int64 , min_value , max_value ,
721- is_categorical = True ),
722- column_schema .axes ,
723- column_schema .representation )
724- else :
725- new_column_schemas [key ] = column_schema
726-
727- return dataset_metadata .DatasetMetadata (dataset_schema .Schema (
728- new_column_schemas ))
729-
730- return (
731- self .pipeline
732- | 'CreateMetadata' >> beam .Create ([self ._metadata ])
733- | 'ExtractScalarConstants' >> beam .Map (
734- compute_deferred_metadata ,
735- column_schema_overrides = self ._column_schema_overrides ,
736- saved_model_dir = self ._saved_model_dir ,
737- tensor_value_mapping = tensor_value_mapping ))
643+ Returns:
644+ An augmented DatasetMetadata. The original DatasetMetadata is unchanged.
645+ """
646+ with tf .Graph ().as_default () as graph :
647+ with tf .Session (graph = graph ) as session :
648+ _ , output_tensor_by_name = (
649+ saved_transform_io .partially_apply_saved_transform_internal (
650+ saved_model_dir , {}))
651+
652+ # Get overrides for the min/max of tensors from the graph, and use these
653+ # determine overrides for the min/max of the outputs of the graph.
654+ tensor_schema_overrides = tft_api .get_tensor_schema_overrides ()
655+ column_schema_overrides = {}
656+ for name , tensor in six .iteritems (output_tensor_by_name ):
657+ if isinstance (tensor , tf .SparseTensor ):
658+ tensor = tensor .values
659+ if tensor in tensor_schema_overrides :
660+ column_schema_overrides [name ] = tensor_schema_overrides [tensor ]
661+
662+ session .run (tf .global_variables_initializer ())
663+ session .run (tf .tables_initializer ())
664+ column_schema_override_values = session .run (column_schema_overrides )
665+
666+ new_column_schemas = {}
667+ for key , column_schema in six .iteritems (metadata .schema .column_schemas ):
668+ if key in column_schema_override_values :
669+ min_value , max_value = column_schema_override_values [key ]
670+ assert column_schema .domain .dtype == tf .int64
671+ assert isinstance (column_schema .domain , dataset_schema .IntDomain )
672+ # Create a new column schema. An override always results in a
673+ # categorical column.
674+ new_column_schemas [key ] = dataset_schema .ColumnSchema (
675+ dataset_schema .IntDomain (tf .int64 , min_value , max_value ,
676+ is_categorical = True ),
677+ column_schema .axes ,
678+ column_schema .representation )
679+ else :
680+ new_column_schemas [key ] = column_schema
681+
682+ return dataset_metadata .DatasetMetadata (dataset_schema .Schema (
683+ new_column_schemas ))
738684
739685
740686class AnalyzeDataset (beam .PTransform ):
@@ -860,23 +806,15 @@ def expand(self, dataset):
860806 # refer to values of tensors in the graph. The override tensors must
861807 # be "constant" in that they don't depend on input data. The tensors can
862808 # depend on analyzer outputs though. This allows us to set metadata that
863- # depends on analyzer outputs. _ComputeDeferredMetadata will use
864- # tensor_pcoll_mapping to compute the metadata in a deferred manner, once
865- # the analyzer outputs are known.
809+ # depends on analyzer outputs. _augment_metadata will use the analyzer
810+ # outputs stored in `transform_fn` to compute the metadata in a
811+ # deferred manner, once the analyzer outputs are known.
866812 metadata = dataset_metadata .DatasetMetadata (
867813 schema = impl_helper .infer_feature_schema (outputs ))
868814
869- tensor_schema_overrides = tft_api .get_tensor_schema_overrides ()
870- column_schema_overrides = {
871- key : tensor_schema_overrides [tensor ]
872- for key , tensor in six .iteritems (outputs )
873- if tensor in tensor_schema_overrides }
874-
875815 deferred_metadata = (
876- tensor_pcoll_mapping
877- | 'ComputeDeferredMetadata' >>
878- _ComputeDeferredMetadata (metadata , column_schema_overrides ,
879- saved_model_dir , input_values .pipeline ))
816+ transform_fn
817+ | 'ComputeDeferredMetadata' >> beam .Map (_augment_metadata , metadata ))
880818
881819 full_metadata = beam_metadata_io .BeamDatasetMetadata (
882820 metadata , deferred_metadata )
0 commit comments