Skip to content

Commit 3014617

Browse files
tf-transform-teamzoyahav
authored andcommitted
Project import generated by Copybara.
PiperOrigin-RevId: 200715809
1 parent 91e2f84 commit 3014617

File tree

7 files changed

+142
-169
lines changed

7 files changed

+142
-169
lines changed

RELEASE.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
* tft.mean now supports SparseTensor when reduce_instance_dimensions=True.
2929
In this case it returns a scalar mean computed over the non-missing values of
3030
the SparseTensor.
31+
* tft.mean now supports SparseTensor when reduce_instance_dimensions=False.
32+
In this case it returns a vector mean computed over the non-missing values of
33+
the SparseTensor.
3134
* Update examples to use "core" TensorFlow estimator API (`tf.estimator`).
3235

3336
## Breaking changes

tensorflow_transform/analyzers.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -470,16 +470,23 @@ def mean(x, reduce_instance_dims=True, name=None, output_dtype=None):
470470
if output_dtype is None:
471471
raise TypeError('Tensor type %r is not supported' % x.dtype)
472472
sum_dtype, sum_fn = _sum_combine_fn_and_dtype(x.dtype)
473-
if isinstance(x, tf.SparseTensor):
474-
if not reduce_instance_dims:
475-
raise TypeError(
476-
'SparseTensor is only supported when reduce_instance_dims=True')
477-
x = x.values
478473
with tf.name_scope(name, 'mean'):
479-
# For now _numeric_combine will return a tuple with as many elements as the
480-
# input tuple.
474+
if isinstance(x, tf.SparseTensor):
475+
if reduce_instance_dims:
476+
ones_values, x_values = tf.ones_like(x.values), x.values
477+
else:
478+
sparse_ones = tf.SparseTensor(
479+
indices=x.indices,
480+
values=tf.ones_like(x.values),
481+
dense_shape=x.dense_shape)
482+
ones_values = tf.sparse_reduce_sum(sparse_ones, axis=0, keep_dims=True)
483+
x = tf.cast(x, output_dtype)
484+
ones_values = tf.cast(ones_values, output_dtype)
485+
x_values = tf.sparse_reduce_sum(x, axis=0, keep_dims=True)
486+
else:
487+
ones_values, x_values = tf.ones_like(x), x
481488
x_count, x_sum = _numeric_combine( # pylint: disable=unbalanced-tuple-unpacking
482-
[tf.ones_like(x), x],
489+
[ones_values, x_values],
483490
sum_fn,
484491
reduce_instance_dims,
485492
output_dtypes=[sum_dtype, sum_dtype])

tensorflow_transform/api.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,6 @@ def preprocessing_fn(inputs):
4848
Beam implementation.
4949
"""
5050

51-
import collections
52-
5351
import tensorflow as tf
5452
from tensorflow_transform import analyzers
5553

@@ -168,37 +166,32 @@ def _convert_label(x):
168166
return FunctionApplication(fn, args).user_output
169167

170168

171-
# min_value and max_value are tensor names.
172-
_SchemaOverride = collections.namedtuple(
173-
'SchemaOverride', ['min_value', 'max_value'])
174-
175-
176-
_TF_METADATA_TENSORS_COLLECTION = 'tft_metadata_tensors'
177-
_TF_METADATA_SCHEMA_OVERRIDES_COLLECTION = 'tft_metadata_schema_overrides'
169+
# Names of collections, which should all be the same length and contain tensors.
170+
# Each tensor in the first collection should have its min/max described by the
171+
# tensors in the other two collections.
172+
_TF_METADATA_TENSOR_COLLECTION = 'tft_schema_override_tensor'
173+
_TF_METADATA_TENSOR_MIN_COLLECTION = 'tft_schema_override_min'
174+
_TF_METADATA_TENSOR_MAX_COLLECTION = 'tft_schema_override_max'
178175

179176

180177
def set_tensor_schema_overrides(tensor, min_value, max_value):
181-
"""Override parts of the schema of a `Tensor` or `SparseTensor`."""
182-
if not (isinstance(tensor, tf.Tensor) or isinstance(tensor, tf.SparseTensor)):
183-
raise ValueError(
184-
'tensor {} was not a Tensor or SparseTensor'.format(tensor))
178+
"""Override parts of the schema of a `Tensor`."""
179+
if not isinstance(tensor, tf.Tensor):
180+
raise ValueError('tensor {} was not a Tensor'.format(tensor))
185181
if not isinstance(min_value, tf.Tensor):
186182
raise ValueError('min_vaue {} was not a Tensor'.format(min_value))
187183
if not isinstance(max_value, tf.Tensor):
188184
raise ValueError('max_vaue {} was not a Tensor'.format(min_value))
189-
190-
tf.add_to_collection(_TF_METADATA_TENSORS_COLLECTION, tensor)
191-
192-
# Construct a _SchemaOverride using the tensor names of min_value and
193-
# max_value.
194-
tf.add_to_collection(_TF_METADATA_SCHEMA_OVERRIDES_COLLECTION,
195-
_SchemaOverride(min_value.name, max_value.name))
185+
tf.add_to_collection(_TF_METADATA_TENSOR_COLLECTION, tensor)
186+
tf.add_to_collection(_TF_METADATA_TENSOR_MIN_COLLECTION, min_value)
187+
tf.add_to_collection(_TF_METADATA_TENSOR_MAX_COLLECTION, max_value)
196188

197189

198190
def get_tensor_schema_overrides():
199-
"""Gets a dict from `Tensor` or `SparseTensor`s to `_SchemaOverride`s."""
200-
tensors = tf.get_collection(_TF_METADATA_TENSORS_COLLECTION)
201-
schema_overrides = tf.get_collection(_TF_METADATA_SCHEMA_OVERRIDES_COLLECTION)
202-
assert len(tensors) == len(schema_overrides), '{} != {}'.format(
203-
tensors, schema_overrides)
204-
return dict(zip(tensors, schema_overrides))
191+
"""Gets a dict from `Tensor`s to pairs of `Tensor`s containing min/max."""
192+
tensors = tf.get_collection(_TF_METADATA_TENSOR_COLLECTION)
193+
min_values = tf.get_collection(_TF_METADATA_TENSOR_MIN_COLLECTION)
194+
max_values = tf.get_collection(_TF_METADATA_TENSOR_MAX_COLLECTION)
195+
assert len(tensors) == len(min_values), '{} != {}'.format(tensors, min_values)
196+
assert len(tensors) == len(max_values), '{} != {}'.format(tensors, max_values)
197+
return dict(zip(tensors, zip(min_values, max_values)))

tensorflow_transform/beam/impl.py

Lines changed: 54 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -408,8 +408,6 @@ def _write_saved_transform(graph, inputs, outputs, saved_model_dir):
408408
# warnings.
409409
# pylint: disable=protected-access
410410
collections_blacklist = [
411-
tft_api._TF_METADATA_TENSORS_COLLECTION,
412-
tft_api._TF_METADATA_SCHEMA_OVERRIDES_COLLECTION,
413411
tft_api.FUNCTION_APPLICATION_COLLECTION,
414412
tft_analyzers.ANALYZER_COLLECTION
415413
]
@@ -629,112 +627,60 @@ def expand(self, inputs):
629627
return result
630628

631629

632-
class _ComputeDeferredMetadata(beam.PTransform):
633-
"""Extracts values of tensors from a transform function.
630+
def _augment_metadata(saved_model_dir, metadata):
631+
"""Augments the metadata with min/max values stored in the SavedModel.
634632
635-
This transform takes the path to a SavedModel in its constructor, and in its
636-
expand() method accepts a mapping from tensors to PCollections. When run, it
637-
replaces the tensors corresponding to the keys of this mapping, with the
638-
values wrapped in the PCollections. It then extracts the values of some
639-
tensors in the new graph. This allows us to compute values that depend on
640-
values in the tensor-PCollection mapping in arbitrary ways, where the values
641-
are represented by tensors in the graph that depend on the tensor-PCollection
642-
mapping (but not on the inputs to the graph).
633+
Takes the min/max values of tensors stored in the SavedModel, and uses these
634+
to augment the metadata. For each feature in the metadata, the min/max of
635+
the corresponding `Tensor` are used to augment the schema. For a feature
636+
represented by a `SparseTensor` we use the min/max for the `values` field of
637+
the `SparseTensor`.
643638
644639
Args:
640+
saved_model_dir: Location of a SavedModel
645641
metadata: A `DatasetMetadata`
646-
column_schema_overrides: A dict from column names to `api._SchemaOverride`s
647-
saved_model_dir: The model to extract the constants from.
648-
pipeline: The beam Pipeline.
649-
"""
650-
651-
def __init__(self, metadata, column_schema_overrides, saved_model_dir,
652-
pipeline):
653-
self._metadata = metadata
654-
self._column_schema_overrides = column_schema_overrides
655-
self._saved_model_dir = saved_model_dir
656-
# Generally the pipeline is inferred from its inputs, however we need
657-
# to know the pipeline for beam.Create.
658-
self.pipeline = pipeline
659642
660-
def expand(self, tensor_pcoll_mapping):
661-
"""Converts a dict of statistics to a transform function.
662-
663-
Args:
664-
tensor_pcoll_mapping: A dictionary mapping `Tensor`s to a singleton
665-
PCollection containing a _TensorValue.
666-
667-
Returns:
668-
A dict from tensor names to singleton `PCollection`s.
669-
"""
670-
# Convert tensor_value_mapping into a DictPCollectionView so it can be
671-
# passed as a side input to the beam Map below.
672-
tensor_value_pairs = []
673-
for name, pcoll in six.iteritems(tensor_pcoll_mapping):
674-
tensor_value_pairs.append(
675-
pcoll
676-
| 'AddName[%s]' % name >> beam.Map(lambda x, name=name: (name, x)))
677-
tensor_value_mapping = beam.pvalue.AsDict(
678-
tensor_value_pairs
679-
| 'MergeTensorValuePairs' >> beam.Flatten(pipeline=self.pipeline))
680-
681-
def compute_deferred_metadata(metadata, column_schema_overrides,
682-
saved_model_dir, tensor_value_mapping):
683-
"""Extracts constant values from graph."""
684-
tensor_names = {
685-
tensor_name
686-
for override in six.itervalues(column_schema_overrides)
687-
for tensor_name in [override.min_value, override.max_value]}
688-
689-
graph = tf.Graph()
690-
with graph.as_default():
691-
tensor_replacement_map = {}
692-
for orig_tensor_name, (value,
693-
is_asset) in six.iteritems(tensor_value_mapping):
694-
new_tensor = tf.constant(value)
695-
if is_asset:
696-
# Any newly frozen constant tensors containing filenames must be
697-
# added to the ASSET_FILENAMES collection.
698-
graph.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, new_tensor)
699-
tensor_replacement_map[orig_tensor_name] = new_tensor
700-
701-
with tf.Session(graph=graph) as session:
702-
tensors_by_name = (
703-
saved_transform_io.fetch_tensor_values(
704-
saved_model_dir, tensor_replacement_map, tensor_names))
705-
session.run(tf.global_variables_initializer())
706-
session.run(tf.tables_initializer())
707-
tensor_values_by_name = session.run(tensors_by_name)
708-
709-
new_column_schemas = {}
710-
for key, column_schema in six.iteritems(metadata.schema.column_schemas):
711-
if key in column_schema_overrides:
712-
override = column_schema_overrides[key]
713-
min_value = tensor_values_by_name[override.min_value]
714-
max_value = tensor_values_by_name[override.max_value]
715-
assert column_schema.domain.dtype == tf.int64
716-
assert isinstance(column_schema.domain, dataset_schema.IntDomain)
717-
# Create a new column schema. An override always results in a
718-
# categorical column.
719-
new_column_schemas[key] = dataset_schema.ColumnSchema(
720-
dataset_schema.IntDomain(tf.int64, min_value, max_value,
721-
is_categorical=True),
722-
column_schema.axes,
723-
column_schema.representation)
724-
else:
725-
new_column_schemas[key] = column_schema
726-
727-
return dataset_metadata.DatasetMetadata(dataset_schema.Schema(
728-
new_column_schemas))
729-
730-
return (
731-
self.pipeline
732-
| 'CreateMetadata' >> beam.Create([self._metadata])
733-
| 'ExtractScalarConstants' >> beam.Map(
734-
compute_deferred_metadata,
735-
column_schema_overrides=self._column_schema_overrides,
736-
saved_model_dir=self._saved_model_dir,
737-
tensor_value_mapping=tensor_value_mapping))
643+
Returns:
644+
An augmented DatasetMetadata. The original DatasetMetadata is unchanged.
645+
"""
646+
with tf.Graph().as_default() as graph:
647+
with tf.Session(graph=graph) as session:
648+
_, output_tensor_by_name = (
649+
saved_transform_io.partially_apply_saved_transform_internal(
650+
saved_model_dir, {}))
651+
652+
# Get overrides for the min/max of tensors from the graph, and use these
653+
# determine overrides for the min/max of the outputs of the graph.
654+
tensor_schema_overrides = tft_api.get_tensor_schema_overrides()
655+
column_schema_overrides = {}
656+
for name, tensor in six.iteritems(output_tensor_by_name):
657+
if isinstance(tensor, tf.SparseTensor):
658+
tensor = tensor.values
659+
if tensor in tensor_schema_overrides:
660+
column_schema_overrides[name] = tensor_schema_overrides[tensor]
661+
662+
session.run(tf.global_variables_initializer())
663+
session.run(tf.tables_initializer())
664+
column_schema_override_values = session.run(column_schema_overrides)
665+
666+
new_column_schemas = {}
667+
for key, column_schema in six.iteritems(metadata.schema.column_schemas):
668+
if key in column_schema_override_values:
669+
min_value, max_value = column_schema_override_values[key]
670+
assert column_schema.domain.dtype == tf.int64
671+
assert isinstance(column_schema.domain, dataset_schema.IntDomain)
672+
# Create a new column schema. An override always results in a
673+
# categorical column.
674+
new_column_schemas[key] = dataset_schema.ColumnSchema(
675+
dataset_schema.IntDomain(tf.int64, min_value, max_value,
676+
is_categorical=True),
677+
column_schema.axes,
678+
column_schema.representation)
679+
else:
680+
new_column_schemas[key] = column_schema
681+
682+
return dataset_metadata.DatasetMetadata(dataset_schema.Schema(
683+
new_column_schemas))
738684

739685

740686
class AnalyzeDataset(beam.PTransform):
@@ -860,23 +806,15 @@ def expand(self, dataset):
860806
# refer to values of tensors in the graph. The override tensors must
861807
# be "constant" in that they don't depend on input data. The tensors can
862808
# depend on analyzer outputs though. This allows us to set metadata that
863-
# depends on analyzer outputs. _ComputeDeferredMetadata will use
864-
# tensor_pcoll_mapping to compute the metadata in a deferred manner, once
865-
# the analyzer outputs are known.
809+
# depends on analyzer outputs. _augment_metadata will use the analyzer
810+
# outputs stored in `transform_fn` to compute the metadata in a
811+
# deferred manner, once the analyzer outputs are known.
866812
metadata = dataset_metadata.DatasetMetadata(
867813
schema=impl_helper.infer_feature_schema(outputs))
868814

869-
tensor_schema_overrides = tft_api.get_tensor_schema_overrides()
870-
column_schema_overrides = {
871-
key: tensor_schema_overrides[tensor]
872-
for key, tensor in six.iteritems(outputs)
873-
if tensor in tensor_schema_overrides}
874-
875815
deferred_metadata = (
876-
tensor_pcoll_mapping
877-
| 'ComputeDeferredMetadata' >>
878-
_ComputeDeferredMetadata(metadata, column_schema_overrides,
879-
saved_model_dir, input_values.pipeline))
816+
transform_fn
817+
| 'ComputeDeferredMetadata' >> beam.Map(_augment_metadata, metadata))
880818

881819
full_metadata = beam_metadata_io.BeamDatasetMetadata(
882820
metadata, deferred_metadata)

tensorflow_transform/beam/impl_test.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,7 +1282,7 @@ def analyzer_fn(inputs):
12821282
self.assertAnalyzerOutputs(
12831283
input_data, input_metadata, analyzer_fn, expected_outputs)
12841284

1285-
def testNumericMeanWithSparseTensor(self):
1285+
def testNumericMeanWithSparseTensorReduceTrue(self):
12861286

12871287
def analyzer_fn(inputs):
12881288
return {'mean': tft.mean(inputs['a'])}
@@ -1295,6 +1295,52 @@ def analyzer_fn(inputs):
12951295
self.assertAnalyzerOutputs(input_data, input_metadata, analyzer_fn,
12961296
expected_outputs)
12971297

1298+
def testNumericMeanWithSparseTensorReduceFalse(self):
1299+
1300+
def analyzer_fn(inputs):
1301+
return {'mean': tft.mean(inputs['sparse'], False)}
1302+
1303+
input_data = [{
1304+
'sparse': ([0, 1], [0., 1.])
1305+
}, {
1306+
'sparse': ([1, 3], [2., 3.])
1307+
}]
1308+
input_metadata = dataset_metadata.DatasetMetadata({
1309+
'sparse':
1310+
sch.ColumnSchema(
1311+
tf.float32, [4],
1312+
sch.SparseColumnRepresentation(
1313+
'val', [sch.SparseIndexField('idx', False)]))
1314+
})
1315+
expected_outputs = {
1316+
'mean': np.array([0., 1.5, float('nan'), 3.], np.float32)
1317+
}
1318+
self.assertAnalyzerOutputs(input_data, input_metadata, analyzer_fn,
1319+
expected_outputs)
1320+
1321+
def testNumericMeanWithSparseTensorReduceFalseOverflow(self):
1322+
1323+
def analyzer_fn(inputs):
1324+
return {'mean': tft.mean(inputs['sparse'], False)}
1325+
1326+
input_data = [{
1327+
'sparse': ([0, 1], [1, 1])
1328+
}, {
1329+
'sparse': ([1, 3], [2147483647, 3])
1330+
}]
1331+
input_metadata = dataset_metadata.DatasetMetadata({
1332+
'sparse':
1333+
sch.ColumnSchema(
1334+
tf.int32, [4],
1335+
sch.SparseColumnRepresentation(
1336+
'val', [sch.SparseIndexField('idx', False)]))
1337+
})
1338+
expected_outputs = {
1339+
'mean': np.array([1., 1073741824., float('nan'), 3.], np.float32)
1340+
}
1341+
self.assertAnalyzerOutputs(input_data, input_metadata, analyzer_fn,
1342+
expected_outputs)
1343+
12981344
def testNumericAnalyzersWithSparseInputs(self):
12991345
def repeat(in_tensor, value):
13001346
batch_size = tf.shape(in_tensor)[0]
@@ -1327,15 +1373,6 @@ def size_fn(inputs):
13271373
return {'size': repeat(inputs['a'], tft.size(inputs['a']))}
13281374
_ = input_dataset | beam_impl.AnalyzeDataset(size_fn)
13291375

1330-
with self.assertRaises(TypeError):
1331-
def mean_fn(inputs):
1332-
return {
1333-
'mean':
1334-
repeat(inputs['a'],
1335-
tft.mean(inputs['a'], reduce_instance_dims=False))
1336-
}
1337-
_ = input_dataset | beam_impl.AnalyzeDataset(mean_fn)
1338-
13391376
with self.assertRaises(TypeError):
13401377
def var_fn(inputs):
13411378
return {'var': repeat(inputs['a'], tft.var(inputs['a']))}

0 commit comments

Comments
 (0)