Skip to content

Commit 60e5473

Browse files
authored
Disable all custom gradient tests on Windows (#488)
1 parent a447b4b commit 60e5473

File tree

1 file changed

+23
-22
lines changed

1 file changed

+23
-22
lines changed

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientTest.java

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,31 @@
3030
import org.tensorflow.proto.framework.DataType;
3131
import org.tensorflow.types.TFloat32;
3232

33+
// FIXME: Since TF 2.10.1, custom gradient registration is failing on Windows, see
34+
// https://github.com/tensorflow/java/issues/486
3335
public class CustomGradientTest {
3436

37+
@EnabledOnOs(OS.WINDOWS)
38+
@Test
39+
public void customGradientRegistrationUnsupportedOnWindows() {
40+
assertThrows(
41+
UnsupportedOperationException.class,
42+
() ->
43+
TensorFlow.registerCustomGradient(
44+
NthElement.OP_NAME,
45+
(tf, op, gradInputs) ->
46+
Arrays.asList(tf.withName("inAGrad").constant(0f), tf.constant(0f))));
47+
48+
assertThrows(
49+
UnsupportedOperationException.class,
50+
() ->
51+
TensorFlow.registerCustomGradient(
52+
NthElement.Inputs.class,
53+
(tf, op, gradInputs) ->
54+
Arrays.asList(tf.withName("inAGrad").constant(0f), tf.constant(0f))));
55+
}
56+
57+
@DisabledOnOs(OS.WINDOWS)
3558
@Test
3659
public void testAlreadyExisting() {
3760
assertFalse(
@@ -45,8 +68,6 @@ public void testAlreadyExisting() {
4568
}));
4669
}
4770

48-
// FIXME: Since TF 2.10.1, custom gradient registration is failing on Windows, see
49-
// https://github.com/tensorflow/java/issues/486
5071
@DisabledOnOs(OS.WINDOWS)
5172
@Test
5273
public void testCustomGradient() {
@@ -77,26 +98,6 @@ public void testCustomGradient() {
7798
}
7899
}
79100

80-
@EnabledOnOs(OS.WINDOWS)
81-
@Test
82-
public void testCustomGradientThrowsOnWindows() {
83-
assertThrows(
84-
UnsupportedOperationException.class,
85-
() ->
86-
TensorFlow.registerCustomGradient(
87-
NthElement.OP_NAME,
88-
(tf, op, gradInputs) ->
89-
Arrays.asList(tf.withName("inAGrad").constant(0f), tf.constant(0f))));
90-
91-
assertThrows(
92-
UnsupportedOperationException.class,
93-
() ->
94-
TensorFlow.registerCustomGradient(
95-
NthElement.Inputs.class,
96-
(tf, op, gradInputs) ->
97-
Arrays.asList(tf.withName("inAGrad").constant(0f), tf.constant(0f))));
98-
}
99-
100101
private static Output<?>[] toArray(Output<?>... outputs) {
101102
return outputs;
102103
}

0 commit comments

Comments
 (0)