From d4987ead2af186cf1ab11a18a563da51c97a8cb4 Mon Sep 17 00:00:00 2001 From: Julia Beliaeva Date: Mon, 2 Jan 2023 15:31:39 +0100 Subject: [PATCH] Fix HardShrink activation (#505) Build a graph of operations in the apply function instead of incorrectly trying to perform these operations. --- .../dl/api/core/activation/Activations.kt | 15 +++++++----- .../activation/HardShrinkActivationTest.kt | 23 +++++++++++++++---- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/activation/Activations.kt b/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/activation/Activations.kt index 9fdd3ac11..d434a6664 100644 --- a/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/activation/Activations.kt +++ b/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/core/activation/Activations.kt @@ -488,6 +488,12 @@ public class MishActivation : Activation { } /** + * ``` + * hardshrink(x) = x if x < lower + * x if x > upper + * 0 otherwise + * ``` + * * @property [lower] lower bound for setting values to zeros * @property [upper] upper bound for setting values to zeros * @@ -498,12 +504,9 @@ public class HardShrinkActivation(public val lower: Float = -0.5f, public val up require(lower < upper) { "The value of lower should not be higher than upper" } - val maskLower = tf.math.minimum(features, tf.constant(lower)) != tf.constant(lower) - val maskUpper = tf.math.maximum(features, tf.constant(upper)) != tf.constant(upper) - return when (maskLower || maskUpper) { - false -> tf.constant(0) as Operand - true -> features - } + val maskLower = tf.math.less(features, tf.constant(lower)) + val maskUpper = tf.math.greater(features, tf.constant(upper)) + return tf.where3(tf.math.logicalOr(maskLower, maskUpper), features, tf.zerosLike(features)) } } diff --git a/tensorflow/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/activation/HardShrinkActivationTest.kt b/tensorflow/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/activation/HardShrinkActivationTest.kt index 75d1140db..27c1a50cb 100644 --- a/tensorflow/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/activation/HardShrinkActivationTest.kt +++ b/tensorflow/src/test/kotlin/org/jetbrains/kotlinx/dl/api/core/activation/HardShrinkActivationTest.kt @@ -4,10 +4,23 @@ import org.junit.jupiter.api.Test class HardShrinkActivationTest : ActivationTest() { @Test - fun apply() { - val input = floatArrayOf(1f, 0f, 1f) - val expected = floatArrayOf(1f, 0f, 1f) - val actual = floatArrayOf(0f, 0f, 0f) - assertActivationFunction(HardShrinkActivation(-0.5f, 0.5f), input, actual, expected) + fun mask() { + val input = floatArrayOf(1f, 0f, -1f) + val expected = floatArrayOf(1f, 0f, -1f) + assertActivationFunction(HardShrinkActivation(-0.5f, 0.5f), input, expected) + } + + @Test + fun setToZero() { + val input = floatArrayOf(0.239f, 0.01f, -0.239f) + val expected = floatArrayOf(0f, 0f, 0f) + assertActivationFunction(HardShrinkActivation(-0.5f, 0.5f), input, expected) + } + + @Test + fun mixed() { + val input = floatArrayOf(0.239f, -5f, -10f, 0.3f, -0.5f, 239f, 0.7f, -100f, -0.4f) + val expected = floatArrayOf(0f, -5f, -10f, 0f, -0.5f, 239f, 0f, -100f, -0f) + assertActivationFunction(HardShrinkActivation(-0.4f, 0.7f), input, expected) } }