From 0e0f870226ae15095de093599136b5d23696dd13 Mon Sep 17 00:00:00 2001 From: RaymondLi0 Date: Tue, 5 Jan 2021 17:54:33 -0500 Subject: [PATCH 01/12] make repo usable as a library (#2) * use relative imports * refactor evaluation so it can be used as a library --- .gitignore | 2 + evaluate_classical.py | 2 +- evaluation.py | 453 ++++++++++++++++++++++-------------------- exec_eval.py | 7 +- 4 files changed, 241 insertions(+), 223 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..08684dc --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +database/* +!database/readme.txt \ No newline at end of file diff --git a/evaluate_classical.py b/evaluate_classical.py index f5437e0..c1239ff 100644 --- a/evaluate_classical.py +++ b/evaluate_classical.py @@ -2,7 +2,7 @@ from typing import List, Dict, Any, Tuple import pickle as pkl import tqdm -from exec_eval import exec_on_db, result_eq +from .exec_eval import exec_on_db, result_eq import os from collections import defaultdict import time diff --git a/evaluation.py b/evaluation.py index 8e99534..9dac97f 100644 --- a/evaluation.py +++ b/evaluation.py @@ -24,10 +24,24 @@ import sqlite3 import argparse -from process_sql import get_schema, Schema, get_sql -from exec_eval import eval_exec_match +from .process_sql import get_schema, Schema, get_sql +from .exec_eval import eval_exec_match # Flag to disable value evaluation +LEVELS = ["easy", "medium", "hard", "extra", "all", "joint_all"] +TURNS = ["turn 1", "turn 2", "turn 3", "turn 4", "turn > 4"] +PARTIAL_TYPES = [ + "select", + "select(no AGG)", + "where", + "where(no OP)", + "group(no Having)", + "group", + "order", + "and/or", + "IUEN", + "keywords", +] DISABLE_VALUE = True # Flag to disable distinct in select evaluation DISABLE_DISTINCT = True @@ -260,7 +274,8 @@ def eval_nested(pred, label): if label is not None: label_total += 1 if pred is not None and label is not None: - cnt += Evaluator().eval_exact_match(pred, label) + partial_scores = Evaluator.eval_partial_match(pred, label) + cnt += Evaluator.eval_exact_match(pred, label, partial_scores) return label_total, pred_total, cnt @@ -415,8 +430,46 @@ def count_others(sql): class Evaluator: """A simple evaluator""" - def __init__(self): - self.partial_scores = None + def __init__( + self, + db_dir, + kmaps, + etype, + plug_value, + keep_distinct, + progress_bar_for_each_datapoint + ): + self.db_dir = db_dir + self.kmaps = kmaps + self.etype = etype + self.plug_value = plug_value + self.keep_distinct = keep_distinct + self.progress_bar_for_each_datapoint = progress_bar_for_each_datapoint + + self.db_paths = {} + self.schemas = {} + for db_name in self.kmaps.keys(): + db_path = os.path.join(db_dir, db_name, db_name + ".sqlite") + self.db_paths[db_name] = db_path + self.schemas[db_name] = Schema(get_schema(db_path)) + + self.scores = {} + + for turn in TURNS: + self.scores[turn] = {"count": 0, "exact": 0.0} + self.scores[turn]["exec"] = 0 + + for level in LEVELS: + self.scores[level] = {"count": 0, "partial": {}, "exact": 0.0} + self.scores[level]["exec"] = 0 + for type_ in PARTIAL_TYPES: + self.scores[level]["partial"][type_] = { + "acc": 0.0, + "rec": 0.0, + "f1": 0.0, + "acc_count": 0, + "rec_count": 0, + } def eval_hardness(self, sql): count_comp1_ = count_component1(sql) @@ -438,10 +491,8 @@ def eval_hardness(self, sql): else: return "extra" - def eval_exact_match(self, pred, label): - partial_scores = self.eval_partial_match(pred, label) - self.partial_scores = partial_scores - + @classmethod + def eval_exact_match(cls, pred, label, partial_scores): for key, score in partial_scores.items(): if score["f1"] != 1: return 0 @@ -452,7 +503,8 @@ def eval_exact_match(self, pred, label): return label_tables == pred_tables return 1 - def eval_partial_match(self, pred, label): + @classmethod + def eval_partial_match(cls, pred, label): res = {} label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label) @@ -553,6 +605,164 @@ def eval_partial_match(self, pred, label): return res + def evaluate_one(self, db_name, gold, predicted, turn_scores, idx): + schema = self.schemas[db_name] + g_sql = get_sql(schema, gold) + hardness = self.eval_hardness(g_sql) + if idx > 3: + idx = "> 4" + else: + idx += 1 + turn_id = "turn " + str(idx) + self.scores[turn_id]["count"] += 1 + self.scores[hardness]["count"] += 1 + self.scores["all"]["count"] += 1 + + try: + p_sql = get_sql(schema, predicted) + except: + # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql + p_sql = { + "except": None, + "from": {"conds": [], "table_units": []}, + "groupBy": [], + "having": [], + "intersect": None, + "limit": None, + "orderBy": [], + "select": [False, []], + "union": None, + "where": [], + } + + if self.etype in ["all", "exec"]: + exec_score = eval_exec_match( + db=self.db_paths[db_name], + p_str=predicted, + g_str=gold, + plug_value=self.plug_value, + keep_distinct=self.keep_distinct, + progress_bar_for_each_datapoint=self.progress_bar_for_each_datapoint, + ) + if exec_score: + self.scores[hardness]["exec"] += 1 + self.scores[turn_id]["exec"] += 1 + self.scores["all"]["exec"] += 1 + turn_scores["exec"].append(1) + else: + turn_scores["exec"].append(0) + + if self.etype in ["all", "match"]: + # rebuild sql for value evaluation + kmap = self.kmaps[db_name] + g_valid_col_units = build_valid_col_units( + g_sql["from"]["table_units"], schema + ) + g_sql = rebuild_sql_val(g_sql) + g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) + p_valid_col_units = build_valid_col_units( + p_sql["from"]["table_units"], schema + ) + p_sql = rebuild_sql_val(p_sql) + p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) + partial_scores = self.eval_partial_match(p_sql, g_sql) + exact_score = self.eval_exact_match(p_sql, g_sql, partial_scores) + if exact_score == 0: + turn_scores["exact"].append(0) + print("{} pred: {}".format(hardness, predicted)) + print("{} gold: {}".format(hardness, gold)) + print("") + else: + turn_scores["exact"].append(1) + self.scores[turn_id]["exact"] += exact_score + self.scores[hardness]["exact"] += exact_score + self.scores["all"]["exact"] += exact_score + for type_ in PARTIAL_TYPES: + if partial_scores[type_]["pred_total"] > 0: + self.scores[hardness]["partial"][type_]["acc"] += partial_scores[ + type_ + ]["acc"] + self.scores[hardness]["partial"][type_]["acc_count"] += 1 + if partial_scores[type_]["label_total"] > 0: + self.scores[hardness]["partial"][type_]["rec"] += partial_scores[ + type_ + ]["rec"] + self.scores[hardness]["partial"][type_]["rec_count"] += 1 + self.scores[hardness]["partial"][type_]["f1"] += partial_scores[type_][ + "f1" + ] + if partial_scores[type_]["pred_total"] > 0: + self.scores["all"]["partial"][type_]["acc"] += partial_scores[type_][ + "acc" + ] + self.scores["all"]["partial"][type_]["acc_count"] += 1 + if partial_scores[type_]["label_total"] > 0: + self.scores["all"]["partial"][type_]["rec"] += partial_scores[type_][ + "rec" + ] + self.scores["all"]["partial"][type_]["rec_count"] += 1 + self.scores["all"]["partial"][type_]["f1"] += partial_scores[type_]["f1"] + + return { + "predictSQL": predicted, + "goldSQL": gold, + "hardness": hardness, + "exact": exact_score, + "partial": partial_scores, + } + + def finalize(self): + scores = self.scores + for turn in TURNS: + if scores[turn]["count"] == 0: + continue + if self.etype in ["all", "exec"]: + scores[turn]["exec"] /= scores[turn]["count"] + + if self.etype in ["all", "match"]: + scores[turn]["exact"] /= scores[turn]["count"] + + for level in LEVELS: + if scores[level]["count"] == 0: + continue + if self.etype in ["all", "exec"]: + scores[level]["exec"] /= scores[level]["count"] + + if self.etype in ["all", "match"]: + scores[level]["exact"] /= scores[level]["count"] + for type_ in PARTIAL_TYPES: + if scores[level]["partial"][type_]["acc_count"] == 0: + scores[level]["partial"][type_]["acc"] = 0 + else: + scores[level]["partial"][type_]["acc"] = ( + scores[level]["partial"][type_]["acc"] + / scores[level]["partial"][type_]["acc_count"] + * 1.0 + ) + if scores[level]["partial"][type_]["rec_count"] == 0: + scores[level]["partial"][type_]["rec"] = 0 + else: + scores[level]["partial"][type_]["rec"] = ( + scores[level]["partial"][type_]["rec"] + / scores[level]["partial"][type_]["rec_count"] + * 1.0 + ) + if ( + scores[level]["partial"][type_]["acc"] == 0 + and scores[level]["partial"][type_]["rec"] == 0 + ): + scores[level]["partial"][type_]["f1"] = 1 + else: + scores[level]["partial"][type_]["f1"] = ( + 2.0 + * scores[level]["partial"][type_]["acc"] + * scores[level]["partial"][type_]["rec"] + / ( + scores[level]["partial"][type_]["rec"] + + scores[level]["partial"][type_]["acc"] + ) + ) + def isValidSQL(sql, db): conn = sqlite3.connect(db) @@ -570,22 +780,11 @@ def print_formated_s(row_name, l, element_format): def print_scores(scores, etype, include_turn_acc=True): - turns = ["turn 1", "turn 2", "turn 3", "turn 4", "turn > 4"] + turns = TURNS levels = ["easy", "medium", "hard", "extra", "all"] if include_turn_acc: levels.append("joint_all") - partial_types = [ - "select", - "select(no AGG)", - "where", - "where(no OP)", - "group(no Having)", - "group", - "order", - "and/or", - "IUEN", - "keywords", - ] + partial_types = PARTIAL_TYPES print_formated_s("", levels, "{:20}") counts = [scores[level]["count"] for level in levels] @@ -684,217 +883,33 @@ def evaluate( assert len(plist) == len(glist), "number of sessions must equal" - evaluator = Evaluator() - turns = ["turn 1", "turn 2", "turn 3", "turn 4", "turn > 4"] - levels = ["easy", "medium", "hard", "extra", "all", "joint_all"] - - partial_types = [ - "select", - "select(no AGG)", - "where", - "where(no OP)", - "group(no Having)", - "group", - "order", - "and/or", - "IUEN", - "keywords", - ] - entries = [] - scores = {} - - for turn in turns: - scores[turn] = {"count": 0, "exact": 0.0} - scores[turn]["exec"] = 0 - - for level in levels: - scores[level] = {"count": 0, "partial": {}, "exact": 0.0} - scores[level]["exec"] = 0 - for type_ in partial_types: - scores[level]["partial"][type_] = { - "acc": 0.0, - "rec": 0.0, - "f1": 0.0, - "acc_count": 0, - "rec_count": 0, - } + evaluator = Evaluator(db_dir, kmaps, etype, plug_value, keep_distinct, progress_bar_for_each_datapoint) + results = [] for i, (p, g) in enumerate(zip(plist, glist)): if (i + 1) % 10 == 0: print("Evaluating %dth prediction" % (i + 1)) - scores["joint_all"]["count"] += 1 + evaluator.scores["joint_all"]["count"] += 1 turn_scores = {"exec": [], "exact": []} for idx, pg in enumerate(zip(p, g)): p, g = pg p_str = p[0] p_str = p_str.replace("value", "1") - g_str, db = g - db_name = db - db = os.path.join(db_dir, db, db + ".sqlite") - schema = Schema(get_schema(db)) - g_sql = get_sql(schema, g_str) - hardness = evaluator.eval_hardness(g_sql) - if idx > 3: - idx = "> 4" - else: - idx += 1 - turn_id = "turn " + str(idx) - scores[turn_id]["count"] += 1 - scores[hardness]["count"] += 1 - scores["all"]["count"] += 1 - - try: - p_sql = get_sql(schema, p_str) - except: - # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql - p_sql = { - "except": None, - "from": {"conds": [], "table_units": []}, - "groupBy": [], - "having": [], - "intersect": None, - "limit": None, - "orderBy": [], - "select": [False, []], - "union": None, - "where": [], - } + g_str, db_name = g - if etype in ["all", "exec"]: - exec_score = eval_exec_match( - db=db, - p_str=p_str, - g_str=g_str, - plug_value=plug_value, - keep_distinct=keep_distinct, - progress_bar_for_each_datapoint=progress_bar_for_each_datapoint, - ) - if exec_score: - scores[hardness]["exec"] += 1 - scores[turn_id]["exec"] += 1 - scores["all"]["exec"] += 1 - turn_scores["exec"].append(1) - else: - turn_scores["exec"].append(0) - - if etype in ["all", "match"]: - # rebuild sql for value evaluation - kmap = kmaps[db_name] - g_valid_col_units = build_valid_col_units( - g_sql["from"]["table_units"], schema - ) - g_sql = rebuild_sql_val(g_sql) - g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) - p_valid_col_units = build_valid_col_units( - p_sql["from"]["table_units"], schema - ) - p_sql = rebuild_sql_val(p_sql) - p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) - exact_score = evaluator.eval_exact_match(p_sql, g_sql) - partial_scores = evaluator.partial_scores - if exact_score == 0: - turn_scores["exact"].append(0) - print("{} pred: {}".format(hardness, p_str)) - print("{} gold: {}".format(hardness, g_str)) - print("") - else: - turn_scores["exact"].append(1) - scores[turn_id]["exact"] += exact_score - scores[hardness]["exact"] += exact_score - scores["all"]["exact"] += exact_score - for type_ in partial_types: - if partial_scores[type_]["pred_total"] > 0: - scores[hardness]["partial"][type_]["acc"] += partial_scores[ - type_ - ]["acc"] - scores[hardness]["partial"][type_]["acc_count"] += 1 - if partial_scores[type_]["label_total"] > 0: - scores[hardness]["partial"][type_]["rec"] += partial_scores[ - type_ - ]["rec"] - scores[hardness]["partial"][type_]["rec_count"] += 1 - scores[hardness]["partial"][type_]["f1"] += partial_scores[type_][ - "f1" - ] - if partial_scores[type_]["pred_total"] > 0: - scores["all"]["partial"][type_]["acc"] += partial_scores[type_][ - "acc" - ] - scores["all"]["partial"][type_]["acc_count"] += 1 - if partial_scores[type_]["label_total"] > 0: - scores["all"]["partial"][type_]["rec"] += partial_scores[type_][ - "rec" - ] - scores["all"]["partial"][type_]["rec_count"] += 1 - scores["all"]["partial"][type_]["f1"] += partial_scores[type_]["f1"] - - entries.append( - { - "predictSQL": p_str, - "goldSQL": g_str, - "hardness": hardness, - "exact": exact_score, - "partial": partial_scores, - } - ) + results.append(evaluator.evaluate_one(db_name, g_str, p_str, turn_scores, idx)) if all(v == 1 for v in turn_scores["exec"]): - scores["joint_all"]["exec"] += 1 + evaluator.scores["joint_all"]["exec"] += 1 if all(v == 1 for v in turn_scores["exact"]): - scores["joint_all"]["exact"] += 1 - - for turn in turns: - if scores[turn]["count"] == 0: - continue - if etype in ["all", "exec"]: - scores[turn]["exec"] /= scores[turn]["count"] - - if etype in ["all", "match"]: - scores[turn]["exact"] /= scores[turn]["count"] - - for level in levels: - if scores[level]["count"] == 0: - continue - if etype in ["all", "exec"]: - scores[level]["exec"] /= scores[level]["count"] - - if etype in ["all", "match"]: - scores[level]["exact"] /= scores[level]["count"] - for type_ in partial_types: - if scores[level]["partial"][type_]["acc_count"] == 0: - scores[level]["partial"][type_]["acc"] = 0 - else: - scores[level]["partial"][type_]["acc"] = ( - scores[level]["partial"][type_]["acc"] - / scores[level]["partial"][type_]["acc_count"] - * 1.0 - ) - if scores[level]["partial"][type_]["rec_count"] == 0: - scores[level]["partial"][type_]["rec"] = 0 - else: - scores[level]["partial"][type_]["rec"] = ( - scores[level]["partial"][type_]["rec"] - / scores[level]["partial"][type_]["rec_count"] - * 1.0 - ) - if ( - scores[level]["partial"][type_]["acc"] == 0 - and scores[level]["partial"][type_]["rec"] == 0 - ): - scores[level]["partial"][type_]["f1"] = 1 - else: - scores[level]["partial"][type_]["f1"] = ( - 2.0 - * scores[level]["partial"][type_]["acc"] - * scores[level]["partial"][type_]["rec"] - / ( - scores[level]["partial"][type_]["rec"] - + scores[level]["partial"][type_]["acc"] - ) - ) + evaluator.scores["joint_all"]["exact"] += 1 - print_scores(scores, etype, include_turn_acc=include_turn_acc) + print_scores(evaluator.scores, etype, include_turn_acc=include_turn_acc) + return { + "per_item": results, + "total_scores": evaluator.scores + } # Rebuild SQL functions for value evaluation diff --git a/exec_eval.py b/exec_eval.py index d266adb..007fd76 100644 --- a/exec_eval.py +++ b/exec_eval.py @@ -5,7 +5,7 @@ from collections import defaultdict import tqdm import random -from parse import get_all_preds_for_execution, remove_distinct +from .parse import get_all_preds_for_execution, remove_distinct import time import pickle as pkl import subprocess @@ -14,7 +14,7 @@ threadLock = threading.Lock() TIMEOUT = 60 -EXEC_TMP_DIR = "tmp/" +EXEC_TMP_DIR = os.path.join(os.path.dirname(__file__), "tmp") def permute_tuple(element: Tuple, perm: Tuple) -> Tuple: @@ -139,8 +139,9 @@ def exec_on_db( f_prefix = os.path.join(EXEC_TMP_DIR, process_id) pkl.dump((sqlite_path, query), open(f_prefix + ".in", "wb")) try: + exec_script_fpath = os.path.join(os.path.dirname(__file__), "exec_subprocess.py") subprocess.call( - ["python3", "exec_subprocess.py", f_prefix], + ["python3", exec_script_fpath, f_prefix], timeout=timeout, stderr=open("runerr.log", "a"), ) From 4ae53ab3e77f906c520952e30cdc16e8198f2320 Mon Sep 17 00:00:00 2001 From: RaymondLi0 Date: Wed, 6 Jan 2021 15:46:10 -0500 Subject: [PATCH 02/12] get db schemas when needed instead of at initialization --- evaluation.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/evaluation.py b/evaluation.py index 9dac97f..f6e5011 100644 --- a/evaluation.py +++ b/evaluation.py @@ -448,10 +448,6 @@ def __init__( self.db_paths = {} self.schemas = {} - for db_name in self.kmaps.keys(): - db_path = os.path.join(db_dir, db_name, db_name + ".sqlite") - self.db_paths[db_name] = db_path - self.schemas[db_name] = Schema(get_schema(db_path)) self.scores = {} @@ -606,6 +602,11 @@ def eval_partial_match(cls, pred, label): return res def evaluate_one(self, db_name, gold, predicted, turn_scores, idx): + if db_name not in self.db_paths: + db_path = os.path.join(self.db_dir, db_name, db_name + ".sqlite") + self.db_paths[db_name] = db_path + self.schemas[db_name] = Schema(get_schema(db_path)) + schema = self.schemas[db_name] g_sql = get_sql(schema, gold) hardness = self.eval_hardness(g_sql) From ca70640bfbdc5a3c6bc4ee3ff3a6f0e68e332e42 Mon Sep 17 00:00:00 2001 From: RaymondLi0 Date: Wed, 6 Jan 2021 22:45:28 -0500 Subject: [PATCH 03/12] add missing .finalize() --- evaluation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/evaluation.py b/evaluation.py index f6e5011..04b75b6 100644 --- a/evaluation.py +++ b/evaluation.py @@ -906,6 +906,7 @@ def evaluate( if all(v == 1 for v in turn_scores["exact"]): evaluator.scores["joint_all"]["exact"] += 1 + evaluator.finalize() print_scores(evaluator.scores, etype, include_turn_acc=include_turn_acc) return { "per_item": results, From 2528301bac1997bab444a10cf5fefda603a4e8cb Mon Sep 17 00:00:00 2001 From: RaymondLi0 Date: Thu, 7 Jan 2021 17:55:42 -0500 Subject: [PATCH 04/12] add script to alter michigan databases --- alter_michigan_databases.sh | 47 +++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 alter_michigan_databases.sh diff --git a/alter_michigan_databases.sh b/alter_michigan_databases.sh new file mode 100644 index 0000000..85a416c --- /dev/null +++ b/alter_michigan_databases.sh @@ -0,0 +1,47 @@ +#~/usr/bin/env bash +set -e + +DATABASE_DIR=database + +copy_databases () { + db=$1 + # Copy to *_test directory + altered=$DATABASE_DIR/${db}_test + cp -r "$DATABASE_DIR/$db" "$altered" + + # Rename .sqlite files + cd "$altered" + for f in ${db}*.sqlite + do + mv "$f" "${db}_test${f#${db}}" + done + cd ../.. +} + +alter_yelp () { + for f in "$DATABASE_DIR/yelp_test"/*.sqlite + do + echo "ALTER TABLE neighbourhood RENAME TO neighborhood" | sqlite3 "$f" + echo "ALTER TABLE neighborhood RENAME COLUMN neighbourhood_name TO neighborhood_name" | sqlite3 "$f" + done +} + +alter_imdb () { + for f in "$DATABASE_DIR/imdb_test"/*.sqlite + do + echo "ALTER TABLE cast RENAME TO cast2" | sqlite3 "$f" + done +} + + +for DB in "imdb" "yelp" +do + echo $DB + if [ ! -d "$DATABASE_DIR/${DB}_test" ] + then + copy_databases $DB + alter_"$DB" + else + echo "$DATABASE_DIR/${DB}_test already exists" + fi +done From 7c03f03691a71d367cc21b65254888a671d9c8b3 Mon Sep 17 00:00:00 2001 From: RaymondLi0 Date: Fri, 8 Jan 2021 16:58:58 -0500 Subject: [PATCH 05/12] move created file to tmp --- exec_eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exec_eval.py b/exec_eval.py index 007fd76..914b7d5 100644 --- a/exec_eval.py +++ b/exec_eval.py @@ -143,7 +143,7 @@ def exec_on_db( subprocess.call( ["python3", exec_script_fpath, f_prefix], timeout=timeout, - stderr=open("runerr.log", "a"), + stderr=open(os.path.join(os.path.dirname(__file__), "tmp", "runerr.log"), "a"), ) except Exception as e: print(e) From 0b5ece18e04ebfb2c45a08281faa9d2518844881 Mon Sep 17 00:00:00 2001 From: rizar Date: Thu, 11 Mar 2021 20:04:03 +0000 Subject: [PATCH 06/12] git ignore more files --- .gitignore | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 08684dc..60db9dd 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,5 @@ database/* -!database/readme.txt \ No newline at end of file +*/*.pyc +tmp +!database/readme.txt +!tmp/readme.txt From b9858f0efbd78f02cc835bdb5d5a8ac09323a2df Mon Sep 17 00:00:00 2001 From: Dzmitry Bahdanau Date: Wed, 17 Mar 2021 14:52:30 +0000 Subject: [PATCH 07/12] add geo, scholar & academic to db alternation script --- alter_michigan_databases.sh | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/alter_michigan_databases.sh b/alter_michigan_databases.sh index 85a416c..32aeabf 100644 --- a/alter_michigan_databases.sh +++ b/alter_michigan_databases.sh @@ -1,7 +1,7 @@ #~/usr/bin/env bash set -e -DATABASE_DIR=database +DATABASE_DIR=. copy_databases () { db=$1 @@ -15,11 +15,11 @@ copy_databases () { do mv "$f" "${db}_test${f#${db}}" done - cd ../.. + cd - } alter_yelp () { - for f in "$DATABASE_DIR/yelp_test"/*.sqlite + for f in `ls $DATABASE_DIR/yelp_test/*.sqlite` do echo "ALTER TABLE neighbourhood RENAME TO neighborhood" | sqlite3 "$f" echo "ALTER TABLE neighborhood RENAME COLUMN neighbourhood_name TO neighborhood_name" | sqlite3 "$f" @@ -27,14 +27,33 @@ alter_yelp () { } alter_imdb () { - for f in "$DATABASE_DIR/imdb_test"/*.sqlite + for f in `ls $DATABASE_DIR/imdb_test/*.sqlite` do echo "ALTER TABLE cast RENAME TO cast2" | sqlite3 "$f" done } +alter_academic () { + : +} + +alter_geo () { + : +} + +alter_scholar () { + : +} + +# geo is an exception in that we want to change the name from "geography" to "geo_test" +# it is easiest to achieve this is by copying "geography" to "geo" first +if [ ! -d $DATABASE_DIR/geo ] +then + cp -r $DATABASE_DIR/geography $DATABASE_DIR/geo + mv $DATABASE_DIR/geo/geography.sqlite $DATABASE_DIR/geo/geo.sqlite +fi -for DB in "imdb" "yelp" +for DB in imdb yelp academic geo scholar do echo $DB if [ ! -d "$DATABASE_DIR/${DB}_test" ] From 7ac3ae357b105e4a2c6f9a5cbf9db1cd6ea65841 Mon Sep 17 00:00:00 2001 From: Dzmitry Bahdanau Date: Fri, 19 Mar 2021 13:51:47 +0000 Subject: [PATCH 08/12] fix exec-only evaluation --- evaluation.py | 60 ++++++++++++++++++++++++++++++--------------------- 1 file changed, 35 insertions(+), 25 deletions(-) diff --git a/evaluation.py b/evaluation.py index 04b75b6..c368015 100644 --- a/evaluation.py +++ b/evaluation.py @@ -607,34 +607,37 @@ def evaluate_one(self, db_name, gold, predicted, turn_scores, idx): self.db_paths[db_name] = db_path self.schemas[db_name] = Schema(get_schema(db_path)) - schema = self.schemas[db_name] - g_sql = get_sql(schema, gold) - hardness = self.eval_hardness(g_sql) if idx > 3: idx = "> 4" else: idx += 1 turn_id = "turn " + str(idx) + self.scores[turn_id]["count"] += 1 - self.scores[hardness]["count"] += 1 self.scores["all"]["count"] += 1 - try: - p_sql = get_sql(schema, predicted) - except: - # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql - p_sql = { - "except": None, - "from": {"conds": [], "table_units": []}, - "groupBy": [], - "having": [], - "intersect": None, - "limit": None, - "orderBy": [], - "select": [False, []], - "union": None, - "where": [], - } + if self.etype in ['all', 'match']: + schema = self.schemas[db_name] + g_sql = get_sql(schema, gold) + hardness = self.eval_hardness(g_sql) + self.scores[hardness]["count"] += 1 + + try: + p_sql = get_sql(schema, predicted) + except: + # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql + p_sql = { + "except": None, + "from": {"conds": [], "table_units": []}, + "groupBy": [], + "having": [], + "intersect": None, + "limit": None, + "orderBy": [], + "select": [False, []], + "union": None, + "where": [], + } if self.etype in ["all", "exec"]: exec_score = eval_exec_match( @@ -646,7 +649,8 @@ def evaluate_one(self, db_name, gold, predicted, turn_scores, idx): progress_bar_for_each_datapoint=self.progress_bar_for_each_datapoint, ) if exec_score: - self.scores[hardness]["exec"] += 1 + if self.etype == 'all': + self.scores[hardness]["exec"] += 1 self.scores[turn_id]["exec"] += 1 self.scores["all"]["exec"] += 1 turn_scores["exec"].append(1) @@ -704,13 +708,19 @@ def evaluate_one(self, db_name, gold, predicted, turn_scores, idx): self.scores["all"]["partial"][type_]["rec_count"] += 1 self.scores["all"]["partial"][type_]["f1"] += partial_scores[type_]["f1"] - return { - "predictSQL": predicted, - "goldSQL": gold, + result = { + "predictSQL": predicted, + "goldSQL": gold, + } + if self.etype in ['all', 'match']: + result.extend({ "hardness": hardness, "exact": exact_score, "partial": partial_scores, - } + }) + if self.etype in ['all', 'exec']: + result['exec'] = exec_score + return result def finalize(self): scores = self.scores From 95b604496ff52cf9c075694f0fdd13b11288050d Mon Sep 17 00:00:00 2001 From: Dzmitry Bahdanau Date: Fri, 19 Mar 2021 13:52:03 +0000 Subject: [PATCH 09/12] greatly speed up query execution --- exec_eval.py | 83 +++++++++++++++++++++++++--------------------- exec_subprocess.py | 50 ---------------------------- 2 files changed, 46 insertions(+), 87 deletions(-) delete mode 100644 exec_subprocess.py diff --git a/exec_eval.py b/exec_eval.py index 914b7d5..490f3f7 100644 --- a/exec_eval.py +++ b/exec_eval.py @@ -1,16 +1,21 @@ import os +import re +import sqlite3 +import asyncio import threading from typing import Tuple, Any, List, Set from itertools import product from collections import defaultdict import tqdm import random -from .parse import get_all_preds_for_execution, remove_distinct import time import pickle as pkl import subprocess +import timeout_decorator from itertools import chain +from .parse import get_all_preds_for_execution, remove_distinct + threadLock = threading.Lock() TIMEOUT = 60 @@ -119,45 +124,49 @@ def result_eq(result1: List[Tuple], result2: List[Tuple], order_matters: bool) - return False -def clean_tmp_f(f_prefix: str): - with threadLock: - for suffix in (".in", ".out"): - f_path = f_prefix + suffix - if os.path.exists(f_path): - os.unlink(f_path) +def replace_cur_year(query: str) -> str: + return re.sub( + "YEAR\s*\(\s*CURDATE\s*\(\s*\)\s*\)\s*", "2020", query, flags=re.IGNORECASE + ) -# we need a wrapper, because simple timeout will not stop the database connection -def exec_on_db( - sqlite_path: str, query: str, process_id: str = "", timeout: int = TIMEOUT -) -> Tuple[str, Any]: - f_prefix = None - with threadLock: - while f_prefix is None or os.path.exists(f_prefix + ".in"): - process_id += str(time.time()) - process_id += str(random.randint(0, 10000000000)) - f_prefix = os.path.join(EXEC_TMP_DIR, process_id) - pkl.dump((sqlite_path, query), open(f_prefix + ".in", "wb")) +# get the database cursor for a sqlite database path +def get_cursor_from_path(sqlite_path: str): try: - exec_script_fpath = os.path.join(os.path.dirname(__file__), "exec_subprocess.py") - subprocess.call( - ["python3", exec_script_fpath, f_prefix], - timeout=timeout, - stderr=open(os.path.join(os.path.dirname(__file__), "tmp", "runerr.log"), "a"), - ) + if not os.path.exists(sqlite_path): + print("Openning a new connection %s" % sqlite_path) + connection = sqlite3.connect(sqlite_path) except Exception as e: - print(e) - clean_tmp_f(f_prefix) + print(sqlite_path) + raise e + connection.text_factory = lambda b: b.decode(errors="ignore") + cursor = connection.cursor() + return cursor + + +async def exec_on_db_(sqlite_path: str, query: str) -> Tuple[str, Any]: + query = replace_cur_year(query) + cursor = get_cursor_from_path(sqlite_path) + try: + cursor.execute(query) + result = cursor.fetchall() + cursor.close() + cursor.connection.close() + return "result", result + except Exception as e: + cursor.close() + cursor.connection.close() return "exception", e - result_path = f_prefix + ".out" - returned_val = ("exception", TimeoutError) + +async def exec_on_db( + sqlite_path: str, query: str, process_id: str = "", timeout: int = TIMEOUT +) -> Tuple[str, Any]: try: - if os.path.exists(result_path): - returned_val = pkl.load(open(result_path, "rb")) - except: - pass - clean_tmp_f(f_prefix) - return returned_val + return await asyncio.wait_for(exec_on_db_(sqlite_path, query), timeout) + except asyncio.TimeoutError: + return ('exception', TimeoutError) + except Exception as e: + return ("exception", e) # postprocess the model predictions to avoid execution errors @@ -225,13 +234,13 @@ def eval_exec_match( ranger = db_paths for db_path in ranger: - g_flag, g_denotation = exec_on_db(db_path, g_str) - p_flag, p_denotation = exec_on_db(db_path, pred) + g_flag, g_denotation = asyncio.run(exec_on_db(db_path, g_str)) + p_flag, p_denotation = asyncio.run(exec_on_db(db_path, pred)) # we should expect the gold to be succesfully executed on the database assert ( g_flag != "exception" - ), "gold query %s has error on database file %s" % (g_str, db_path) + ), f"gold query {g_str} has error {g_denotation} on database file {db_path}" # wrong if execution fails if p_flag == "exception": diff --git a/exec_subprocess.py b/exec_subprocess.py deleted file mode 100644 index 6bebeeb..0000000 --- a/exec_subprocess.py +++ /dev/null @@ -1,50 +0,0 @@ -import sys - -sys.path.append("./") -import os -import pickle as pkl -from typing import Tuple, Any -import sqlite3 -import re - - -def replace_cur_year(query: str) -> str: - return re.sub( - "YEAR\s*\(\s*CURDATE\s*\(\s*\)\s*\)\s*", "2020", query, flags=re.IGNORECASE - ) - - -# get the database cursor for a sqlite database path -def get_cursor_from_path(sqlite_path: str): - try: - if not os.path.exists(sqlite_path): - print("Openning a new connection %s" % sqlite_path) - connection = sqlite3.connect(sqlite_path) - except Exception as e: - print(sqlite_path) - raise e - connection.text_factory = lambda b: b.decode(errors="ignore") - cursor = connection.cursor() - return cursor - - -def exec_on_db_(sqlite_path: str, query: str) -> Tuple[str, Any]: - query = replace_cur_year(query) - cursor = get_cursor_from_path(sqlite_path) - try: - cursor.execute(query) - result = cursor.fetchall() - cursor.close() - cursor.connection.close() - return "result", result - except Exception as e: - cursor.close() - cursor.connection.close() - return "exception", e - - -f_prefix = sys.argv[1] -func_args = pkl.load(open(f_prefix + ".in", "rb")) -sqlite_path, query = func_args -result = exec_on_db_(sqlite_path, query) -pkl.dump(result, open(f_prefix + ".out", "wb")) From 4ab19e073caab34d2cabbbd5ad9c8571334001b9 Mon Sep 17 00:00:00 2001 From: Dzmitry Bahdanau Date: Fri, 19 Mar 2021 13:57:03 +0000 Subject: [PATCH 10/12] timeout_decorator was redundant --- exec_eval.py | 1 - 1 file changed, 1 deletion(-) diff --git a/exec_eval.py b/exec_eval.py index 490f3f7..62f66dd 100644 --- a/exec_eval.py +++ b/exec_eval.py @@ -11,7 +11,6 @@ import time import pickle as pkl import subprocess -import timeout_decorator from itertools import chain from .parse import get_all_preds_for_execution, remove_distinct From 6455603b43d0728442462104bbc0e1dbead47fd8 Mon Sep 17 00:00:00 2001 From: Dzmitry Bahdanau Date: Fri, 19 Mar 2021 19:33:01 +0000 Subject: [PATCH 11/12] catch prediction parsing error --- exec_eval.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/exec_eval.py b/exec_eval.py index 62f66dd..bcfc93d 100644 --- a/exec_eval.py +++ b/exec_eval.py @@ -194,7 +194,11 @@ def eval_exec_match( # e.g. removing spaces between ">" and "=" p_str, g_str = postprocess(p_str), postprocess(g_str) if not keep_distinct: - p_str = remove_distinct(p_str) + try: + # if sqlparse can't parse p_str, we should not even try to execute it + p_str = remove_distinct(p_str) + except Exception as e: + return 0 g_str = remove_distinct(g_str) # we decide whether two denotations are equivalent based on "bag semantics" From 61d71b8e902c0a5cac9ef81987fdcaebc27bfbc6 Mon Sep 17 00:00:00 2001 From: Torsten Scholak Date: Fri, 4 Jun 2021 18:42:27 +0000 Subject: [PATCH 12/12] extend -> update --- evaluation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evaluation.py b/evaluation.py index c368015..dfee517 100644 --- a/evaluation.py +++ b/evaluation.py @@ -713,7 +713,7 @@ def evaluate_one(self, db_name, gold, predicted, turn_scores, idx): "goldSQL": gold, } if self.etype in ['all', 'match']: - result.extend({ + result.update({ "hardness": hardness, "exact": exact_score, "partial": partial_scores,