Open
Description
TensorFlow version
2.18.0
You can obtain the TensorFlow version with:
python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"
Describe the problem.
API documentation of tf.keras.losses.SparseCategoricalCrossentropy mentions that one of the parameters can be None, but the implementation does not check None, it checks 'none' which is a string.
(directed here from tensorflow/tensorflow#89246)
Describe the current behavior.
Throwing a value error for not using 'none' instead of None.
Describe the expected behavior.
It should not raise this error as per the API Doc.
- Do you want to contribute a PR? (yes/no): No
- If yes, please read this page for instructions
- Briefly describe your candidate solution(if contributing):
Standalone code to reproduce the issue.
def test_reduction_none(self):
# Test with reduction set to None
y_true = np.array([0, 1, 2])
y_pred = np.array([[0.9, 0.05, 0.05],
[0.05, 0.9, 0.05],
[0.05, 0.05, 0.9]])
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(reduction=None)
loss = loss_fn(y_true, y_pred).numpy()
expected_loss = -np.log([0.9, 0.9, 0.9])
np.testing.assert_almost_equal(loss, expected_loss, decimal=5)
Source code / logs.
Traceback (most recent call last):
File "/home/user/projects/api_guided_testgen/out/bug_detect_gpt4o/exec/basic_rag_apidoc/tf/tf.keras.losses.SparseCategoricalCrossentropy.py", line 49, in test_reduction_none
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(reduction=None)
File "/home/user/anaconda3/lib/python3.8/site-packages/keras/losses.py", line 1026, in __init__
super().__init__(
File "/home/user/anaconda3/lib/python3.8/site-packages/keras/losses.py", line 262, in __init__
super().__init__(reduction=reduction, name=name)
File "/home/user/anaconda3/lib/python3.8/site-packages/keras/losses.py", line 93, in __init__
losses_utils.ReductionV2.validate(reduction)
File "/home/user/anaconda3/lib/python3.8/site-packages/keras/utils/losses_utils.py", line 88, in validate
raise ValueError(
ValueError: Invalid Reduction Key: None. Expected keys are "('auto', 'none', 'sum', 'sum_over_batch_size')"