-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpostprocess_predict.py
54 lines (41 loc) · 1.32 KB
/
postprocess_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import argparse
import json
import os
import re
FILE_PATH = "./dev_predict.txt"
def postprocess(text):
# phrases = ["```sql", "```", "```sql\n", "```sql\r\n", "```sql\r"]
# text = text.strip()
# for phrase in phrases:
# if phrase in text:
# start_idx = text.find(phrase)
# start_idx = start_idx + len(phrase)
# text = text[start_idx:]
# end_idx = text.find("```")
# if end_idx != -1:
# text = text[:end_idx]
# text = text.strip()
# break
# if "select" not in text.lower():
# text = "###"
# return text
text = text.replace('"', "'")
return text
def argparser():
parser = argparse.ArgumentParser()
parser.add_argument("--file_path", type=str, default="./dev_predict.txt")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = argparser()
FILE_PATH = args.file_path
with open(FILE_PATH, "r", encoding="utf-8") as f:
data = f.readlines()
CORRECTED_FILE_PATH = os.path.join(
os.path.dirname(FILE_PATH), "dev_predict_corrected.txt"
)
with open(CORRECTED_FILE_PATH, "w", encoding="utf-8") as f:
for line in data:
line = line.strip()
line = postprocess(line)
f.write(line + "\n")