Skip to content

Commit 9574624

Browse files
authored
Deal with multiple possible callables (#121)
* Return `null` when there are multiple possible callables. * Add test to exercise call string imprecision. Based on the call string length. See wala/WALA#1417 (reply in thread). * Expect the test to fail. In the past, we could add 0's to the parameters, but since we are not enforcing the existing of the node in the CG, we can no longer do that. Still, this test should now fail if wala#207 is fixed.
1 parent f328537 commit 9574624

File tree

7 files changed

+157
-1
lines changed

7 files changed

+157
-1
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

+35
Original file line numberDiff line numberDiff line change
@@ -1181,6 +1181,41 @@ public void testModelCall4()
11811181
test("tf2_test_model_call4.py", "SequentialModel.__call__", 1, 1, 3);
11821182
}
11831183

1184+
/**
1185+
* Test call string imprecision as described in
1186+
* https://github.com/wala/WALA/discussions/1417#discussioncomment-10085680. This should fail due
1187+
* to https://github.com/wala/ML/issues/207.
1188+
*/
1189+
@Test(expected = java.lang.AssertionError.class)
1190+
public void testModelCall5()
1191+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1192+
test(
1193+
new String[] {
1194+
"proj66/src/tf2_test_model_call5b.py",
1195+
"proj66/tf2_test_model_call5.py",
1196+
"proj66/tf2_test_model_call5a.py"
1197+
},
1198+
"tf2_test_model_call5.py",
1199+
"SequentialModel.__call__",
1200+
"proj66",
1201+
1,
1202+
1,
1203+
3);
1204+
1205+
test(
1206+
new String[] {
1207+
"proj66/src/tf2_test_model_call5b.py",
1208+
"proj66/tf2_test_model_call5.py",
1209+
"proj66/tf2_test_model_call5a.py"
1210+
},
1211+
"tf2_test_model_call5a.py",
1212+
"SequentialModel.__call__",
1213+
"proj66",
1214+
1,
1215+
1,
1216+
3);
1217+
}
1218+
11841219
@Test
11851220
public void testModelAttributes()
11861221
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {

com.ibm.wala.cast.python.test/.pydevproject

+1
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,6 @@
2020
<path>/${PROJECT_DIR_NAME}/data/proj52</path>
2121
<path>/${PROJECT_DIR_NAME}/data/proj55</path>
2222
<path>/${PROJECT_DIR_NAME}/data/proj56</path>
23+
<path>/${PROJECT_DIR_NAME}/data/proj66</path>
2324
</pydev_pathproperty>
2425
</pydev_project>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Test https://github.com/wala/WALA/discussions/1417#discussioncomment-10085680.
2+
3+
4+
def f(m, d):
5+
return m.predict(d)
6+
7+
8+
def g(m, d):
9+
return f(m, d)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Test https://github.com/wala/WALA/discussions/1417#discussioncomment-10085680.
2+
3+
import tensorflow as tf
4+
from src.tf2_test_model_call5b import g
5+
6+
# Create an override model to classify pictures
7+
8+
9+
class SequentialModel(tf.keras.Model):
10+
11+
def __init__(self, **kwargs):
12+
super(SequentialModel, self).__init__(**kwargs)
13+
14+
self.flatten = tf.keras.layers.Flatten(input_shape=(28, 28))
15+
16+
# Add a lot of small layers
17+
num_layers = 100
18+
self.my_layers = [
19+
tf.keras.layers.Dense(64, activation="relu") for n in range(num_layers)
20+
]
21+
22+
self.dropout = tf.keras.layers.Dropout(0.2)
23+
self.dense_2 = tf.keras.layers.Dense(10)
24+
25+
def __call__(self, x):
26+
print("Raffi 1")
27+
x = self.flatten(x)
28+
29+
for layer in self.my_layers:
30+
x = layer(x)
31+
32+
x = self.dropout(x)
33+
x = self.dense_2(x)
34+
35+
return x
36+
37+
def predict(self, x):
38+
return self(x)
39+
40+
41+
input_data = tf.random.uniform([20, 28, 28])
42+
43+
model = SequentialModel()
44+
result = g(model, input_data)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Test https://github.com/wala/WALA/discussions/1417#discussioncomment-10085680.
2+
3+
import tensorflow as tf
4+
from src.tf2_test_model_call5b import g
5+
6+
# Create an override model to classify pictures
7+
8+
9+
class SequentialModel(tf.keras.Model):
10+
11+
def __init__(self, **kwargs):
12+
super(SequentialModel, self).__init__(**kwargs)
13+
14+
self.flatten = tf.keras.layers.Flatten(input_shape=(28, 28))
15+
16+
# Add a lot of small layers
17+
num_layers = 100
18+
self.my_layers = [
19+
tf.keras.layers.Dense(64, activation="relu") for n in range(num_layers)
20+
]
21+
22+
self.dropout = tf.keras.layers.Dropout(0.2)
23+
self.dense_2 = tf.keras.layers.Dense(10)
24+
25+
def __call__(self, x):
26+
print("Raffi 2")
27+
x = self.flatten(x)
28+
29+
for layer in self.my_layers:
30+
x = layer(x)
31+
32+
x = self.dropout(x)
33+
x = self.dense_2(x)
34+
35+
return x
36+
37+
def predict(self, x):
38+
return self(x)
39+
40+
41+
input_data = tf.random.uniform([20, 28, 28])
42+
43+
model = SequentialModel()
44+
result = g(model, input_data)

com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonInstanceMethodTrampolineTargetSelector.java

+23-1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import com.ibm.wala.util.collections.HashMapFactory;
4343
import com.ibm.wala.util.collections.Pair;
4444
import com.ibm.wala.util.intset.OrdinalSet;
45+
import java.util.HashMap;
4546
import java.util.Map;
4647
import java.util.logging.Logger;
4748

@@ -222,6 +223,8 @@ private IClass getCallable(CGNode caller, IClassHierarchy cha, PythonInvokeInstr
222223
PointerKey receiver = pkf.getPointerKeyForLocal(caller, call.getUse(0));
223224
OrdinalSet<InstanceKey> objs = builder.getPointerAnalysis().getPointsToSet(receiver);
224225

226+
Map<InstanceKey, IClass> instanceToCallable = new HashMap<>();
227+
225228
for (InstanceKey o : objs) {
226229
AllocationSiteInNode instanceKey = getAllocationSiteInNode(o);
227230
if (instanceKey != null) {
@@ -253,10 +256,29 @@ private IClass getCallable(CGNode caller, IClassHierarchy cha, PythonInvokeInstr
253256
LOGGER.info("Applying callable workaround for https://github.com/wala/ML/issues/118.");
254257
}
255258

256-
if (callable != null) return callable;
259+
if (callable != null) {
260+
if (instanceToCallable.containsKey(instanceKey))
261+
throw new IllegalStateException("Exisitng mapping found for: " + instanceKey);
262+
263+
IClass previousValue = instanceToCallable.put(instanceKey, callable);
264+
assert previousValue == null : "Not expecting a previous mapping.";
265+
}
257266
}
258267
}
259268

269+
// if there's only one possible option.
270+
if (instanceToCallable.values().size() == 1) {
271+
IClass callable = instanceToCallable.values().iterator().next();
272+
assert callable != null : "Callable should be non-null.";
273+
return callable;
274+
}
275+
276+
// if we have multiple candidates.
277+
if (instanceToCallable.values().size() > 1)
278+
// we cannot accurately select one.
279+
LOGGER.warning(
280+
"Multiple (" + instanceToCallable.values().size() + ") callable targets found.");
281+
260282
return null;
261283
}
262284

0 commit comments

Comments
 (0)