Skip to content

Commit 0494727

Browse files
authored
chore(tool): add Postgres data source for the local run script (#1186)
1 parent c0f888d commit 0494727

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

ibis-server/tools/query_local_run.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import base64
1515
import json
1616
import os
17-
from app.model import MySqlConnectionInfo
17+
from app.model import MySqlConnectionInfo, PostgresConnectionInfo
1818
import sqlglot
1919
import sys
2020

@@ -30,7 +30,7 @@
3030
sql = sys.stdin.read()
3131

3232

33-
load_dotenv()
33+
load_dotenv(override=True)
3434
manifest_json_path = os.getenv("WREN_MANIFEST_JSON_PATH")
3535
function_list_path = os.getenv("REMOTE_FUNCTION_LIST_PATH")
3636
connection_info_path = os.getenv("CONNECTION_INFO_PATH")
@@ -59,28 +59,28 @@
5959

6060
print("### Starting the session context ###")
6161
print("#")
62-
session_context = SessionContext(encoded_str, function_list_path + f"{data_source}.csv")
62+
session_context = SessionContext(encoded_str, function_list_path + f"/{data_source}.csv")
6363
planned_sql = session_context.transform_sql(sql)
6464
print("# Planned SQL:\n", planned_sql)
6565

6666
# Transpile the planned SQL
67-
dialect_sql = sqlglot.transpile(planned_sql, read="trino", write=data_source)[0]
67+
dialect_sql = sqlglot.transpile(planned_sql, read=None, write=data_source)[0]
6868
print("# Dialect SQL:\n", dialect_sql)
6969
print("#")
7070

7171
if data_source == "bigquery":
7272
connection_info = BigQueryConnectionInfo.model_validate_json(json.dumps(connection_info))
7373
connection = DataSourceExtension.get_bigquery_connection(connection_info)
74-
df = connection.sql(dialect_sql).limit(10).to_pandas()
75-
print("### Result ###")
76-
print("")
77-
print(df)
7874
elif data_source == "mysql":
7975
connection_info = MySqlConnectionInfo.model_validate_json(json.dumps(connection_info))
8076
connection = DataSourceExtension.get_mysql_connection(connection_info)
81-
df = connection.sql(dialect_sql).limit(10).to_pandas()
82-
print("### Result ###")
83-
print("")
84-
print(df)
77+
elif data_source == "postgres":
78+
connection_info = PostgresConnectionInfo.model_validate_json(json.dumps(connection_info))
79+
connection = DataSourceExtension.get_postgres_connection(connection_info)
8580
else:
86-
print("Unsupported data source:", data_source)
81+
raise Exception("Unsupported data source:", data_source)
82+
83+
df = connection.sql(dialect_sql).limit(10).to_pandas()
84+
print("### Result ###")
85+
print("")
86+
print(df)

0 commit comments

Comments
 (0)