Description
System information.
- Have I written custom code (as opposed to using a stock example script provided in Keras): yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Debian Stable
- TensorFlow installed from (source or binary): binary
- TensorFlow version (use command below): TF 2.8
- Python version: 3.7
Describe the problem.
When some loss (tf.losses.SparseCategoricalCrossentropy
, tf.losses.CategoricalCrossentropy
, tf.losses.BinaryCrossentropy
, or tf.losses.MeanSquaredError
) is used on Ragged tensors, that the gradient computation on a GPU crashes with
Node: 'Adam/gradients/zeros_like_2'
2 root error(s) found.
(0) INTERNAL: No unary variant unary_op function found for op ZEROS_LIKE Variant type_name: RaggedTensorVariant for device type: GPU
[[{{node Adam/gradients/zeros_like_2}}]]
[[binary_crossentropy/map/while/loop_body_control/_124/_67]]
(1) INTERNAL: No unary variant unary_op function found for op ZEROS_LIKE Variant type_name: RaggedTensorVariant for device type: GPU
[[{{node Adam/gradients/zeros_like_2}}]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_16690]
Describe the current behavior.
The code crashes on a GPU. It does not crash on a CPU and it does not crash when tf.function
s are executed eagerly.
Describe the expected behavior.
The code should not crash.
Standalone code to reproduce the issue.
A simple Colab reproducing the error is here: https://colab.research.google.com/drive/1OELAhvpQHhaz3sOYabf4SdBqKlQCjNjs?usp=sharing
Source code / logs.
The problem is somehow connected to the usage of ragged map in here: https://github.com/keras-team/keras/blob/2db5acf3e3c5904b014cb409d3c514bef44f9640/keras/losses.py#L1408 . My guess is that a TensorArray of ragged arrays is created and some operation for manipulating it on GPU is missing.
- When avoiding the ragged map by using for example
loss=lambda yt, yp: tf.losses.BinaryCrossentropy()(yt. values, yp.values)
, the problem does not appear and the computation works.
Note that metrics with ragged tensors work fine; but they take a different approach, and instead of a ragged map, they use flat_values
, see https://github.com/keras-team/keras/blob/2db5acf3e3c5904b014cb409d3c514bef44f9640/keras/utils/metrics_utils.py#L800 .
Possible courses of action
- the ragged map might be fixed on TensorFlow side
- we might avoid using the ragged map, and use .flat_values instead, similarly to what the metrics do
Personally I like 2. more, because the problem at hand can be fixed by a "simple" solution.