Skip to content
This repository was archived by the owner on Jul 30, 2024. It is now read-only.

Commit fc13637

Browse files
committed
csnet-pretraining script and bug fix
1 parent d0c4c4e commit fc13637

File tree

15 files changed

+392
-46
lines changed

15 files changed

+392
-46
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Go to `data/stackoverflow` directory and follow instructions.
4242
```bash
4343
cd pretrain
4444
bash binarize.sh
45-
bash absolute.sh GPU_IDS
45+
bash pretrain.sh GPU_IDS
4646
```
4747

4848
Note. We pre-trained PLBART on 8 `GeForce RTX 2080` (11gb) GPUs (took 11.5 days).

evaluation/nl_eval.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
import re
1818
import sys
19+
import json
1920
import math
21+
import argparse
2022
import xml.sax.saxutils
2123

2224
# Added to bypass NIST-style pre-processing of hyp and ref files -- wade
@@ -158,7 +160,7 @@ def splitPuncts(line):
158160
return ' '.join(re.findall(r"[\w]+|[^\s\w]", line))
159161

160162

161-
def computeMaps(prediction_file, goldfile):
163+
def computeMaps(prediction_file, goldfile, is_goldfile_json):
162164
predictionMap = {}
163165
goldMap = {}
164166
predictions = open(prediction_file, 'r', encoding='utf-8')
@@ -173,6 +175,8 @@ def computeMaps(prediction_file, goldfile):
173175
if rid in predictionMap: # Only insert if the id exists for the method
174176
if rid not in goldMap:
175177
goldMap[rid] = []
178+
if is_goldfile_json:
179+
row = ' '.join(json.loads(row.strip())['docstring_tokens'])
176180
goldMap[rid].append(splitPuncts(row.strip().lower()))
177181

178182
predictions.close()
@@ -196,16 +200,15 @@ def bleuFromMaps(m1, m2):
196200

197201

198202
if __name__ == '__main__':
199-
import argparse
200-
201203
parser = argparse.ArgumentParser(description='Evaluate leaderboard predictions for BigCloneBench dataset.')
202204
parser.add_argument('--references', help="filename of the labels, in txt format.")
203205
parser.add_argument('--predictions', help="filename of the leaderboard predictions, in txt format.")
206+
parser.add_argument('--json_refs', action='store_true', help='reference files are JSON files')
204207

205208
args = parser.parse_args()
206209

207210
reference_file = args.references
208211
prediction_file = args.predictions
209-
(goldMap, predictionMap) = computeMaps(prediction_file, reference_file)
212+
(goldMap, predictionMap) = computeMaps(prediction_file, reference_file, args.json_refs)
210213
res = bleuFromMaps(goldMap, predictionMap)
211214
print("BLEU Score:\t%.2f" % res[0])

multilingual/data/prepare.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,6 @@ fairseq-preprocess \
5858

5959
mkdir -p $DATA_DIR;
6060
PYTHONPATH=${HOME_DIR} python process.py;
61-
for lang in java python ruby go js php; do
61+
for lang in java python ruby go javascript php; do
6262
spm_preprocess $lang && binarize $lang
6363
done

multilingual/data/process.py

+23-16
Original file line numberDiff line numberDiff line change
@@ -20,37 +20,44 @@ def count_file_lines(file_path):
2020
def prepare():
2121
for lang in ['go', 'java', 'python', 'ruby', 'javascript', 'php']:
2222
for split in ['train', 'valid', 'test']:
23-
lang_iso = 'js' if lang == 'javascript' else lang
2423
src_writer = open(
25-
'processed/{}.{}-en_XX.{}'.format(split, lang_iso, lang_iso), 'w', encoding='utf-8'
24+
'processed/{}.{}-en_XX.{}'.format(split, lang, lang), 'w', encoding='utf-8'
2625
)
2726
tgt_writer = open(
28-
'processed/{}.{}-en_XX.en_XX'.format(split, lang_iso), 'w', encoding='utf-8'
27+
'processed/{}.{}-en_XX.en_XX'.format(split, lang), 'w', encoding='utf-8'
2928
)
3029
filename = '{}/{}.jsonl'.format(lang, split)
3130
with open(filename) as f:
3231
for line in tqdm(
3332
f, total=count_file_lines(filename), desc="{}-{}".format(lang, split)
3433
):
3534
ex = json.loads(line.strip())
35+
code = ' '.join(ex['code_tokens'])
36+
code = re.sub("[\n\r\t ]+", " ", code).strip()
37+
docstring = ' '.join(ex['docstring_tokens'])
38+
docstring = re.sub("[\n\r\t ]+", " ", docstring).strip()
39+
if len(code) == 0 or len(docstring) == 0:
40+
continue
41+
42+
tokenized_code = None
43+
if lang == 'python' or lang == 'java':
44+
_tokens = tokenize_python(ex['code']) \
45+
if lang == 'python' else tokenize_java(ex['code'])
46+
tokenized_code = ' '.join(_tokens)
47+
tokenized_code = re.sub("[\n\r\t ]+", " ", tokenized_code).strip()
48+
if len(tokenized_code) == 0:
49+
continue
50+
3651
try:
3752
if lang == 'python' or lang == 'java':
38-
code_tokens = tokenize_python(ex['code']) \
39-
if lang == 'python' else tokenize_java(ex['code'])
40-
if len(code_tokens) > 0:
41-
raise ValueError('Empty tokenized code')
53+
# this line can throw error `UnicodeEncodeError`
54+
src_writer.write(tokenized_code + '\n')
4255
else:
43-
code_tokens = ex['code_tokens']
56+
src_writer.write(code + '\n')
4457
except:
45-
code_tokens = ex['code_tokens']
58+
src_writer.write(code + '\n')
4659

47-
code = ' '.join(code_tokens)
48-
code = re.sub("[\n\r\t ]+", " ", code)
49-
docstring = ' '.join(ex['docstring_tokens'])
50-
docstring = re.sub("[\n\r\t ]+", " ", docstring)
51-
if len(code) > 0 and len(docstring) > 0:
52-
src_writer.write(code.strip() + '\n')
53-
tgt_writer.write(docstring.strip() + '\n')
60+
tgt_writer.write(docstring + '\n')
5461

5562
src_writer.close()
5663
tgt_writer.close()

multilingual/multi_task/run.sh

+5-5
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ if [[ $LANGUAGE_GROUP_CORRECT = "" ]] ; then
2020
exit;
2121
fi
2222

23-
GROUP_LISTS="java python go ruby php js"
23+
GROUP_LISTS="java python go ruby php javascript"
2424
LANGUAGE_GROUP_CORRECT=`echo $GROUP_LISTS | grep -w $LANG`;
2525
if [[ $LANGUAGE_GROUP_CORRECT = "" ]]; then
2626
echo "LANGUAGE(3rd parameter) must be one of the following";
@@ -193,21 +193,21 @@ then
193193
evaluate_summarization $LANG;
194194
else
195195
if [[ "$LANGUAGE_GROUP" == 'all' ]]; then
196-
languages=(java python js php ruby go);
196+
languages=(java python javascript php ruby go);
197197
MAX_UPDATE=400000;
198198
WARMUP=5000;
199199
elif [[ "$LANGUAGE_GROUP" == 'compiled' ]]; then
200200
languages=(java ruby go);
201201
elif [[ "$LANGUAGE_GROUP" == 'interpreted' ]]; then
202-
languages=(php python js);
202+
languages=(php python javascript);
203203
elif [[ "$LANGUAGE_GROUP" == 'static' ]]; then
204204
languages=(java go);
205205
elif [[ "$LANGUAGE_GROUP" == 'dynamic' ]]; then
206-
languages=(js python php ruby);
206+
languages=(javascript python php ruby);
207207
elif [[ "$LANGUAGE_GROUP" == 'strong' ]]; then
208208
languages=(java go python ruby);
209209
elif [[ "$LANGUAGE_GROUP" == 'weak' ]]; then
210-
languages=(php js);
210+
languages=(php javascript);
211211
fi
212212

213213
lang_pairs="";

multilingual/plbart/lang_dict.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
java
22
python
33
en_XX
4-
js
4+
javascript
55
php
66
ruby
77
go

multilingual/single_task/generation.sh

+8-8
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ if [[ $LANGUAGE_GROUP_CORRECT = "" ]] ; then
2020
exit;
2121
fi
2222

23-
GROUP_LISTS="java python go ruby php js"
23+
GROUP_LISTS="java python go ruby php javascript"
2424
LANGUAGE_GROUP_CORRECT=`echo $GROUP_LISTS | grep -w $LANG`;
2525
if [[ $LANGUAGE_GROUP_CORRECT = "" ]]; then
2626
echo "LANGUAGE(3rd parameter) must be one of the following";
@@ -42,16 +42,16 @@ BATCH_SIZE=8;
4242
UPDATE_FREQ=4;
4343

4444
# CSNET data size is as follows
45-
# java: 165k, python: 252k, php: 241k, go: 167k, js: 58k, ruby:25k
45+
# java: 165k, python: 252k, php: 241k, go: 167k, javascript: 58k, ruby:25k
4646
# So, number of mini-batches for each language would be:
47-
# java: ~5100, python: ~7800, php: ~7500, go: ~5200, js: ~1800, ruby: ~780
47+
# java: ~5100, python: ~7800, php: ~7500, go: ~5200, javascript: ~1800, ruby: ~780
4848

4949
declare -A LANG_WISE_WARMUP
5050
LANG_WISE_WARMUP['java']=5000
5151
LANG_WISE_WARMUP['python']=5000
5252
LANG_WISE_WARMUP['php']=5000
5353
LANG_WISE_WARMUP['go']=5000
54-
LANG_WISE_WARMUP['js']=2000
54+
LANG_WISE_WARMUP['javascript']=2000
5555
LANG_WISE_WARMUP['ruby']=1000
5656

5757

@@ -162,19 +162,19 @@ else
162162
SAVE_DIR=${SAVE_DIR}/generation;
163163
mkdir -p $SAVE_DIR
164164
if [[ "$LANGUAGE_GROUP" == 'all' ]]; then
165-
languages=(java python js php ruby go);
165+
languages=(java python javascript php ruby go);
166166
elif [[ "$LANGUAGE_GROUP" == 'compiled' ]]; then
167167
languages=(java ruby go);
168168
elif [[ "$LANGUAGE_GROUP" == 'interpreted' ]]; then
169-
languages=(php python js);
169+
languages=(php python javascript);
170170
elif [[ "$LANGUAGE_GROUP" == 'static' ]]; then
171171
languages=(java go);
172172
elif [[ "$LANGUAGE_GROUP" == 'dynamic' ]]; then
173-
languages=(js python php ruby);
173+
languages=(javascript python php ruby);
174174
elif [[ "$LANGUAGE_GROUP" == 'strong' ]]; then
175175
languages=(java go python ruby);
176176
elif [[ "$LANGUAGE_GROUP" == 'weak' ]]; then
177-
languages=(php js);
177+
languages=(php javascript);
178178
fi
179179

180180
# a list language pairs to train multilingual models, e.g. "en-java,en-python"

multilingual/single_task/summarization.sh

+8-8
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ if [[ $LANGUAGE_GROUP_CORRECT = "" ]] ; then
2020
exit;
2121
fi
2222

23-
GROUP_LISTS="java python go ruby php js"
23+
GROUP_LISTS="java python go ruby php javascript"
2424
LANGUAGE_GROUP_CORRECT=`echo $GROUP_LISTS | grep -w $LANG`;
2525
if [[ $LANGUAGE_GROUP_CORRECT = "" ]]; then
2626
echo "LANGUAGE(2nd parameter) must be one of the following";
@@ -42,16 +42,16 @@ BATCH_SIZE=8;
4242
UPDATE_FREQ=4;
4343

4444
# CSNET data size is as follows
45-
# java: 165k, python: 252k, php: 241k, go: 167k, js: 58k, ruby:25k
45+
# java: 165k, python: 252k, php: 241k, go: 167k, javascript: 58k, ruby:25k
4646
# So, number of mini-batches for each language would be:
47-
# java: ~5100, python: ~7800, php: ~7500, go: ~5200, js: ~1800, ruby: ~780
47+
# java: ~5100, python: ~7800, php: ~7500, go: ~5200, javascript: ~1800, ruby: ~780
4848

4949
declare -A LANG_WISE_WARMUP
5050
LANG_WISE_WARMUP['java']=5000
5151
LANG_WISE_WARMUP['python']=5000
5252
LANG_WISE_WARMUP['php']=5000
5353
LANG_WISE_WARMUP['go']=5000
54-
LANG_WISE_WARMUP['js']=2000
54+
LANG_WISE_WARMUP['javascript']=2000
5555
LANG_WISE_WARMUP['ruby']=1000
5656

5757

@@ -158,19 +158,19 @@ else
158158
SAVE_DIR=${SAVE_DIR}/summarization;
159159
mkdir -p $SAVE_DIR
160160
if [[ "$LANGUAGE_GROUP" == 'all' ]]; then
161-
languages=(java python js php ruby go);
161+
languages=(java python javascript php ruby go);
162162
elif [[ "$LANGUAGE_GROUP" == 'compiled' ]]; then
163163
languages=(java ruby go);
164164
elif [[ "$LANGUAGE_GROUP" == 'interpreted' ]]; then
165-
languages=(php python js);
165+
languages=(php python javascript);
166166
elif [[ "$LANGUAGE_GROUP" == 'static' ]]; then
167167
languages=(java go);
168168
elif [[ "$LANGUAGE_GROUP" == 'dynamic' ]]; then
169-
languages=(js python php ruby);
169+
languages=(javascript python php ruby);
170170
elif [[ "$LANGUAGE_GROUP" == 'strong' ]]; then
171171
languages=(java go python ruby);
172172
elif [[ "$LANGUAGE_GROUP" == 'weak' ]]; then
173-
languages=(php js);
173+
languages=(php javascript);
174174
fi
175175

176176
# a list language pairs to train multilingual models, e.g. "java-en,python-en"

pretrain/csnet/binarize.sh

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#!/usr/bin/env bash
2+
3+
export PYTHONIOENCODING=utf-8;
4+
5+
CURRENT_DIR=`pwd`
6+
HOME_DIR=`realpath ../..`;
7+
8+
SPM_DIR=${HOME_DIR}/sentencepiece
9+
DICT_FILE=${SPM_DIR}/dict.txt
10+
SPM_VOCAB=${SPM_DIR}/sentencepiece.bpe.vocab
11+
SPM_ENC_SCRIPT=${SPM_DIR}/encode.py
12+
13+
DATA_DIR=${CURRENT_DIR}/data
14+
SHARD_DIR=${DATA_DIR}/shard
15+
mkdir -p $SHARD_DIR
16+
cp $DICT_FILE $SHARD_DIR
17+
18+
19+
function preprocess_pl () {
20+
21+
for LANG in java python javascript php go ruby; do
22+
python $SPM_ENC_SCRIPT \
23+
--model-file $SPM_DIR/sentencepiece.bpe.model \
24+
--inputs $DATA_DIR/$LANG/train.functions.tok \
25+
--outputs $DATA_DIR/$LANG/train.functions.spm \
26+
--max_len 510 \
27+
--workers 60;
28+
python $SPM_ENC_SCRIPT \
29+
--model-file $SPM_DIR/sentencepiece.bpe.model \
30+
--inputs $DATA_DIR/$LANG/valid.functions.tok \
31+
--outputs $DATA_DIR/$LANG/valid.functions.spm \
32+
--max_len 510 \
33+
--workers 60;
34+
done
35+
36+
}
37+
38+
function preprocess_nl () {
39+
40+
python $SPM_ENC_SCRIPT \
41+
--model-file $SPM_DIR/sentencepiece.bpe.model \
42+
--inputs $DATA_DIR/train.docstring.tok \
43+
--outputs $DATA_DIR/train.docstring.spm \
44+
--max_len 510 \
45+
--workers 60;
46+
python $SPM_ENC_SCRIPT \
47+
--model-file $SPM_DIR/sentencepiece.bpe.model \
48+
--inputs $DATA_DIR/valid.docstring.tok \
49+
--outputs $DATA_DIR/valid.docstring.spm \
50+
--max_len 510 \
51+
--workers 60;
52+
53+
}
54+
55+
function binarize_pl () {
56+
57+
for LANG in java python javascript php go ruby; do
58+
fairseq-preprocess \
59+
--only-source \
60+
--trainpref $DATA_DIR/$LANG/train.functions.spm \
61+
--validpref $DATA_DIR/$LANG/valid.functions.spm \
62+
--destdir $SHARD_DIR/$LANG \
63+
--srcdict $DICT_FILE \
64+
--workers 60;
65+
done
66+
67+
}
68+
69+
function binarize_nl () {
70+
71+
fairseq-preprocess \
72+
--only-source \
73+
--trainpref $DATA_DIR/train.docstring.spm \
74+
--validpref $DATA_DIR/valid.docstring.spm \
75+
--destdir $SHARD_DIR/en_XX \
76+
--srcdict $DICT_FILE \
77+
--workers 60;
78+
79+
}
80+
81+
preprocess_pl
82+
preprocess_nl
83+
binarize_pl
84+
binarize_nl

0 commit comments

Comments
 (0)