diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..60db9dd --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +database/* +*/*.pyc +tmp +!database/readme.txt +!tmp/readme.txt diff --git a/alter_michigan_databases.sh b/alter_michigan_databases.sh new file mode 100644 index 0000000..32aeabf --- /dev/null +++ b/alter_michigan_databases.sh @@ -0,0 +1,66 @@ +#~/usr/bin/env bash +set -e + +DATABASE_DIR=. + +copy_databases () { + db=$1 + # Copy to *_test directory + altered=$DATABASE_DIR/${db}_test + cp -r "$DATABASE_DIR/$db" "$altered" + + # Rename .sqlite files + cd "$altered" + for f in ${db}*.sqlite + do + mv "$f" "${db}_test${f#${db}}" + done + cd - +} + +alter_yelp () { + for f in `ls $DATABASE_DIR/yelp_test/*.sqlite` + do + echo "ALTER TABLE neighbourhood RENAME TO neighborhood" | sqlite3 "$f" + echo "ALTER TABLE neighborhood RENAME COLUMN neighbourhood_name TO neighborhood_name" | sqlite3 "$f" + done +} + +alter_imdb () { + for f in `ls $DATABASE_DIR/imdb_test/*.sqlite` + do + echo "ALTER TABLE cast RENAME TO cast2" | sqlite3 "$f" + done +} + +alter_academic () { + : +} + +alter_geo () { + : +} + +alter_scholar () { + : +} + +# geo is an exception in that we want to change the name from "geography" to "geo_test" +# it is easiest to achieve this is by copying "geography" to "geo" first +if [ ! -d $DATABASE_DIR/geo ] +then + cp -r $DATABASE_DIR/geography $DATABASE_DIR/geo + mv $DATABASE_DIR/geo/geography.sqlite $DATABASE_DIR/geo/geo.sqlite +fi + +for DB in imdb yelp academic geo scholar +do + echo $DB + if [ ! -d "$DATABASE_DIR/${DB}_test" ] + then + copy_databases $DB + alter_"$DB" + else + echo "$DATABASE_DIR/${DB}_test already exists" + fi +done diff --git a/evaluate_classical.py b/evaluate_classical.py index f5437e0..c1239ff 100644 --- a/evaluate_classical.py +++ b/evaluate_classical.py @@ -2,7 +2,7 @@ from typing import List, Dict, Any, Tuple import pickle as pkl import tqdm -from exec_eval import exec_on_db, result_eq +from .exec_eval import exec_on_db, result_eq import os from collections import defaultdict import time diff --git a/evaluation.py b/evaluation.py index 8e99534..dfee517 100644 --- a/evaluation.py +++ b/evaluation.py @@ -24,10 +24,24 @@ import sqlite3 import argparse -from process_sql import get_schema, Schema, get_sql -from exec_eval import eval_exec_match +from .process_sql import get_schema, Schema, get_sql +from .exec_eval import eval_exec_match # Flag to disable value evaluation +LEVELS = ["easy", "medium", "hard", "extra", "all", "joint_all"] +TURNS = ["turn 1", "turn 2", "turn 3", "turn 4", "turn > 4"] +PARTIAL_TYPES = [ + "select", + "select(no AGG)", + "where", + "where(no OP)", + "group(no Having)", + "group", + "order", + "and/or", + "IUEN", + "keywords", +] DISABLE_VALUE = True # Flag to disable distinct in select evaluation DISABLE_DISTINCT = True @@ -260,7 +274,8 @@ def eval_nested(pred, label): if label is not None: label_total += 1 if pred is not None and label is not None: - cnt += Evaluator().eval_exact_match(pred, label) + partial_scores = Evaluator.eval_partial_match(pred, label) + cnt += Evaluator.eval_exact_match(pred, label, partial_scores) return label_total, pred_total, cnt @@ -415,8 +430,42 @@ def count_others(sql): class Evaluator: """A simple evaluator""" - def __init__(self): - self.partial_scores = None + def __init__( + self, + db_dir, + kmaps, + etype, + plug_value, + keep_distinct, + progress_bar_for_each_datapoint + ): + self.db_dir = db_dir + self.kmaps = kmaps + self.etype = etype + self.plug_value = plug_value + self.keep_distinct = keep_distinct + self.progress_bar_for_each_datapoint = progress_bar_for_each_datapoint + + self.db_paths = {} + self.schemas = {} + + self.scores = {} + + for turn in TURNS: + self.scores[turn] = {"count": 0, "exact": 0.0} + self.scores[turn]["exec"] = 0 + + for level in LEVELS: + self.scores[level] = {"count": 0, "partial": {}, "exact": 0.0} + self.scores[level]["exec"] = 0 + for type_ in PARTIAL_TYPES: + self.scores[level]["partial"][type_] = { + "acc": 0.0, + "rec": 0.0, + "f1": 0.0, + "acc_count": 0, + "rec_count": 0, + } def eval_hardness(self, sql): count_comp1_ = count_component1(sql) @@ -438,10 +487,8 @@ def eval_hardness(self, sql): else: return "extra" - def eval_exact_match(self, pred, label): - partial_scores = self.eval_partial_match(pred, label) - self.partial_scores = partial_scores - + @classmethod + def eval_exact_match(cls, pred, label, partial_scores): for key, score in partial_scores.items(): if score["f1"] != 1: return 0 @@ -452,7 +499,8 @@ def eval_exact_match(self, pred, label): return label_tables == pred_tables return 1 - def eval_partial_match(self, pred, label): + @classmethod + def eval_partial_match(cls, pred, label): res = {} label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label) @@ -553,6 +601,179 @@ def eval_partial_match(self, pred, label): return res + def evaluate_one(self, db_name, gold, predicted, turn_scores, idx): + if db_name not in self.db_paths: + db_path = os.path.join(self.db_dir, db_name, db_name + ".sqlite") + self.db_paths[db_name] = db_path + self.schemas[db_name] = Schema(get_schema(db_path)) + + if idx > 3: + idx = "> 4" + else: + idx += 1 + turn_id = "turn " + str(idx) + + self.scores[turn_id]["count"] += 1 + self.scores["all"]["count"] += 1 + + if self.etype in ['all', 'match']: + schema = self.schemas[db_name] + g_sql = get_sql(schema, gold) + hardness = self.eval_hardness(g_sql) + self.scores[hardness]["count"] += 1 + + try: + p_sql = get_sql(schema, predicted) + except: + # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql + p_sql = { + "except": None, + "from": {"conds": [], "table_units": []}, + "groupBy": [], + "having": [], + "intersect": None, + "limit": None, + "orderBy": [], + "select": [False, []], + "union": None, + "where": [], + } + + if self.etype in ["all", "exec"]: + exec_score = eval_exec_match( + db=self.db_paths[db_name], + p_str=predicted, + g_str=gold, + plug_value=self.plug_value, + keep_distinct=self.keep_distinct, + progress_bar_for_each_datapoint=self.progress_bar_for_each_datapoint, + ) + if exec_score: + if self.etype == 'all': + self.scores[hardness]["exec"] += 1 + self.scores[turn_id]["exec"] += 1 + self.scores["all"]["exec"] += 1 + turn_scores["exec"].append(1) + else: + turn_scores["exec"].append(0) + + if self.etype in ["all", "match"]: + # rebuild sql for value evaluation + kmap = self.kmaps[db_name] + g_valid_col_units = build_valid_col_units( + g_sql["from"]["table_units"], schema + ) + g_sql = rebuild_sql_val(g_sql) + g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) + p_valid_col_units = build_valid_col_units( + p_sql["from"]["table_units"], schema + ) + p_sql = rebuild_sql_val(p_sql) + p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) + partial_scores = self.eval_partial_match(p_sql, g_sql) + exact_score = self.eval_exact_match(p_sql, g_sql, partial_scores) + if exact_score == 0: + turn_scores["exact"].append(0) + print("{} pred: {}".format(hardness, predicted)) + print("{} gold: {}".format(hardness, gold)) + print("") + else: + turn_scores["exact"].append(1) + self.scores[turn_id]["exact"] += exact_score + self.scores[hardness]["exact"] += exact_score + self.scores["all"]["exact"] += exact_score + for type_ in PARTIAL_TYPES: + if partial_scores[type_]["pred_total"] > 0: + self.scores[hardness]["partial"][type_]["acc"] += partial_scores[ + type_ + ]["acc"] + self.scores[hardness]["partial"][type_]["acc_count"] += 1 + if partial_scores[type_]["label_total"] > 0: + self.scores[hardness]["partial"][type_]["rec"] += partial_scores[ + type_ + ]["rec"] + self.scores[hardness]["partial"][type_]["rec_count"] += 1 + self.scores[hardness]["partial"][type_]["f1"] += partial_scores[type_][ + "f1" + ] + if partial_scores[type_]["pred_total"] > 0: + self.scores["all"]["partial"][type_]["acc"] += partial_scores[type_][ + "acc" + ] + self.scores["all"]["partial"][type_]["acc_count"] += 1 + if partial_scores[type_]["label_total"] > 0: + self.scores["all"]["partial"][type_]["rec"] += partial_scores[type_][ + "rec" + ] + self.scores["all"]["partial"][type_]["rec_count"] += 1 + self.scores["all"]["partial"][type_]["f1"] += partial_scores[type_]["f1"] + + result = { + "predictSQL": predicted, + "goldSQL": gold, + } + if self.etype in ['all', 'match']: + result.update({ + "hardness": hardness, + "exact": exact_score, + "partial": partial_scores, + }) + if self.etype in ['all', 'exec']: + result['exec'] = exec_score + return result + + def finalize(self): + scores = self.scores + for turn in TURNS: + if scores[turn]["count"] == 0: + continue + if self.etype in ["all", "exec"]: + scores[turn]["exec"] /= scores[turn]["count"] + + if self.etype in ["all", "match"]: + scores[turn]["exact"] /= scores[turn]["count"] + + for level in LEVELS: + if scores[level]["count"] == 0: + continue + if self.etype in ["all", "exec"]: + scores[level]["exec"] /= scores[level]["count"] + + if self.etype in ["all", "match"]: + scores[level]["exact"] /= scores[level]["count"] + for type_ in PARTIAL_TYPES: + if scores[level]["partial"][type_]["acc_count"] == 0: + scores[level]["partial"][type_]["acc"] = 0 + else: + scores[level]["partial"][type_]["acc"] = ( + scores[level]["partial"][type_]["acc"] + / scores[level]["partial"][type_]["acc_count"] + * 1.0 + ) + if scores[level]["partial"][type_]["rec_count"] == 0: + scores[level]["partial"][type_]["rec"] = 0 + else: + scores[level]["partial"][type_]["rec"] = ( + scores[level]["partial"][type_]["rec"] + / scores[level]["partial"][type_]["rec_count"] + * 1.0 + ) + if ( + scores[level]["partial"][type_]["acc"] == 0 + and scores[level]["partial"][type_]["rec"] == 0 + ): + scores[level]["partial"][type_]["f1"] = 1 + else: + scores[level]["partial"][type_]["f1"] = ( + 2.0 + * scores[level]["partial"][type_]["acc"] + * scores[level]["partial"][type_]["rec"] + / ( + scores[level]["partial"][type_]["rec"] + + scores[level]["partial"][type_]["acc"] + ) + ) + def isValidSQL(sql, db): conn = sqlite3.connect(db) @@ -570,22 +791,11 @@ def print_formated_s(row_name, l, element_format): def print_scores(scores, etype, include_turn_acc=True): - turns = ["turn 1", "turn 2", "turn 3", "turn 4", "turn > 4"] + turns = TURNS levels = ["easy", "medium", "hard", "extra", "all"] if include_turn_acc: levels.append("joint_all") - partial_types = [ - "select", - "select(no AGG)", - "where", - "where(no OP)", - "group(no Having)", - "group", - "order", - "and/or", - "IUEN", - "keywords", - ] + partial_types = PARTIAL_TYPES print_formated_s("", levels, "{:20}") counts = [scores[level]["count"] for level in levels] @@ -684,217 +894,34 @@ def evaluate( assert len(plist) == len(glist), "number of sessions must equal" - evaluator = Evaluator() - turns = ["turn 1", "turn 2", "turn 3", "turn 4", "turn > 4"] - levels = ["easy", "medium", "hard", "extra", "all", "joint_all"] - - partial_types = [ - "select", - "select(no AGG)", - "where", - "where(no OP)", - "group(no Having)", - "group", - "order", - "and/or", - "IUEN", - "keywords", - ] - entries = [] - scores = {} - - for turn in turns: - scores[turn] = {"count": 0, "exact": 0.0} - scores[turn]["exec"] = 0 - - for level in levels: - scores[level] = {"count": 0, "partial": {}, "exact": 0.0} - scores[level]["exec"] = 0 - for type_ in partial_types: - scores[level]["partial"][type_] = { - "acc": 0.0, - "rec": 0.0, - "f1": 0.0, - "acc_count": 0, - "rec_count": 0, - } + evaluator = Evaluator(db_dir, kmaps, etype, plug_value, keep_distinct, progress_bar_for_each_datapoint) + results = [] for i, (p, g) in enumerate(zip(plist, glist)): if (i + 1) % 10 == 0: print("Evaluating %dth prediction" % (i + 1)) - scores["joint_all"]["count"] += 1 + evaluator.scores["joint_all"]["count"] += 1 turn_scores = {"exec": [], "exact": []} for idx, pg in enumerate(zip(p, g)): p, g = pg p_str = p[0] p_str = p_str.replace("value", "1") - g_str, db = g - db_name = db - db = os.path.join(db_dir, db, db + ".sqlite") - schema = Schema(get_schema(db)) - g_sql = get_sql(schema, g_str) - hardness = evaluator.eval_hardness(g_sql) - if idx > 3: - idx = "> 4" - else: - idx += 1 - turn_id = "turn " + str(idx) - scores[turn_id]["count"] += 1 - scores[hardness]["count"] += 1 - scores["all"]["count"] += 1 - - try: - p_sql = get_sql(schema, p_str) - except: - # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql - p_sql = { - "except": None, - "from": {"conds": [], "table_units": []}, - "groupBy": [], - "having": [], - "intersect": None, - "limit": None, - "orderBy": [], - "select": [False, []], - "union": None, - "where": [], - } + g_str, db_name = g - if etype in ["all", "exec"]: - exec_score = eval_exec_match( - db=db, - p_str=p_str, - g_str=g_str, - plug_value=plug_value, - keep_distinct=keep_distinct, - progress_bar_for_each_datapoint=progress_bar_for_each_datapoint, - ) - if exec_score: - scores[hardness]["exec"] += 1 - scores[turn_id]["exec"] += 1 - scores["all"]["exec"] += 1 - turn_scores["exec"].append(1) - else: - turn_scores["exec"].append(0) - - if etype in ["all", "match"]: - # rebuild sql for value evaluation - kmap = kmaps[db_name] - g_valid_col_units = build_valid_col_units( - g_sql["from"]["table_units"], schema - ) - g_sql = rebuild_sql_val(g_sql) - g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) - p_valid_col_units = build_valid_col_units( - p_sql["from"]["table_units"], schema - ) - p_sql = rebuild_sql_val(p_sql) - p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) - exact_score = evaluator.eval_exact_match(p_sql, g_sql) - partial_scores = evaluator.partial_scores - if exact_score == 0: - turn_scores["exact"].append(0) - print("{} pred: {}".format(hardness, p_str)) - print("{} gold: {}".format(hardness, g_str)) - print("") - else: - turn_scores["exact"].append(1) - scores[turn_id]["exact"] += exact_score - scores[hardness]["exact"] += exact_score - scores["all"]["exact"] += exact_score - for type_ in partial_types: - if partial_scores[type_]["pred_total"] > 0: - scores[hardness]["partial"][type_]["acc"] += partial_scores[ - type_ - ]["acc"] - scores[hardness]["partial"][type_]["acc_count"] += 1 - if partial_scores[type_]["label_total"] > 0: - scores[hardness]["partial"][type_]["rec"] += partial_scores[ - type_ - ]["rec"] - scores[hardness]["partial"][type_]["rec_count"] += 1 - scores[hardness]["partial"][type_]["f1"] += partial_scores[type_][ - "f1" - ] - if partial_scores[type_]["pred_total"] > 0: - scores["all"]["partial"][type_]["acc"] += partial_scores[type_][ - "acc" - ] - scores["all"]["partial"][type_]["acc_count"] += 1 - if partial_scores[type_]["label_total"] > 0: - scores["all"]["partial"][type_]["rec"] += partial_scores[type_][ - "rec" - ] - scores["all"]["partial"][type_]["rec_count"] += 1 - scores["all"]["partial"][type_]["f1"] += partial_scores[type_]["f1"] - - entries.append( - { - "predictSQL": p_str, - "goldSQL": g_str, - "hardness": hardness, - "exact": exact_score, - "partial": partial_scores, - } - ) + results.append(evaluator.evaluate_one(db_name, g_str, p_str, turn_scores, idx)) if all(v == 1 for v in turn_scores["exec"]): - scores["joint_all"]["exec"] += 1 + evaluator.scores["joint_all"]["exec"] += 1 if all(v == 1 for v in turn_scores["exact"]): - scores["joint_all"]["exact"] += 1 - - for turn in turns: - if scores[turn]["count"] == 0: - continue - if etype in ["all", "exec"]: - scores[turn]["exec"] /= scores[turn]["count"] - - if etype in ["all", "match"]: - scores[turn]["exact"] /= scores[turn]["count"] - - for level in levels: - if scores[level]["count"] == 0: - continue - if etype in ["all", "exec"]: - scores[level]["exec"] /= scores[level]["count"] - - if etype in ["all", "match"]: - scores[level]["exact"] /= scores[level]["count"] - for type_ in partial_types: - if scores[level]["partial"][type_]["acc_count"] == 0: - scores[level]["partial"][type_]["acc"] = 0 - else: - scores[level]["partial"][type_]["acc"] = ( - scores[level]["partial"][type_]["acc"] - / scores[level]["partial"][type_]["acc_count"] - * 1.0 - ) - if scores[level]["partial"][type_]["rec_count"] == 0: - scores[level]["partial"][type_]["rec"] = 0 - else: - scores[level]["partial"][type_]["rec"] = ( - scores[level]["partial"][type_]["rec"] - / scores[level]["partial"][type_]["rec_count"] - * 1.0 - ) - if ( - scores[level]["partial"][type_]["acc"] == 0 - and scores[level]["partial"][type_]["rec"] == 0 - ): - scores[level]["partial"][type_]["f1"] = 1 - else: - scores[level]["partial"][type_]["f1"] = ( - 2.0 - * scores[level]["partial"][type_]["acc"] - * scores[level]["partial"][type_]["rec"] - / ( - scores[level]["partial"][type_]["rec"] - + scores[level]["partial"][type_]["acc"] - ) - ) - - print_scores(scores, etype, include_turn_acc=include_turn_acc) + evaluator.scores["joint_all"]["exact"] += 1 + + evaluator.finalize() + print_scores(evaluator.scores, etype, include_turn_acc=include_turn_acc) + return { + "per_item": results, + "total_scores": evaluator.scores + } # Rebuild SQL functions for value evaluation diff --git a/exec_eval.py b/exec_eval.py index d266adb..bcfc93d 100644 --- a/exec_eval.py +++ b/exec_eval.py @@ -1,20 +1,24 @@ import os +import re +import sqlite3 +import asyncio import threading from typing import Tuple, Any, List, Set from itertools import product from collections import defaultdict import tqdm import random -from parse import get_all_preds_for_execution, remove_distinct import time import pickle as pkl import subprocess from itertools import chain +from .parse import get_all_preds_for_execution, remove_distinct + threadLock = threading.Lock() TIMEOUT = 60 -EXEC_TMP_DIR = "tmp/" +EXEC_TMP_DIR = os.path.join(os.path.dirname(__file__), "tmp") def permute_tuple(element: Tuple, perm: Tuple) -> Tuple: @@ -119,44 +123,49 @@ def result_eq(result1: List[Tuple], result2: List[Tuple], order_matters: bool) - return False -def clean_tmp_f(f_prefix: str): - with threadLock: - for suffix in (".in", ".out"): - f_path = f_prefix + suffix - if os.path.exists(f_path): - os.unlink(f_path) +def replace_cur_year(query: str) -> str: + return re.sub( + "YEAR\s*\(\s*CURDATE\s*\(\s*\)\s*\)\s*", "2020", query, flags=re.IGNORECASE + ) -# we need a wrapper, because simple timeout will not stop the database connection -def exec_on_db( - sqlite_path: str, query: str, process_id: str = "", timeout: int = TIMEOUT -) -> Tuple[str, Any]: - f_prefix = None - with threadLock: - while f_prefix is None or os.path.exists(f_prefix + ".in"): - process_id += str(time.time()) - process_id += str(random.randint(0, 10000000000)) - f_prefix = os.path.join(EXEC_TMP_DIR, process_id) - pkl.dump((sqlite_path, query), open(f_prefix + ".in", "wb")) +# get the database cursor for a sqlite database path +def get_cursor_from_path(sqlite_path: str): try: - subprocess.call( - ["python3", "exec_subprocess.py", f_prefix], - timeout=timeout, - stderr=open("runerr.log", "a"), - ) + if not os.path.exists(sqlite_path): + print("Openning a new connection %s" % sqlite_path) + connection = sqlite3.connect(sqlite_path) except Exception as e: - print(e) - clean_tmp_f(f_prefix) + print(sqlite_path) + raise e + connection.text_factory = lambda b: b.decode(errors="ignore") + cursor = connection.cursor() + return cursor + + +async def exec_on_db_(sqlite_path: str, query: str) -> Tuple[str, Any]: + query = replace_cur_year(query) + cursor = get_cursor_from_path(sqlite_path) + try: + cursor.execute(query) + result = cursor.fetchall() + cursor.close() + cursor.connection.close() + return "result", result + except Exception as e: + cursor.close() + cursor.connection.close() return "exception", e - result_path = f_prefix + ".out" - returned_val = ("exception", TimeoutError) + +async def exec_on_db( + sqlite_path: str, query: str, process_id: str = "", timeout: int = TIMEOUT +) -> Tuple[str, Any]: try: - if os.path.exists(result_path): - returned_val = pkl.load(open(result_path, "rb")) - except: - pass - clean_tmp_f(f_prefix) - return returned_val + return await asyncio.wait_for(exec_on_db_(sqlite_path, query), timeout) + except asyncio.TimeoutError: + return ('exception', TimeoutError) + except Exception as e: + return ("exception", e) # postprocess the model predictions to avoid execution errors @@ -185,7 +194,11 @@ def eval_exec_match( # e.g. removing spaces between ">" and "=" p_str, g_str = postprocess(p_str), postprocess(g_str) if not keep_distinct: - p_str = remove_distinct(p_str) + try: + # if sqlparse can't parse p_str, we should not even try to execute it + p_str = remove_distinct(p_str) + except Exception as e: + return 0 g_str = remove_distinct(g_str) # we decide whether two denotations are equivalent based on "bag semantics" @@ -224,13 +237,13 @@ def eval_exec_match( ranger = db_paths for db_path in ranger: - g_flag, g_denotation = exec_on_db(db_path, g_str) - p_flag, p_denotation = exec_on_db(db_path, pred) + g_flag, g_denotation = asyncio.run(exec_on_db(db_path, g_str)) + p_flag, p_denotation = asyncio.run(exec_on_db(db_path, pred)) # we should expect the gold to be succesfully executed on the database assert ( g_flag != "exception" - ), "gold query %s has error on database file %s" % (g_str, db_path) + ), f"gold query {g_str} has error {g_denotation} on database file {db_path}" # wrong if execution fails if p_flag == "exception": diff --git a/exec_subprocess.py b/exec_subprocess.py deleted file mode 100644 index 6bebeeb..0000000 --- a/exec_subprocess.py +++ /dev/null @@ -1,50 +0,0 @@ -import sys - -sys.path.append("./") -import os -import pickle as pkl -from typing import Tuple, Any -import sqlite3 -import re - - -def replace_cur_year(query: str) -> str: - return re.sub( - "YEAR\s*\(\s*CURDATE\s*\(\s*\)\s*\)\s*", "2020", query, flags=re.IGNORECASE - ) - - -# get the database cursor for a sqlite database path -def get_cursor_from_path(sqlite_path: str): - try: - if not os.path.exists(sqlite_path): - print("Openning a new connection %s" % sqlite_path) - connection = sqlite3.connect(sqlite_path) - except Exception as e: - print(sqlite_path) - raise e - connection.text_factory = lambda b: b.decode(errors="ignore") - cursor = connection.cursor() - return cursor - - -def exec_on_db_(sqlite_path: str, query: str) -> Tuple[str, Any]: - query = replace_cur_year(query) - cursor = get_cursor_from_path(sqlite_path) - try: - cursor.execute(query) - result = cursor.fetchall() - cursor.close() - cursor.connection.close() - return "result", result - except Exception as e: - cursor.close() - cursor.connection.close() - return "exception", e - - -f_prefix = sys.argv[1] -func_args = pkl.load(open(f_prefix + ".in", "rb")) -sqlite_path, query = func_args -result = exec_on_db_(sqlite_path, query) -pkl.dump(result, open(f_prefix + ".out", "wb"))