3939import org .tensorflow .op .CustomGradient ;
4040import org .tensorflow .op .RawCustomGradient ;
4141import org .tensorflow .op .RawOpInputs ;
42- import org .tensorflow .op .annotation .GeneratedOpInputsMetadata ;
43- import org .tensorflow .op .annotation .GeneratedOpMetadata ;
42+ import org .tensorflow .op .annotation .OpInputsMetadata ;
43+ import org .tensorflow .op .annotation .OpMetadata ;
4444import org .tensorflow .op .math .Add ;
4545import org .tensorflow .proto .framework .OpList ;
4646
@@ -187,35 +187,34 @@ public static synchronized boolean registerCustomGradient(
187187 }
188188
189189 /**
190- * Register a custom gradient function for ops of {@code opClass} type. The actual op type is
191- * detected from the class's {@link GeneratedOpMetadata } annotation. As such, it only works on
192- * generated op classes.
190+ * Register a custom gradient function for ops of {@code inputClass}'s op type. The actual op type
191+ * is detected from the class's {@link OpInputsMetadata } annotation. As such, it only works on
192+ * generated op classes or custom op classes with the correct annotations .
193193 *
194- * @param opClass the class of op to register the gradient for.
194+ * @param inputClass the inputs class of op to register the gradient for.
195195 * @param gradient the gradient function to use
196196 * @return {@code true} if the gradient was registered, {@code false} if there was already a
197197 * gradient registered for this op
198- * @throws IllegalArgumentException if {@code opClass} does not have a {@link GeneratedOpMetadata}
199- * field .
198+ * @throws IllegalArgumentException if {@code inputClass} is not annotated with {@link
199+ * OpInputsMetadata} or the op class is not annotated with {@link OpMetadata} .
200200 */
201201 public static synchronized <T extends RawOpInputs <?>> boolean registerCustomGradient (
202- Class <T > opClass , CustomGradient <T > gradient ) {
203- GeneratedOpInputsMetadata metadata = opClass .getAnnotation (GeneratedOpInputsMetadata .class );
202+ Class <T > inputClass , CustomGradient <T > gradient ) {
203+ OpInputsMetadata metadata = inputClass .getAnnotation (OpInputsMetadata .class );
204204
205205 if (metadata == null ) {
206206 throw new IllegalArgumentException (
207207 "Inputs Class "
208- + opClass
209- + " does not have a GeneratedOpInputsMetadata annotation. Was it generated by tensorflow/java? If it was, this is a bug." );
208+ + inputClass
209+ + " does not have a OpInputsMetadata annotation. Was it generated by tensorflow/java? If it was, this is a bug." );
210210 }
211- GeneratedOpMetadata outputMetadata =
212- metadata .outputsClass ().getAnnotation (GeneratedOpMetadata .class );
211+ OpMetadata outputMetadata = metadata .outputsClass ().getAnnotation (OpMetadata .class );
213212
214213 if (outputMetadata == null ) {
215214 throw new IllegalArgumentException (
216215 "Op Class "
217216 + metadata .outputsClass ()
218- + " does not have a GeneratedOpMetadata annotation. Was it generated by tensorflow/java? If it was, this is a bug." );
217+ + " does not have a OpMetadata annotation. Was it generated by tensorflow/java? If it was, this is a bug." );
219218 }
220219
221220 String opType = outputMetadata .opType ();
@@ -224,7 +223,7 @@ public static synchronized <T extends RawOpInputs<?>> boolean registerCustomGrad
224223 return false ;
225224 }
226225
227- GradFunc g = CustomGradient .adapter (gradient , opClass );
226+ GradFunc g = CustomGradient .adapter (gradient , inputClass );
228227 GradOpRegistry .Global ().Register (opType , g );
229228 gradientFuncs .add (g );
230229 return true ;
0 commit comments