Skip to content

Commit 70d9a55

Browse files
committed
Add rouge scores calculations and python150k reproducing.
1 parent be7b040 commit 70d9a55

File tree

7 files changed

+356
-3
lines changed

7 files changed

+356
-3
lines changed

Python150kExtractor/README.md

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Python150k dataset
2+
3+
## Steps to reproduce
4+
5+
1. Download parsed python dataset from [here](https://www.sri.inf.ethz.ch/py150
6+
), unarchive and place under `PYTHON150K_DIR`:
7+
8+
```bash
9+
# Replace with desired path.
10+
>>> PYTHON150K_DIR=/path/to/data/dir
11+
>>> mkdir -p $PYTHON150K_DIR
12+
>>> cd $PYTHON150K_DIR
13+
>>> wget http://files.srl.inf.ethz.ch/data/py150.tar.gz
14+
...
15+
>>> tar -xzvf py150.tar.gz
16+
...
17+
```
18+
19+
2. Extract samples to `DATA_DIR`:
20+
21+
```bash
22+
# Replace with desired path.
23+
>>> DATA_DIR=$(pwd)/data/default
24+
>>> SEED=239
25+
>>> python extract.py \
26+
--data_dir=$PYTHON150K_DIR \
27+
--output_dir=$DATA_DIR \
28+
--seed=$SEED
29+
...
30+
```
31+
32+
3. Preprocess for training:
33+
34+
```bash
35+
>>> ./preprocess.sh $DATA_DIR
36+
...
37+
```
38+
39+
4. Train:
40+
41+
```bash
42+
>>> cd ..
43+
>>> DESC=default
44+
>>> CUDA=0
45+
>>> ./train_python150k.sh $DATA_DIR $DESC $CUDA $SEED
46+
...
47+
```
48+
49+
## Test results (seed=239)
50+
51+
### Best scores
52+
53+
**setup#2**: `batch_size=64`
54+
**setup#3**: `embedding_size=256,use_momentum=False`
55+
**setup#4**: `batch_size=32,embedding_size=256,embeddings_dropout_keep_prob=0.5,use_momentum=False`
56+
57+
| params | Precision | Recall | F1 | ROUGE-2 | ROUGE-L |
58+
|---|---|---|---|---|---|
59+
| default | 0.37 | 0.27 | 0.31 | 0.06 | 0.38 |
60+
| setup#2 | 0.40 | 0.31 | 0.34 | 0.08 | 0.41 |
61+
| setup#3 | 0.36 | 0.31 | 0.33 | 0.09 | 0.38 |
62+
| setup#4 | 0.33 | 0.25 | 0.28 | 0.05 | 0.34 |
63+
64+
### Ablation studies
65+
66+
| params | Precision | Recall | F1 | ROUGE-2 | ROUGE-L |
67+
|---|---|---|---|---|---|
68+
| default | 0.37 | 0.27 | 0.31 | 0.06 | 0.38 |
69+
| no ast nodes (5th epoch) | 0.27 | 0.16 | 0.20 | 0.02 | 0.28 |
70+
| no token split (4th epoch) | 0.60 | 0.09 | 0.15 | 0.00 | 0.60 |

Python150kExtractor/extract.py

+193
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
import argparse
2+
import re
3+
import json
4+
import multiprocessing
5+
import itertools
6+
import tqdm
7+
import joblib
8+
import numpy as np
9+
10+
from pathlib import Path
11+
from sklearn import model_selection as sklearn_model_selection
12+
13+
METHOD_NAME, NUM = 'METHODNAME', 'NUM'
14+
15+
parser = argparse.ArgumentParser()
16+
parser.add_argument('--data_dir', required=True, type=str)
17+
parser.add_argument('--valid_p', type=float, default=0.2)
18+
parser.add_argument('--max_path_length', type=int, default=8)
19+
parser.add_argument('--max_path_width', type=int, default=2)
20+
parser.add_argument('--use_method_name', type=bool, default=True)
21+
parser.add_argument('--use_nums', type=bool, default=True)
22+
parser.add_argument('--output_dir', required=True, type=str)
23+
parser.add_argument('--n_jobs', type=int, default=multiprocessing.cpu_count())
24+
parser.add_argument('--seed', type=int, default=239)
25+
26+
27+
def __collect_asts(json_file):
28+
asts = []
29+
with open(json_file, 'r', encoding='utf-8') as f:
30+
for line in f:
31+
ast = json.loads(line.strip())
32+
asts.append(ast)
33+
34+
return asts
35+
36+
37+
def __terminals(ast, node_index, args):
38+
stack, paths = [], []
39+
40+
def dfs(v):
41+
stack.append(v)
42+
43+
v_node = ast[v]
44+
45+
if 'value' in v_node:
46+
if v == node_index: # Top-level func def node.
47+
if args.use_method_name:
48+
paths.append((stack.copy(), METHOD_NAME))
49+
else:
50+
v_type = v_node['type']
51+
52+
if v_type.startswith('Name'):
53+
paths.append((stack.copy(), v_node['value']))
54+
elif args.use_nums and v_type == 'Num':
55+
paths.append((stack.copy(), NUM))
56+
else:
57+
pass
58+
59+
if 'children' in v_node:
60+
for child in v_node['children']:
61+
dfs(child)
62+
63+
stack.pop()
64+
65+
dfs(node_index)
66+
67+
return paths
68+
69+
70+
def __merge_terminals2_paths(v_path, u_path):
71+
s, n, m = 0, len(v_path), len(u_path)
72+
while s < min(n, m) and v_path[s] == u_path[s]:
73+
s += 1
74+
75+
prefix = list(reversed(v_path[s:]))
76+
lca = v_path[s - 1]
77+
suffix = u_path[s:]
78+
79+
return prefix, lca, suffix
80+
81+
82+
def __raw_tree_paths(ast, node_index, args):
83+
tnodes = __terminals(ast, node_index, args)
84+
85+
tree_paths = []
86+
for (v_path, v_value), (u_path, u_value) in itertools.combinations(
87+
iterable=tnodes,
88+
r=2,
89+
):
90+
prefix, lca, suffix = __merge_terminals2_paths(v_path, u_path)
91+
if (len(prefix) + 1 + len(suffix) <= args.max_path_length) \
92+
and (abs(len(prefix) - len(suffix)) <= args.max_path_width):
93+
path = prefix + [lca] + suffix
94+
tree_path = v_value, path, u_value
95+
tree_paths.append(tree_path)
96+
97+
return tree_paths
98+
99+
100+
def __delim_name(name):
101+
if name in {METHOD_NAME, NUM}:
102+
return name
103+
104+
def camel_case_split(identifier):
105+
matches = re.finditer(
106+
'.+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)',
107+
identifier,
108+
)
109+
return [m.group(0) for m in matches]
110+
111+
blocks = []
112+
for underscore_block in name.split('_'):
113+
blocks.extend(camel_case_split(underscore_block))
114+
115+
return '|'.join(block.lower() for block in blocks)
116+
117+
118+
def __collect_sample(ast, fd_index, args):
119+
root = ast[fd_index]
120+
if root['type'] != 'FunctionDef':
121+
raise ValueError('Wrong node type.')
122+
123+
target = root['value']
124+
125+
tree_paths = __raw_tree_paths(ast, fd_index, args)
126+
contexts = []
127+
for tree_path in tree_paths:
128+
start, connector, finish = tree_path
129+
130+
start, finish = __delim_name(start), __delim_name(finish)
131+
connector = '|'.join(ast[v]['type'] for v in connector)
132+
133+
context = f'{start},{connector},{finish}'
134+
contexts.append(context)
135+
136+
if len(contexts) == 0:
137+
return None
138+
139+
target = __delim_name(target)
140+
context = ' '.join(contexts)
141+
142+
return f'{target} {context}'
143+
144+
145+
def __collect_samples(ast, args):
146+
samples = []
147+
for node_index, node in enumerate(ast):
148+
if node['type'] == 'FunctionDef':
149+
sample = __collect_sample(ast, node_index, args)
150+
if sample is not None:
151+
samples.append(sample)
152+
153+
return samples
154+
155+
156+
def __collect_all_and_save(asts, args, output_file):
157+
parallel = joblib.Parallel(n_jobs=args.n_jobs)
158+
func = joblib.delayed(__collect_samples)
159+
160+
samples = parallel(func(ast, args) for ast in tqdm.tqdm(asts))
161+
samples = list(itertools.chain.from_iterable(samples))
162+
163+
with open(output_file, 'w') as f:
164+
for line_index, line in enumerate(samples):
165+
f.write(line + ('' if line_index == len(samples) - 1 else '\n'))
166+
167+
168+
def main():
169+
args = parser.parse_args()
170+
np.random.seed(args.seed)
171+
172+
data_dir = Path(args.data_dir)
173+
trains = __collect_asts(data_dir / 'python100k_train.json')
174+
evals = __collect_asts(data_dir / 'python50k_eval.json')
175+
176+
train, valid = sklearn_model_selection.train_test_split(
177+
trains,
178+
test_size=args.valid_p,
179+
)
180+
test = evals
181+
182+
output_dir = Path(args.output_dir)
183+
output_dir.mkdir(exist_ok=True)
184+
for split_name, split in zip(
185+
('train', 'valid', 'test'),
186+
(train, valid, test),
187+
):
188+
output_file = output_dir / f'{split_name}_output_file.txt'
189+
__collect_all_and_save(split, args, output_file)
190+
191+
192+
if __name__ == '__main__':
193+
main()

Python150kExtractor/preprocess.sh

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#!/usr/bin/env bash
2+
3+
MAX_CONTEXTS=200
4+
MAX_DATA_CONTEXTS=1000
5+
SUBTOKEN_VOCAB_SIZE=186277
6+
TARGET_VOCAB_SIZE=26347
7+
8+
data_dir=${1:-data}
9+
mkdir -p "${data_dir}"
10+
train_data_file=$data_dir/train_output_file.txt
11+
valid_data_file=$data_dir/valid_output_file.txt
12+
test_data_file=$data_dir/test_output_file.txt
13+
14+
echo "Creating histograms from the training data..."
15+
target_histogram_file=$data_dir/histo.tgt.c2s
16+
source_subtoken_histogram=$data_dir/histo.ori.c2s
17+
node_histogram_file=$data_dir/histo.node.c2s
18+
cut <"${train_data_file}" -d' ' -f1 | tr '|' '\n' | awk '{n[$0]++} END {for (i in n) print i,n[i]}' >"${target_histogram_file}"
19+
cut <"${train_data_file}" -d' ' -f2- | tr ' ' '\n' | cut -d',' -f1,3 | tr ',|' '\n' | awk '{n[$0]++} END {for (i in n) print i,n[i]}' >"${source_subtoken_histogram}"
20+
cut <"${train_data_file}" -d' ' -f2- | tr ' ' '\n' | cut -d',' -f2 | tr '|' '\n' | awk '{n[$0]++} END {for (i in n) print i,n[i]}' >"${node_histogram_file}"
21+
22+
echo "Preprocessing..."
23+
python ../preprocess.py \
24+
--train_data "${train_data_file}" \
25+
--val_data "${valid_data_file}" \
26+
--test_data "${test_data_file}" \
27+
--max_contexts ${MAX_CONTEXTS} \
28+
--max_data_contexts ${MAX_DATA_CONTEXTS} \
29+
--subtoken_vocab_size ${SUBTOKEN_VOCAB_SIZE} \
30+
--target_vocab_size ${TARGET_VOCAB_SIZE} \
31+
--target_histogram "${target_histogram_file}" \
32+
--subtoken_histogram "${source_subtoken_histogram}" \
33+
--node_histogram "${node_histogram_file}" \
34+
--output_name "${data_dir}"/"$(basename "${data_dir}")"
35+
rm \
36+
"${target_histogram_file}" \
37+
"${source_subtoken_histogram}" \
38+
"${node_histogram_file}"

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ Table of Contents
3838
> python3 -c 'import tensorflow as tf; print(tf.\_\_version\_\_)'
3939
* For [creating a new Java dataset](#creating-and-preprocessing-a-new-java-dataset) or [manually examining a trained model](#step-4-manual-examination-of-a-trained-model) (any operation that requires parsing of a new code example): [JDK](https://openjdk.java.net/install/)
4040
* For creating a C# dataset: [dotnet-core](https://dotnet.microsoft.com/download) version 2.2 or newer.
41+
* `pip install rouge` for computing rouge scores.
4142

4243
## Quickstart
4344
### Step 0: Cloning this repository

code2seq.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from argparse import ArgumentParser
2+
import numpy as np
3+
import tensorflow as tf
24

35
from config import Config
46
from interactive_predict import InteractivePredictor
@@ -20,8 +22,12 @@
2022
'size.')
2123
parser.add_argument('--predict', action='store_true')
2224
parser.add_argument('--debug', action='store_true')
25+
parser.add_argument('--seed', type=int, default=239)
2326
args = parser.parse_args()
2427

28+
np.random.seed(args.seed)
29+
tf.set_random_seed(args.seed)
30+
2531
if args.debug:
2632
config = Config.get_debug_config(args)
2733
else:
@@ -32,9 +38,10 @@
3238
if config.TRAIN_PATH:
3339
model.train()
3440
if config.TEST_PATH and not args.data_path:
35-
results, precision, recall, f1 = model.evaluate()
41+
results, precision, recall, f1, rouge = model.evaluate()
3642
print('Accuracy: ' + str(results))
3743
print('Precision: ' + str(precision) + ', recall: ' + str(recall) + ', F1: ' + str(f1))
44+
print('Rouge: ', rouge)
3845
if args.predict:
3946
predictor = InteractivePredictor(config, model)
4047
predictor.predict()

0 commit comments

Comments
 (0)