Skip to content

Commit 7ed84ad

Browse files
Add ragged tensor input support to Keras via the 'ragged' arg in keras.Input.
PiperOrigin-RevId: 254256357
1 parent 4ab0dc8 commit 7ed84ad

16 files changed

+406
-329
lines changed

tensorflow/python/keras/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -1168,6 +1168,7 @@ tf_py_test(
11681168
"//tensorflow/python:sparse_ops",
11691169
"//tensorflow/python:sparse_tensor",
11701170
],
1171+
shard_count = 4,
11711172
)
11721173

11731174
tf_py_test(

tensorflow/python/keras/backend.py

+39-68
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from tensorflow.python.distribute import distribution_strategy_context
3939
from tensorflow.python.distribute import multi_worker_util
4040
from tensorflow.python.eager import context
41-
from tensorflow.python.framework import composite_tensor_utils
4241
from tensorflow.python.eager import function as eager_function
4342
from tensorflow.python.eager import lift_to_graph
4443
from tensorflow.python.framework import composite_tensor
@@ -70,6 +69,7 @@
7069
from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
7170
from tensorflow.python.ops import tensor_array_ops
7271
from tensorflow.python.ops import variables as variables_module
72+
from tensorflow.python.ops.ragged import ragged_factory_ops
7373
from tensorflow.python.platform import tf_logging as logging
7474
from tensorflow.python.util import nest
7575
from tensorflow.python.util import tf_contextlib
@@ -958,7 +958,12 @@ def is_keras_tensor(x):
958958

959959

960960
@keras_export('keras.backend.placeholder')
961-
def placeholder(shape=None, ndim=None, dtype=None, sparse=False, name=None):
961+
def placeholder(shape=None,
962+
ndim=None,
963+
dtype=None,
964+
sparse=False,
965+
name=None,
966+
ragged=False):
962967
"""Instantiates a placeholder tensor and returns it.
963968
964969
Arguments:
@@ -970,9 +975,14 @@ def placeholder(shape=None, ndim=None, dtype=None, sparse=False, name=None):
970975
dtype: Placeholder type.
971976
sparse: Boolean, whether the placeholder should have a sparse type.
972977
name: Optional name string for the placeholder.
978+
ragged: Boolean, whether the placeholder should have a ragged type.
979+
In this case, values of 'None' in the 'shape' argument represent
980+
ragged dimensions. For more information about RaggedTensors, see this
981+
[guide](https://www.tensorflow.org/guide/ragged_tensors).
973982
974983
Raises:
975-
ValueError: If called with eager execution.
984+
ValueError: If called with eager execution
985+
ValueError: If called with sparse = True and ragged = True.
976986
977987
Returns:
978988
Tensor instance (with Keras metadata included).
@@ -985,6 +995,11 @@ def placeholder(shape=None, ndim=None, dtype=None, sparse=False, name=None):
985995
<tf.Tensor 'Placeholder_4:0' shape=(2, 4, 5) dtype=float32>
986996
```
987997
"""
998+
if sparse and ragged:
999+
raise ValueError(
1000+
'Cannot set both sparse and ragged to True when creating a placeholder.'
1001+
)
1002+
9881003
if dtype is None:
9891004
dtype = floatx()
9901005
if not shape:
@@ -993,6 +1008,20 @@ def placeholder(shape=None, ndim=None, dtype=None, sparse=False, name=None):
9931008
with get_graph().as_default():
9941009
if sparse:
9951010
x = array_ops.sparse_placeholder(dtype, shape=shape, name=name)
1011+
elif ragged:
1012+
ragged_rank = 0
1013+
for i in range(1, len(shape)):
1014+
if shape[i] is None:
1015+
ragged_rank += 1
1016+
else:
1017+
break
1018+
value_shape = shape[(ragged_rank + 1):]
1019+
1020+
x = ragged_factory_ops.placeholder(
1021+
dtype=dtype,
1022+
ragged_rank=ragged_rank,
1023+
value_shape=value_shape,
1024+
name=name)
9961025
else:
9971026
x = array_ops.placeholder(dtype, shape=shape, name=name)
9981027
return x
@@ -1008,7 +1037,11 @@ def is_placeholder(x):
10081037
Boolean.
10091038
"""
10101039
try:
1011-
return x.op.type == 'Placeholder'
1040+
if isinstance(x, composite_tensor.CompositeTensor):
1041+
flat_components = nest.flatten(x, expand_composites=True)
1042+
return py_any(is_placeholder(c) for c in flat_components)
1043+
else:
1044+
return x.op.type == 'Placeholder'
10121045
except AttributeError:
10131046
return False
10141047

@@ -3108,63 +3141,6 @@ def print_tensor(x, message=''):
31083141
logging_ops.print_v2(message, x, output_stream=sys.stdout)
31093142
return x
31103143

3111-
3112-
def is_tensor_or_composite_tensor(value):
3113-
"""Test if a passed value object is a tensor-like or composite tensor."""
3114-
return (tensor_util.is_tensor(value) or isinstance(value, np.ndarray) or
3115-
composite_tensor_utils.is_composite_or_composite_value(value))
3116-
3117-
3118-
def _try_process_scipy_sparse_input(value):
3119-
"""Converts 'value' to a SparseTensor if it is a scipy sparse matrix.
3120-
3121-
Arguments:
3122-
value: An object that may have the attributes of a scipy sparse matrix.
3123-
3124-
Returns:
3125-
Either a SparseTensor based off of 'value' or 'value' itself.
3126-
"""
3127-
try:
3128-
sparse_coo = value.tocoo()
3129-
row, col = sparse_coo.row, sparse_coo.col
3130-
data, shape = sparse_coo.data, sparse_coo.shape
3131-
except AttributeError:
3132-
# If we can't convert this object, it could be either a single data
3133-
# element (ie, a bool/int/float) which is OK to pass on, or something
3134-
# that we don't understand (which may or may not be OK). In either
3135-
# case, don't die here: the data standardization code will catch
3136-
# those issues.
3137-
return value
3138-
3139-
indices = np.concatenate((np.expand_dims(row, 1), np.expand_dims(col, 1)), 1)
3140-
return sparse_tensor.SparseTensor(indices, data, shape)
3141-
3142-
3143-
def try_convert_scipy_to_sparse(values):
3144-
"""Converts scipy sparse matrices in 'values' to SparseTensors, if possible.
3145-
3146-
Arguments:
3147-
values: An input or list of inputs to convert. These may be TensorLikes,
3148-
ndarrays, composite tensors, or scipy sparse values.
3149-
3150-
Returns:
3151-
An input or list of inputs where scipy sparse tensors have been converted
3152-
to tf.SparseTensors.
3153-
3154-
Raises:
3155-
ValueError: If input cannot be converted to a SparseTensor.
3156-
"""
3157-
# Convert scipy sparse data into sparse tensors.
3158-
value_structure = values
3159-
values = nest.flatten(values)
3160-
for idx, value in enumerate(values):
3161-
if not is_tensor_or_composite_tensor(value):
3162-
values[idx] = _try_process_scipy_sparse_input(value)
3163-
values = nest.pack_sequence_as(value_structure, values)
3164-
3165-
return values
3166-
3167-
31683144
# GRAPH MANIPULATION
31693145

31703146

@@ -3194,6 +3170,7 @@ def __init__(self, inputs, outputs, updates=None, name=None,
31943170
if not isinstance(updates, (list, tuple)):
31953171
raise TypeError('`updates` in a Keras backend function '
31963172
'should be a list or tuple.')
3173+
31973174
self._inputs_structure = inputs
31983175
self.inputs = nest.flatten(inputs, expand_composites=True)
31993176
self._outputs_structure = outputs
@@ -3311,10 +3288,6 @@ def _eval_if_composite(self, tensor):
33113288
return tensor
33123289

33133290
def __call__(self, inputs):
3314-
inputs = try_convert_scipy_to_sparse(inputs)
3315-
3316-
# Ensure that input value types match any expected composite tensor types.
3317-
# TODO(momernick): Once TensorSpecs are implemented for CTs, use that here.
33183291
inputs = nest.flatten(inputs, expand_composites=True)
33193292

33203293
session = get_session(inputs)
@@ -3488,10 +3461,8 @@ def __init__(self, inputs, outputs, updates=None, name=None):
34883461
x.op.inputs[0])
34893462

34903463
def __call__(self, inputs):
3491-
# Convert scipy sparse data into sparse tensors.
3492-
inputs = try_convert_scipy_to_sparse(inputs)
3493-
34943464
input_values = nest.flatten(inputs, expand_composites=True)
3465+
34953466
if self._freezable_vars_values:
34963467
input_values = input_values + self._freezable_vars_values
34973468
converted_inputs = []

tensorflow/python/keras/engine/input_layer.py

+33-18
Original file line numberDiff line numberDiff line change
@@ -34,21 +34,27 @@ class InputLayer(base_layer.Layer):
3434
"""Layer to be used as an entry point into a Network (a graph of layers).
3535
3636
It can either wrap an existing tensor (pass an `input_tensor` argument)
37-
or create its a placeholder tensor (pass arguments `input_shape`, and
37+
or create a placeholder tensor (pass arguments `input_shape`, and
3838
optionally, `dtype`).
3939
4040
It is generally recommend to use the functional layer API via `Input`,
4141
(which creates an `InputLayer`) without directly using `InputLayer`.
4242
43+
This class can create placeholders for tf.Tensors, tf.SparseTensors, and
44+
tf.RaggedTensors by choosing 'sparse=True' or 'ragged=True'.
45+
4346
Arguments:
4447
input_shape: Shape tuple (not including the batch axis), or `TensorShape`
4548
instance (not including the batch axis).
4649
batch_size: Optional input batch size (integer or None).
4750
dtype: Datatype of the input.
4851
input_tensor: Optional tensor to use as layer input
4952
instead of creating a placeholder.
50-
sparse: Boolean, whether the placeholder created
51-
is meant to be sparse.
53+
sparse: Boolean, whether the placeholder created is meant to be sparse.
54+
ragged: Boolean, whether the placeholder created is meant to be ragged.
55+
In this case, values of 'None' in the 'shape' argument represent
56+
ragged dimensions. For more information about RaggedTensors, see
57+
https://www.tensorflow.org/guide/ragged_tensors.
5258
name: Name of the layer (string).
5359
"""
5460

@@ -59,6 +65,7 @@ def __init__(self,
5965
input_tensor=None,
6066
sparse=False,
6167
name=None,
68+
ragged=False,
6269
**kwargs):
6370
strategy = distribution_strategy_context.get_strategy()
6471
if strategy and batch_size is not None and \
@@ -110,18 +117,12 @@ def __init__(self,
110117
batch_input_shape = None
111118
graph = backend.get_graph()
112119
with graph.as_default():
113-
# In graph mode, create a graph placeholder to call the layer on.
114-
if sparse:
115-
input_tensor = backend.placeholder(
116-
shape=batch_input_shape,
117-
dtype=dtype,
118-
name=self.name,
119-
sparse=True)
120-
else:
121-
input_tensor = backend.placeholder(
122-
shape=batch_input_shape,
123-
dtype=dtype,
124-
name=self.name)
120+
input_tensor = backend.placeholder(
121+
shape=batch_input_shape,
122+
dtype=dtype,
123+
name=self.name,
124+
sparse=sparse,
125+
ragged=ragged)
125126

126127
self.is_placeholder = True
127128
self._batch_input_shape = batch_input_shape
@@ -164,6 +165,7 @@ def Input( # pylint: disable=invalid-name
164165
dtype=None,
165166
sparse=False,
166167
tensor=None,
168+
ragged=False,
167169
**kwargs):
168170
"""`Input()` is used to instantiate a Keras tensor.
169171
@@ -184,17 +186,24 @@ def Input( # pylint: disable=invalid-name
184186
Arguments:
185187
shape: A shape tuple (integers), not including the batch size.
186188
For instance, `shape=(32,)` indicates that the expected input
187-
will be batches of 32-dimensional vectors.
189+
will be batches of 32-dimensional vectors. Elements of this tuple
190+
can be None; 'None' elements represent dimensions where the shape is
191+
not known.
188192
batch_size: optional static batch size (integer).
189193
name: An optional name string for the layer.
190194
Should be unique in a model (do not reuse the same name twice).
191195
It will be autogenerated if it isn't provided.
192196
dtype: The data type expected by the input, as a string
193197
(`float32`, `float64`, `int32`...)
194-
sparse: A boolean specifying whether the placeholder
195-
to be created is sparse.
198+
sparse: A boolean specifying whether the placeholder to be created is
199+
sparse. Only one of 'ragged' and 'sparse' can be True.
196200
tensor: Optional existing tensor to wrap into the `Input` layer.
197201
If set, the layer will not create a placeholder tensor.
202+
ragged: A boolean specifying whether the placeholder to be created is
203+
ragged. Only one of 'ragged' and 'sparse' can be True. In this case,
204+
values of 'None' in the 'shape' argument represent ragged dimensions.
205+
For more information about RaggedTensors, see
206+
https://www.tensorflow.org/guide/ragged_tensors.
198207
**kwargs: deprecated arguments support.
199208
200209
Returns:
@@ -222,6 +231,10 @@ def Input( # pylint: disable=invalid-name
222231
Raises:
223232
ValueError: in case of invalid arguments.
224233
"""
234+
if sparse and ragged:
235+
raise ValueError(
236+
'Cannot set both sparse and ragged to True in a Keras input.')
237+
225238
batch_shape = None
226239
if 'batch_shape' in kwargs:
227240
batch_shape = kwargs.pop('batch_shape')
@@ -246,6 +259,7 @@ def Input( # pylint: disable=invalid-name
246259
name=name,
247260
dtype=dtype,
248261
sparse=sparse,
262+
ragged=ragged,
249263
input_tensor=tensor)
250264
else:
251265
input_layer = InputLayer(
@@ -254,6 +268,7 @@ def Input( # pylint: disable=invalid-name
254268
name=name,
255269
dtype=dtype,
256270
sparse=sparse,
271+
ragged=ragged,
257272
input_tensor=tensor)
258273

259274
# Return tensor including `_keras_history`.

0 commit comments

Comments
 (0)