Skip to content

Commit 7d1350c

Browse files
Merge branch 'master' of https://github.com/petablox/code2seq
2 parents 8846ebb + b617c8f commit 7d1350c

File tree

10 files changed

+702
-8
lines changed

10 files changed

+702
-8
lines changed

clean_and_split.py

+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import glob
2+
import os
3+
import sys
4+
import random
5+
import javalang
6+
import numpy as np
7+
from tqdm import tqdm
8+
from math import ceil
9+
from shutil import copyfile as cp
10+
11+
TRAIN_SPLIT = .8
12+
TEST_VAL_SPLIT = .1
13+
14+
#minimum number of times the method name must be seen to include it in the dataset
15+
MIN_NUM = 5
16+
17+
def copy_files(files, folder):
18+
for i in range(0, len(files)):
19+
cp(files[i], os.path.join(out_dir, folder, str(i) + ".java"))
20+
21+
def add_to_method_map(m_name):
22+
if m_name in m_names:
23+
m_names[m_name] += 1
24+
else:
25+
m_names[m_name] = 1
26+
27+
def get_all_methods(f_name):
28+
with open(f_name, "rb") as f:
29+
c = f.read()
30+
31+
try:
32+
tree = javalang.parse.parse(c)
33+
methods = list(tree.filter(javalang.tree.MethodDeclaration))
34+
35+
except (javalang.parser.JavaSyntaxError, AttributeError, javalang.tokenizer.LexerError, TypeError, RecursionError, StopIteration) as e:
36+
#print(e)
37+
return []
38+
39+
return methods
40+
41+
def split_by_token(name):
42+
tokens = []
43+
token = ""
44+
prev = ""
45+
46+
for c in name:
47+
if ((c.isupper() and prev.islower()) or c == "_" ) and len(token) > 0:
48+
tokens.append(token)
49+
token = c
50+
51+
else:
52+
token += c
53+
54+
prev = c
55+
56+
57+
if len(token) > 0:
58+
tokens.append(token)
59+
60+
return tokens
61+
62+
if len(sys.argv) < 3:
63+
print("USAGE: python clean_and_split.py IN_DIR OUT_DIR")
64+
65+
data_dir = sys.argv[1]
66+
out_dir = sys.argv[2]
67+
68+
split_or_clean = sys.argv[3]
69+
split, clean, vec = False, False, False
70+
71+
if split_or_clean == "split":
72+
split = True
73+
elif split_or_clean == "clean":
74+
clean = True
75+
vec_or_seq = sys.argv[4]
76+
if vec == "seq":
77+
vec = False
78+
else:
79+
print("command not accepted")
80+
sys.exit(1)
81+
82+
83+
all_files = []
84+
m_names = {}
85+
86+
for (dirpath, dirnames, filenames) in os.walk(data_dir):
87+
all_files += [os.path.join(dirpath, _file) for _file in filenames]
88+
89+
if clean:
90+
for _file in tqdm(all_files):
91+
methods = get_all_methods(_file)
92+
for path, node in methods:
93+
names = [node.name] if vec else split_by_token(node.name)
94+
95+
for name in names:
96+
add_to_method_map(name)
97+
98+
m_clean = {k: v for k, v in m_names.items() if v >= MIN_NUM}
99+
print("total", len(m_names), "clean", len(m_clean))
100+
101+
s = ""
102+
for k, v in m_clean.items():
103+
s += k + "\n"
104+
105+
with open("clean_names.txt", "w") as f:
106+
f.write(s)
107+
108+
109+
#clean files here by putting each method in a new file?
110+
111+
if split:
112+
random.shuffle(all_files)
113+
114+
l = len(all_files)
115+
end = ceil(TRAIN_SPLIT*l)
116+
train = all_files[0:end]
117+
118+
start = end
119+
end = end + ceil(TEST_VAL_SPLIT*l)
120+
val = all_files[start:end]
121+
122+
test = all_files[end:]
123+
124+
125+
if not os.path.exists(out_dir):
126+
os.mkdir(out_dir)
127+
os.mkdir(os.path.join(out_dir, "training"))
128+
os.mkdir(os.path.join(out_dir, "test"))
129+
os.mkdir(os.path.join(out_dir, "validation"))
130+
131+
copy_files(train, "training")
132+
copy_files(test, "test")
133+
copy_files(val, "validation")

config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ def get_default_config(args):
55
config.NUM_EPOCHS = 3000
66
config.SAVE_EVERY_EPOCHS = 1
77
config.PATIENCE = 10
8-
config.BATCH_SIZE = 512
8+
config.BATCH_SIZE = 450
99
config.TEST_BATCH_SIZE = 256
1010
config.READER_NUM_PARALLEL_BATCHES = 1
1111
config.SHUFFLE_BUFFER_SIZE = 10000

file_level_split.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import glob
2+
import os
3+
import random
4+
from math import ceil
5+
from shutil import copyfile as cp
6+
7+
8+
data_dir = "/data2/edinella/java-small-og/"
9+
out_dir = "/data2/edinella/java-small-og-fs/"
10+
11+
TRAIN_SPLIT = .8
12+
TEST_VAL_SPLIT = .1
13+
14+
def copy_files(files, folder):
15+
for i in range(0, len(files)):
16+
cp(files[i], os.path.join(out_dir, folder, str(i) + ".java"))
17+
18+
all_files = []
19+
20+
for (dirpath, dirnames, filenames) in os.walk(data_dir):
21+
all_files += [os.path.join(dirpath, _file) for _file in filenames]
22+
23+
random.shuffle(all_files)
24+
25+
l = len(all_files)
26+
end = ceil(TRAIN_SPLIT*l)
27+
train = all_files[0:end]
28+
29+
start = end
30+
end = end + ceil(TEST_VAL_SPLIT*l)
31+
val = all_files[start:end]
32+
33+
test = all_files[end:]
34+
35+
36+
if not os.path.exists(out_dir):
37+
os.mkdir(out_dir)
38+
os.mkdir(os.path.join(out_dir, "training"))
39+
os.mkdir(os.path.join(out_dir, "test"))
40+
os.mkdir(os.path.join(out_dir, "validation"))
41+
42+
copy_files(train, "training")
43+
copy_files(test, "test")
44+
copy_files(val, "validation")
45+

java-parser/pom.xml

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
2+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
3+
<modelVersion>4.0.0</modelVersion>
4+
5+
<groupId>method_analyzer</groupId>
6+
<artifactId>MethodLines</artifactId>
7+
<version>0.0.1-SNAPSHOT</version>
8+
<packaging>jar</packaging>
9+
10+
<name>java-parser</name>
11+
<url>http://maven.apache.org</url>
12+
13+
<properties>
14+
<maven.compiler.source>1.8</maven.compiler.source>
15+
<maven.compiler.target>1.8</maven.compiler.target>
16+
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
17+
</properties>
18+
19+
<build>
20+
<plugins>
21+
<plugin>
22+
<groupId>org.apache.maven.plugins</groupId>
23+
<artifactId>maven-jar-plugin</artifactId>
24+
<version>2.4</version>
25+
<configuration>
26+
<archive>
27+
<manifest>
28+
<mainClass>method_analyzer.MethodLines</mainClass>
29+
</manifest>
30+
</archive>
31+
</configuration>
32+
</plugin>
33+
<plugin>
34+
<groupId>org.apache.maven.plugins</groupId>
35+
<artifactId>maven-shade-plugin</artifactId>
36+
<version>3.2.1</version>
37+
<executions>
38+
<execution>
39+
<phase>package</phase>
40+
<goals>
41+
<goal>shade</goal>
42+
</goals>
43+
</execution>
44+
</executions>
45+
</plugin>
46+
</plugins>
47+
</build>
48+
49+
<dependencies>
50+
<dependency>
51+
<groupId>commons-io</groupId>
52+
<artifactId>commons-io</artifactId>
53+
<version>2.2</version>
54+
</dependency>
55+
<dependency>
56+
<groupId>junit</groupId>
57+
<artifactId>junit</artifactId>
58+
<version>3.8.1</version>
59+
<scope>test</scope>
60+
</dependency>
61+
<dependency>
62+
<groupId>com.github.javaparser</groupId>
63+
<artifactId>javaparser-symbol-solver-core</artifactId>
64+
<version>3.15.6</version>
65+
</dependency>
66+
<dependency>
67+
<groupId>com.google.code.gson</groupId>
68+
<artifactId>gson</artifactId>
69+
<version>2.8.6</version>
70+
</dependency>
71+
</dependencies>
72+
</project>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
package method_analyzer;
2+
3+
import java.io.File;
4+
import java.io.FileNotFoundException;
5+
import java.io.FilenameFilter;
6+
import java.io.IOException;
7+
import java.io.PrintStream;
8+
import java.nio.file.Files;
9+
import java.nio.file.Path;
10+
import java.nio.file.Paths;
11+
import java.nio.file.StandardOpenOption;
12+
import java.util.ArrayList;
13+
14+
import com.github.javaparser.StaticJavaParser;
15+
import com.github.javaparser.ast.CompilationUnit;
16+
import com.github.javaparser.ast.DataKey;
17+
import com.github.javaparser.ast.Node;
18+
import com.github.javaparser.ast.body.MethodDeclaration;
19+
import com.github.javaparser.ast.visitor.VoidVisitorAdapter;
20+
21+
import method_analyzer.Utils;
22+
23+
@SuppressWarnings({"WeakerAccess", "unused"})
24+
public final class Common {
25+
26+
static File inputFile;
27+
static String outputPath = "";
28+
static String mSavePath = "";
29+
30+
static MethodDeclaration before;
31+
static MethodDeclaration after;
32+
33+
static final DataKey<Integer> VariableId = new DataKey<Integer>() {};
34+
static final DataKey<String> VariableName = new DataKey<String>() {};
35+
36+
static ArrayList<Path> getFilePaths(String rootPath) {
37+
ArrayList<Path> listOfPaths = new ArrayList<>();
38+
final FilenameFilter filter = (dir, name) -> dir.isDirectory() && name.toLowerCase().endsWith(".txt");
39+
File[] listOfFiles = new File(rootPath).listFiles(filter);
40+
if (listOfFiles == null) return new ArrayList<>();
41+
for (File file : listOfFiles) {
42+
Path codePath = Paths.get(file.getPath());
43+
listOfPaths.add(codePath);
44+
}
45+
return listOfPaths;
46+
}
47+
48+
public static void inspectSourceCode(Object obj, File javaFile) {
49+
}
50+
51+
static void setOutputPath(Object obj, File javaFile) {
52+
//assume '/transforms' in output path
53+
Common.mSavePath = Common.outputPath.replace("/transforms",
54+
"/transforms/"+obj.getClass().getSimpleName());
55+
}
56+
57+
static CompilationUnit getParseUnit(File javaFile) {
58+
CompilationUnit root = null;
59+
try {
60+
String txtCode = new String(Files.readAllBytes(javaFile.toPath()));
61+
root = StaticJavaParser.parse(txtCode);
62+
} catch (Exception ex) {
63+
/*
64+
System.out.println("\n" + "Exception: " + javaFile.getPath());
65+
ex.printStackTrace();*/
66+
String error_dir = Common.mSavePath + "java_parser_error.txt";
67+
Common.saveErrText(error_dir, javaFile);
68+
}
69+
return root;
70+
}
71+
72+
73+
static synchronized void saveTransformation(CompilationUnit aRoot) {
74+
aRoot.accept(new VoidVisitorAdapter<Void>() {
75+
@Override
76+
public void visit(MethodDeclaration md, Void arg) {
77+
int numberOfFiles = Utils.getNumberOfFiles(outputPath);
78+
String newCodePath = outputPath + numberOfFiles + ".java";
79+
Common.writeSourceCode(md, newCodePath);
80+
Utils.incrementNumberOfFiles(outputPath);
81+
}
82+
}, null);
83+
}
84+
85+
static void saveErrText(String error_dir, File javaFile) {
86+
try {
87+
File targetFile = new File(error_dir);
88+
if ((targetFile.getParentFile() != null) && (targetFile.getParentFile().exists() || targetFile.getParentFile().mkdirs())) {
89+
if (targetFile.exists() || targetFile.createNewFile()) {
90+
Files.write(Paths.get(error_dir),
91+
(javaFile.getPath() + "\n").getBytes(),
92+
StandardOpenOption.APPEND);
93+
}
94+
}
95+
} catch (IOException ioEx) {
96+
ioEx.printStackTrace();
97+
}
98+
}
99+
100+
static void writeSourceCode(MethodDeclaration md, String codePath) {
101+
102+
try (PrintStream ps = new PrintStream(codePath)) {
103+
if (md.isDefault()){
104+
md.setDefault(false);
105+
}
106+
107+
String tfSourceCode = md.toString();
108+
String surroundingClassDef = "class AABBCC { \n\n" + tfSourceCode + "\n\n}";
109+
ps.println(surroundingClassDef);
110+
111+
} catch (FileNotFoundException ex) {
112+
ex.printStackTrace();
113+
}
114+
}
115+
116+
}

0 commit comments

Comments
 (0)