@@ -112,9 +112,8 @@ def _make_placeholder(tensor_spec):
112
112
return tf .compat .v1 .sparse_placeholder (
113
113
shape = tensor_spec .shape , dtype = tensor_spec .dtype )
114
114
if isinstance (tensor_spec , tf .RaggedTensorSpec ):
115
- # TODO(b/160294509): Switch to public APIs once TF 1 support is dropped.
116
115
return tf .compat .v1 .ragged .placeholder (
117
- tensor_spec ._dtype , tensor_spec ._ragged_rank , value_shape = ()) # pylint: disable=protected-access
116
+ tensor_spec .dtype , tensor_spec .ragged_rank , value_shape = ())
118
117
else :
119
118
return tf .compat .v1 .placeholder (
120
119
shape = tensor_spec .shape , dtype = tensor_spec .dtype )
@@ -164,8 +163,7 @@ def _wrap_as_constant(value, tensor_spec):
164
163
values = tf .constant (value .values , dtype = tensor_spec .dtype ),
165
164
dense_shape = tf .constant (value .dense_shape , dtype = tf .int64 ))
166
165
elif isinstance (tensor_spec , tf .RaggedTensorSpec ):
167
- # TODO(b/160294509): Switch to public APIs once TF 1 support is dropped.
168
- result = _ragged_value_as_constant (value , tensor_spec ._dtype ) # pylint: disable=protected-access
166
+ result = _ragged_value_as_constant (value , tensor_spec .dtype )
169
167
else :
170
168
result = tf .constant (value , dtype = tensor_spec .dtype )
171
169
result .shape .assert_is_compatible_with (tensor_spec .shape )
@@ -299,10 +297,10 @@ def _assertValuesCloseOrEqual(self, a_value, b_value, msg=None):
299
297
if (isinstance (a_value , (bytes , str )) or isinstance (a_value , list ) and
300
298
a_value and isinstance (a_value [0 ], (bytes , str )) or
301
299
isinstance (a_value , np .ndarray ) and a_value .dtype == object ):
302
- self .assertAllEqual (a_value , b_value )
300
+ self .assertAllEqual (a_value , b_value , msg = msg )
303
301
else :
304
302
# TODO(varshaan): Change atol only for tests for which 1e-6 is too strict.
305
- self .assertAllClose (a_value , b_value , atol = 1e-5 )
303
+ self .assertAllClose (a_value , b_value , atol = 1e-5 , msg = msg )
306
304
307
305
def AssertVocabularyContents (self , vocab_file_path , file_contents ):
308
306
if vocab_file_path .endswith ('.tfrecord.gz' ):
0 commit comments