From 004f5577f4ceb2c7b6c1005cda2974d66a1a3ce4 Mon Sep 17 00:00:00 2001 From: Kameron Rodrigues Date: Mon, 15 Jan 2024 14:17:25 -0800 Subject: [PATCH] =?UTF-8?q?Critical=20typo:=20squeeze=20the=20y=5Fpred=20t?= =?UTF-8?q?ensor=20even=20when=20it=E2=80=99s=20the=20same=20rank=20and=20?= =?UTF-8?q?shape=20as=20the=20y=5Fpred=20tensor=20should=20be=20=3D=3D=20i?= =?UTF-8?q?nstead=20of=20!=3D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Current behavior will squeeze the y_pred tensor even when it’s the same rank and shape as the y_pred tensor. The issue is that the code is written wrong. Read the code comments and what it’s supposed to do (only squeeze when the ranks differ by exactly 1 and not squeeze for situations of equal rank). Yet the code’s if statements will squeeze whenever the ranks are equal, which is wrong from what I can tell. the line: if (y_pred_rank - y_true_rank != 1) or y_pred_shape[-1] == 1: should be: if (y_pred_rank - y_true_rank == 1) or y_pred_shape[-1] == 1: --- tf_keras/utils/losses_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tf_keras/utils/losses_utils.py b/tf_keras/utils/losses_utils.py index 49d5e1cc7..1a7d100f2 100644 --- a/tf_keras/utils/losses_utils.py +++ b/tf_keras/utils/losses_utils.py @@ -195,7 +195,7 @@ def squeeze_or_expand_dimensions(y_pred, y_true=None, sample_weight=None): y_true_rank = y_true_shape.ndims if (y_true_rank is not None) and (y_pred_rank is not None): # Use static rank for `y_true` and `y_pred`. - if (y_pred_rank - y_true_rank != 1) or y_pred_shape[-1] == 1: + if (y_pred_rank - y_true_rank == 1) or y_pred_shape[-1] == 1: y_true, y_pred = remove_squeezable_dimensions(y_true, y_pred) else: # Use dynamic rank.