diff --git a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java index 4aa57619e..fe1a33a9f 100644 --- a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java +++ b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java @@ -1,5 +1,7 @@ package com.ibm.wala.cast.python.ml.test; +import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType.FLOAT32; +import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType.INT32; import static com.ibm.wala.cast.python.ml.types.TensorType.mnistInput; import static com.ibm.wala.cast.python.util.Util.addPytestEntrypoints; import static java.util.Arrays.asList; @@ -56,8 +58,67 @@ public class TestTensorflow2Model extends TestPythonMLCallGraphShape { private static final Logger LOGGER = Logger.getLogger(TestTensorflow2Model.class.getName()); + private static final String FLOAT_32 = FLOAT32.name().toLowerCase(); + + private static final String INT_32 = INT32.name().toLowerCase(); + private static final TensorType MNIST_INPUT = mnistInput(); + private static final TensorType SCALAR_TENSOR_OF_INT32 = new TensorType(INT_32, emptyList()); + + private static final TensorType SCALAR_TENSOR_OF_FLOAT32 = new TensorType(FLOAT_32, emptyList()); + + private static final TensorType TENSOR_1_2_FLOAT32 = + new TensorType(FLOAT_32, asList(new NumericDim(1), new NumericDim(2))); + + private static final TensorType TENSOR_1_2_INT32 = + new TensorType(INT_32, asList(new NumericDim(1), new NumericDim(2))); + + private static final TensorType TENSOR_2_2_FLOAT32 = + new TensorType(FLOAT_32, asList(new NumericDim(2), new NumericDim(2))); + + private static final TensorType TENSOR_2_2_INT32 = + new TensorType(INT_32, asList(new NumericDim(2), new NumericDim(2))); + + private static final TensorType TENSOR_3_2_FLOAT32 = + new TensorType(FLOAT_32, asList(new NumericDim(3), new NumericDim(2))); + + private static final TensorType TENSOR_2_1_FLOAT32 = + new TensorType(FLOAT_32, asList(new NumericDim(2), new NumericDim(1))); + + private static final TensorType TENSOR_2_3_3_INT32 = + new TensorType(INT_32, asList(new NumericDim(2), new NumericDim(3), new NumericDim(3))); + + private static final TensorType TENSOR_2_3_4_INT32 = + new TensorType(INT_32, asList(new NumericDim(2), new NumericDim(3), new NumericDim(4))); + + private static final TensorType TENSOR_20_28_28_FLOAT32 = + new TensorType(FLOAT_32, asList(new NumericDim(20), new NumericDim(28), new NumericDim(28))); + + private static final TensorType TENSOR_20_28_28_INT32 = + new TensorType(INT_32, asList(new NumericDim(20), new NumericDim(28), new NumericDim(28))); + + private static final TensorType TENSOR_2_FLOAT32 = + new TensorType(FLOAT_32, asList(new NumericDim(2))); + + private static final TensorType TENSOR_2_INT32 = + new TensorType(INT_32, asList(new NumericDim(2))); + + private static final TensorType TENSOR_3_INT32 = + new TensorType(INT_32, asList(new NumericDim(3))); + + private static final TensorType TENSOR_3_FLOAT32 = + new TensorType(FLOAT_32, asList(new NumericDim(3))); + + private static final TensorType TENSOR_4_FLOAT32 = + new TensorType(FLOAT_32, asList(new NumericDim(4))); + + private static final TensorType TENSOR_5_FLOAT32 = + new TensorType(FLOAT_32, asList(new NumericDim(5))); + + private static final TensorType TENSOR_5_INT32 = + new TensorType(INT_32, asList(new NumericDim(5))); + @Test public void testValueIndex() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { @@ -126,70 +187,156 @@ public void testFunction4() test("tf2_test_function4.py", "func2", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); } + @Test + public void testFunction5() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_function5.py", "func", 1, 1, Map.of(2, Set.of(TENSOR_2_2_FLOAT32))); + } + + @Test + public void testFunction6() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_function6.py", "func", 1, 1, Map.of(2, Set.of(TENSOR_2_1_FLOAT32))); + } + + @Test + public void testFunction7() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_function7.py", "func", 1, 1, Map.of(2, Set.of(TENSOR_2_FLOAT32))); + } + + @Test + public void testFunction8() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test( + "tf2_test_function8.py", + "func", + 1, + 1, + Map.of(2, Set.of(TENSOR_2_1_FLOAT32, TENSOR_2_FLOAT32))); + } + + @Test + public void testFunction9() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_function9.py", "func", 1, 1, Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); + } + + @Test + public void testFunction10() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_function10.py", "func", 1, 1, Map.of(2, Set.of(TENSOR_2_3_4_INT32))); + } + + @Test + public void testFunction11() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_function11.py", "func", 1, 1, Map.of(2, Set.of(TENSOR_2_3_3_INT32))); + } + + /** Test https://github.com/wala/ML/issues/308. */ + @Test + public void testFunction12() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test( + "tf2_test_function12.py", + "func", + 1, + 1, + Map.of(2, Set.of(TENSOR_2_1_FLOAT32, TENSOR_3_2_FLOAT32))); + } + + /** + * Test https://github.com/wala/ML/issues/308. + * + *

This one has lexical scoping. + */ + @Test + public void testFunction13() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test( + "tf2_test_function13.py", + "func", + 1, + 1, + Map.of(2, Set.of(TENSOR_2_1_FLOAT32, TENSOR_3_2_FLOAT32))); + } + @Test public void testDecorator() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_decorator.py", "returned", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorator.py", "returned", 1, 1, Map.of(2, Set.of(TENSOR_5_INT32))); } @Test public void testDecorator2() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_decorator2.py", "returned", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorator2.py", "returned", 1, 1, Map.of(2, Set.of(TENSOR_5_INT32))); } @Test public void testDecorator3() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_decorator3.py", "returned", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorator3.py", "returned", 1, 1, Map.of(2, Set.of(TENSOR_2_FLOAT32))); } @Test public void testDecorator4() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_decorator4.py", "returned", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorator4.py", "returned", 1, 1, Map.of(2, Set.of(TENSOR_5_INT32))); } @Test public void testDecorator5() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_decorator5.py", "returned", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorator5.py", "returned", 1, 1, Map.of(2, Set.of(TENSOR_5_INT32))); } @Test public void testDecorator6() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_decorator6.py", "returned", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorator6.py", "returned", 1, 1, Map.of(2, Set.of(TENSOR_5_INT32))); } @Test public void testDecorator7() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_decorator7.py", "returned", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorator7.py", "returned", 1, 1, Map.of(2, Set.of(TENSOR_5_INT32))); } @Test public void testDecorator8() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_decorator8.py", "returned", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorator8.py", "returned", 1, 1, Map.of(2, Set.of(TENSOR_5_INT32))); } @Test public void testDecorator9() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_decorator9.py", "returned", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorator9.py", "returned", 1, 1, Map.of(2, Set.of(TENSOR_5_INT32))); } @Test public void testDecorator10() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_decorator10.py", "returned", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorator10.py", "returned", 1, 1, Map.of(2, Set.of(TENSOR_5_INT32))); } @Test public void testDecorator11() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_decorator11.py", "C.returned", 1, 1, Map.of(3, Set.of(MNIST_INPUT))); + test("tf2_test_decorator11.py", "C.returned", 1, 1, Map.of(3, Set.of(TENSOR_5_INT32))); + } + + @Test + public void testDecorator12() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test( + "tf2_test_decorator12.py", + "returned", + 1, + 1, + Map.of(2, Set.of(TENSOR_2_FLOAT32, TENSOR_2_INT32))); } @Test @@ -582,7 +729,11 @@ public void testTensorList() "add", 2, 2, - Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + Map.of( + 2, + Set.of(TENSOR_1_2_FLOAT32, TENSOR_2_2_FLOAT32), + 3, + Set.of(TENSOR_1_2_FLOAT32, TENSOR_2_2_FLOAT32))); } @Test @@ -617,19 +768,33 @@ public void testTensorList5() public void testModelCall() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { test( - "tf2_test_model_call.py", "SequentialModel.__call__", 1, 1, Map.of(3, Set.of(MNIST_INPUT))); + "tf2_test_model_call.py", + "SequentialModel.__call__", + 1, + 1, + Map.of(3, Set.of(TENSOR_20_28_28_FLOAT32))); } @Test public void testModelCall2() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_model_call2.py", "SequentialModel.call", 1, 1, Map.of(3, Set.of(MNIST_INPUT))); + test( + "tf2_test_model_call2.py", + "SequentialModel.call", + 1, + 1, + Map.of(3, Set.of(TENSOR_20_28_28_FLOAT32))); } @Test public void testModelCall3() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_model_call3.py", "SequentialModel.call", 1, 1, Map.of(3, Set.of(MNIST_INPUT))); + test( + "tf2_test_model_call3.py", + "SequentialModel.call", + 1, + 1, + Map.of(3, Set.of(TENSOR_20_28_28_FLOAT32))); } @Test @@ -640,7 +805,7 @@ public void testModelCall4() "SequentialModel.__call__", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_20_28_28_FLOAT32))); } /** @@ -678,6 +843,22 @@ public void testModelCall5() Map.of(3, Set.of(MNIST_INPUT))); } + /** + * Test https://github.com/wala/ML/issues/267. + * + *

Explicit dtype. + */ + @Test + public void testModelCall6() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test( + "tf2_test_model_call6.py", + "SequentialModel.__call__", + 1, + 1, + Map.of(3, Set.of(TENSOR_20_28_28_INT32))); + } + @Test public void testModelAttributes() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { @@ -717,13 +898,13 @@ public void testModelAttributes6() @Test public void testCallbacks() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_callbacks.py", "replica_fn", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_callbacks.py", "replica_fn", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_FLOAT32))); } @Test public void testCallbacks2() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_callbacks2.py", "replica_fn", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_callbacks2.py", "replica_fn", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_FLOAT32))); } @Test @@ -822,13 +1003,13 @@ public void testAutoencoder4() @Test public void testSigmoid() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_sigmoid.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_sigmoid.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_4_FLOAT32))); } @Test public void testSigmoid2() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_sigmoid2.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_sigmoid2.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_4_FLOAT32))); } @Test @@ -870,49 +1051,89 @@ public void testAdd6() @Test public void testAdd7() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_add7.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + test( + "tf2_test_add7.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_1_2_FLOAT32), 3, Set.of(TENSOR_2_2_FLOAT32))); } @Test public void testAdd8() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_add8.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + test( + "tf2_test_add8.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_1_2_FLOAT32), 3, Set.of(TENSOR_2_2_FLOAT32))); } @Test public void testAdd9() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_add9.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + test( + "tf2_test_add9.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_1_2_FLOAT32), 3, Set.of(TENSOR_2_2_FLOAT32))); } @Test public void testAdd10() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_add10.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + test( + "tf2_test_add10.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_1_2_FLOAT32), 3, Set.of(TENSOR_2_2_FLOAT32))); } @Test public void testAdd11() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_add11.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + test( + "tf2_test_add11.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_1_2_FLOAT32), 3, Set.of(TENSOR_2_2_FLOAT32))); } @Test public void testAdd12() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_add12.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + test( + "tf2_test_add12.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_1_2_FLOAT32), 3, Set.of(TENSOR_2_2_FLOAT32))); } @Test public void testAdd13() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_add13.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + test( + "tf2_test_add13.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_1_2_FLOAT32), 3, Set.of(TENSOR_2_2_FLOAT32))); } @Test public void testAdd14() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_add14.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + test( + "tf2_test_add14.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_1_2_FLOAT32), 3, Set.of(TENSOR_2_2_FLOAT32))); } @Test @@ -966,31 +1187,56 @@ public void testAdd22() @Test public void testAdd23() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_add23.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + test( + "tf2_test_add23.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_2_INT32), 3, Set.of(TENSOR_2_INT32))); } @Test public void testAdd24() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_add24.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + test( + "tf2_test_add24.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_2_INT32), 3, Set.of(TENSOR_2_INT32))); } @Test public void testAdd25() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_add25.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + test( + "tf2_test_add25.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_2_INT32), 3, Set.of(TENSOR_2_INT32))); } @Test public void testAdd26() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_add26.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + test( + "tf2_test_add26.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_1_2_FLOAT32), 3, Set.of(TENSOR_2_2_FLOAT32))); } @Test public void testAdd27() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_add27.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + test( + "tf2_test_add27.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_1_2_FLOAT32), 3, Set.of(TENSOR_2_2_FLOAT32))); } @Test @@ -1020,13 +1266,23 @@ public void testAdd31() @Test public void testAdd32() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_add32.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + test( + "tf2_test_add32.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_2_INT32), 3, Set.of(TENSOR_2_INT32))); } @Test public void testAdd33() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_add33.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + test( + "tf2_test_add33.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_2_INT32), 3, Set.of(TENSOR_2_INT32))); } @Test @@ -1126,7 +1382,12 @@ public void testAdd47() @Test public void testAdd48() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_add48.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + test( + "tf2_test_add48.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_1_2_FLOAT32), 3, Set.of(TENSOR_2_2_FLOAT32))); } @Test @@ -1134,7 +1395,12 @@ public void testAdd49() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { // NOTE: Set the expected number of tensor variables to 3 once // https://github.com/wala/ML/issues/135 is fixed. - test("tf2_test_add49.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + test( + "tf2_test_add49.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_1_2_FLOAT32), 3, Set.of(TENSOR_2_2_FLOAT32))); } @Test @@ -1533,6 +1799,56 @@ public void testAdd115() test("tf2_test_add115.py", "add", 2, 2, Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); } + @Test + public void testAdd116() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test( + "tf2_test_add116.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_1_2_FLOAT32), 3, Set.of(TENSOR_2_2_FLOAT32))); + } + + @Test + public void testAdd117() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test( + "tf2_test_add117.py", + "add", + 2, + 2, + Map.of( + 2, + Set.of( + TENSOR_1_2_FLOAT32, + new TensorType(FLOAT_32, asList(new NumericDim(3), new NumericDim(2)))), + 3, + Set.of(TENSOR_2_2_FLOAT32))); + } + + @Test + public void testAdd118() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test( + "tf2_test_add118.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_1_2_INT32), 3, Set.of(TENSOR_2_2_INT32))); + } + + @Test + public void testAdd119() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test( + "tf2_test_add119.py", + "add", + 2, + 2, + Map.of(2, Set.of(TENSOR_2_FLOAT32), 3, Set.of(TENSOR_2_FLOAT32))); + } + @Test public void testMultiGPUTraining() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { @@ -1557,19 +1873,19 @@ public void testMultiGPUTraining2() @Test public void testReduceMean() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_reduce_mean.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_reduce_mean.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_2_2_FLOAT32))); } @Test public void testReduceMean2() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_reduce_mean.py", "g", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_reduce_mean.py", "g", 1, 1, Map.of(2, Set.of(TENSOR_2_2_FLOAT32))); } @Test public void testReduceMean3() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_reduce_mean.py", "h", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_reduce_mean.py", "h", 1, 1, Map.of(2, Set.of(TENSOR_2_2_FLOAT32))); } @Test @@ -1604,13 +1920,13 @@ public void testSparseSoftmaxCrossEntropyWithLogits() "f", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_3_INT32))); } @Test public void testRelu() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_relu.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_relu.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_3_FLOAT32))); } @Test @@ -1622,7 +1938,7 @@ public void testTFRange() @Test public void testTFRange2() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_tf_range2.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_tf_range2.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test @@ -1634,41 +1950,41 @@ public void testTFRange3() @Test public void testImport() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_import.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_import.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } @Test public void testImport2() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_import2.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_import2.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } @Test public void testImport3() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_import3.py", "f", 1, 2, Map.of(2, Set.of(MNIST_INPUT))); - test("tf2_test_import3.py", "g", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_import3.py", "f", 1, 2, Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); + test("tf2_test_import3.py", "g", 1, 1, Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } @Test public void testImport4() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_import4.py", "f", 1, 2, Map.of(2, Set.of(MNIST_INPUT))); - test("tf2_test_import4.py", "g", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_import4.py", "f", 1, 2, Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); + test("tf2_test_import4.py", "g", 1, 1, Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } @Test public void testImport5() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { test("tf2_test_import5.py", "f", 0, 1); - test("tf2_test_import5.py", "g", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_import5.py", "g", 1, 1, Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } @Test public void testImport6() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { test("tf2_test_import6.py", "f", 0, 1); - test("tf2_test_import6.py", "g", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_import6.py", "g", 1, 1, Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -1696,8 +2012,8 @@ public void testImport8() @Test public void testImport9() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test("tf2_test_import9.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); - test("tf2_test_import9.py", "g", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_import9.py", "f", 1, 1, Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); + test("tf2_test_import9.py", "g", 1, 1, Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } @Test @@ -1710,7 +2026,7 @@ public void testModule() "", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** This test needs a PYTHONPATH that points to `proj`. */ @@ -1726,7 +2042,7 @@ public void testModule2() "proj", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** This test should not need a PYTHONPATH. */ @@ -1742,7 +2058,7 @@ public void testModule3() "proj2", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -1764,7 +2080,7 @@ public void testModule4() "proj3", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); test( new String[] { @@ -1778,7 +2094,7 @@ public void testModule4() "proj3", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } @Test @@ -1791,7 +2107,7 @@ public void testModule5() "", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** This test needs a PYTHONPATH that points to `proj4`. */ @@ -1807,7 +2123,7 @@ public void testModule6() "proj4", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** This test should not need a PYTHONPATH. */ @@ -1823,7 +2139,7 @@ public void testModule7() "proj5", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -1845,7 +2161,7 @@ public void testModule8() "proj6", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); test( new String[] { @@ -1859,7 +2175,7 @@ public void testModule8() "proj6", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } @Test @@ -1872,7 +2188,7 @@ public void testModule9() "", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } @Test @@ -1885,7 +2201,7 @@ public void testModule10() "", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** This test needs a PYTHONPATH that points to `proj7`. */ @@ -1904,7 +2220,7 @@ public void testModule11() "proj7", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** This test should not need a PYTHONPATH. */ @@ -1923,7 +2239,7 @@ public void testModule12() "proj8", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** This test should not need a PYTHONPATH. */ @@ -1942,7 +2258,7 @@ public void testModule13() "proj9", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -1960,7 +2276,7 @@ public void testModule14() "proj10", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -1978,7 +2294,7 @@ public void testModule15() "proj11", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** This test should not need a PYTHONPATH. */ @@ -1992,7 +2308,7 @@ public void testModule16() "proj12", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2013,7 +2329,7 @@ public void testModule17() "proj13", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2039,7 +2355,7 @@ public void testModule18() "proj14", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); test( new String[] { @@ -2054,7 +2370,7 @@ public void testModule18() "proj14", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2074,7 +2390,7 @@ public void testModule19() "proj15", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2092,7 +2408,7 @@ public void testModule20() "proj16", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2112,7 +2428,7 @@ public void testModule21() "proj17", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2130,7 +2446,7 @@ public void testModule22() "proj18", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2154,7 +2470,7 @@ public void testModule23() "proj19", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2172,7 +2488,7 @@ public void testModule24() "", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2190,7 +2506,7 @@ public void testModule25() "proj20", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2208,7 +2524,7 @@ public void testModule26() "", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2232,7 +2548,7 @@ public void testModule27() "proj21", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); test( new String[] { @@ -2247,7 +2563,7 @@ public void testModule27() "proj21", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2265,7 +2581,7 @@ public void testModule28() "proj22", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2283,7 +2599,7 @@ public void testModule29() "proj23", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2301,7 +2617,7 @@ public void testModule30() "proj24", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2321,7 +2637,7 @@ public void testModule31() "proj25", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2339,7 +2655,7 @@ public void testModule32() "proj26", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2359,7 +2675,7 @@ public void testModule33() "proj27", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2379,7 +2695,7 @@ public void testModule34() "proj28", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2397,7 +2713,7 @@ public void testModule35() "proj29", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2415,7 +2731,7 @@ public void testModule36() "proj30", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2433,7 +2749,7 @@ public void testModule37() "proj31", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2451,7 +2767,7 @@ public void testModule38() "proj32", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2469,7 +2785,7 @@ public void testModule39() "proj33", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2487,7 +2803,7 @@ public void testModule40() "proj34", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2512,7 +2828,7 @@ public void testModule41() "proj35", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2537,7 +2853,7 @@ public void testModule42() "proj36", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2562,7 +2878,7 @@ public void testModule43() "proj37", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2587,7 +2903,7 @@ public void testModule44() "proj38", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2605,7 +2921,7 @@ public void testModule45() "proj39", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2623,7 +2939,7 @@ public void testModule46() "proj40", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2641,7 +2957,7 @@ public void testModule47() "proj41", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2659,7 +2975,7 @@ public void testModule48() "proj42", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2684,7 +3000,7 @@ public void testModule49() "proj43", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2709,7 +3025,7 @@ public void testModule50() "proj44", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2734,7 +3050,7 @@ public void testModule51() "proj45", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2759,7 +3075,7 @@ public void testModule52() "proj46", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -2786,7 +3102,7 @@ public void testModule53() "proj47", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); test( new String[] { @@ -2804,7 +3120,7 @@ public void testModule53() "proj47", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/202. */ @@ -2818,7 +3134,7 @@ public void testModule54() "proj51", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/202. */ @@ -2832,7 +3148,7 @@ public void testModule55() "proj52", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/202. */ @@ -2846,7 +3162,7 @@ public void testModule56() "proj53", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/202. */ @@ -2860,7 +3176,7 @@ public void testModule57() "proj54", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/202. */ @@ -2874,7 +3190,7 @@ public void testModule58() "proj55", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/202. */ @@ -2888,7 +3204,7 @@ public void testModule59() "proj51", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/202. */ @@ -2902,7 +3218,7 @@ public void testModule60() "proj52", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/202. */ @@ -2916,7 +3232,7 @@ public void testModule61() "proj56", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/202. */ @@ -2930,7 +3246,7 @@ public void testModule62() "proj57", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/202. */ @@ -2944,7 +3260,7 @@ public void testModule63() "proj58", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/202. */ @@ -2958,7 +3274,7 @@ public void testModule64() "proj59", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/202. */ @@ -2972,7 +3288,7 @@ public void testModule65() "proj60", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/202. */ @@ -2986,7 +3302,7 @@ public void testModule66() "proj61", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/202. */ @@ -3000,7 +3316,7 @@ public void testModule67() "proj62", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/205. */ @@ -3014,7 +3330,7 @@ public void testModule68() "proj63", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/205. */ @@ -3028,7 +3344,7 @@ public void testModule69() "proj64", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/210. */ @@ -3042,7 +3358,7 @@ public void testModule70() "proj65", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/210. */ @@ -3056,7 +3372,7 @@ public void testModule71() "proj67", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/210. */ @@ -3070,7 +3386,7 @@ public void testModule72() "proj68", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/210. */ @@ -3084,7 +3400,7 @@ public void testModule73() "proj69", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/210. */ @@ -3098,7 +3414,7 @@ public void testModule74() "proj70", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/211. */ @@ -3112,7 +3428,7 @@ public void testModule75() "proj71", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/211. */ @@ -3126,7 +3442,7 @@ public void testModule76() "", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/211. */ @@ -3140,7 +3456,7 @@ public void testModule77() "proj72", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/211. */ @@ -3154,7 +3470,7 @@ public void testModule78() "", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } /** Test https://github.com/wala/ML/issues/209. */ @@ -3174,7 +3490,7 @@ public void testModule79() "proj73", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); test( new String[] { @@ -3189,7 +3505,7 @@ public void testModule79() "proj73", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } /** Test https://github.com/wala/ML/issues/209. */ @@ -3209,7 +3525,7 @@ public void testModule80() "proj74", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); test( new String[] { @@ -3224,7 +3540,7 @@ public void testModule80() "proj74", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test @@ -3234,7 +3550,7 @@ public void testStaticMethod() throws ClassHierarchyException, CancelException, "MyClass.the_static_method", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test @@ -3244,7 +3560,7 @@ public void testStaticMethod2() throws ClassHierarchyException, CancelException, "MyClass.the_static_method", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test @@ -3254,7 +3570,7 @@ public void testStaticMethod3() throws ClassHierarchyException, CancelException, "MyClass.the_static_method", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test @@ -3264,7 +3580,7 @@ public void testStaticMethod4() throws ClassHierarchyException, CancelException, "MyClass.the_static_method", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test @@ -3274,7 +3590,7 @@ public void testStaticMethod5() throws ClassHierarchyException, CancelException, "MyClass.the_static_method", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test @@ -3284,7 +3600,7 @@ public void testStaticMethod6() throws ClassHierarchyException, CancelException, "MyClass.the_static_method", 1, 1, - Map.of(2, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test @@ -3294,7 +3610,7 @@ public void testStaticMethod7() throws ClassHierarchyException, CancelException, "MyClass.the_static_method", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test @@ -3304,7 +3620,7 @@ public void testStaticMethod8() throws ClassHierarchyException, CancelException, "MyClass.the_static_method", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test @@ -3314,7 +3630,7 @@ public void testStaticMethod9() throws ClassHierarchyException, CancelException, "MyClass.the_static_method", 2, 2, - Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32), 3, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test @@ -3324,17 +3640,39 @@ public void testStaticMethod10() throws ClassHierarchyException, CancelException "MyClass.the_static_method", 2, 2, - Map.of(2, Set.of(MNIST_INPUT), 3, Set.of(MNIST_INPUT))); + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32), 3, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test public void testStaticMethod11() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_static_method11.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_static_method11.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test public void testStaticMethod12() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_static_method12.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_static_method12.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); + } + + @Test(expected = IllegalStateException.class) + public void testStaticMethod13() throws ClassHierarchyException, CancelException, IOException { + // NOTE: This test will no longer throw an exception once data types other than lists are + // supported for shape arguments. + test( + "tf2_test_static_method13.py", + "MyClass.the_static_method", + 1, + 1, + Map.of(2, Set.of(TENSOR_5_FLOAT32))); + } + + @Test + public void testStaticMethod14() throws ClassHierarchyException, CancelException, IOException { + test( + "tf2_test_static_method14.py", + "MyClass.the_static_method", + 1, + 1, + Map.of(2, Set.of(TENSOR_1_2_FLOAT32))); } @Test @@ -3344,7 +3682,7 @@ public void testClassMethod() throws ClassHierarchyException, CancelException, I "MyClass.the_class_method", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test @@ -3354,44 +3692,44 @@ public void testClassMethod2() throws ClassHierarchyException, CancelException, "MyClass.the_class_method", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test public void testClassMethod3() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_class_method3.py", "MyClass.f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_class_method3.py", "MyClass.f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test public void testClassMethod4() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_class_method4.py", "MyClass.f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_class_method4.py", "MyClass.f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test public void testClassMethod5() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_class_method5.py", "MyClass.f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_class_method5.py", "MyClass.f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test public void testAbstractMethod() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_abstract_method.py", "D.f", 1, 1, Map.of(3, Set.of(MNIST_INPUT))); - test("tf2_test_abstract_method.py", "C.f", 1, 1, Map.of(3, Set.of(MNIST_INPUT))); + test("tf2_test_abstract_method.py", "D.f", 1, 1, Map.of(3, Set.of(SCALAR_TENSOR_OF_INT32))); + test("tf2_test_abstract_method.py", "C.f", 1, 1, Map.of(3, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test public void testAbstractMethod2() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_abstract_method2.py", "D.f", 1, 1, Map.of(3, Set.of(MNIST_INPUT))); - test("tf2_test_abstract_method2.py", "C.f", 1, 1, Map.of(3, Set.of(MNIST_INPUT))); + test("tf2_test_abstract_method2.py", "D.f", 1, 1, Map.of(3, Set.of(SCALAR_TENSOR_OF_INT32))); + test("tf2_test_abstract_method2.py", "C.f", 1, 1, Map.of(3, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test public void testAbstractMethod3() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_abstract_method3.py", "C.f", 1, 1, Map.of(3, Set.of(MNIST_INPUT))); + test("tf2_test_abstract_method3.py", "C.f", 1, 1, Map.of(3, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test public void testDecoratedMethod() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_decorated_method.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorated_method.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } /** Test https://github.com/wala/ML/issues/188. */ @@ -3408,27 +3746,27 @@ public void testDecoratedMethod3() throws ClassHierarchyException, CancelExcepti @Test public void testDecoratedMethod4() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_decorated_method4.py", "raffi", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorated_method4.py", "raffi", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test public void testDecoratedMethod5() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_decorated_method5.py", "raffi", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorated_method5.py", "raffi", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test public void testDecoratedMethod6() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_decorated_method6.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorated_method6.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test public void testDecoratedMethod7() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_decorated_method7.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorated_method7.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test public void testDecoratedMethod8() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_decorated_method8.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorated_method8.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } /** @@ -3450,7 +3788,7 @@ public void testDecoratedMethod10() throws ClassHierarchyException, CancelExcept @Test public void testDecoratedMethod11() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_decorated_method11.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("tf2_test_decorated_method11.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } @Test @@ -3468,12 +3806,42 @@ public void testDecoratedMethod13() throws ClassHierarchyException, CancelExcept @Test public void testDecoratedFunctions() throws ClassHierarchyException, CancelException, IOException { - test("tf2_test_decorated_functions.py", "dummy_fun", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); - test("tf2_test_decorated_functions.py", "dummy_test", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); - test("tf2_test_decorated_functions.py", "test_function", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); - test("tf2_test_decorated_functions.py", "test_function2", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); - test("tf2_test_decorated_functions.py", "test_function3", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); - test("tf2_test_decorated_functions.py", "test_function4", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test( + "tf2_test_decorated_functions.py", + "dummy_fun", + 1, + 1, + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); + test( + "tf2_test_decorated_functions.py", + "dummy_test", + 1, + 1, + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); + test( + "tf2_test_decorated_functions.py", + "test_function", + 1, + 1, + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); + test( + "tf2_test_decorated_functions.py", + "test_function2", + 1, + 1, + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); + test( + "tf2_test_decorated_functions.py", + "test_function3", + 1, + 1, + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); + test( + "tf2_test_decorated_functions.py", + "test_function4", + 1, + 1, + Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } /** Test a pytest with decorators. */ @@ -3504,21 +3872,21 @@ public void testDecoratedFunctions3() "proj48", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** Test a pytest without decorators. This is a "control." */ @Test public void testDecoratedFunctions4() throws ClassHierarchyException, CancelException, IOException { - test("test_decorated_functions2.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("test_decorated_functions2.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } /** Test a pytest with a decorator. */ @Test public void testDecoratedFunctions5() throws ClassHierarchyException, CancelException, IOException { - test("test_decorated_functions3.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("test_decorated_functions3.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } /** @@ -3541,14 +3909,14 @@ public void testDecoratedFunctions6() "proj49", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** Test a Pytest with a decorator without parameters. */ @Test public void testDecoratedFunctions7() throws ClassHierarchyException, CancelException, IOException { - test("test_decorated_functions4.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("test_decorated_functions4.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } /** @@ -3571,7 +3939,7 @@ public void testDecoratedFunctions8() "proj50", 1, 1, - Map.of(3, Set.of(MNIST_INPUT))); + Map.of(3, Set.of(TENSOR_1_2_FLOAT32))); } /** @@ -3580,7 +3948,7 @@ public void testDecoratedFunctions8() @Test public void testDecoratedFunctions9() throws ClassHierarchyException, CancelException, IOException { - test("decorated_function_test.py", "f", 1, 1, Map.of(2, Set.of(MNIST_INPUT))); + test("decorated_function_test.py", "f", 1, 1, Map.of(2, Set.of(SCALAR_TENSOR_OF_INT32))); } /** Test https://github.com/wala/ML/issues/195. */ diff --git a/com.ibm.wala.cast.python.ml/.settings/org.eclipse.jdt.core.prefs b/com.ibm.wala.cast.python.ml/.settings/org.eclipse.jdt.core.prefs index db19bfd66..58cc75197 100644 --- a/com.ibm.wala.cast.python.ml/.settings/org.eclipse.jdt.core.prefs +++ b/com.ibm.wala.cast.python.ml/.settings/org.eclipse.jdt.core.prefs @@ -1,4 +1,5 @@ eclipse.preferences.version=1 +org.eclipse.jdt.core.builder.annotationPath.allLocations=disabled org.eclipse.jdt.core.builder.cleanOutputFolder=clean org.eclipse.jdt.core.builder.duplicateResourceTask=warning org.eclipse.jdt.core.builder.invalidClasspath=abort @@ -9,12 +10,119 @@ org.eclipse.jdt.core.classpath.exclusionPatterns=enabled org.eclipse.jdt.core.classpath.mainOnlyProjectHasTestOnlyDependency=error org.eclipse.jdt.core.classpath.multipleOutputLocations=enabled org.eclipse.jdt.core.classpath.outputOverlappingAnotherSource=error +org.eclipse.jdt.core.compiler.annotation.inheritNullAnnotations=disabled +org.eclipse.jdt.core.compiler.annotation.missingNonNullByDefaultAnnotation=ignore +org.eclipse.jdt.core.compiler.annotation.nonnull=org.eclipse.jdt.annotation.NonNull +org.eclipse.jdt.core.compiler.annotation.nonnull.secondary= +org.eclipse.jdt.core.compiler.annotation.nonnullbydefault=org.eclipse.jdt.annotation.NonNullByDefault +org.eclipse.jdt.core.compiler.annotation.nonnullbydefault.secondary= +org.eclipse.jdt.core.compiler.annotation.notowning=org.eclipse.jdt.annotation.NotOwning +org.eclipse.jdt.core.compiler.annotation.nullable=org.eclipse.jdt.annotation.Nullable +org.eclipse.jdt.core.compiler.annotation.nullable.secondary= +org.eclipse.jdt.core.compiler.annotation.nullanalysis=disabled +org.eclipse.jdt.core.compiler.annotation.owning=org.eclipse.jdt.annotation.Owning +org.eclipse.jdt.core.compiler.annotation.resourceanalysis=disabled org.eclipse.jdt.core.compiler.codegen.targetPlatform=25 org.eclipse.jdt.core.compiler.compliance=25 org.eclipse.jdt.core.compiler.maxProblemPerUnit=100 +org.eclipse.jdt.core.compiler.problem.APILeak=warning +org.eclipse.jdt.core.compiler.problem.annotatedTypeArgumentToUnannotated=info +org.eclipse.jdt.core.compiler.problem.annotationSuperInterface=warning +org.eclipse.jdt.core.compiler.problem.autoboxing=ignore +org.eclipse.jdt.core.compiler.problem.comparingIdentical=warning +org.eclipse.jdt.core.compiler.problem.deadCode=warning +org.eclipse.jdt.core.compiler.problem.deprecation=warning +org.eclipse.jdt.core.compiler.problem.deprecationInDeprecatedCode=disabled +org.eclipse.jdt.core.compiler.problem.deprecationWhenOverridingDeprecatedMethod=disabled +org.eclipse.jdt.core.compiler.problem.discouragedReference=warning +org.eclipse.jdt.core.compiler.problem.emptyStatement=ignore org.eclipse.jdt.core.compiler.problem.enablePreviewFeatures=disabled +org.eclipse.jdt.core.compiler.problem.explicitlyClosedAutoCloseable=ignore +org.eclipse.jdt.core.compiler.problem.fallthroughCase=ignore +org.eclipse.jdt.core.compiler.problem.fatalOptionalError=disabled +org.eclipse.jdt.core.compiler.problem.fieldHiding=ignore +org.eclipse.jdt.core.compiler.problem.finalParameterBound=warning +org.eclipse.jdt.core.compiler.problem.finallyBlockNotCompletingNormally=warning org.eclipse.jdt.core.compiler.problem.forbiddenReference=warning +org.eclipse.jdt.core.compiler.problem.hiddenCatchBlock=warning +org.eclipse.jdt.core.compiler.problem.includeNullInfoFromAsserts=disabled +org.eclipse.jdt.core.compiler.problem.incompatibleNonInheritedInterfaceMethod=warning +org.eclipse.jdt.core.compiler.problem.incompatibleOwningContract=warning +org.eclipse.jdt.core.compiler.problem.incompleteEnumSwitch=warning +org.eclipse.jdt.core.compiler.problem.indirectStaticAccess=ignore +org.eclipse.jdt.core.compiler.problem.insufficientResourceAnalysis=warning +org.eclipse.jdt.core.compiler.problem.localVariableHiding=ignore +org.eclipse.jdt.core.compiler.problem.methodWithConstructorName=warning +org.eclipse.jdt.core.compiler.problem.missingDefaultCase=ignore +org.eclipse.jdt.core.compiler.problem.missingDeprecatedAnnotation=ignore +org.eclipse.jdt.core.compiler.problem.missingEnumCaseDespiteDefault=disabled +org.eclipse.jdt.core.compiler.problem.missingHashCodeMethod=ignore +org.eclipse.jdt.core.compiler.problem.missingOverrideAnnotation=ignore +org.eclipse.jdt.core.compiler.problem.missingOverrideAnnotationForInterfaceMethodImplementation=enabled +org.eclipse.jdt.core.compiler.problem.missingSerialVersion=warning +org.eclipse.jdt.core.compiler.problem.missingSynchronizedOnInheritedMethod=ignore +org.eclipse.jdt.core.compiler.problem.noEffectAssignment=warning +org.eclipse.jdt.core.compiler.problem.noImplicitStringConversion=warning +org.eclipse.jdt.core.compiler.problem.nonExternalizedStringLiteral=ignore +org.eclipse.jdt.core.compiler.problem.nonnullParameterAnnotationDropped=warning +org.eclipse.jdt.core.compiler.problem.nonnullTypeVariableFromLegacyInvocation=warning +org.eclipse.jdt.core.compiler.problem.nullAnnotationInferenceConflict=error +org.eclipse.jdt.core.compiler.problem.nullReference=warning +org.eclipse.jdt.core.compiler.problem.nullSpecViolation=error +org.eclipse.jdt.core.compiler.problem.nullUncheckedConversion=warning +org.eclipse.jdt.core.compiler.problem.overridingPackageDefaultMethod=warning +org.eclipse.jdt.core.compiler.problem.parameterAssignment=ignore +org.eclipse.jdt.core.compiler.problem.pessimisticNullAnalysisForFreeTypeVariables=warning +org.eclipse.jdt.core.compiler.problem.possibleAccidentalBooleanAssignment=ignore +org.eclipse.jdt.core.compiler.problem.potentialNullReference=ignore +org.eclipse.jdt.core.compiler.problem.potentiallyUnclosedCloseable=ignore +org.eclipse.jdt.core.compiler.problem.rawTypeReference=warning +org.eclipse.jdt.core.compiler.problem.redundantNullAnnotation=warning +org.eclipse.jdt.core.compiler.problem.redundantNullCheck=ignore +org.eclipse.jdt.core.compiler.problem.redundantSpecificationOfTypeArguments=ignore +org.eclipse.jdt.core.compiler.problem.redundantSuperinterface=ignore +org.eclipse.jdt.core.compiler.problem.reportMethodCanBePotentiallyStatic=ignore +org.eclipse.jdt.core.compiler.problem.reportMethodCanBeStatic=ignore org.eclipse.jdt.core.compiler.problem.reportPreviewFeatures=ignore +org.eclipse.jdt.core.compiler.problem.specialParameterHidingField=disabled +org.eclipse.jdt.core.compiler.problem.staticAccessReceiver=warning +org.eclipse.jdt.core.compiler.problem.suppressOptionalErrors=disabled +org.eclipse.jdt.core.compiler.problem.suppressWarnings=enabled +org.eclipse.jdt.core.compiler.problem.suppressWarningsNotFullyAnalysed=info +org.eclipse.jdt.core.compiler.problem.syntacticNullAnalysisForFields=disabled +org.eclipse.jdt.core.compiler.problem.syntheticAccessEmulation=ignore +org.eclipse.jdt.core.compiler.problem.terminalDeprecation=warning +org.eclipse.jdt.core.compiler.problem.typeParameterHiding=warning +org.eclipse.jdt.core.compiler.problem.unavoidableGenericTypeProblems=enabled +org.eclipse.jdt.core.compiler.problem.uncheckedTypeOperation=warning +org.eclipse.jdt.core.compiler.problem.unclosedCloseable=warning +org.eclipse.jdt.core.compiler.problem.undocumentedEmptyBlock=ignore +org.eclipse.jdt.core.compiler.problem.unhandledWarningToken=warning +org.eclipse.jdt.core.compiler.problem.unlikelyCollectionMethodArgumentType=warning +org.eclipse.jdt.core.compiler.problem.unlikelyCollectionMethodArgumentTypeStrict=disabled +org.eclipse.jdt.core.compiler.problem.unlikelyEqualsArgumentType=info +org.eclipse.jdt.core.compiler.problem.unnecessaryElse=ignore +org.eclipse.jdt.core.compiler.problem.unnecessaryTypeCheck=ignore +org.eclipse.jdt.core.compiler.problem.unqualifiedFieldAccess=ignore +org.eclipse.jdt.core.compiler.problem.unstableAutoModuleName=warning +org.eclipse.jdt.core.compiler.problem.unusedDeclaredThrownException=ignore +org.eclipse.jdt.core.compiler.problem.unusedDeclaredThrownExceptionExemptExceptionAndThrowable=enabled +org.eclipse.jdt.core.compiler.problem.unusedDeclaredThrownExceptionIncludeDocCommentReference=enabled +org.eclipse.jdt.core.compiler.problem.unusedDeclaredThrownExceptionWhenOverriding=disabled +org.eclipse.jdt.core.compiler.problem.unusedExceptionParameter=ignore +org.eclipse.jdt.core.compiler.problem.unusedImport=warning +org.eclipse.jdt.core.compiler.problem.unusedLabel=warning +org.eclipse.jdt.core.compiler.problem.unusedLambdaParameter=warning +org.eclipse.jdt.core.compiler.problem.unusedLocal=warning +org.eclipse.jdt.core.compiler.problem.unusedObjectAllocation=ignore +org.eclipse.jdt.core.compiler.problem.unusedParameter=ignore +org.eclipse.jdt.core.compiler.problem.unusedParameterIncludeDocCommentReference=enabled +org.eclipse.jdt.core.compiler.problem.unusedParameterWhenImplementingAbstract=disabled +org.eclipse.jdt.core.compiler.problem.unusedParameterWhenOverridingConcrete=disabled +org.eclipse.jdt.core.compiler.problem.unusedPrivateMember=warning +org.eclipse.jdt.core.compiler.problem.unusedTypeParameter=ignore +org.eclipse.jdt.core.compiler.problem.unusedWarningToken=info +org.eclipse.jdt.core.compiler.problem.varargsArgumentNeedCast=warning org.eclipse.jdt.core.compiler.release=enabled org.eclipse.jdt.core.compiler.source=25 org.eclipse.jdt.core.incompatibleJDKLevel=ignore diff --git a/com.ibm.wala.cast.python.ml/data/tensorflow.xml b/com.ibm.wala.cast.python.ml/data/tensorflow.xml index da59452d1..bac064e68 100644 --- a/com.ibm.wala.cast.python.ml/data/tensorflow.xml +++ b/com.ibm.wala.cast.python.ml/data/tensorflow.xml @@ -154,6 +154,9 @@ + + + @@ -218,6 +221,14 @@ + + + + + + + + @@ -315,6 +326,17 @@ + + + + + + + + + + + diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/analysis/TensorTypeAnalysis.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/analysis/TensorTypeAnalysis.java index 681cf7361..ce6437ef7 100644 --- a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/analysis/TensorTypeAnalysis.java +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/analysis/TensorTypeAnalysis.java @@ -447,11 +447,11 @@ public String toString() { }; } - private final Map init; + private final Map> init; public TensorTypeAnalysis( Graph G, - Map init, + Map> init, Map reshapeTypes, Map set_shapes, Set conv2ds, @@ -480,7 +480,8 @@ protected TensorVariable[] makeStmtRHS(int size) { protected void initializeVariables() { super.initializeVariables(); for (PointsToSetVariable src : init.keySet()) { - getOut(src).state.add(init.get(src)); + Set tensorTypes = init.get(src); + getOut(src).state.addAll(tensorTypes); } } diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Constant.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Constant.java new file mode 100644 index 000000000..b034bfce8 --- /dev/null +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Constant.java @@ -0,0 +1,60 @@ +package com.ibm.wala.cast.python.ml.client; + +import com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType; +import com.ibm.wala.cast.python.ml.types.TensorType.Dimension; +import com.ibm.wala.ipa.callgraph.CGNode; +import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable; +import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder; +import java.util.EnumSet; +import java.util.List; +import java.util.Set; + +/** + * Represents a call to the constant() function in TensorFlow. + * + * @see constant(). + * @author Raffi Khatchadourian + */ +public class Constant extends TensorGenerator { + + private static final int VALUE_NUMBER_FOR_VALUE_ARGUMENT = 2; + + private static final int VALUE_NUMBER_FOR_DTYPE_ARGUMENT = 3; + + private static final int VALUE_NUMBER_FOR_SHAPE_ARGUMENT = 4; + + public Constant(PointsToSetVariable source, CGNode node) { + super(source, node); + } + + @Override + protected Set>> getDefaultShapes(PropagationCallGraphBuilder builder) { + // If the shape argument is not specified, then the shape is inferred from the shape of value. + // TODO: Handle keyword arguments. + return getShapes(builder, this.getValueNumberForValueArgument()); + } + + @Override + protected EnumSet getDefaultDTypes(PropagationCallGraphBuilder builder) { + // If the dtype argument is not specified, then the type is inferred from the type of value. + // TODO: Handle keyword arguments. + return getDTypes(builder, this.getValueNumberForValueArgument()); + } + + @Override + protected int getValueNumberForDTypeArgument() { + return VALUE_NUMBER_FOR_DTYPE_ARGUMENT; + } + + protected int getValueNumberForValueArgument() { + return VALUE_NUMBER_FOR_VALUE_ARGUMENT; + } + + @Override + protected int getValueNumberForShapeArgument() { + // Shapes can also be specified as an explicit argument. Here, we examine the third explicit + // argument (recall that the first argument is implicit and corresponds to the called + // function's name). + return VALUE_NUMBER_FOR_SHAPE_ARGUMENT; + } +} diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Ones.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Ones.java new file mode 100644 index 000000000..df3d1a263 --- /dev/null +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Ones.java @@ -0,0 +1,53 @@ +package com.ibm.wala.cast.python.ml.client; + +import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType.FLOAT32; + +import com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType; +import com.ibm.wala.cast.python.ml.types.TensorType.Dimension; +import com.ibm.wala.ipa.callgraph.CGNode; +import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable; +import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder; +import java.util.EnumSet; +import java.util.List; +import java.util.Set; + +/** + * A generator for tensors created by the `ones()` function in TensorFlow. + * + * @see TensorFlow ones() API. + * @author Raffi Khatchadourian + */ +public class Ones extends TensorGenerator { + + private static final int VALUE_NUMBER_FOR_SHAPE_ARGUMENT = 2; + + private static final int VALUE_NUMBER_FOR_DTYPE_ARGUMENT = 3; + + public Ones(PointsToSetVariable source, CGNode node) { + super(source, node); + } + + @Override + protected EnumSet getDefaultDTypes(PropagationCallGraphBuilder builder) { + LOGGER.info( + "No dtype specified for source: " + source + ". Using default dtype of: " + FLOAT32 + " ."); + + // Use the default dtype of float32. + return EnumSet.of(FLOAT32); + } + + @Override + protected Set>> getDefaultShapes(PropagationCallGraphBuilder builder) { + throw new UnsupportedOperationException("Shape is mandatory and must be provided explicitly."); + } + + @Override + protected int getValueNumberForShapeArgument() { + return VALUE_NUMBER_FOR_SHAPE_ARGUMENT; + } + + @Override + protected int getValueNumberForDTypeArgument() { + return VALUE_NUMBER_FOR_DTYPE_ARGUMENT; + } +} diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java index 710b3b44a..1bd29592a 100644 --- a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java @@ -667,10 +667,9 @@ public TensorTypeAnalysis performAnalysis(PropagationCallGraphBuilder builder) Set sources = getDataflowSources(dataflow, builder.getCallGraph(), builder.getPointerAnalysis()); - TensorType mnistData = TensorType.mnistInput(); - Map init = HashMapFactory.make(); + Map> init = HashMapFactory.make(); - for (PointsToSetVariable v : sources) init.put(v, mnistData); + for (PointsToSetVariable v : sources) init.put(v, getTensorTypes(v, builder)); Map placeholders = null; try { @@ -681,7 +680,7 @@ public TensorTypeAnalysis performAnalysis(PropagationCallGraphBuilder builder) logger.fine("Placeholders: " + placeholders); for (Map.Entry e : placeholders.entrySet()) - init.put(e.getKey(), e.getValue()); + init.put(e.getKey(), Set.of(e.getValue())); Map setCalls = HashMapFactory.make(); Map set_shapes = getShapeSourceCalls(set_shape, builder, 1); @@ -722,6 +721,24 @@ public TensorTypeAnalysis performAnalysis(PropagationCallGraphBuilder builder) return tt; } + /** + * Returns the set of possible {@link TensorType}s that the given {@link PointsToSetVariable} can + * take on. + * + * @param source The dataflow source to analyze. + * @param builder The {@link PropagationCallGraphBuilder} used to build the call graph and pointer + * analysis. + * @return A set of {@link TensorType}s that the given {@link PointsToSetVariable} can take on. + * Empty set is returned if the possible tensor types cannot be determined. + */ + private Set getTensorTypes( + PointsToSetVariable source, PropagationCallGraphBuilder builder) { + logger.info("Getting tensor types for source: " + source + "."); + + TensorGenerator generator = TensorGeneratorFactory.getGenerator(source); + return generator.getTensorTypes(builder); + } + private Map handleShapeSourceOp( PropagationCallGraphBuilder builder, Graph dataflow, diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Range.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Range.java new file mode 100644 index 000000000..3c6647827 --- /dev/null +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Range.java @@ -0,0 +1,166 @@ +package com.ibm.wala.cast.python.ml.client; + +import static java.util.function.Function.identity; + +import com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType; +import com.ibm.wala.cast.python.ml.types.TensorType.Dimension; +import com.ibm.wala.cast.python.ml.types.TensorType.NumericDim; +import com.ibm.wala.ipa.callgraph.CGNode; +import com.ibm.wala.ipa.callgraph.propagation.ConstantKey; +import com.ibm.wala.ipa.callgraph.propagation.InstanceKey; +import com.ibm.wala.ipa.callgraph.propagation.PointerAnalysis; +import com.ibm.wala.ipa.callgraph.propagation.PointerKey; +import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable; +import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder; +import com.ibm.wala.util.collections.HashSetFactory; +import com.ibm.wala.util.debug.UnimplementedError; +import com.ibm.wala.util.intset.OrdinalSet; +import java.util.EnumSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.StreamSupport; + +/** + * A representation of the TensorFlow range operation. + * + *

This class is used to generate a tensor that contains a sequence of numbers, similar to the + * range function in Python. + * + * @see TensorFlow range + * documentation. + * @author Raffi Khatchadourian + */ +public class Range extends TensorGenerator { + + public Range(PointsToSetVariable source, CGNode node) { + super(source, node); + } + + @Override + protected Set>> getShapes(PropagationCallGraphBuilder builder) { + Set>> ret = HashSetFactory.make(); + PointerAnalysis pointerAnalysis = builder.getPointerAnalysis(); + + // The shape of a range tensor is always a 1D tensor with the length equal to the number of + // elements in the range. + // For example, `tf.range(5)` produces a tensor with shape (5,). + + double start = 0; // Default start value. + double limit = start; // Default limit value. + double delta = 1; // Default step value. + + // There are two versions of the `range` function: + // 1. `tf.range(limit)` - generates a range from 0 to limit + // 2. `tf.range(start, limit, delta)` - generates a range from start to limit with a step of + // delta. + + // First, decide which version of the `range` function is being called based on the number of + // numeric arguments.j + // TODO: Handle keyword arguments. + + int numOfNumericPositionalArgs = getNumberOfNumericPositionalArgs(pointerAnalysis); + + if (numOfNumericPositionalArgs == 1) { + // it must *just* be `limit`. + PointerKey limitPK = pointerAnalysis.getHeapModel().getPointerKeyForLocal(node, 2); + OrdinalSet limitPointsToSet = pointerAnalysis.getPointsToSet(limitPK); + + assert !limitPointsToSet.isEmpty() : "Expected a non-empty points-to set for limit."; + + for (InstanceKey limitIK : limitPointsToSet) + if (limitIK instanceof ConstantKey) { + limit = ((Number) ((ConstantKey) limitIK).getValue()).doubleValue(); + int shape = (int) Math.ceil((limit - start) / delta); + ret.add(List.of(new NumericDim(shape))); // Add the shape as a 1D tensor. + } else + throw new IllegalStateException( + "Expected a " + ConstantKey.class + " for limit, but got: " + limitIK + "."); + } else + // TODO: Handle more cases. + throw new UnimplementedError( + "Currently cannot handle more than one numeric positional argument for range()."); + + return ret; + } + + private int getNumberOfNumericPositionalArgs(PointerAnalysis pointerAnalysis) { + int ret = 0; + int explicitArgumentIndex = 2; // Start from the first explicit argument. + + while (true) { + PointerKey pk = + pointerAnalysis.getHeapModel().getPointerKeyForLocal(node, explicitArgumentIndex); + OrdinalSet pointsToSet = pointerAnalysis.getPointsToSet(pk); + + if (pointsToSet.isEmpty()) break; // End of positional arguments. + + // Check if the pointsToSet contains numeric values. + boolean allNumeric = + StreamSupport.stream(pointsToSet.spliterator(), false) + .filter(ik -> ik instanceof ConstantKey) + .map(ik -> (ConstantKey) ik) + .map(ConstantKey::getValue) + .allMatch(v -> v instanceof Number); // Check if all values are numeric. + + if (!allNumeric) break; // There's some argument that is not numeric for this argument. + + ret++; // Increment the count of numeric positional arguments. + explicitArgumentIndex++; // Move to the next explicit argument. + } + + return ret; + } + + @Override + protected EnumSet getDefaultDTypes(PropagationCallGraphBuilder builder) { + // The dtype of the resulting tensor is inferred from the inputs unless it is provided + // explicitly. + + // TODO: Handle keyword arguments. + int numberOfNumericPositionalArgs = + getNumberOfNumericPositionalArgs(builder.getPointerAnalysis()); + + EnumSet types = + IntStream.range(0, numberOfNumericPositionalArgs) + .map(i -> i + 2) // Positional arguments start at index 2. + .mapToObj(val -> getDTypes(builder, val).stream()) + .flatMap(identity()) + .distinct() + .collect(Collectors.toCollection(() -> EnumSet.noneOf(DType.class))); + + // FIXME: We can't tell the difference here between varying dtypes in a single call and that of + // possible varying dtypes values from the points-to graph. Below, we are treating it as these + // values lie in a single call, but that may not be the case. + + if (types.contains(DType.FLOAT64)) return EnumSet.of(DType.FLOAT64); + else if (types.contains(DType.FLOAT32)) return EnumSet.of(DType.FLOAT32); + else if (types.contains(DType.INT64)) return EnumSet.of(DType.INT64); + else if (types.contains(DType.INT32)) return EnumSet.of(DType.INT32); + + throw new IllegalStateException( + "Expected at least one numeric dtype for range(), but got: " + types + "."); + } + + @Override + protected Set>> getDefaultShapes(PropagationCallGraphBuilder builder) { + throw new UnsupportedOperationException( + "Shapes for range() are derived from mandatory numeric arguments and must be provided" + + " explicitly."); + } + + @Override + protected int getValueNumberForShapeArgument() { + throw new UnsupportedOperationException( + "Range does not have a shape argument. Its shape is derived from the numeric arguments."); + } + + @Override + protected int getValueNumberForDTypeArgument() { + // TODO: We need a value number for the dtype argument. Also, that value number can differ + // depending on the version of the `range` function being called. + + return -1; // Positional dtype argument for range() is not yet implemented. + } +} diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/TensorGenerator.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/TensorGenerator.java new file mode 100644 index 000000000..c619867cf --- /dev/null +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/TensorGenerator.java @@ -0,0 +1,581 @@ +package com.ibm.wala.cast.python.ml.client; + +import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType.FLOAT32; +import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType.INT32; +import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType.STRING; +import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.FIELD_REFERENCE_TO_DTYPE; +import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.TENSORFLOW; +import static com.ibm.wala.cast.python.types.PythonTypes.list; +import static com.ibm.wala.cast.python.util.Util.getAllocationSiteInNode; +import static com.ibm.wala.core.util.strings.Atom.findOrCreateAsciiAtom; +import static com.ibm.wala.ipa.callgraph.propagation.cfa.CallStringContextSelector.CALL_STRING; +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; + +import com.ibm.wala.cast.ipa.callgraph.AstPointerKeyFactory; +import com.ibm.wala.cast.python.ml.types.TensorFlowTypes; +import com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType; +import com.ibm.wala.cast.python.ml.types.TensorType; +import com.ibm.wala.cast.python.ml.types.TensorType.Dimension; +import com.ibm.wala.cast.python.ml.types.TensorType.NumericDim; +import com.ibm.wala.cast.python.types.PythonTypes; +import com.ibm.wala.classLoader.IClass; +import com.ibm.wala.classLoader.IField; +import com.ibm.wala.classLoader.IMethod; +import com.ibm.wala.classLoader.NewSiteReference; +import com.ibm.wala.ipa.callgraph.CGNode; +import com.ibm.wala.ipa.callgraph.ContextItem; +import com.ibm.wala.ipa.callgraph.propagation.AllocationSiteInNode; +import com.ibm.wala.ipa.callgraph.propagation.ConstantKey; +import com.ibm.wala.ipa.callgraph.propagation.InstanceKey; +import com.ibm.wala.ipa.callgraph.propagation.PointerAnalysis; +import com.ibm.wala.ipa.callgraph.propagation.PointerKey; +import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable; +import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder; +import com.ibm.wala.ipa.callgraph.propagation.cfa.CallString; +import com.ibm.wala.types.Descriptor; +import com.ibm.wala.types.FieldReference; +import com.ibm.wala.types.MethodReference; +import com.ibm.wala.types.TypeReference; +import com.ibm.wala.util.collections.HashSetFactory; +import com.ibm.wala.util.intset.OrdinalSet; +import java.util.ArrayList; +import java.util.EnumSet; +import java.util.List; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.Set; +import java.util.logging.Logger; + +public abstract class TensorGenerator { + + protected static final Logger LOGGER = Logger.getLogger(TensorGenerator.class.getName()); + + private static final MethodReference IMPORT = + MethodReference.findOrCreate( + TENSORFLOW, + findOrCreateAsciiAtom("import"), + Descriptor.findOrCreate(null, TENSORFLOW.getName())); + + protected PointsToSetVariable source; + + protected CGNode node; + + public TensorGenerator(PointsToSetVariable source, CGNode node) { + this.source = source; + this.node = node; + } + + public Set getTensorTypes(PropagationCallGraphBuilder builder) { + Set>> shapes = getShapes(builder); + EnumSet dTypes = getDTypes(builder); + + Set ret = HashSetFactory.make(); + + // Create a tensor type for each possible shape and dtype combination. + for (List> dimensionList : shapes) + for (DType dtype : dTypes) ret.add(new TensorType(dtype.name().toLowerCase(), dimensionList)); + + return ret; + } + + /** + * Returns the possible shapes of the tensor returned by this generator. + * + * @param builder The {@link PropagationCallGraphBuilder} used to build the call graph. + * @param pointsToSet The points-to set of the shape argument. + * @return A set of possible shapes of the tensor returned by this generator. + */ + protected Set>> getShapesFromShapeArgument( + PropagationCallGraphBuilder builder, Iterable pointsToSet) { + Set>> ret = HashSetFactory.make(); + PointerAnalysis pointerAnalysis = builder.getPointerAnalysis(); + + for (InstanceKey instanceKey : pointsToSet) { + AllocationSiteInNode asin = getAllocationSiteInNode(instanceKey); + TypeReference reference = asin.getConcreteType().getReference(); + + if (reference.equals(list)) { // TODO: This can also be a tuple of tensors. + // We have a list of integers that represent the shape. + OrdinalSet objectCatalogPointsToSet = + pointerAnalysis.getPointsToSet( + ((AstPointerKeyFactory) builder.getPointerKeyFactory()) + .getPointerKeyForObjectCatalog(asin)); + + // We expect the object catalog to contain a list of integers. Each element in the array + // corresponds to the set of possible dimensions for that index. + @SuppressWarnings({"unchecked", "rawtypes"}) + Set>[] possibleDimensions = new Set[objectCatalogPointsToSet.size()]; + + for (InstanceKey catalogIK : objectCatalogPointsToSet) { + ConstantKey constantKey = (ConstantKey) catalogIK; + Object constantKeyValue = constantKey.getValue(); + + Integer fieldIndex = (Integer) constantKeyValue; + + FieldReference subscript = + FieldReference.findOrCreate( + PythonTypes.Root, findOrCreateAsciiAtom(fieldIndex.toString()), PythonTypes.Root); + + IField f = builder.getClassHierarchy().resolveField(subscript); + LOGGER.fine("Found field: " + f); + + // We can now get the pointer key for the instance field. + PointerKey pointerKeyForInstanceField = builder.getPointerKeyForInstanceField(asin, f); + LOGGER.fine("Found pointer key for instance field: " + pointerKeyForInstanceField + "."); + + // Get the points-to set for the instance field. + OrdinalSet instanceFieldPointsToSet = + pointerAnalysis.getPointsToSet(pointerKeyForInstanceField); + LOGGER.fine("Points-to set for instance field: " + instanceFieldPointsToSet + "."); + + // If the instance field points to a constant, we can use it as the shape. + // TODO: Is it possible to also do it for (simple) expressions? + Set> tensorDimensions = HashSetFactory.make(); + + for (InstanceKey instanceFieldIK : instanceFieldPointsToSet) { + if (instanceFieldIK instanceof ConstantKey) { + // We have a constant key. + ConstantKey instanceFieldConstant = (ConstantKey) instanceFieldIK; + Object instanceFieldValue = instanceFieldConstant.getValue(); + + // We have a shape value. + Long shapeValue = (Long) instanceFieldValue; + LOGGER.fine( + "Found shape value: " + shapeValue + " for " + source.getPointerKey() + "."); + + Dimension dimension = new NumericDim(shapeValue.intValue()); + + LOGGER.fine("Adding dimension: " + dimension + "."); + tensorDimensions.add(dimension); + } else + throw new IllegalStateException( + "Expected a constant key for instance field: " + + pointerKeyForInstanceField + + ", but got: " + + instanceFieldIK + + "."); + } + + LOGGER.info( + "Found possible shape dimensions: " + + tensorDimensions + + " for field: " + + pointerKeyForInstanceField + + " for source: " + + source + + "."); + + // Add the shape dimensions. + assert possibleDimensions[fieldIndex] == null + : "Duplicate field index: " + + fieldIndex + + " in object catalog: " + + objectCatalogPointsToSet + + "."; + + possibleDimensions[fieldIndex] = tensorDimensions; + LOGGER.fine( + "Added shape dimensions: " + + tensorDimensions + + " for field index: " + + fieldIndex + + "."); + } + + for (int i = 0; i < possibleDimensions.length; i++) + for (Dimension iDim : possibleDimensions[i]) { + @SuppressWarnings({"unchecked", "rawtypes"}) + Dimension[] dimensions = new Dimension[possibleDimensions.length]; + + dimensions[i] = iDim; + + for (int j = 0; j < possibleDimensions.length; j++) + if (i != j) + for (Dimension jDim : possibleDimensions[j]) dimensions[j] = jDim; + + ret.add(asList(dimensions)); + } + } else + throw new IllegalStateException( + "Expected a " + PythonTypes.list + " for the shape, but got: " + reference + "."); + } + + return ret; + } + + /** + * Returns the default shapes if no shape argument is provided. + * + * @param builder The {@link PropagationCallGraphBuilder} used to build the call graph. + * @return The default shapes if no shape argument is provided. + */ + protected abstract Set>> getDefaultShapes(PropagationCallGraphBuilder builder); + + /** + * Returns the value number for the shape argument in the function call. A return value of a + * number less than or equal to zero signifies that there is no shape parameter. + * + * @return The value number for the shape argument in the function call. May return a number less + * than or equal to 0 if there is no shape parameter. + */ + protected abstract int getValueNumberForShapeArgument(); + + /** + * Returns the possible shapes of the tensor returned by this generator. + * + * @param builder The {@link PropagationCallGraphBuilder} used to build the call graph. + * @return a set of shapes, where each shape is represented as a list of dimensions + */ + protected Set>> getShapes(PropagationCallGraphBuilder builder) { + PointerAnalysis pointerAnalysis = builder.getPointerAnalysis(); + + // Get the shape from the explicit argument. + // FIXME: Handle keyword arguments. + int shapeArgValueNum = this.getValueNumberForShapeArgument(); + OrdinalSet pointsToSet = null; + + if (shapeArgValueNum > 0) { + PointerKey pointerKey = + pointerAnalysis.getHeapModel().getPointerKeyForLocal(node, shapeArgValueNum); + pointsToSet = pointerAnalysis.getPointsToSet(pointerKey); + } + + // If the argument shape is not specified. + if (pointsToSet == null || pointsToSet.isEmpty()) return getDefaultShapes(builder); + else + // The shape points-to set is non-empty, meaning that the shape was explicitly set. + return getShapesFromShapeArgument(builder, pointsToSet); + } + + /** + * Returns the possible shapes of the tensor returned by this generator. The shape is inferred + * from the argument represented by the given value number. + * + * @param builder The {@link PropagationCallGraphBuilder} used to build the call graph. + * @param valueNumber The value number of the argument from which to infer the shape. + * @return A set of possible shapes of the tensor returned by this generator. + */ + protected Set>> getShapes( + PropagationCallGraphBuilder builder, int valueNumber) { + PointerAnalysis pointerAnalysis = builder.getPointerAnalysis(); + PointerKey valuePK = pointerAnalysis.getHeapModel().getPointerKeyForLocal(node, valueNumber); + OrdinalSet valuePointsToSet = pointerAnalysis.getPointsToSet(valuePK); + return getShapesOfValue(builder, valuePointsToSet); + } + + /** + * Returns the possible shapes of the tensor returned by this generator. + * + * @param builder The {@link PropagationCallGraphBuilder} used to build the call graph. + * @param pointsToSet The points-to set of the value from which the shape will be derived. + * @return A set of possible shapes of the tensor returned by this generator. + */ + private Set>> getShapesOfValue( + PropagationCallGraphBuilder builder, OrdinalSet valuePointsToSet) { + Set>> ret = HashSetFactory.make(); + PointerAnalysis pointerAnalysis = builder.getPointerAnalysis(); + + for (InstanceKey valueIK : valuePointsToSet) + if (valueIK instanceof ConstantKey) ret.add(emptyList()); // Scalar value. + else if (valueIK instanceof AllocationSiteInNode) { + AllocationSiteInNode asin = getAllocationSiteInNode(valueIK); + TypeReference reference = asin.getConcreteType().getReference(); + + if (reference.equals(list)) { + OrdinalSet objectCatalogPointsToSet = + pointerAnalysis.getPointsToSet( + ((AstPointerKeyFactory) builder.getPointerKeyFactory()) + .getPointerKeyForObjectCatalog(asin)); + + LOGGER.fine( + "The object catalog points-to set size is: " + objectCatalogPointsToSet.size() + "."); + + for (InstanceKey catalogIK : objectCatalogPointsToSet) { + ConstantKey constantKey = (ConstantKey) catalogIK; + Object constantKeyValue = constantKey.getValue(); + + Integer fieldIndex = (Integer) constantKeyValue; + + FieldReference subscript = + FieldReference.findOrCreate( + PythonTypes.Root, + findOrCreateAsciiAtom(fieldIndex.toString()), + PythonTypes.Root); + + IField f = builder.getClassHierarchy().resolveField(subscript); + LOGGER.fine("Found field: " + f); + + PointerKey pointerKeyForInstanceField = builder.getPointerKeyForInstanceField(asin, f); + LOGGER.fine( + "Found pointer key for instance field: " + pointerKeyForInstanceField + "."); + + OrdinalSet instanceFieldPointsToSet = + pointerAnalysis.getPointsToSet(pointerKeyForInstanceField); + LOGGER.fine("Points-to set for instance field: " + instanceFieldPointsToSet + "."); + + Set>> shapesOfField = + getShapesOfValue(builder, instanceFieldPointsToSet); + + for (List> shapeList : shapesOfField) { + List> shape = new ArrayList<>(); + + shape.add(new NumericDim(objectCatalogPointsToSet.size())); + shape.addAll(shapeList); + + ret.add(shape); + } + } + } else throw new IllegalStateException("Unknown type reference: " + reference + "."); + } else + throw new IllegalStateException( + "Expected a " + ConstantKey.class + " for value, but got: " + valueIK + "."); + + return ret; + } + + /** + * Returns the possible dtypes of the tensor returned by this generator. + * + * @param builder The {@link PropagationCallGraphBuilder} used to build the call graph. + * @param pointsToSet The points-to set of the dtype argument, which is expected to be a set of + * type literals. + * @return A set of possible dtypes of the tensor returned by this generator. + */ + protected EnumSet getDTypesFromDTypeArgument( + PropagationCallGraphBuilder builder, Iterable pointsToSet) { + EnumSet ret = EnumSet.noneOf(DType.class); + PointerAnalysis pointerAnalysis = builder.getPointerAnalysis(); + + for (InstanceKey instanceKey : pointsToSet) { + IClass concreteType = instanceKey.getConcreteType(); + TypeReference typeReference = concreteType.getReference(); + + if (typeReference.equals(TensorFlowTypes.D_TYPE)) { + // we have a dtype. + // let's see if it's float32. + Set importNodes = builder.getCallGraph().getNodes(IMPORT); + + // find the import node from this file. + Optional importNode = + importNodes.stream() + .filter( + in -> { + ContextItem contextItem = in.getContext().get(CALL_STRING); + CallString cs = (CallString) contextItem; + + // We expect the first method in the call string to be the import. + assert cs.getMethods().length == 1 + : "Expected a single method in the call string, but got: " + + cs.getMethods().length + + " for node: " + + in; + + IMethod method = cs.getMethods()[0]; + + CallString nodeCS = (CallString) node.getContext().get(CALL_STRING); + + // We expect the first method in the call string to be the import. + assert nodeCS.getMethods().length == 1 + : "Expected a single method in the call string, but got: " + + nodeCS.getMethods().length + + " for node: " + + in; + + return method.equals(nodeCS.getMethods()[0]); + }) + .findFirst(); + + InstanceKey tensorFlowIK = + pointerAnalysis + .getHeapModel() + .getInstanceKeyForAllocation( + importNode.get(), NewSiteReference.make(0, TENSORFLOW)); + + // Check dtype literals. + boolean found = false; + + for (Entry entry : FIELD_REFERENCE_TO_DTYPE.entrySet()) { + FieldReference fieldRef = entry.getKey(); + DType dtype = entry.getValue(); + IField field = builder.getClassHierarchy().resolveField(fieldRef); + + PointerKey pk = + pointerAnalysis.getHeapModel().getPointerKeyForInstanceField(tensorFlowIK, field); + + for (InstanceKey ik : pointerAnalysis.getPointsToSet(pk)) + if (ik.equals(instanceKey)) { + ret.add(dtype); + LOGGER.info( + "Found dtype: " + + dtype + + " for source: " + + source + + " from dType: " + + instanceKey + + "."); + found = true; + } + } + + if (!found) throw new IllegalStateException("Unknown dtype: " + instanceKey + "."); + } else + throw new IllegalStateException( + "Expected a " + + TensorFlowTypes.D_TYPE + + " for the dtype, but got: " + + typeReference + + "."); + } + + return ret; + } + + /** + * Returns a set of possible dtypes of the tensor returned by this generator when an explicit + * dtype isn't provided as an argument. + * + * @param builder The {@link PropagationCallGraphBuilder} used to build the call graph. + * @return The set of possible dtypes of the tensor returned by this generator when an explicit + * dtype isn't provided as an argument. + */ + protected abstract EnumSet getDefaultDTypes(PropagationCallGraphBuilder builder); + + /** + * Returns the value number for the dtype argument in the function call. + * + * @return The value number for the dtype argument in the function call or -1 if the dtype + * argument is not supported. + */ + protected abstract int getValueNumberForDTypeArgument(); + + protected EnumSet getDTypes(PropagationCallGraphBuilder builder) { + PointerAnalysis pointerAnalysis = builder.getPointerAnalysis(); + + int valNum = this.getValueNumberForDTypeArgument(); + OrdinalSet pointsToSet = null; + + if (valNum > 0) { + // The dtype is in an explicit argument. + // FIXME: Handle keyword arguments. + PointerKey pointerKey = pointerAnalysis.getHeapModel().getPointerKeyForLocal(node, valNum); + pointsToSet = pointerAnalysis.getPointsToSet(pointerKey); + } + + // If the argument dtype is not specified. + if (pointsToSet == null || pointsToSet.isEmpty()) return getDefaultDTypes(builder); + else + // The dtype points-to set is non-empty, meaning that the dtype was explicitly set. + return getDTypesFromDTypeArgument(builder, pointsToSet); + } + + /** + * Returns the possible dtypes of the tensor returned by this generator. The dtype is inferred + * from the argument represented by the given value number. + * + * @param builder The {@link PropagationCallGraphBuilder} used to build the call graph. + * @param valueNumber The value number of the argument from which to infer the dtype. + * @return A set of possible dtypes of the tensor returned by this generator. + */ + protected EnumSet getDTypes(PropagationCallGraphBuilder builder, int valueNumber) { + PointerAnalysis pointerAnalysis = builder.getPointerAnalysis(); + PointerKey valuePK = pointerAnalysis.getHeapModel().getPointerKeyForLocal(node, valueNumber); + OrdinalSet valuePointsToSet = pointerAnalysis.getPointsToSet(valuePK); + return getDTypesOfValue(builder, valuePointsToSet); + } + + /** + * Returns the possible dtypes of the tensor returned by this generator. The dtype is inferred + * from the given points-to set. + * + * @param builder The {@link PropagationCallGraphBuilder} used to build the call graph. + * @param pointsToSet The points-to set of the value from which the dtype will be derived. + * @return A set of possible dtypes of the tensor returned by this generator. + */ + private EnumSet getDTypesOfValue( + PropagationCallGraphBuilder builder, OrdinalSet valuePointsToSet) { + EnumSet ret = EnumSet.noneOf(DType.class); + PointerAnalysis pointerAnalysis = builder.getPointerAnalysis(); + + for (InstanceKey valueIK : valuePointsToSet) + if (valueIK instanceof ConstantKey) { // It's a scalar value. + ConstantKey constantKey = (ConstantKey) valueIK; + Object value = constantKey.getValue(); + if (value instanceof Float || value instanceof Double) { + ret.add(FLOAT32); + LOGGER.info( + "Inferred dtype: " + + FLOAT32 + + " for source: " + + source + + " from value: " + + value + + "."); + } else if (value instanceof Integer || value instanceof Long) { + ret.add(INT32); + LOGGER.info( + "Inferred dtype: " + + INT32 + + " for source: " + + source + + " from value: " + + value + + "."); + } else if (value instanceof String) { + ret.add(STRING); + LOGGER.info( + "Inferred dtype: " + + STRING + + " for source: " + + source + + " from value: " + + value + + "."); + } else throw new IllegalStateException("Unknown constant type: " + value.getClass() + "."); + } else if (valueIK instanceof AllocationSiteInNode) { + AllocationSiteInNode asin = getAllocationSiteInNode(valueIK); + TypeReference reference = asin.getConcreteType().getReference(); + + if (reference.equals(list)) { + OrdinalSet objectCatalogPointsToSet = + pointerAnalysis.getPointsToSet( + ((AstPointerKeyFactory) builder.getPointerKeyFactory()) + .getPointerKeyForObjectCatalog(asin)); + + LOGGER.fine( + "The object catalog points-to set size is: " + objectCatalogPointsToSet.size() + "."); + + for (InstanceKey catalogIK : objectCatalogPointsToSet) { + ConstantKey constantKey = (ConstantKey) catalogIK; + Object constantKeyValue = constantKey.getValue(); + + Integer fieldIndex = (Integer) constantKeyValue; + + FieldReference subscript = + FieldReference.findOrCreate( + PythonTypes.Root, + findOrCreateAsciiAtom(fieldIndex.toString()), + PythonTypes.Root); + + IField f = builder.getClassHierarchy().resolveField(subscript); + LOGGER.fine("Found field: " + f); + + PointerKey pointerKeyForInstanceField = builder.getPointerKeyForInstanceField(asin, f); + LOGGER.fine( + "Found pointer key for instance field: " + pointerKeyForInstanceField + "."); + + OrdinalSet instanceFieldPointsToSet = + pointerAnalysis.getPointsToSet(pointerKeyForInstanceField); + LOGGER.fine("Points-to set for instance field: " + instanceFieldPointsToSet + "."); + + EnumSet dTypesOfField = getDTypesOfValue(builder, instanceFieldPointsToSet); + ret.addAll(dTypesOfField); + } + } else throw new IllegalStateException("Unknown type reference: " + reference + "."); + } else + // TODO: More cases. + throw new IllegalStateException( + "Expected a " + ConstantKey.class + " for value, but got: " + valueIK + "."); + return ret; + } +} diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/TensorGeneratorFactory.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/TensorGeneratorFactory.java new file mode 100644 index 000000000..e1d58f451 --- /dev/null +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/TensorGeneratorFactory.java @@ -0,0 +1,82 @@ +package com.ibm.wala.cast.python.ml.client; + +import com.ibm.wala.cast.python.types.PythonTypes; +import com.ibm.wala.cast.types.AstMethodReference; +import com.ibm.wala.ipa.callgraph.CGNode; +import com.ibm.wala.ipa.callgraph.propagation.LocalPointerKey; +import com.ibm.wala.ipa.callgraph.propagation.PointerKey; +import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable; +import com.ibm.wala.types.MethodReference; +import com.ibm.wala.types.TypeName; +import com.ibm.wala.types.TypeReference; +import java.util.logging.Logger; + +public class TensorGeneratorFactory { + + private static final Logger LOGGER = Logger.getLogger(TensorGeneratorFactory.class.getName()); + + /** https://www.tensorflow.org/api_docs/python/tf/ones. */ + private static final MethodReference ONES = + MethodReference.findOrCreate( + TypeReference.findOrCreate( + PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/ones")), + AstMethodReference.fnSelector); + + /** https://www.tensorflow.org/api_docs/python/tf/constant. */ + private static final MethodReference CONSTANT = + MethodReference.findOrCreate( + TypeReference.findOrCreate( + PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/constant")), + AstMethodReference.fnSelector); + + /** https://www.tensorflow.org/api_docs/python/tf/range. */ + private static final MethodReference RANGE = + MethodReference.findOrCreate( + TypeReference.findOrCreate( + PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/range")), + AstMethodReference.fnSelector); + + /** https://www.tensorflow.org/api_docs/python/tf/random/uniform. */ + private static final MethodReference UNIFORM = + MethodReference.findOrCreate( + TypeReference.findOrCreate( + PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/uniform")), + AstMethodReference.fnSelector); + + /** https://www.tensorflow.org/api_docs/python/tf/zeros. */ + private static final MethodReference ZEROS = + MethodReference.findOrCreate( + TypeReference.findOrCreate( + PythonTypes.pythonLoader, TypeName.string2TypeName("Ltensorflow/functions/zeros")), + AstMethodReference.fnSelector); + + /** https://www.tensorflow.org/api_docs/python/tf/zeros_like. */ + private static final MethodReference ZEROS_LIKE = + MethodReference.findOrCreate( + TypeReference.findOrCreate( + PythonTypes.pythonLoader, + TypeName.string2TypeName("Ltensorflow/functions/zeros_like")), + AstMethodReference.fnSelector); + + public static TensorGenerator getGenerator(PointsToSetVariable source) { + // Get the pointer key for the source. + PointerKey pointerKey = source.getPointerKey(); + + LocalPointerKey localPointerKey = (LocalPointerKey) pointerKey; + CGNode node = localPointerKey.getNode(); + + TypeReference calledFunction = node.getMethod().getDeclaringClass().getReference(); + LOGGER.info("Getting tensor generator for call to: " + calledFunction.getName() + "."); + + if (calledFunction.equals(ONES.getDeclaringClass())) return new Ones(source, node); + else if (calledFunction.equals(CONSTANT.getDeclaringClass())) return new Constant(source, node); + else if (calledFunction.equals(RANGE.getDeclaringClass())) return new Range(source, node); + else if (calledFunction.equals(UNIFORM.getDeclaringClass())) return new Uniform(source, node); + else if (calledFunction.equals(ZEROS.getDeclaringClass())) return new Zeros(source, node); + else if (calledFunction.equals(ZEROS_LIKE.getDeclaringClass())) + return new ZerosLike(source, node); + else + throw new IllegalArgumentException( + "Unknown call: " + calledFunction + " for source: " + source + "."); + } +} diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Uniform.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Uniform.java new file mode 100644 index 000000000..99175fb73 --- /dev/null +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Uniform.java @@ -0,0 +1,25 @@ +package com.ibm.wala.cast.python.ml.client; + +import com.ibm.wala.ipa.callgraph.CGNode; +import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable; + +/** + * A generator for tensors created by the `uniform()` function in TensorFlow. + * + * @see TensorFlow uniform() + * API. + * @author Raffi Khatchadourian + */ +public class Uniform extends Ones { + + private static final int VALUE_NUMBER_FOR_DTYPE_ARGUMENT = 5; + + public Uniform(PointsToSetVariable source, CGNode node) { + super(source, node); + } + + @Override + protected int getValueNumberForDTypeArgument() { + return VALUE_NUMBER_FOR_DTYPE_ARGUMENT; + } +} diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Zeros.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Zeros.java new file mode 100644 index 000000000..659ab3b76 --- /dev/null +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/Zeros.java @@ -0,0 +1,17 @@ +package com.ibm.wala.cast.python.ml.client; + +import com.ibm.wala.ipa.callgraph.CGNode; +import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable; + +/** + * A generator for tensors created by the `zeros()` function in TensorFlow. + * + * @see TensorFlow zeros() API. + * @author Raffi Khatchadourian + */ +public class Zeros extends Ones { + + public Zeros(PointsToSetVariable source, CGNode node) { + super(source, node); + } +} diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/ZerosLike.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/ZerosLike.java new file mode 100644 index 000000000..2eade7a19 --- /dev/null +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/ZerosLike.java @@ -0,0 +1,29 @@ +package com.ibm.wala.cast.python.ml.client; + +import com.ibm.wala.ipa.callgraph.CGNode; +import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable; + +/** + * A generator for tensors created by the `zeros_like()` function in TensorFlow. + * + * @see TensorFlow zeros_like() + * API. + * @author Raffi Khatchadourian + */ +public class ZerosLike extends Constant { + + /** + * The shape argument is not explicitly provided to zeros_like(); rather, the shape is inferred + * from the `input` argument. + */ + private static final int VALUE_NUMBER_FOR_SHAPE_ARGUMENT = -1; + + public ZerosLike(PointsToSetVariable source, CGNode node) { + super(source, node); + } + + @Override + protected int getValueNumberForShapeArgument() { + return VALUE_NUMBER_FOR_SHAPE_ARGUMENT; + } +} diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/types/TensorFlowTypes.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/types/TensorFlowTypes.java index dbe791c2b..437f3374a 100644 --- a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/types/TensorFlowTypes.java +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/types/TensorFlowTypes.java @@ -1,8 +1,14 @@ package com.ibm.wala.cast.python.ml.types; +import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType.FLOAT32; +import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DType.INT32; +import static com.ibm.wala.core.util.strings.Atom.findOrCreateAsciiAtom; + import com.ibm.wala.cast.python.types.PythonTypes; +import com.ibm.wala.types.FieldReference; import com.ibm.wala.types.TypeName; import com.ibm.wala.types.TypeReference; +import java.util.Map; /** * Types found in the TensorFlow library. @@ -11,11 +17,59 @@ */ public class TensorFlowTypes extends PythonTypes { + /** + * Defined data types used in TensorFlow. + * + * @see TensorFlow + * dtypes. + */ + public enum DType { + FLOAT32, + FLOAT64, + INT32, + INT64, + STRING; + } + public static final TypeReference TENSORFLOW = TypeReference.findOrCreate(pythonLoader, TypeName.findOrCreate("Ltensorflow")); public static final TypeReference DATASET = TypeReference.findOrCreate(pythonLoader, TypeName.findOrCreate("Ltensorflow/data/Dataset")); + /** + * Represents the TensorFlow data type. + * + * @see TensorFlow DType. + */ + public static final TypeReference D_TYPE = + TypeReference.findOrCreate(pythonLoader, TypeName.findOrCreate("Ltensorflow/dtypes/DType")); + + /** + * Represents the TensorFlow float32 data type. + * + * @see TensorFlow + * float32 DType. + */ + public static final FieldReference FLOAT_32 = + FieldReference.findOrCreate( + PythonTypes.Root, findOrCreateAsciiAtom(FLOAT32.name().toLowerCase()), D_TYPE); + + /** + * Represents the TensorFlow int32 data type. + * + * @see TensorFlow + * int32 DType. + */ + public static final FieldReference INT_32 = + FieldReference.findOrCreate( + PythonTypes.Root, findOrCreateAsciiAtom(INT32.name().toLowerCase()), D_TYPE); + + /** A mapping from a field reference to its associated {@link DType}, if any. */ + public static final Map FIELD_REFERENCE_TO_DTYPE = + Map.of(FLOAT_32, FLOAT32, INT_32, INT32); + private TensorFlowTypes() {} } diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_add116.py b/com.ibm.wala.cast.python.test/data/tf2_test_add116.py new file mode 100644 index 000000000..55ebbe20e --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_add116.py @@ -0,0 +1,17 @@ +import tensorflow as tf + + +def add(a, b): + assert a.shape == (1, 2), f"Expected shape (1, 2), got {a.shape}" + assert b.shape == (2, 2), f"Expected shape (2, 2), got {b.shape}" + + assert a.dtype == tf.float32, f"Expected dtype float32, got {a.dtype}" + assert b.dtype == tf.float32, f"Expected dtype float32, got {b.dtype}" + + return a + b + + +c = add(tf.ones([1, 2], tf.float32), tf.ones([2, 2], tf.float32)) + +assert c.shape == (2, 2), f"Expected shape (2, 2), got {c.shape}" +assert c.dtype == tf.float32, f"Expected dtype float32, got {c.dtype}" diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_add117.py b/com.ibm.wala.cast.python.test/data/tf2_test_add117.py new file mode 100644 index 000000000..fe3255a50 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_add117.py @@ -0,0 +1,14 @@ +import tensorflow as tf +import random + + +def add(a, b): + return a + b + + +if random.random() < 0.5: + a = 1 +else: + a = 3 + +c = add(tf.ones([a, 2]), tf.ones([2, 2])) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_add118.py b/com.ibm.wala.cast.python.test/data/tf2_test_add118.py new file mode 100644 index 000000000..418491829 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_add118.py @@ -0,0 +1,13 @@ +import tensorflow +from tensorflow.python.ops.array_ops import zeros + + +def add(a, b): + return a + b + + +arg = tensorflow.zeros([1, 2], tensorflow.int32) +assert arg.shape == (1, 2) +assert arg.dtype == tensorflow.int32 + +c = add(arg, zeros([2, 2], tensorflow.int32)) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_add119.py b/com.ibm.wala.cast.python.test/data/tf2_test_add119.py new file mode 100644 index 000000000..53da9276c --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_add119.py @@ -0,0 +1,12 @@ +import tensorflow as tf + + +def add(a, b): + return a + b + + +arg = tf.zeros_like([1, 2], tf.float32) +assert arg.shape == (2,) +assert arg.dtype == tf.float32 + +c = add(arg, tf.zeros_like([2, 2], tf.float32)) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_add27.py b/com.ibm.wala.cast.python.test/data/tf2_test_add27.py index 7dbdeb005..87f7d1a69 100644 --- a/com.ibm.wala.cast.python.test/data/tf2_test_add27.py +++ b/com.ibm.wala.cast.python.test/data/tf2_test_add27.py @@ -6,4 +6,8 @@ def add(a, b): return a + b -c = add(tensorflow.zeros([1, 2]), zeros([2, 2])) +arg = tensorflow.zeros([1, 2]) +assert arg.shape == (1, 2) +assert arg.dtype == tensorflow.float32 + +c = add(arg, zeros([2, 2])) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_add32.py b/com.ibm.wala.cast.python.test/data/tf2_test_add32.py index 59a39751f..0c3e451c5 100644 --- a/com.ibm.wala.cast.python.test/data/tf2_test_add32.py +++ b/com.ibm.wala.cast.python.test/data/tf2_test_add32.py @@ -5,4 +5,8 @@ def add(a, b): return a + b -c = add(tf.zeros_like([1, 2]), tf.zeros_like([2, 2])) +arg = tf.zeros_like([1, 2]) +assert arg.shape == (2,) +assert arg.dtype == tf.int32 + +c = add(arg, tf.zeros_like([2, 2])) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_add7.py b/com.ibm.wala.cast.python.test/data/tf2_test_add7.py index a21250c2e..dc4eb2017 100644 --- a/com.ibm.wala.cast.python.test/data/tf2_test_add7.py +++ b/com.ibm.wala.cast.python.test/data/tf2_test_add7.py @@ -2,7 +2,16 @@ def add(a, b): + assert a.shape == (1, 2), f"Expected shape (1, 2), got {a.shape}" + assert b.shape == (2, 2), f"Expected shape (2, 2), got {b.shape}" + + assert a.dtype == tf.float32, f"Expected dtype float32, got {a.dtype}" + assert b.dtype == tf.float32, f"Expected dtype float32, got {b.dtype}" + return a + b -c = add(tf.ones([1, 2]), tf.ones([2, 2])) # [[2., 2.], [2., 2.]] +c = add(tf.ones([1, 2]), tf.ones([2, 2])) + +assert c.shape == (2, 2), f"Expected shape (2, 2), got {c.shape}" +assert c.dtype == tf.float32, f"Expected dtype float32, got {c.dtype}" diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_decorator12.py b/com.ibm.wala.cast.python.test/data/tf2_test_decorator12.py new file mode 100644 index 000000000..8d6329d8f --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_decorator12.py @@ -0,0 +1,14 @@ +import tensorflow as tf + + +@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.float32),)) +@tf.function(reduce_retracing=True) +def returned(a): + return a + + +a = tf.constant([1, 1.0]) +b = returned(a) + +assert a.shape == (2,) +assert a.dtype == tf.float32 diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_decorator2.py b/com.ibm.wala.cast.python.test/data/tf2_test_decorator2.py index 3df833404..60531f3e7 100644 --- a/com.ibm.wala.cast.python.test/data/tf2_test_decorator2.py +++ b/com.ibm.wala.cast.python.test/data/tf2_test_decorator2.py @@ -7,4 +7,8 @@ def returned(a): a = tf.range(5) + +assert a.shape == (5,) +assert a.dtype == tf.int32 + b = returned(a) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_decorator3.py b/com.ibm.wala.cast.python.test/data/tf2_test_decorator3.py index e65180b79..f7dff82fc 100644 --- a/com.ibm.wala.cast.python.test/data/tf2_test_decorator3.py +++ b/com.ibm.wala.cast.python.test/data/tf2_test_decorator3.py @@ -9,3 +9,6 @@ def returned(a): a = tf.constant([1.0, 1.0]) b = returned(a) + +assert a.shape == (2,) +assert a.dtype == tf.float32 diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_function.py b/com.ibm.wala.cast.python.test/data/tf2_test_function.py index 05006b94c..c4e000163 100644 --- a/com.ibm.wala.cast.python.test/data/tf2_test_function.py +++ b/com.ibm.wala.cast.python.test/data/tf2_test_function.py @@ -8,7 +8,11 @@ def func2(t): @tf.function def func(): a = tf.constant([[1.0, 2.0], [3.0, 4.0]]) + assert a.shape == (2, 2) + b = tf.constant([[1.0, 1.0], [0.0, 1.0]]) + assert b.shape == (2, 2) + c = tf.matmul(a, b) tensor = tf.experimental.numpy.ndarray(c.op, 0, tf.float32) func2(tensor) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_function10.py b/com.ibm.wala.cast.python.test/data/tf2_test_function10.py new file mode 100644 index 000000000..51ce7319b --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_function10.py @@ -0,0 +1,17 @@ +import tensorflow as tf + + +def func(t): + pass + + +a = tf.constant( + [ + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], + [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]], + ] +) +assert a.shape == (2, 3, 4) +assert a.dtype == tf.int32 + +func(a) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_function11.py b/com.ibm.wala.cast.python.test/data/tf2_test_function11.py new file mode 100644 index 000000000..2e150de13 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_function11.py @@ -0,0 +1,17 @@ +import tensorflow as tf + + +def func(t): + pass + + +a = tf.constant( + [ + [[1, 2, 3], [5, 6, 7], [9, 10, 11]], + [[13, 14, 15], [17, 18, 19], [21, 22, 23]], + ] +) +assert a.shape == (2, 3, 3) +assert a.dtype == tf.int32 + +func(a) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_function12.py b/com.ibm.wala.cast.python.test/data/tf2_test_function12.py new file mode 100644 index 000000000..cf82dfd35 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_function12.py @@ -0,0 +1,21 @@ +import tensorflow as tf +from random import random + + +def func(t): + pass + + +n = random() + +a = None + +if n > 0.5: + a = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + assert a.shape == (3, 2) +else: + a = tf.constant([[1.0], [3.0]]) + assert a.shape == (2, 1) + +assert a.shape == (3, 2) or a.shape == (2, 1) +func(a) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_function13.py b/com.ibm.wala.cast.python.test/data/tf2_test_function13.py new file mode 100644 index 000000000..9149486d0 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_function13.py @@ -0,0 +1,19 @@ +import tensorflow as tf +from random import random + + +def func(t): + pass + + +n = random() + +if n > 0.5: + a = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + assert a.shape == (3, 2) +else: + a = tf.constant([[1.0], [3.0]]) + assert a.shape == (2, 1) + +assert a.shape == (3, 2) or a.shape == (2, 1) +func(a) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_function5.py b/com.ibm.wala.cast.python.test/data/tf2_test_function5.py new file mode 100644 index 000000000..375130982 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_function5.py @@ -0,0 +1,11 @@ +import tensorflow as tf + + +def func(t): + pass + + +a = tf.constant([[1.0, 2.0], [3.0, 4.0]]) +assert a.shape == (2, 2) + +func(a) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_function6.py b/com.ibm.wala.cast.python.test/data/tf2_test_function6.py new file mode 100644 index 000000000..61c99f02a --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_function6.py @@ -0,0 +1,11 @@ +import tensorflow as tf + + +def func(t): + pass + + +a = tf.constant([[1.0], [3.0]]) +assert a.shape == (2, 1) + +func(a) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_function7.py b/com.ibm.wala.cast.python.test/data/tf2_test_function7.py new file mode 100644 index 000000000..3acf6acff --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_function7.py @@ -0,0 +1,11 @@ +import tensorflow as tf + + +def func(t): + pass + + +a = tf.constant([1.0, 3.0]) +assert a.shape == (2,) + +func(a) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_function8.py b/com.ibm.wala.cast.python.test/data/tf2_test_function8.py new file mode 100644 index 000000000..ae30e4842 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_function8.py @@ -0,0 +1,19 @@ +import tensorflow as tf +from random import random + + +def func(t): + pass + + +n = random() + +if n > 0.5: + l = [[1.0], [3.0]] +else: + l = [1.0, 3.0] + +a = tf.constant(l) +assert a.shape == (2, 1) or a.shape == (2,) + +func(a) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_function9.py b/com.ibm.wala.cast.python.test/data/tf2_test_function9.py new file mode 100644 index 000000000..43c43b9e4 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_function9.py @@ -0,0 +1,11 @@ +import tensorflow as tf + + +def func(t): + pass + + +a = tf.constant([[1.0, 3.0]]) +assert a.shape == (1, 2) + +func(a) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_model_call.py b/com.ibm.wala.cast.python.test/data/tf2_test_model_call.py index 5e5ffce77..3c7d88fea 100644 --- a/com.ibm.wala.cast.python.test/data/tf2_test_model_call.py +++ b/com.ibm.wala.cast.python.test/data/tf2_test_model_call.py @@ -31,6 +31,8 @@ def __call__(self, x): input_data = tf.random.uniform([20, 28, 28]) +assert input_data.shape == (20, 28, 28) +assert input_data.dtype == tf.float32 model = SequentialModel() result = model(input_data) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_model_call6.py b/com.ibm.wala.cast.python.test/data/tf2_test_model_call6.py new file mode 100644 index 000000000..36ece641b --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_model_call6.py @@ -0,0 +1,38 @@ +import tensorflow as tf + +# Create an override model to classify pictures + + +class SequentialModel(tf.keras.Model): + def __init__(self, **kwargs): + super(SequentialModel, self).__init__(**kwargs) + + self.flatten = tf.keras.layers.Flatten(input_shape=(28, 28)) + + # Add a lot of small layers + num_layers = 100 + self.my_layers = [ + tf.keras.layers.Dense(64, activation="relu") for n in range(num_layers) + ] + + self.dropout = tf.keras.layers.Dropout(0.2) + self.dense_2 = tf.keras.layers.Dense(10) + + def __call__(self, x): + x = self.flatten(x) + + for layer in self.my_layers: + x = layer(x) + + x = self.dropout(x) + x = self.dense_2(x) + + return x + + +input_data = tf.random.uniform([20, 28, 28], 0, 10, tf.int32) +assert input_data.shape == (20, 28, 28) +assert input_data.dtype == tf.int32 + +model = SequentialModel() +result = model(input_data) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_static_method13.py b/com.ibm.wala.cast.python.test/data/tf2_test_static_method13.py new file mode 100644 index 000000000..a60222983 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_static_method13.py @@ -0,0 +1,16 @@ +import tensorflow as tf + + +class MyClass: + + @staticmethod + def the_static_method(x): + assert isinstance(x, tf.Tensor) + + +a = tf.constant(1, tf.float32, (5,)) + +assert a.shape == (5,) +assert a.dtype == tf.float32 + +MyClass.the_static_method(a) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_static_method14.py b/com.ibm.wala.cast.python.test/data/tf2_test_static_method14.py new file mode 100644 index 000000000..2c2457989 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_static_method14.py @@ -0,0 +1,16 @@ +import tensorflow as tf + + +class MyClass: + + @staticmethod + def the_static_method(x): + assert isinstance(x, tf.Tensor) + + +a = tf.constant(1, tf.float32, ([1, 2])) + +assert a.shape == (1, 2) +assert a.dtype == tf.float32 + +MyClass.the_static_method(a) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list.py b/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list.py index 86167f7db..723ece779 100644 --- a/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list.py +++ b/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list.py @@ -7,5 +7,11 @@ def add(a, b): list = [tf.ones([1, 2]), tf.ones([2, 2])] +assert list[0].shape == (1, 2) +assert list[1].shape == (2, 2) + +assert list[0].dtype == tf.float32 +assert list[1].dtype == tf.float32 + for element in list: c = add(element, element)