Skip to content

Commit a9a43ef

Browse files
authored
Fix fail when a variable receives zero gradient #482 (#483)
1 parent db3f00e commit a9a43ef

File tree

2 files changed

+12
-1
lines changed
  • tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow
  • tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers

2 files changed

+12
-1
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,15 @@ Pointer getUnsafeNativeHandle() {
147147
return operation.getUnsafeNativeHandle(index);
148148
}
149149

150+
/**
151+
* Returns whether the underlying operation has no valid handle. Makes the opposite check as
152+
* GraphOperation.requireHandle *
153+
*/
154+
public boolean isClosed() {
155+
Pointer handle = operation.getUnsafeNativeHandle(index);
156+
return handle == null || handle.isNull();
157+
}
158+
150159
private final AbstractOperation operation;
151160
private final int index;
152161
}

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,9 @@ public Op applyGradients(List<GradAndVar<? extends TType>> gradsAndVars, String
174174
List<Op> updateOps = new ArrayList<>();
175175
prepOp.ifPresent(updateOps::add);
176176
for (GradAndVar<? extends TType> pair : gradsAndVars) {
177-
updateOps.add(applyDense(pair));
177+
if (!pair.gradient.isClosed()) {
178+
updateOps.add(applyDense(pair));
179+
}
178180
}
179181

180182
return finish(updateOps, name);

0 commit comments

Comments
 (0)