Skip to content

Commit a9d2911

Browse files
authored
Merge pull request #38 from zoyahav/master
Project import generated by Copybara.
2 parents 51a0c5f + 03f16fb commit a9d2911

File tree

4 files changed

+52
-57
lines changed

4 files changed

+52
-57
lines changed

tensorflow_transform/analyzers.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -49,28 +49,36 @@ class Analyzer(object):
4949
5050
Args:
5151
inputs: The inputs to the analyzer.
52-
output_tensors_and_is_asset: List of pairs of `(Tensor, bool)` for each
53-
output. The `Tensor`s are typically placeholders; they will be later
54-
be replaced with analysis results. The boolean value states whether this
55-
Tensor represents an asset filename or not.
52+
output_dtype_shape_and_is_asset: List of tuples of `(DType, Shape, bool)`
53+
for each output. A tf.placeholder with the given DType and Shape will be
54+
constructed to represent the output of the analyzer, and this placeholder
55+
will eventually be replaced by the actual value of the analyzer. The
56+
boolean value states whether this Tensor represents an asset filename or
57+
not.
5658
spec: A description of the computation to be done.
59+
name: Similar to a TF op name. Used to define a unique scope for this
60+
analyzer, which can be used for debugging info.
5761
5862
Raises:
5963
ValueError: If the inputs are not all `Tensor`s.
6064
"""
6165

62-
def __init__(self, inputs, output_tensors_and_is_asset, spec):
66+
def __init__(self, inputs, output_dtype_shape_and_is_asset, spec, name):
6367
for tensor in inputs:
6468
if not isinstance(tensor, tf.Tensor):
6569
raise ValueError('Analyzers can only accept `Tensor`s as inputs')
6670
self._inputs = inputs
67-
for output_tensor, is_asset in output_tensors_and_is_asset:
68-
if is_asset and output_tensor.dtype != tf.string:
69-
raise ValueError(('Tensor {} cannot represent an asset, because it is '
70-
'not a string.').format(output_tensor.name))
71-
self._outputs = [output_tensor
72-
for output_tensor, _ in output_tensors_and_is_asset]
73-
self._output_is_asset_map = dict(output_tensors_and_is_asset)
71+
self._outputs = []
72+
self._output_is_asset_map = {}
73+
with tf.name_scope(name) as scope:
74+
self._name = scope
75+
for dtype, shape, is_asset in output_dtype_shape_and_is_asset:
76+
output_tensor = tf.placeholder(dtype, shape)
77+
if is_asset and output_tensor.dtype != tf.string:
78+
raise ValueError(('Tensor {} cannot represent an asset, because it '
79+
'is not a string.').format(output_tensor.name))
80+
self._outputs.append(output_tensor)
81+
self._output_is_asset_map[output_tensor] = is_asset
7482
self._spec = spec
7583
tf.add_to_collection(ANALYZER_COLLECTION, self)
7684

@@ -86,6 +94,10 @@ def outputs(self):
8694
def spec(self):
8795
return self._spec
8896

97+
@property
98+
def name(self):
99+
return self._name
100+
89101
def output_is_asset(self, output_tensor):
90102
return self._output_is_asset_map[output_tensor]
91103

@@ -131,11 +143,9 @@ def _numeric_combine(x, combiner_type, reduce_instance_dims=True):
131143
# If reducing over batch dimensions, with unknown shape, the result will
132144
# also have unknown shape.
133145
shape = None
134-
with tf.name_scope(combiner_type):
135-
spec = NumericCombineSpec(x.dtype, combiner_type, reduce_instance_dims)
136-
return Analyzer([x],
137-
[(tf.placeholder(x.dtype, shape), False)],
138-
spec).outputs[0]
146+
spec = NumericCombineSpec(x.dtype, combiner_type, reduce_instance_dims)
147+
return Analyzer(
148+
[x], [(x.dtype, shape, False)], spec, combiner_type).outputs[0]
139149

140150

141151
def min(x, reduce_instance_dims=True): # pylint: disable=redefined-builtin
@@ -381,9 +391,7 @@ def uniques(x, top_k=None, frequency_threshold=None,
381391

382392
spec = UniquesSpec(tf.string, top_k, frequency_threshold,
383393
vocab_filename, store_frequency)
384-
return Analyzer([x],
385-
[(tf.placeholder(tf.string, []), True)],
386-
spec).outputs[0]
394+
return Analyzer([x], [(tf.string, [], True)], spec, 'uniques').outputs[0]
387395

388396

389397
class QuantilesSpec(object):

tensorflow_transform/beam/impl.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -513,9 +513,8 @@ def __init__(self, analyzers, base_temp_dir):
513513
def expand(self, analyzer_input_values):
514514
# For each analyzer output, look up its input values (by tensor name)
515515
# and run the analyzer on these values.
516-
#
517516
result = {}
518-
for idx, analyzer in enumerate(self._analyzers):
517+
for analyzer in self._analyzers:
519518
temp_assets_dir = _make_unique_temp_dir(self._base_temp_dir)
520519
tf.gfile.MkDir(temp_assets_dir)
521520
analyzer_impl = analyzer_impls._impl_for_analyzer(
@@ -525,10 +524,10 @@ def expand(self, analyzer_input_values):
525524
assert len(analyzer.inputs) == 1
526525
output_pcolls = (
527526
analyzer_input_values
528-
| 'Extract_%d' % idx >> beam.Map(
527+
| 'ExtractInput[%s]' % analyzer.name >> beam.Map(
529528
lambda batch, key: batch[key],
530529
key=analyzer.inputs[0].name)
531-
| 'Analyze_%d' % idx >> analyzer_impl)
530+
| 'Analyze[%s]' % analyzer.name >> analyzer_impl)
532531
assert len(analyzer.outputs) == len(output_pcolls), (
533532
'Analyzer outputs don\'t match the expected outputs from the '
534533
'Analyzer definition: %d != %d' %
@@ -537,7 +536,7 @@ def expand(self, analyzer_input_values):
537536
for collection_idx, (tensor, pcoll) in enumerate(
538537
zip(analyzer.outputs, output_pcolls)):
539538
is_asset = analyzer.output_is_asset(tensor)
540-
pcoll |= ('WrapAsTensorValue_%d_%d' % (idx, collection_idx)
539+
pcoll |= ('WrapAsTensorValue[%s][%d]' % (analyzer.name, collection_idx)
541540
>> beam.Map(_TensorValue, is_asset))
542541
result[tensor.name] = pcoll
543542
return result
@@ -711,15 +710,15 @@ def expand(self, dataset):
711710
graph, inputs, analyzer_inputs, unbound_saved_model_dir)
712711
saved_model_dir = (
713712
tensor_pcoll_mapping
714-
| 'CreateSavedModelForAnaylzerInputs_%d' % level
713+
| 'CreateSavedModelForAnaylzerInputs[%d]' % level
715714
>> _ReplaceTensorsWithConstants(
716715
unbound_saved_model_dir, base_temp_dir, input_values.pipeline))
717716

718717
# Run this saved model on the input dataset to obtain the inputs to the
719718
# analyzers.
720719
analyzer_input_values = (
721720
input_values
722-
| 'ComputeAnalyzerInputs_%d' % level >> beam.ParDo(
721+
| 'ComputeAnalyzerInputs[%d]' % level >> beam.ParDo(
723722
_RunMetaGraphDoFn(
724723
input_schema,
725724
analyzer_inputs_schema,

tensorflow_transform/coders/example_proto_coder.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -122,19 +122,21 @@ class _FixedLenFeatureHandler(object):
122122
def __init__(self, name, feature_spec):
123123
self._name = name
124124
self._np_dtype = feature_spec.dtype.as_numpy_dtype
125+
self._default_value = feature_spec.default_value
125126
self._value_fn = _make_feature_value_fn(feature_spec.dtype)
126127
self._shape = feature_spec.shape
127128
self._rank = len(feature_spec.shape)
129+
if self._rank > 0 and self._default_value:
130+
raise ValueError('FixedLenFeature %r got default value for rank > 0, '
131+
'only scalar default values are supported'
132+
% (self._name,))
133+
if isinstance(self._default_value, list):
134+
raise ValueError('FixedLenFeature %r got non-scalar default value, '
135+
'only scalar default values are supported' %
136+
(self._name,))
128137
self._size = 1
129138
for dim in feature_spec.shape:
130139
self._size *= dim
131-
self._default_value = feature_spec.default_value
132-
if self._default_value:
133-
if list(np.asarray(self._default_value).shape) != self._shape:
134-
raise ValueError(
135-
'FixedLenFeature %r got default value with incorrect shape' %
136-
(self._name,))
137-
self._default_value = np.asarray(self._default_value).reshape(-1).tolist()
138140

139141
@property
140142
def name(self):
@@ -152,7 +154,7 @@ def parse_value(self, feature_map):
152154
feature = feature_map[self._name]
153155
values = self._value_fn(feature)
154156
elif self._default_value is not None:
155-
values = self._default_value
157+
values = [self._default_value]
156158
else:
157159
values = []
158160

tensorflow_transform/coders/example_proto_coder_test.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -168,16 +168,8 @@ def test_example_proto_coder(self):
168168

169169
def test_example_proto_coder_default_value(self):
170170
input_schema = dataset_schema.from_feature_spec({
171-
'scalar_feature_3':
172-
tf.FixedLenFeature(shape=[], dtype=tf.float32, default_value=1.0),
173-
'1d_vector_feature':
174-
tf.FixedLenFeature(
175-
shape=[1], dtype=tf.float32, default_value=[2.0]),
176-
'2d_vector_feature':
177-
tf.FixedLenFeature(
178-
shape=[2, 2],
179-
dtype=tf.float32,
180-
default_value=[[1.0, 2.0], [3.0, 4.0]])
171+
'scalar_feature_3': tf.FixedLenFeature(shape=[], dtype=tf.float32,
172+
default_value=1.0),
181173
})
182174
coder = example_proto_coder.ExampleProtoCoder(input_schema)
183175

@@ -193,31 +185,25 @@ def test_example_proto_coder_default_value(self):
193185
# Assert the data is decoded into the expected format.
194186
expected_decoded = {
195187
'scalar_feature_3': 1.0,
196-
'1d_vector_feature': [2.0],
197-
'2d_vector_feature': [[1.0, 2.0], [3.0, 4.0]]
198188
}
199189
decoded = coder.decode(data)
200190
np.testing.assert_equal(expected_decoded, decoded)
201191

202192
def test_example_proto_coder_bad_default_value(self):
203193
input_schema = dataset_schema.from_feature_spec({
204-
'1d_vector_feature':
205-
tf.FixedLenFeature(
206-
shape=[2], dtype=tf.float32, default_value=[1.0]),
194+
'scalar_feature_2': tf.FixedLenFeature(shape=[2], dtype=tf.float32,
195+
default_value=[1.0, 2.0]),
207196
})
208197
with self.assertRaisesRegexp(ValueError,
209-
'got default value with incorrect shape'):
198+
'only scalar default values are supported'):
210199
example_proto_coder.ExampleProtoCoder(input_schema)
211200

212201
input_schema = dataset_schema.from_feature_spec({
213-
'2d_vector_feature':
214-
tf.FixedLenFeature(
215-
shape=[2, 3],
216-
dtype=tf.float32,
217-
default_value=[[1.0, 1.0], [1.0]]),
202+
'scalar_feature_2': tf.FixedLenFeature(shape=[], dtype=tf.float32,
203+
default_value=[1.0]),
218204
})
219205
with self.assertRaisesRegexp(ValueError,
220-
'got default value with incorrect shape'):
206+
'only scalar default values are supported'):
221207
example_proto_coder.ExampleProtoCoder(input_schema)
222208

223209
def test_example_proto_coder_picklable(self):

0 commit comments

Comments
 (0)