30
30
import org .tensorflow .proto .framework .DataType ;
31
31
import org .tensorflow .types .TFloat32 ;
32
32
33
+ // FIXME: Since TF 2.10.1, custom gradient registration is failing on Windows, see
34
+ // https://github.com/tensorflow/java/issues/486
33
35
public class CustomGradientTest {
34
36
37
+ @ EnabledOnOs (OS .WINDOWS )
38
+ @ Test
39
+ public void customGradientRegistrationUnsupportedOnWindows () {
40
+ assertThrows (
41
+ UnsupportedOperationException .class ,
42
+ () ->
43
+ TensorFlow .registerCustomGradient (
44
+ NthElement .OP_NAME ,
45
+ (tf , op , gradInputs ) ->
46
+ Arrays .asList (tf .withName ("inAGrad" ).constant (0f ), tf .constant (0f ))));
47
+
48
+ assertThrows (
49
+ UnsupportedOperationException .class ,
50
+ () ->
51
+ TensorFlow .registerCustomGradient (
52
+ NthElement .Inputs .class ,
53
+ (tf , op , gradInputs ) ->
54
+ Arrays .asList (tf .withName ("inAGrad" ).constant (0f ), tf .constant (0f ))));
55
+ }
56
+
57
+ @ DisabledOnOs (OS .WINDOWS )
35
58
@ Test
36
59
public void testAlreadyExisting () {
37
60
assertFalse (
@@ -45,8 +68,6 @@ public void testAlreadyExisting() {
45
68
}));
46
69
}
47
70
48
- // FIXME: Since TF 2.10.1, custom gradient registration is failing on Windows, see
49
- // https://github.com/tensorflow/java/issues/486
50
71
@ DisabledOnOs (OS .WINDOWS )
51
72
@ Test
52
73
public void testCustomGradient () {
@@ -77,26 +98,6 @@ public void testCustomGradient() {
77
98
}
78
99
}
79
100
80
- @ EnabledOnOs (OS .WINDOWS )
81
- @ Test
82
- public void testCustomGradientThrowsOnWindows () {
83
- assertThrows (
84
- UnsupportedOperationException .class ,
85
- () ->
86
- TensorFlow .registerCustomGradient (
87
- NthElement .OP_NAME ,
88
- (tf , op , gradInputs ) ->
89
- Arrays .asList (tf .withName ("inAGrad" ).constant (0f ), tf .constant (0f ))));
90
-
91
- assertThrows (
92
- UnsupportedOperationException .class ,
93
- () ->
94
- TensorFlow .registerCustomGradient (
95
- NthElement .Inputs .class ,
96
- (tf , op , gradInputs ) ->
97
- Arrays .asList (tf .withName ("inAGrad" ).constant (0f ), tf .constant (0f ))));
98
- }
99
-
100
101
private static Output <?>[] toArray (Output <?>... outputs ) {
101
102
return outputs ;
102
103
}
0 commit comments