26
26
import com .google .protobuf .InvalidProtocolBufferException ;
27
27
import java .util .Collections ;
28
28
import java .util .IdentityHashMap ;
29
+ import java .util .Locale ;
29
30
import java .util .Set ;
30
31
import java .util .stream .Collectors ;
31
32
import org .bytedeco .javacpp .PointerPointer ;
@@ -193,6 +194,9 @@ private static synchronized boolean hasGradient(String opType) {
193
194
* <p>Note that this only works with graph gradients, and will eventually be deprecated in favor
194
195
* of unified gradient support once it is fully supported by tensorflow core.
195
196
*
197
+ * <p><i>Warning: Custom gradient registration is currently not supported on Windows, see <a
198
+ * href=https://github.com/tensorflow/java/issues/486>GitHub issue</a> for more info.</i>
199
+ *
196
200
* @param opType the type of op to register the gradient for. Should usually be an {@code OP_NAME}
197
201
* field, i.e. {@link Add#OP_NAME}.
198
202
* @param gradient the gradient function to use
@@ -201,6 +205,10 @@ private static synchronized boolean hasGradient(String opType) {
201
205
*/
202
206
public static synchronized boolean registerCustomGradient (
203
207
String opType , RawCustomGradient gradient ) {
208
+ if (isWindowsOs ()) {
209
+ throw new UnsupportedOperationException (
210
+ "Custom gradient registration is not supported on Windows systems." );
211
+ }
204
212
if (hasGradient (opType )) {
205
213
return false ;
206
214
}
@@ -216,6 +224,9 @@ public static synchronized boolean registerCustomGradient(
216
224
* generated op classes or custom op classes with the correct annotations. To operate on the
217
225
* {@link org.tensorflow.GraphOperation} directly use {@link RawCustomGradient}.
218
226
*
227
+ * <p><i>Warning: Custom gradient registration is currently not supported on Windows, see <a
228
+ * href=https://github.com/tensorflow/java/issues/486>GitHub issue</a> for more info.</i>
229
+ *
219
230
* @param inputClass the inputs class of op to register the gradient for.
220
231
* @param gradient the gradient function to use
221
232
* @return {@code true} if the gradient was registered, {@code false} if there was already a
@@ -225,8 +236,11 @@ public static synchronized boolean registerCustomGradient(
225
236
*/
226
237
public static synchronized <T extends RawOpInputs <?>> boolean registerCustomGradient (
227
238
Class <T > inputClass , CustomGradient <T > gradient ) {
239
+ if (isWindowsOs ()) {
240
+ throw new UnsupportedOperationException (
241
+ "Custom gradient registration is not supported on Windows systems." );
242
+ }
228
243
OpInputsMetadata metadata = inputClass .getAnnotation (OpInputsMetadata .class );
229
-
230
244
if (metadata == null ) {
231
245
throw new IllegalArgumentException (
232
246
"Inputs Class "
@@ -253,4 +267,8 @@ public static synchronized <T extends RawOpInputs<?>> boolean registerCustomGrad
253
267
gradientFuncs .add (g );
254
268
return true ;
255
269
}
270
+
271
+ private static boolean isWindowsOs () {
272
+ return System .getProperty ("os.name" , "" ).toLowerCase (Locale .ENGLISH ).startsWith ("win" );
273
+ }
256
274
}
0 commit comments