Skip to content

Commit a447b4b

Browse files
authored
Throw exception on custom gradient registration on Windows (#487)
1 parent dc25607 commit a447b4b

File tree

2 files changed

+42
-3
lines changed

2 files changed

+42
-3
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import com.google.protobuf.InvalidProtocolBufferException;
2727
import java.util.Collections;
2828
import java.util.IdentityHashMap;
29+
import java.util.Locale;
2930
import java.util.Set;
3031
import java.util.stream.Collectors;
3132
import org.bytedeco.javacpp.PointerPointer;
@@ -193,6 +194,9 @@ private static synchronized boolean hasGradient(String opType) {
193194
* <p>Note that this only works with graph gradients, and will eventually be deprecated in favor
194195
* of unified gradient support once it is fully supported by tensorflow core.
195196
*
197+
* <p><i>Warning: Custom gradient registration is currently not supported on Windows, see <a
198+
* href=https://github.com/tensorflow/java/issues/486>GitHub issue</a> for more info.</i>
199+
*
196200
* @param opType the type of op to register the gradient for. Should usually be an {@code OP_NAME}
197201
* field, i.e. {@link Add#OP_NAME}.
198202
* @param gradient the gradient function to use
@@ -201,6 +205,10 @@ private static synchronized boolean hasGradient(String opType) {
201205
*/
202206
public static synchronized boolean registerCustomGradient(
203207
String opType, RawCustomGradient gradient) {
208+
if (isWindowsOs()) {
209+
throw new UnsupportedOperationException(
210+
"Custom gradient registration is not supported on Windows systems.");
211+
}
204212
if (hasGradient(opType)) {
205213
return false;
206214
}
@@ -216,6 +224,9 @@ public static synchronized boolean registerCustomGradient(
216224
* generated op classes or custom op classes with the correct annotations. To operate on the
217225
* {@link org.tensorflow.GraphOperation} directly use {@link RawCustomGradient}.
218226
*
227+
* <p><i>Warning: Custom gradient registration is currently not supported on Windows, see <a
228+
* href=https://github.com/tensorflow/java/issues/486>GitHub issue</a> for more info.</i>
229+
*
219230
* @param inputClass the inputs class of op to register the gradient for.
220231
* @param gradient the gradient function to use
221232
* @return {@code true} if the gradient was registered, {@code false} if there was already a
@@ -225,8 +236,11 @@ public static synchronized boolean registerCustomGradient(
225236
*/
226237
public static synchronized <T extends RawOpInputs<?>> boolean registerCustomGradient(
227238
Class<T> inputClass, CustomGradient<T> gradient) {
239+
if (isWindowsOs()) {
240+
throw new UnsupportedOperationException(
241+
"Custom gradient registration is not supported on Windows systems.");
242+
}
228243
OpInputsMetadata metadata = inputClass.getAnnotation(OpInputsMetadata.class);
229-
230244
if (metadata == null) {
231245
throw new IllegalArgumentException(
232246
"Inputs Class "
@@ -253,4 +267,8 @@ public static synchronized <T extends RawOpInputs<?>> boolean registerCustomGrad
253267
gradientFuncs.add(g);
254268
return true;
255269
}
270+
271+
private static boolean isWindowsOs() {
272+
return System.getProperty("os.name", "").toLowerCase(Locale.ENGLISH).startsWith("win");
273+
}
256274
}

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientTest.java

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.util.Arrays;
2222
import org.junit.jupiter.api.Test;
2323
import org.junit.jupiter.api.condition.DisabledOnOs;
24+
import org.junit.jupiter.api.condition.EnabledOnOs;
2425
import org.junit.jupiter.api.condition.OS;
2526
import org.tensorflow.ndarray.index.Indices;
2627
import org.tensorflow.op.Ops;
@@ -44,8 +45,8 @@ public void testAlreadyExisting() {
4445
}));
4546
}
4647

47-
// FIXME: Since TF 2.10.1, this test is failing on Windows, because the whole JVM crashes when
48-
// calling the JavaCPP generated binding `NameMap.erase`. Disable it until we find a fix.
48+
// FIXME: Since TF 2.10.1, custom gradient registration is failing on Windows, see
49+
// https://github.com/tensorflow/java/issues/486
4950
@DisabledOnOs(OS.WINDOWS)
5051
@Test
5152
public void testCustomGradient() {
@@ -76,6 +77,26 @@ public void testCustomGradient() {
7677
}
7778
}
7879

80+
@EnabledOnOs(OS.WINDOWS)
81+
@Test
82+
public void testCustomGradientThrowsOnWindows() {
83+
assertThrows(
84+
UnsupportedOperationException.class,
85+
() ->
86+
TensorFlow.registerCustomGradient(
87+
NthElement.OP_NAME,
88+
(tf, op, gradInputs) ->
89+
Arrays.asList(tf.withName("inAGrad").constant(0f), tf.constant(0f))));
90+
91+
assertThrows(
92+
UnsupportedOperationException.class,
93+
() ->
94+
TensorFlow.registerCustomGradient(
95+
NthElement.Inputs.class,
96+
(tf, op, gradInputs) ->
97+
Arrays.asList(tf.withName("inAGrad").constant(0f), tf.constant(0f))));
98+
}
99+
79100
private static Output<?>[] toArray(Output<?>... outputs) {
80101
return outputs;
81102
}

0 commit comments

Comments
 (0)