Skip to content

tf.keras.losses on RaggedTensors crash during gradient computation on a GPU #638

Open
@foxik

Description

@foxik

System information.

  • Have I written custom code (as opposed to using a stock example script provided in Keras): yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Debian Stable
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): TF 2.8
  • Python version: 3.7

Describe the problem.

When some loss (tf.losses.SparseCategoricalCrossentropy, tf.losses.CategoricalCrossentropy, tf.losses.BinaryCrossentropy, or tf.losses.MeanSquaredError) is used on Ragged tensors, that the gradient computation on a GPU crashes with

Node: 'Adam/gradients/zeros_like_2'
2 root error(s) found.
  (0) INTERNAL:  No unary variant unary_op function found for op ZEROS_LIKE Variant type_name: RaggedTensorVariant for device type: GPU
	 [[{{node Adam/gradients/zeros_like_2}}]]
	 [[binary_crossentropy/map/while/loop_body_control/_124/_67]]
  (1) INTERNAL:  No unary variant unary_op function found for op ZEROS_LIKE Variant type_name: RaggedTensorVariant for device type: GPU
	 [[{{node Adam/gradients/zeros_like_2}}]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_16690]

Describe the current behavior.

The code crashes on a GPU. It does not crash on a CPU and it does not crash when tf.functions are executed eagerly.

Describe the expected behavior.

The code should not crash.

Standalone code to reproduce the issue.

A simple Colab reproducing the error is here: https://colab.research.google.com/drive/1OELAhvpQHhaz3sOYabf4SdBqKlQCjNjs?usp=sharing

Source code / logs.

The problem is somehow connected to the usage of ragged map in here: https://github.com/keras-team/keras/blob/2db5acf3e3c5904b014cb409d3c514bef44f9640/keras/losses.py#L1408 . My guess is that a TensorArray of ragged arrays is created and some operation for manipulating it on GPU is missing.

  • When avoiding the ragged map by using for example loss=lambda yt, yp: tf.losses.BinaryCrossentropy()(yt. values, yp.values), the problem does not appear and the computation works.

Note that metrics with ragged tensors work fine; but they take a different approach, and instead of a ragged map, they use flat_values, see https://github.com/keras-team/keras/blob/2db5acf3e3c5904b014cb409d3c514bef44f9640/keras/utils/metrics_utils.py#L800 .

Possible courses of action

  1. the ragged map might be fixed on TensorFlow side
  2. we might avoid using the ragged map, and use .flat_values instead, similarly to what the metrics do

Personally I like 2. more, because the problem at hand can be fixed by a "simple" solution.

Metadata

Metadata

Assignees

Labels

duplicateThis issue or pull request already exists

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions