Skip to content

Commit

Permalink
Fix HardShrink activation (#505)
Browse files Browse the repository at this point in the history
Build a graph of operations in the apply function instead of incorrectly trying to perform these operations.
  • Loading branch information
juliabeliaeva authored Jan 2, 2023
1 parent b528133 commit d4987ea
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,12 @@ public class MishActivation : Activation {
}

/**
* ```
* hardshrink(x) = x if x < lower
* x if x > upper
* 0 otherwise
* ```
*
* @property [lower] lower bound for setting values to zeros
* @property [upper] upper bound for setting values to zeros
*
Expand All @@ -498,12 +504,9 @@ public class HardShrinkActivation(public val lower: Float = -0.5f, public val up
require(lower < upper) {
"The value of lower should not be higher than upper"
}
val maskLower = tf.math.minimum(features, tf.constant(lower)) != tf.constant(lower)
val maskUpper = tf.math.maximum(features, tf.constant(upper)) != tf.constant(upper)
return when (maskLower || maskUpper) {
false -> tf.constant(0) as Operand<Float>
true -> features
}
val maskLower = tf.math.less(features, tf.constant(lower))
val maskUpper = tf.math.greater(features, tf.constant(upper))
return tf.where3(tf.math.logicalOr(maskLower, maskUpper), features, tf.zerosLike(features))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,23 @@ import org.junit.jupiter.api.Test

class HardShrinkActivationTest : ActivationTest() {
@Test
fun apply() {
val input = floatArrayOf(1f, 0f, 1f)
val expected = floatArrayOf(1f, 0f, 1f)
val actual = floatArrayOf(0f, 0f, 0f)
assertActivationFunction(HardShrinkActivation(-0.5f, 0.5f), input, actual, expected)
fun mask() {
val input = floatArrayOf(1f, 0f, -1f)
val expected = floatArrayOf(1f, 0f, -1f)
assertActivationFunction(HardShrinkActivation(-0.5f, 0.5f), input, expected)
}

@Test
fun setToZero() {
val input = floatArrayOf(0.239f, 0.01f, -0.239f)
val expected = floatArrayOf(0f, 0f, 0f)
assertActivationFunction(HardShrinkActivation(-0.5f, 0.5f), input, expected)
}

@Test
fun mixed() {
val input = floatArrayOf(0.239f, -5f, -10f, 0.3f, -0.5f, 239f, 0.7f, -100f, -0.4f)
val expected = floatArrayOf(0f, -5f, -10f, 0f, -0.5f, 239f, 0f, -100f, -0f)
assertActivationFunction(HardShrinkActivation(-0.4f, 0.7f), input, expected)
}
}

0 comments on commit d4987ea

Please sign in to comment.