Skip to content

Commit 13c1d84

Browse files
zoyahavtfx-copybara
authored andcommitted
Workaround in test for an issue with beam 2.40 not allowing a built-in function to be passed to transformations.
This change also adds the TFT label to the raised error when constructing the ptransform fails. PiperOrigin-RevId: 457773563
1 parent 3943431 commit 13c1d84

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

tensorflow_transform/beam/cached_impl_test.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,10 +1104,13 @@ def test_cached_ptransform_analyzer(self, use_tf_compat_v1):
11041104
class _AnalyzerMakeAccumulators(beam.PTransform):
11051105

11061106
def expand(self, pcoll):
1107+
# TODO(b/237367328): Use sum directly when beam>=2.40 allows it.
1108+
def _sum(x):
1109+
return sum(x)
11071110
input_sum = pcoll | beam.FlatMap(
1108-
sum) | 'ReduceSum' >> beam.CombineGlobally(sum)
1111+
_sum) | 'ReduceSum' >> beam.CombineGlobally(_sum)
11091112
size = pcoll | beam.Map(
1110-
np.size) | 'ReduceCount' >> beam.CombineGlobally(sum)
1113+
np.size) | 'ReduceCount' >> beam.CombineGlobally(_sum)
11111114

11121115
return (pcoll.pipeline
11131116
| beam.Create([None])

tensorflow_transform/beam/common.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,11 @@ def visit(self, operation, inputs):
190190
tagged_label = operation.label
191191
else:
192192
tagged_label = '{label}[{tag}]'.format(label=operation.label, tag=tag)
193-
outputs = ((inputs or beam.pvalue.PBegin(self._extra_args.pipeline))
194-
| tagged_label >> ptransform(operation, self._extra_args))
193+
try:
194+
outputs = ((inputs or beam.pvalue.PBegin(self._extra_args.pipeline))
195+
| tagged_label >> ptransform(operation, self._extra_args))
196+
except Exception as e:
197+
raise RuntimeError('Failed to apply: {}'.format(tagged_label)) from e
195198

196199
if isinstance(outputs, beam.pvalue.PCollection):
197200
return (outputs,)

0 commit comments

Comments
 (0)