Skip to content

Commit ddba21e

Browse files
authored
Support class methods (#187)
https://docs.python.org/3/library/functions.html#classmethod Includes an extract superclass refactoring between two trampoline target selectors. * Add comment. For #107. * Add javadoc.
1 parent 106d708 commit ddba21e

File tree

15 files changed

+690
-127
lines changed

15 files changed

+690
-127
lines changed

com.ibm.wala.cast.python.jython3/source/com/ibm/wala/cast/python/parser/PythonParser.java

+7-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
*****************************************************************************/
1111
package com.ibm.wala.cast.python.parser;
1212

13+
import static com.ibm.wala.cast.python.util.Util.CLASS_METHOD_ANNOTATION_NAME;
1314
import static com.ibm.wala.cast.python.util.Util.DYNAMIC_ANNOTATION_KEY;
1415
import static com.ibm.wala.cast.python.util.Util.STATIC_METHOD_ANNOTATION_NAME;
1516
import static com.ibm.wala.cast.python.util.Util.getNameStream;
@@ -1263,8 +1264,12 @@ public int getKind() {
12631264
@Override
12641265
public CAstNode getAST() {
12651266
if (function instanceof FunctionDef) {
1266-
// Only add object metadata for non-static methods.
1267-
if (isMethod && !staticMethod) {
1267+
1268+
boolean classMethod =
1269+
getNameStream(annotations).anyMatch(s -> s.equals(CLASS_METHOD_ANNOTATION_NAME));
1270+
1271+
// Only add object metadata for non-static and non-class methods.
1272+
if (isMethod && !staticMethod && !classMethod) {
12681273
CAst Ast = PythonParser.this.Ast;
12691274

12701275
CAstNode[] newNodes = new CAstNode[nodes.length + 2];

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java

+109
Original file line numberDiff line numberDiff line change
@@ -3500,6 +3500,115 @@ public void testStaticMethod12() throws ClassHierarchyException, CancelException
35003500
expectedTensorParameterValueNumbers);
35013501
}
35023502

3503+
@Test
3504+
public void testClassMethod() throws ClassHierarchyException, CancelException, IOException {
3505+
int expectNumberofTensorParameters;
3506+
int expectedNumberOfTensorVariables;
3507+
int[] expectedTensorParameterValueNumbers;
3508+
3509+
// Class methods are only supported for Jython3.
3510+
if (usesJython3Testing()) {
3511+
expectNumberofTensorParameters = 1;
3512+
expectedNumberOfTensorVariables = 1;
3513+
expectedTensorParameterValueNumbers = new int[] {3};
3514+
} else {
3515+
// NOTE: Remove this case once https://github.com/wala/ML/issues/147 is fixed.
3516+
expectNumberofTensorParameters = 1;
3517+
expectedNumberOfTensorVariables = 1;
3518+
expectedTensorParameterValueNumbers = new int[] {2};
3519+
}
3520+
3521+
test(
3522+
"tf2_test_class_method.py",
3523+
"MyClass.the_class_method",
3524+
expectNumberofTensorParameters,
3525+
expectedNumberOfTensorVariables,
3526+
expectedTensorParameterValueNumbers);
3527+
}
3528+
3529+
@Test
3530+
public void testClassMethod2() throws ClassHierarchyException, CancelException, IOException {
3531+
test("tf2_test_class_method2.py", "MyClass.the_class_method", 1, 1, 3);
3532+
}
3533+
3534+
@Test
3535+
public void testClassMethod3() throws ClassHierarchyException, CancelException, IOException {
3536+
int expectNumberofTensorParameters;
3537+
int expectedNumberOfTensorVariables;
3538+
int[] expectedTensorParameterValueNumbers;
3539+
3540+
// Class methods are only supported for Jython3.
3541+
if (usesJython3Testing()) {
3542+
expectNumberofTensorParameters = 1;
3543+
expectedNumberOfTensorVariables = 1;
3544+
expectedTensorParameterValueNumbers = new int[] {3};
3545+
} else {
3546+
// NOTE: Remove this case once https://github.com/wala/ML/issues/147 is fixed.
3547+
expectNumberofTensorParameters = 0;
3548+
expectedNumberOfTensorVariables = 0;
3549+
expectedTensorParameterValueNumbers = new int[] {};
3550+
}
3551+
3552+
test(
3553+
"tf2_test_class_method3.py",
3554+
"MyClass.f",
3555+
expectNumberofTensorParameters,
3556+
expectedNumberOfTensorVariables,
3557+
expectedTensorParameterValueNumbers);
3558+
}
3559+
3560+
@Test
3561+
public void testClassMethod4() throws ClassHierarchyException, CancelException, IOException {
3562+
int expectNumberofTensorParameters;
3563+
int expectedNumberOfTensorVariables;
3564+
int[] expectedTensorParameterValueNumbers;
3565+
3566+
// Class methods are only supported for Jython3.
3567+
if (usesJython3Testing()) {
3568+
expectNumberofTensorParameters = 1;
3569+
expectedNumberOfTensorVariables = 1;
3570+
expectedTensorParameterValueNumbers = new int[] {3};
3571+
} else {
3572+
// NOTE: Remove this case once https://github.com/wala/ML/issues/147 is fixed.
3573+
expectNumberofTensorParameters = 0;
3574+
expectedNumberOfTensorVariables = 0;
3575+
expectedTensorParameterValueNumbers = new int[] {};
3576+
}
3577+
3578+
test(
3579+
"tf2_test_class_method4.py",
3580+
"MyClass.f",
3581+
expectNumberofTensorParameters,
3582+
expectedNumberOfTensorVariables,
3583+
expectedTensorParameterValueNumbers);
3584+
}
3585+
3586+
@Test
3587+
public void testClassMethod5() throws ClassHierarchyException, CancelException, IOException {
3588+
int expectNumberofTensorParameters;
3589+
int expectedNumberOfTensorVariables;
3590+
int[] expectedTensorParameterValueNumbers;
3591+
3592+
// Class methods are only supported for Jython3.
3593+
if (usesJython3Testing()) {
3594+
expectNumberofTensorParameters = 1;
3595+
expectedNumberOfTensorVariables = 1;
3596+
expectedTensorParameterValueNumbers = new int[] {3};
3597+
} else {
3598+
// NOTE: Remove this case once https://github.com/wala/ML/issues/147 is fixed.
3599+
expectNumberofTensorParameters = 0;
3600+
expectedNumberOfTensorVariables = 0;
3601+
expectedTensorParameterValueNumbers = new int[] {};
3602+
}
3603+
3604+
test(
3605+
"tf2_test_class_method5.py",
3606+
"MyClass.f",
3607+
expectNumberofTensorParameters,
3608+
expectedNumberOfTensorVariables,
3609+
expectedTensorParameterValueNumbers);
3610+
}
3611+
35033612
private void test(
35043613
String filename,
35053614
String functionName,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import tensorflow as tf
2+
3+
4+
class MyClass:
5+
6+
@classmethod
7+
def the_class_method(cls, x):
8+
assert isinstance(x, tf.Tensor)
9+
10+
11+
MyClass.the_class_method(tf.constant(1))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import tensorflow as tf
2+
3+
4+
class MyClass:
5+
6+
@classmethod
7+
def the_class_method(cls, x):
8+
assert isinstance(x, tf.Tensor)
9+
10+
11+
MyClass().the_class_method(tf.constant(1))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import tensorflow as tf
2+
3+
4+
class MyClass:
5+
6+
def f(x):
7+
assert isinstance(x, tf.Tensor)
8+
9+
@classmethod
10+
def the_class_method(cls, x):
11+
assert isinstance(x, tf.Tensor)
12+
cls.f(x)
13+
14+
15+
MyClass().the_class_method(tf.constant(1))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import tensorflow as tf
2+
3+
4+
class MyClass:
5+
6+
def f(x):
7+
assert isinstance(x, tf.Tensor)
8+
9+
@classmethod
10+
def the_class_method(cls, x):
11+
assert isinstance(x, tf.Tensor)
12+
cls.f(x)
13+
14+
15+
MyClass.the_class_method(tf.constant(1))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import tensorflow as tf
2+
3+
4+
class MyClass:
5+
6+
def f(x):
7+
assert isinstance(x, tf.Tensor)
8+
9+
@classmethod
10+
def the_class_method(cls, x):
11+
assert isinstance(x, tf.Tensor)
12+
cls.f(x)
13+
14+
15+
MyClass.the_class_method(tf.constant(1))

com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/client/PythonAnalysisEngine.java

+6-4
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
import com.ibm.wala.cast.ipa.callgraph.AstContextInsensitiveSSAContextInterpreter;
77
import com.ibm.wala.cast.ir.ssa.AstIRFactory;
88
import com.ibm.wala.cast.loader.AstDynamicField;
9+
import com.ibm.wala.cast.python.ipa.callgraph.PythonClassMethodTrampolineTargetSelector;
910
import com.ibm.wala.cast.python.ipa.callgraph.PythonConstructorTargetSelector;
11+
import com.ibm.wala.cast.python.ipa.callgraph.PythonInstanceMethodTrampolineTargetSelector;
1012
import com.ibm.wala.cast.python.ipa.callgraph.PythonSSAPropagationCallGraphBuilder;
1113
import com.ibm.wala.cast.python.ipa.callgraph.PythonScopeMappingInstanceKeys;
12-
import com.ibm.wala.cast.python.ipa.callgraph.PythonTrampolineTargetSelector;
1314
import com.ibm.wala.cast.python.ipa.summaries.BuiltinFunctions;
1415
import com.ibm.wala.cast.python.ipa.summaries.PythonComprehensionTrampolines;
1516
import com.ibm.wala.cast.python.ipa.summaries.PythonSuper;
@@ -299,9 +300,10 @@ public boolean isReferenceType() {
299300

300301
protected void addBypassLogic(IClassHierarchy cha, AnalysisOptions options) {
301302
options.setSelector(
302-
new PythonTrampolineTargetSelector<T>(
303-
new PythonConstructorTargetSelector(
304-
new PythonComprehensionTrampolines(options.getMethodTargetSelector())),
303+
new PythonInstanceMethodTrampolineTargetSelector<T>(
304+
new PythonClassMethodTrampolineTargetSelector<T>(
305+
new PythonConstructorTargetSelector(
306+
new PythonComprehensionTrampolines(options.getMethodTargetSelector()))),
305307
this));
306308

307309
BuiltinFunctions builtins = new BuiltinFunctions(cha);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/******************************************************************************
2+
* Copyright (c) 2018 IBM Corporation.
3+
* All rights reserved. This program and the accompanying materials
4+
* are made available under the terms of the Eclipse Public License v1.0
5+
* which accompanies this distribution, and is available at
6+
* http://www.eclipse.org/legal/epl-v10.html
7+
*
8+
* Contributors:
9+
* IBM Corporation - initial API and implementation
10+
*****************************************************************************/
11+
package com.ibm.wala.cast.python.ipa.callgraph;
12+
13+
import static com.ibm.wala.cast.python.types.Util.getGlobalName;
14+
import static com.ibm.wala.cast.python.types.Util.makeGlobalRef;
15+
import static com.ibm.wala.cast.python.util.Util.isClassMethod;
16+
17+
import com.ibm.wala.cast.ir.ssa.AstGlobalRead;
18+
import com.ibm.wala.cast.loader.DynamicCallSiteReference;
19+
import com.ibm.wala.cast.python.ipa.summaries.PythonSummary;
20+
import com.ibm.wala.cast.python.ir.PythonLanguage;
21+
import com.ibm.wala.cast.python.ssa.PythonInvokeInstruction;
22+
import com.ibm.wala.cast.python.types.PythonTypes;
23+
import com.ibm.wala.classLoader.CallSiteReference;
24+
import com.ibm.wala.classLoader.IClass;
25+
import com.ibm.wala.core.util.strings.Atom;
26+
import com.ibm.wala.ipa.callgraph.CGNode;
27+
import com.ibm.wala.ipa.callgraph.MethodTargetSelector;
28+
import com.ibm.wala.ipa.cha.IClassHierarchy;
29+
import com.ibm.wala.ssa.SSAInstructionFactory;
30+
import com.ibm.wala.ssa.SSAReturnInstruction;
31+
import com.ibm.wala.types.FieldReference;
32+
import com.ibm.wala.util.collections.HashMapFactory;
33+
import com.ibm.wala.util.collections.Pair;
34+
import java.util.Map;
35+
import java.util.logging.Logger;
36+
37+
/**
38+
* A trampoline for <a href="https://docs.python.org/3/library/functions.html#classmethod">class
39+
* methods</a> that are not called using object instances.
40+
*
41+
* @author <a href="mailto:[email protected]">Raffi Khatchadourian</a>
42+
*/
43+
public class PythonClassMethodTrampolineTargetSelector<T>
44+
extends PythonMethodTrampolineTargetSelector<T> {
45+
46+
protected static final Logger LOGGER =
47+
Logger.getLogger(PythonClassMethodTrampolineTargetSelector.class.getName());
48+
49+
public PythonClassMethodTrampolineTargetSelector(MethodTargetSelector base) {
50+
super(base);
51+
}
52+
53+
@Override
54+
protected boolean shouldProcess(CGNode caller, CallSiteReference site, IClass receiver) {
55+
IClassHierarchy cha = receiver.getClassHierarchy();
56+
57+
// Are we calling a class method?
58+
boolean classMethodReceiver = isClassMethod(receiver);
59+
60+
// Is the caller a trampoline?
61+
boolean trampoline =
62+
caller
63+
.getMethod()
64+
.getSelector()
65+
.getName()
66+
.startsWith(Atom.findOrCreateAsciiAtom("trampoline"));
67+
68+
return classMethodReceiver
69+
&& !cha.isSubclassOf(receiver, cha.lookupClass(PythonTypes.trampoline))
70+
&& !trampoline;
71+
}
72+
73+
@SuppressWarnings("unchecked")
74+
@Override
75+
protected void populate(
76+
PythonSummary x, int v, IClass receiver, PythonInvokeInstruction call, Logger logger) {
77+
Map<Integer, Atom> names = HashMapFactory.make();
78+
SSAInstructionFactory insts = PythonLanguage.Python.instructionFactory();
79+
80+
// Read the class from the global scope.
81+
String globalName = getGlobalName(receiver.getReference());
82+
FieldReference globalRef = makeGlobalRef(receiver.getClassLoader(), globalName);
83+
int globalReadRes = v++;
84+
int pc = 0;
85+
86+
x.addStatement(new AstGlobalRead(pc++, globalReadRes, globalRef));
87+
88+
int getInstRes = v++;
89+
90+
// Read the field from the class corresponding to the called method.
91+
FieldReference method =
92+
FieldReference.findOrCreate(
93+
PythonTypes.Root, Atom.findOrCreateUnicodeAtom("the_class_method"), PythonTypes.Root);
94+
95+
x.addStatement(insts.GetInstruction(pc++, getInstRes, globalReadRes, method));
96+
97+
int i = 0;
98+
int paramSize = Math.max(2, call.getNumberOfPositionalParameters() + 1);
99+
int[] params = new int[paramSize];
100+
params[i++] = getInstRes;
101+
params[i++] = globalReadRes;
102+
103+
for (int j = 1; j < call.getNumberOfPositionalParameters(); j++) params[i++] = j + 1;
104+
105+
int ki = 0, ji = call.getNumberOfPositionalParameters() + 1;
106+
Pair<String, Integer>[] keys = new Pair[0];
107+
108+
if (call.getKeywords() != null) {
109+
keys = new Pair[call.getKeywords().size()];
110+
111+
for (String k : call.getKeywords()) {
112+
names.put(ji, Atom.findOrCreateUnicodeAtom(k));
113+
keys[ki++] = Pair.make(k, ji++);
114+
}
115+
}
116+
117+
CallSiteReference ref = new DynamicCallSiteReference(call.getCallSite().getDeclaredTarget(), 2);
118+
119+
int except = v++;
120+
int invokeResult = v++;
121+
122+
x.addStatement(new PythonInvokeInstruction(pc++, invokeResult, except, ref, params, keys));
123+
x.addStatement(new SSAReturnInstruction(pc++, invokeResult, false));
124+
x.setValueNames(names);
125+
}
126+
127+
@Override
128+
protected Logger getLogger() {
129+
return LOGGER;
130+
}
131+
}

0 commit comments

Comments
 (0)