Skip to content

Commit 948e6bf

Browse files
authored
Deal with multiple possible callables (#208)
* Return `null` when there are multiple possible callables. * Add test to exercise call string imprecision. Based on the call string length. See wala/WALA#1417 (reply in thread). * Expect the test to fail. In the past, we could add 0's to the parameters, but since we are not enforcing the existing of the node in the CG, we can no longer do that. Still, this test should now fail if #207 is fixed.
1 parent b8eccd0 commit 948e6bf

File tree

6 files changed

+153
-1
lines changed

6 files changed

+153
-1
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

+35
Original file line numberDiff line numberDiff line change
@@ -1204,6 +1204,41 @@ public void testModelCall4()
12041204
test("tf2_test_model_call4.py", "SequentialModel.__call__", 1, 1, 3);
12051205
}
12061206

1207+
/**
1208+
* Test call string imprecision as described in
1209+
* https://github.com/wala/WALA/discussions/1417#discussioncomment-10085680. This should fail due
1210+
* to https://github.com/wala/ML/issues/207.
1211+
*/
1212+
@Test(expected = java.lang.AssertionError.class)
1213+
public void testModelCall5()
1214+
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
1215+
test(
1216+
new String[] {
1217+
"proj66/src/tf2_test_model_call5b.py",
1218+
"proj66/tf2_test_model_call5.py",
1219+
"proj66/tf2_test_model_call5a.py"
1220+
},
1221+
"tf2_test_model_call5.py",
1222+
"SequentialModel.__call__",
1223+
"proj66",
1224+
1,
1225+
1,
1226+
3);
1227+
1228+
test(
1229+
new String[] {
1230+
"proj66/src/tf2_test_model_call5b.py",
1231+
"proj66/tf2_test_model_call5.py",
1232+
"proj66/tf2_test_model_call5a.py"
1233+
},
1234+
"tf2_test_model_call5a.py",
1235+
"SequentialModel.__call__",
1236+
"proj66",
1237+
1,
1238+
1,
1239+
3);
1240+
}
1241+
12071242
@Test
12081243
public void testModelAttributes()
12091244
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Test https://github.com/wala/WALA/discussions/1417#discussioncomment-10085680.
2+
3+
4+
def f(m, d):
5+
return m.predict(d)
6+
7+
8+
def g(m, d):
9+
return f(m, d)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Test https://github.com/wala/WALA/discussions/1417#discussioncomment-10085680.
2+
3+
import tensorflow as tf
4+
from src.tf2_test_model_call5b import g
5+
6+
# Create an override model to classify pictures
7+
8+
9+
class SequentialModel(tf.keras.Model):
10+
11+
def __init__(self, **kwargs):
12+
super(SequentialModel, self).__init__(**kwargs)
13+
14+
self.flatten = tf.keras.layers.Flatten(input_shape=(28, 28))
15+
16+
# Add a lot of small layers
17+
num_layers = 100
18+
self.my_layers = [
19+
tf.keras.layers.Dense(64, activation="relu") for n in range(num_layers)
20+
]
21+
22+
self.dropout = tf.keras.layers.Dropout(0.2)
23+
self.dense_2 = tf.keras.layers.Dense(10)
24+
25+
def __call__(self, x):
26+
print("Raffi 1")
27+
x = self.flatten(x)
28+
29+
for layer in self.my_layers:
30+
x = layer(x)
31+
32+
x = self.dropout(x)
33+
x = self.dense_2(x)
34+
35+
return x
36+
37+
def predict(self, x):
38+
return self(x)
39+
40+
41+
input_data = tf.random.uniform([20, 28, 28])
42+
43+
model = SequentialModel()
44+
result = g(model, input_data)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Test https://github.com/wala/WALA/discussions/1417#discussioncomment-10085680.
2+
3+
import tensorflow as tf
4+
from src.tf2_test_model_call5b import g
5+
6+
# Create an override model to classify pictures
7+
8+
9+
class SequentialModel(tf.keras.Model):
10+
11+
def __init__(self, **kwargs):
12+
super(SequentialModel, self).__init__(**kwargs)
13+
14+
self.flatten = tf.keras.layers.Flatten(input_shape=(28, 28))
15+
16+
# Add a lot of small layers
17+
num_layers = 100
18+
self.my_layers = [
19+
tf.keras.layers.Dense(64, activation="relu") for n in range(num_layers)
20+
]
21+
22+
self.dropout = tf.keras.layers.Dropout(0.2)
23+
self.dense_2 = tf.keras.layers.Dense(10)
24+
25+
def __call__(self, x):
26+
print("Raffi 2")
27+
x = self.flatten(x)
28+
29+
for layer in self.my_layers:
30+
x = layer(x)
31+
32+
x = self.dropout(x)
33+
x = self.dense_2(x)
34+
35+
return x
36+
37+
def predict(self, x):
38+
return self(x)
39+
40+
41+
input_data = tf.random.uniform([20, 28, 28])
42+
43+
model = SequentialModel()
44+
result = g(model, input_data)

com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonInstanceMethodTrampolineTargetSelector.java

+20-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@
4343
import com.ibm.wala.util.collections.HashMapFactory;
4444
import com.ibm.wala.util.collections.Pair;
4545
import com.ibm.wala.util.intset.OrdinalSet;
46+
import java.util.HashSet;
4647
import java.util.Map;
48+
import java.util.Set;
4749
import java.util.logging.Logger;
4850

4951
public class PythonInstanceMethodTrampolineTargetSelector<T>
@@ -87,6 +89,8 @@ protected boolean shouldProcess(CGNode caller, CallSiteReference site, IClass re
8789

8890
@Override
8991
public IMethod getCalleeTarget(CGNode caller, CallSiteReference site, IClass receiver) {
92+
// TODO: Callable detection may need to be moved. See https://github.com/wala/ML/issues/207. If
93+
// it stays here, we should further document the receiver swapping process.
9094
if (isCallable(receiver)) {
9195
LOGGER.fine("Encountered callable.");
9296

@@ -223,6 +227,9 @@ private IClass getCallable(CGNode caller, IClassHierarchy cha, PythonInvokeInstr
223227
PointerKey receiver = pkf.getPointerKeyForLocal(caller, call.getUse(0));
224228
OrdinalSet<InstanceKey> objs = builder.getPointerAnalysis().getPointsToSet(receiver);
225229

230+
// The set of potential callables to be returned.
231+
Set<IClass> callableSet = new HashSet<>();
232+
226233
for (InstanceKey o : objs) {
227234
AllocationSiteInNode instanceKey = getAllocationSiteInNode(o);
228235
if (instanceKey != null) {
@@ -254,10 +261,22 @@ private IClass getCallable(CGNode caller, IClassHierarchy cha, PythonInvokeInstr
254261
LOGGER.info("Applying callable workaround for https://github.com/wala/ML/issues/118.");
255262
}
256263

257-
if (callable != null) return callable;
264+
callableSet.add(callable);
258265
}
259266
}
260267

268+
// if there's only one possible option.
269+
if (callableSet.size() == 1) {
270+
IClass callable = callableSet.iterator().next();
271+
assert callable != null : "Callable should be non-null.";
272+
return callable;
273+
}
274+
275+
// if we have multiple candidates.
276+
if (callableSet.size() > 1)
277+
// we cannot accurately select one.
278+
LOGGER.warning("Multiple (" + callableSet.size() + ") callable targets found.");
279+
261280
return null;
262281
}
263282

0 commit comments

Comments
 (0)