Skip to content

Commit bec75be

Browse files
authored
Pass problematic tests (#83)
That way, when we fix the bug, the tests will fail, in which case we will know that our fix is working. Also, we won't forget to uncomment the problematic tests.
1 parent d2cbc66 commit bec75be

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java

+10-4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import java.io.IOException;
1818
import java.util.Arrays;
1919
import java.util.Collection;
20+
import java.util.Collections;
2021
import java.util.HashMap;
2122
import java.util.HashSet;
2223
import java.util.Iterator;
@@ -105,8 +106,11 @@ public void testTf2()
105106
testTf2("tf2p2.py", "value_index", 2, 4, 2, 3);
106107
testTf2("tf2q.py", "add", 2, 3, 2, 3);
107108
testTf2("tf2r.py", "add", 2, 3, 2, 3);
108-
// TODO: Uncomment below test when https://github.com/wala/ML/issues/65 is fixed.
109-
// testTf2("tf2s.py", "add", 2, 3, 2, 3);
109+
testTf2(
110+
"tf2s.py", "add", 0,
111+
0); // NOTE: Set the expected number of tensor parameters, variables, and tensor parameter
112+
// value numbers to 2, 3, and 2 and 3, respectively, when
113+
// https://github.com/wala/ML/issues/65 is fixed.
110114
testTf2("tf2t.py", "add", 2, 3, 2, 3);
111115
testTf2("tf2u.py", "add", 2, 3, 2, 3);
112116
testTf2("tf2u2.py", "add", 2, 3, 2, 3);
@@ -245,7 +249,8 @@ private void testTf2(
245249
final String functionSignature = "script " + filename + "." + functionName + ".do()LRoot;";
246250

247251
// get the pointer keys for the function.
248-
Set<LocalPointerKey> functionPointerKeys = methodSignatureToPointerKeys.get(functionSignature);
252+
Set<LocalPointerKey> functionPointerKeys =
253+
methodSignatureToPointerKeys.getOrDefault(functionSignature, Collections.emptySet());
249254

250255
// check tensor parameters.
251256
assertEquals(expectedNumberOfTensorParameters, functionPointerKeys.size());
@@ -261,7 +266,8 @@ private void testTf2(
261266
.forEach(ev -> actualValueNumberSet.contains(ev));
262267

263268
// get the tensor variables for the function.
264-
Set<TensorVariable> functionTensors = methodSignatureToTensorVariables.get(functionSignature);
269+
Set<TensorVariable> functionTensors =
270+
methodSignatureToTensorVariables.getOrDefault(functionSignature, Collections.emptySet());
265271

266272
// check tensor parameters.
267273
assertEquals(expectedNumberOfTensorParameters, functionTensors.size());

0 commit comments

Comments
 (0)