Skip to content

Commit fb3c103

Browse files
committed
Rename Shape.size(int) to get, add toListOrNull
Signed-off-by: Ryan Nett <[email protected]>
1 parent 4a6e726 commit fb3c103

File tree

15 files changed

+47
-73
lines changed

15 files changed

+47
-73
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Signature.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ private static TensorInfo toTensorInfo(Output<?> operand) {
121121
Shape shape = operand.shape();
122122
TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
123123
for (int i = 0; i < shape.numDimensions(); ++i) {
124-
tensorShapeBuilder.addDim(Dim.newBuilder().setSize(shape.size(i)));
124+
tensorShapeBuilder.addDim(Dim.newBuilder().setSize(shape.get(i)));
125125
}
126126
return TensorInfo.newBuilder()
127127
.setDtype(operand.dataType())

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ public static <T extends TNumber> Operand<T> sigmoidCrossEntropyWithLogits(
9595
private static boolean isCompatible(Shape shape, Shape other) {
9696
if (shape.numDimensions() != other.numDimensions()) return false;
9797
for (int i = 0; i < shape.numDimensions(); i++) {
98-
long aShapeDim = shape.size(i);
99-
long bShapeDim = other.size(i);
98+
long aShapeDim = shape.get(i);
99+
long bShapeDim = other.get(i);
100100
if (aShapeDim == bShapeDim
101101
|| (aShapeDim == Shape.UNKNOWN_SIZE || bShapeDim == Shape.UNKNOWN_SIZE)) {
102102
continue;

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import org.tensorflow.types.TFloat32;
1515
import org.tensorflow.types.TInt64;
1616
import org.tensorflow.types.family.TNumber;
17-
import org.tensorflow.types.family.TType;
1817

1918
import java.util.Arrays;
2019
import java.util.List;
@@ -124,10 +123,10 @@ public static <T extends TNumber, U extends TNumber> Operand<T> softmaxCrossEntr
124123
axis = shape.numDimensions() + axis;
125124
}
126125
for (int i = 0; i < axis; i++) {
127-
newArray[i] = shape.size(i);
126+
newArray[i] = shape.get(i);
128127
}
129128
for (int i = axis + 1; i < shape.numDimensions(); i++) {
130-
newArray[i - 1] = shape.size(i);
129+
newArray[i - 1] = shape.get(i);
131130
}
132131
cost = Reshape.create(scope, cost, Constant.vectorOf(scope, newArray));
133132
}
@@ -152,15 +151,15 @@ private static <T extends TNumber> Operand<T> flattenOuterDims(Scope scope, Oper
152151
long product = 1L;
153152
boolean productValid = true;
154153
for (int i = ndims - 2; i >= 0; i--) {
155-
long d = shape.size(i);
154+
long d = shape.get(i);
156155
if (d == org.tensorflow.ndarray.Shape.UNKNOWN_SIZE) {
157156
productValid = false;
158157
break;
159158
}
160159
product *= d;
161160
}
162161
if (productValid) {
163-
return Reshape.create(scope, logits, Constant.arrayOf(scope, product, shape.size(-1)));
162+
return Reshape.create(scope, logits, Constant.arrayOf(scope, product, shape.get(-1)));
164163
}
165164
}
166165

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ public static <T extends TNumber, U extends TNumber> Operand sparseSoftmaxCrossE
140140
}
141141

142142
// Reshape logits to 2 dims, labels to 1 dim.
143-
long numClassses = logitsShape.size(-1);
143+
long numClassses = logitsShape.get(-1);
144144

145145
preciseLogits = Reshape.create(scope, preciseLogits, Constant.arrayOf(scope, -1L, numClassses));
146146
labels = Reshape.create(scope, labels, Constant.scalarOf(scope, -1));

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ public void outputDataTypeAndShape() {
5757
.setAttr("value", t)
5858
.build();
5959
assertEquals(DataType.DT_INT32, op.dtype(0));
60-
assertEquals(2, op.shape(0).size(0));
61-
assertEquals(3, op.shape(0).size(1));
60+
assertEquals(2, op.shape(0).get(0));
61+
assertEquals(3, op.shape(0).get(1));
6262
}
6363
}
6464

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ public void setAttrShape() {
129129
.build()
130130
.output(0);
131131
assertEquals(2, n.shape().numDimensions());
132-
assertEquals(-1, n.shape().size(0));
133-
assertEquals(784, n.shape().size(1));
132+
assertEquals(-1, n.shape().get(0));
133+
assertEquals(784, n.shape().get(1));
134134
assertEquals(DataType.DT_FLOAT, n.dataType());
135135
}
136136
}

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ public void exportFunctionWithVariables() throws IOException {
146146
assertNotNull(inputInfo);
147147
assertEquals(xyShape.numDimensions(), inputInfo.getTensorShape().getDimCount());
148148
for (int i = 0; i < xyShape.numDimensions(); ++i) {
149-
assertEquals(xyShape.size(i), inputInfo.getTensorShape().getDim(i).getSize());
149+
assertEquals(xyShape.get(i), inputInfo.getTensorShape().getDim(i).getSize());
150150
}
151151

152152
TensorInfo outputInfo = signatureDef.getOutputsMap().get("reducedSum");

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ public void nDimensional() {
325325
assertEquals(TFloat64.class, t.type());
326326
assertEquals(DataType.DT_DOUBLE, t.dataType());
327327
assertEquals(1, t.shape().numDimensions());
328-
assertEquals(3, t.shape().size(0));
328+
assertEquals(3, t.shape().get(0));
329329
assertEquals(vector, t);
330330
}
331331

@@ -334,8 +334,8 @@ public void nDimensional() {
334334
assertEquals(TInt32.class, t.type());
335335
assertEquals(DataType.DT_INT32, t.dataType());
336336
assertEquals(2, t.shape().numDimensions());
337-
assertEquals(2, t.shape().size(0));
338-
assertEquals(3, t.shape().size(1));
337+
assertEquals(2, t.shape().get(0));
338+
assertEquals(3, t.shape().get(1));
339339
assertEquals(matrix, t);
340340
}
341341

@@ -346,9 +346,9 @@ public void nDimensional() {
346346
assertEquals(TInt64.class, t.type());
347347
assertEquals(DataType.DT_INT64, t.dataType());
348348
assertEquals(3, t.shape().numDimensions());
349-
assertEquals(2, t.shape().size(0));
350-
assertEquals(5, t.shape().size(1));
351-
assertEquals(1, t.shape().size(2));
349+
assertEquals(2, t.shape().get(0));
350+
assertEquals(5, t.shape().get(1));
351+
assertEquals(1, t.shape().get(2));
352352
assertEquals(threeD, t);
353353
}
354354

@@ -361,10 +361,10 @@ public void nDimensional() {
361361
assertEquals(TBool.class, t.type());
362362
assertEquals(DataType.DT_BOOL, t.dataType());
363363
assertEquals(4, t.shape().numDimensions());
364-
assertEquals(3, t.shape().size(0));
365-
assertEquals(1, t.shape().size(1));
366-
assertEquals(2, t.shape().size(2));
367-
assertEquals(4, t.shape().size(3));
364+
assertEquals(3, t.shape().get(0));
365+
assertEquals(1, t.shape().get(1));
366+
assertEquals(2, t.shape().get(2));
367+
assertEquals(4, t.shape().get(3));
368368
assertEquals(fourD, t);
369369
}
370370
}
@@ -381,8 +381,8 @@ public void testNDimensionalStringTensor() {
381381
assertEquals(TString.class, t.type());
382382
assertEquals(DataType.DT_STRING, t.dataType());
383383
assertEquals(2, t.shape().numDimensions());
384-
assertEquals(4, t.shape().size(0));
385-
assertEquals(3, t.shape().size(1));
384+
assertEquals(4, t.shape().get(0));
385+
assertEquals(3, t.shape().get(1));
386386
assertEquals(matrix, t);
387387
}
388388

@@ -392,8 +392,8 @@ public void testNDimensionalStringTensor() {
392392
assertEquals(TString.class, t.type());
393393
assertEquals(DataType.DT_STRING, t.dataType());
394394
assertEquals(2, t.shape().numDimensions());
395-
assertEquals(4, t.shape().size(0));
396-
assertEquals(3, t.shape().size(1));
395+
assertEquals(4, t.shape().get(0));
396+
assertEquals(3, t.shape().get(1));
397397
assertEquals(byteMatrix, t.asBytes());
398398
assertEquals(matrix, t);
399399
}
@@ -406,7 +406,7 @@ public void testUint8TensorFromArray() {
406406
assertEquals(TUint8.class, t.type());
407407
assertEquals(DataType.DT_UINT8, t.dataType());
408408
assertEquals(1, t.shape().numDimensions());
409-
assertEquals(4, t.shape().size(0));
409+
assertEquals(4, t.shape().get(0));
410410

411411
byte[] got = new byte[4];
412412
t.read(DataBuffers.of(got));
@@ -421,7 +421,7 @@ public void testCreateFromArrayOfBoxed() {
421421
assertEquals(TInt32.class, t.type());
422422
assertEquals(DataType.DT_INT32, t.dataType());
423423
assertEquals(1, t.shape().numDimensions());
424-
assertEquals(4, t.shape().size(0));
424+
assertEquals(4, t.shape().get(0));
425425

426426
Integer[] got = new Integer[4];
427427
t.read(DataBuffers.ofObjects(got));

tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ public Operand<T> call(Operand<TInt64> dims, Class<T> type) {
7070
if (shape.numDimensions() != 2) {
7171
throw new IllegalArgumentException("2D matrix required, got " + shape.numDimensions());
7272
}
73-
boolean isSquare = shape.size(0) == shape.size(1);
74-
long diagSize = Math.min(shape.size(0), shape.size(1));
73+
boolean isSquare = shape.get(0) == shape.get(1);
74+
long diagSize = Math.min(shape.get(0), shape.get(1));
7575
Shape diagShape = Shape.of(diagSize);
7676

7777
Operand<T> op;
@@ -83,8 +83,8 @@ public Operand<T> call(Operand<TInt64> dims, Class<T> type) {
8383
tf.linalg.matrixDiag(
8484
diagOnes,
8585
tf.constant(0), // don't cast here, expecting TInt32
86-
tf.constant((int) shape.size(0)),
87-
tf.constant((int) shape.size(1)),
86+
tf.constant((int) shape.get(0)),
87+
tf.constant((int) shape.get(1)),
8888
zero);
8989
} else {
9090
Operand<T> zeroMatrix = tf.zeros(dims, type);

tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ public Operand<T> call(Operand<TInt64> dims, Class<T> type) {
9090
}
9191
long numRows = 1;
9292
int i = 0;
93-
for (; i < dimsShape.numDimensions() - 1; i++) numRows *= dimsShape.size(i);
94-
long numCols = dimsShape.size(i);
93+
for (; i < dimsShape.numDimensions() - 1; i++) numRows *= dimsShape.get(i);
94+
long numCols = dimsShape.get(i);
9595
Shape flatShape = Shape.of(Math.max(numRows, numCols), Math.min(numRows, numCols));
9696
long[] seeds = {seed, 0};
9797
Operand<T> op =

0 commit comments

Comments
 (0)