Skip to content

Commit c1417d2

Browse files
committed
Fixes for Scope and native adapter
Signed-off-by: Ryan Nett <[email protected]>
1 parent 1ed0df8 commit c1417d2

File tree

9 files changed

+208
-162
lines changed

9 files changed

+208
-162
lines changed

tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/GradFunc.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,6 @@ public class GradFunc extends FunctionPointer {
2020
protected GradFunc() { allocate(); }
2121
private native void allocate();
2222
public native @ByVal NativeStatus call(@Const @ByRef TF_Scope scope, @Const @ByRef NativeOperation op,
23-
NativeOutputVector grad_inputs,
23+
@Const @ByRef NativeOutputVector grad_inputs,
2424
NativeOutputVector grad_outputs);
2525
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,8 @@ public void map(InfoMap infoMap) {
304304
.annotations("@StdString").valueTypes("BytePointer", "String")
305305
.pointerTypes("BytePointer"))
306306
.put(new Info("absl::Span", "tensorflow::gtl::ArraySlice").annotations("@Span"))
307-
.put(new Info("std::vector<tensorflow::Output>").valueTypes("@StdMove NativeOutputVector")
308-
.pointerTypes("NativeOutputVector").define())
307+
.put(
308+
new Info("std::vector<tensorflow::Output>").pointerTypes("NativeOutputVector").define())
309309
.put(new Info("tensorflow::Output").javaNames("NativeOutput"))
310310
.put(new Info("tensorflow::Operation").javaNames("NativeOperation"))
311311
.put(new Info("tensorflow::Status").javaNames("NativeStatus").purify())

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,15 @@
2626
import org.tensorflow.types.TInt32;
2727
import org.tensorflow.types.family.TType;
2828

29-
/** Unit tests for {@link org.tensorflow.op.Scope}. */
29+
/**
30+
* Unit tests for {@link org.tensorflow.op.Scope}.
31+
*/
3032
public class ScopeTest {
3133

3234
@Test
3335
public void basicNames() {
3436
try (Graph g = new Graph()) {
35-
Scope root = new Scope(g);
37+
Scope root = new JavaScope(g);
3638
assertEquals("add", root.makeOpName("add"));
3739
assertEquals("add_1", root.makeOpName("add"));
3840
assertEquals("add_2", root.makeOpName("add"));
@@ -43,7 +45,7 @@ public void basicNames() {
4345
@Test
4446
public void hierarchicalNames() {
4547
try (Graph g = new Graph()) {
46-
Scope root = new Scope(g);
48+
Scope root = new JavaScope(g);
4749
Scope child = root.withSubScope("child");
4850
assertEquals("child/add", child.makeOpName("add"));
4951
assertEquals("child/add_1", child.makeOpName("add"));
@@ -69,7 +71,7 @@ public void hierarchicalNames() {
6971
@Test
7072
public void scopeAndOpNames() {
7173
try (Graph g = new Graph()) {
72-
Scope root = new Scope(g);
74+
Scope root = new JavaScope(g);
7375

7476
Scope child = root.withSubScope("child");
7577

@@ -82,12 +84,12 @@ public void scopeAndOpNames() {
8284
@Test
8385
public void validateNames() {
8486
try (Graph g = new Graph()) {
85-
Scope root = new Scope(g);
87+
Scope root = new JavaScope(g);
8688

8789
final String[] invalid_names = {
88-
"_", "-", "-x", // Names are constrained to start with [A-Za-z0-9.]
89-
null, "", "a$", // Invalid characters
90-
"a/b", // slashes not allowed
90+
"_", "-", "-x", // Names are constrained to start with [A-Za-z0-9.]
91+
null, "", "a$", // Invalid characters
92+
"a/b", // slashes not allowed
9193
};
9294

9395
for (String name : invalid_names) {
@@ -119,7 +121,7 @@ public void validateNames() {
119121
@Test
120122
public void basic() {
121123
try (Graph g = new Graph()) {
122-
Scope s = new Scope(g);
124+
Scope s = new JavaScope(g);
123125
Const<TInt32> c1 = Const.create(s, 42);
124126
assertEquals("Const", c1.output().op().name());
125127
Const<TInt32> c2 = Const.create(s, 7);
@@ -134,7 +136,7 @@ public void basic() {
134136
@Test
135137
public void hierarchy() {
136138
try (Graph g = new Graph()) {
137-
Scope root = new Scope(g);
139+
Scope root = new JavaScope(g);
138140
Scope child = root.withSubScope("child");
139141
assertEquals("child/Const", Const.create(child, 42).output().op().name());
140142
assertEquals("child/four", Const.create(child.withName("four"), 4).output().op().name());
@@ -145,9 +147,9 @@ public void hierarchy() {
145147
public void composite() {
146148
try (Graph g = new Graph();
147149
Session sess = new Session(g)) {
148-
Scope s = new Scope(g);
150+
Scope s = new JavaScope(g);
149151
Output<TInt32> data =
150-
Const.create(s.withName("data"), new int[] {600, 470, 170, 430, 300}).output();
152+
Const.create(s.withName("data"), new int[]{600, 470, 170, 430, 300}).output();
151153

152154
// Create a composite op with a customized name
153155
Variance<TInt32> var1 = Variance.create(s.withName("example"), data);
@@ -168,15 +170,16 @@ public void composite() {
168170
// assertNotNull(g.operation("variance/zero"));
169171

170172
// Verify correct results as well.
171-
TInt32 result = (TInt32)sess.runner().fetch(var1.output()).run().get(0);
173+
TInt32 result = (TInt32) sess.runner().fetch(var1.output()).run().get(0);
172174
assertEquals(21704, result.getInt());
173-
result = (TInt32)sess.runner().fetch(var2.output()).run().get(0);
175+
result = (TInt32) sess.runner().fetch(var2.output()).run().get(0);
174176
assertEquals(21704, result.getInt());
175177
}
176178
}
177179

178180
// "handwritten" sample operator classes
179181
private static final class Const<T extends TType> {
182+
180183
private final Output<T> output;
181184

182185
static Const<TInt32> create(Scope s, int v) {
@@ -207,6 +210,7 @@ Output<T> output() {
207210
}
208211

209212
private static final class Mean<T extends TType> {
213+
210214
private final Output<T> output;
211215

212216
static <T extends TType> Mean<T> create(Scope s, Output<T> input, Output<T> reductionIndices) {
@@ -229,6 +233,7 @@ Output<T> output() {
229233
}
230234

231235
private static final class SquaredDifference<T extends TType> {
236+
232237
private final Output<T> output;
233238

234239
static <T extends TType> SquaredDifference<T> create(Scope s, Output<T> x, Output<T> y) {
@@ -251,14 +256,15 @@ Output<T> output() {
251256
}
252257

253258
private static final class Variance<T extends TType> {
259+
254260
private final Output<T> output;
255261

256262
static Variance<TInt32> create(Scope base, Output<TInt32> x) {
257263
Scope s = base.withSubScope("variance");
258264
Output<TInt32> zero = Const.create(base, TInt32.scalarOf(0)).output();
259265
Output<TInt32> sqdiff =
260266
SquaredDifference.create(
261-
s.withName("squared_deviation"), x, Mean.create(s, x, zero).output())
267+
s.withName("squared_deviation"), x, Mean.create(s, x, zero).output())
262268
.output();
263269

264270
return new Variance<>(Mean.create(s.withName("variance"), sqdiff, zero).output());

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,25 @@
2323
import org.tensorflow.Operand;
2424
import org.tensorflow.Session;
2525
import org.tensorflow.ndarray.Shape;
26+
import org.tensorflow.op.JavaScope;
2627
import org.tensorflow.op.Scope;
2728
import org.tensorflow.types.TBool;
2829
import org.tensorflow.types.TFloat32;
2930
import org.tensorflow.types.TInt32;
3031

3132
public class BooleanMaskTest {
33+
3234
@Test
33-
public void testBooleanMask(){
35+
public void testBooleanMask() {
3436
try (Graph g = new Graph();
3537
Session sess = new Session(g)) {
36-
Scope scope = new Scope(g);
38+
Scope scope = new JavaScope(g);
3739

3840
Operand<TInt32> input = Constant.arrayOf(scope, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
3941
Operand<TInt32> input2 = ExpandDims.create(scope, input, Constant.scalarOf(scope, 0));
4042

41-
Operand<TBool> mask = Constant.arrayOf(scope, true, true, false, false, true, true, true, false, false, false);
43+
Operand<TBool> mask = Constant
44+
.arrayOf(scope, true, true, false, false, true, true, true, false, false, false);
4245

4346
Operand<TInt32> output1 = BooleanMask.create(scope, input, mask);
4447
Operand<TInt32> output2 = BooleanMask.create(scope, input2, mask, BooleanMask.axis(1));

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.tensorflow.Session;
2626
import org.tensorflow.Tensor;
2727
import org.tensorflow.ndarray.Shape;
28+
import org.tensorflow.op.JavaScope;
2829
import org.tensorflow.op.Scope;
2930
import org.tensorflow.types.TBool;
3031
import org.tensorflow.types.TInt32;
@@ -35,17 +36,19 @@ public class BooleanMaskUpdateTest {
3536
public void testBooleanMaskUpdateSlice() {
3637
try (Graph g = new Graph();
3738
Session sess = new Session(g)) {
38-
Scope scope = new Scope(g);
39+
Scope scope = new JavaScope(g);
3940

40-
Operand<TInt32> input = Constant.tensorOf(scope, new int[][]{{0, 0, 0}, {1, 1, 1}, {2, 2, 2}});
41+
Operand<TInt32> input = Constant
42+
.tensorOf(scope, new int[][]{{0, 0, 0}, {1, 1, 1}, {2, 2, 2}});
4143

4244
Operand<TBool> mask = Constant.arrayOf(scope, true, false, false);
4345

4446
Operand<TInt32> value = Constant.tensorOf(scope, new int[][]{{-1, -1, -1}});
4547

4648
Operand<TInt32> output = BooleanMaskUpdate.create(scope, input, mask, value);
4749

48-
Operand<TInt32> bcastOutput = BooleanMaskUpdate.create(scope, input, mask, Constant.scalarOf(scope, -1));
50+
Operand<TInt32> bcastOutput = BooleanMaskUpdate
51+
.create(scope, input, mask, Constant.scalarOf(scope, -1));
4952

5053
List<Tensor> results = sess.runner().fetch(output).fetch(bcastOutput).run();
5154
try (TInt32 result = (TInt32) results.get(0);
@@ -72,17 +75,19 @@ public void testBooleanMaskUpdateSlice() {
7275
public void testBooleanMaskUpdateSliceWithBroadcast() {
7376
try (Graph g = new Graph();
7477
Session sess = new Session(g)) {
75-
Scope scope = new Scope(g);
78+
Scope scope = new JavaScope(g);
7679

77-
Operand<TInt32> input = Constant.tensorOf(scope, new int[][]{{0, 0, 0}, {1, 1, 1}, {2, 2, 2}});
80+
Operand<TInt32> input = Constant
81+
.tensorOf(scope, new int[][]{{0, 0, 0}, {1, 1, 1}, {2, 2, 2}});
7882

7983
Operand<TBool> mask = Constant.arrayOf(scope, true, false, false);
8084

8185
Operand<TInt32> value = Constant.vectorOf(scope, new int[]{-1, -1, -1});
8286

8387
Operand<TInt32> output = BooleanMaskUpdate.create(scope, input, mask, value);
8488

85-
Operand<TInt32> bcastOutput = BooleanMaskUpdate.create(scope, input, mask, Constant.scalarOf(scope, -1));
89+
Operand<TInt32> bcastOutput = BooleanMaskUpdate
90+
.create(scope, input, mask, Constant.scalarOf(scope, -1));
8691

8792
List<Tensor> results = sess.runner().fetch(output).fetch(bcastOutput).run();
8893
try (TInt32 result = (TInt32) results.get(0);
@@ -109,15 +114,18 @@ public void testBooleanMaskUpdateSliceWithBroadcast() {
109114
public void testBooleanMaskUpdateAxis() {
110115
try (Graph g = new Graph();
111116
Session sess = new Session(g)) {
112-
Scope scope = new Scope(g);
117+
Scope scope = new JavaScope(g);
113118

114-
Operand<TInt32> input = Constant.tensorOf(scope, new int[][][]{{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}}});
119+
Operand<TInt32> input = Constant
120+
.tensorOf(scope, new int[][][]{{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}}});
115121

116-
Operand<TBool> mask = Constant.arrayOf(scope, true, true, false, false, true, true, true, false, false, false);
122+
Operand<TBool> mask = Constant
123+
.arrayOf(scope, true, true, false, false, true, true, true, false, false, false);
117124

118125
Operand<TInt32> value = Constant.arrayOf(scope, -1, -1, -1, -1, -1);
119126

120-
Operand<TInt32> output = BooleanMaskUpdate.create(scope, input, mask, value, BooleanMaskUpdate.axis(2));
127+
Operand<TInt32> output = BooleanMaskUpdate
128+
.create(scope, input, mask, value, BooleanMaskUpdate.axis(2));
121129

122130
Operand<TInt32> bcastOutput = BooleanMaskUpdate
123131
.create(scope, input, mask, Constant.scalarOf(scope, -1), BooleanMaskUpdate.axis(2));

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
3939
import org.tensorflow.ndarray.buffer.IntDataBuffer;
4040
import org.tensorflow.ndarray.buffer.LongDataBuffer;
41+
import org.tensorflow.op.JavaScope;
4142
import org.tensorflow.op.Ops;
4243
import org.tensorflow.op.Scope;
4344
import org.tensorflow.types.TBfloat16;
@@ -62,7 +63,7 @@ public void createInts() {
6263

6364
try (Graph g = new Graph();
6465
Session sess = new Session(g)) {
65-
Scope scope = new Scope(g);
66+
Scope scope = new JavaScope(g);
6667
Constant<TInt32> op1 = Constant.tensorOf(scope, shape, buffer);
6768
Constant<TInt32> op2 = Constant.tensorOf(scope, array);
6869
try (AutoCloseableList<Tensor> t =
@@ -81,7 +82,7 @@ public void createFloats() {
8182

8283
try (Graph g = new Graph();
8384
Session sess = new Session(g)) {
84-
Scope scope = new Scope(g);
85+
Scope scope = new JavaScope(g);
8586
Constant<TFloat32> op1 = Constant.tensorOf(scope, shape, buffer);
8687
Constant<TFloat32> op2 = Constant.tensorOf(scope, array);
8788
try (AutoCloseableList<Tensor> t =
@@ -100,7 +101,7 @@ public void createDoubles() {
100101

101102
try (Graph g = new Graph();
102103
Session sess = new Session(g)) {
103-
Scope scope = new Scope(g);
104+
Scope scope = new JavaScope(g);
104105
Constant<TFloat64> op1 = Constant.tensorOf(scope, shape, buffer);
105106
Constant<TFloat64> op2 = Constant.tensorOf(scope, array);
106107
try (AutoCloseableList<Tensor> t =
@@ -119,7 +120,7 @@ public void createLongs() {
119120

120121
try (Graph g = new Graph();
121122
Session sess = new Session(g)) {
122-
Scope scope = new Scope(g);
123+
Scope scope = new JavaScope(g);
123124
Constant<TInt64> op1 = Constant.tensorOf(scope, shape, buffer);
124125
Constant<TInt64> op2 = Constant.tensorOf(scope, array);
125126
try (AutoCloseableList<Tensor> t =
@@ -138,7 +139,7 @@ public void createStrings() throws IOException {
138139

139140
try (Graph g = new Graph();
140141
Session sess = new Session(g)) {
141-
Scope scope = new Scope(g);
142+
Scope scope = new JavaScope(g);
142143
Constant<TString> op1 = Constant.tensorOf(scope, shape, buffer);
143144
Constant<TString> op2 = Constant.tensorOf(scope, array);
144145
try (AutoCloseableList<Tensor> t =

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
import org.tensorflow.Graph;
2222
import org.tensorflow.Session;
2323
import org.tensorflow.ndarray.Shape;
24-
import org.tensorflow.ndarray.index.Indices;
2524
import org.tensorflow.ndarray.index.Index;
25+
import org.tensorflow.ndarray.index.Indices;
26+
import org.tensorflow.op.JavaScope;
2627
import org.tensorflow.op.Scope;
2728
import org.tensorflow.types.TFloat32;
2829

@@ -35,7 +36,7 @@ public class IndexingTest {
3536
Indices.all(),
3637
Indices.newAxis(),
3738
Indices.ellipsis(),
38-
Indices.sliceTo( 4),
39+
Indices.sliceTo(4),
3940
Indices.sliceFrom(4, 2)
4041
};
4142

@@ -55,16 +56,17 @@ public void testIndexMerge() {
5556
}
5657

5758
@Test
58-
public void testStridedSliceIndex(){
59+
public void testStridedSliceIndex() {
5960
try (Graph g = new Graph();
6061
Session sess = new Session(g)) {
61-
Scope scope = new Scope(g);
62+
Scope scope = new JavaScope(g);
6263
long[] shape = {10, 10, 10, 10, 10, 10, 10, 10};
6364
Zeros<TFloat32> op = Zeros.create(scope, Constant.vectorOf(scope, shape), TFloat32.class);
6465
StridedSlice<TFloat32> output = StridedSliceHelper.stridedSlice(scope, op, slice);
6566
try (TFloat32 result = (TFloat32) sess.runner().fetch(output.asOutput()).run().get(0)) {
6667
// expected shape from Python tensorflow
67-
assertEquals(Shape.of(1, 10, 1, 10, 10, 10, 4, 3), result.shape(), "Slice index didn't match expected (Python)");
68+
assertEquals(Shape.of(1, 10, 1, 10, 10, 10, 4, 3), result.shape(),
69+
"Slice index didn't match expected (Python)");
6870
}
6971
}
7072
}

0 commit comments

Comments
 (0)