Skip to content

Commit 43079c0

Browse files
authored
Allow passing database for pinot queries (#89)
1 parent 741e6be commit 43079c0

11 files changed

+51
-22
lines changed

examples/pinot_async.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88
async def run_pinot_async_example():
99
async with connect_async(host='localhost', port=8000, path='/query/sql',
10-
scheme='http', verify_ssl=False, timeout=10.0) as conn:
10+
scheme='http', verify_ssl=False, timeout=10.0,
11+
extra_request_headers="Database=default") as conn:
1112
curs = await conn.execute("""
1213
SELECT count(*)
1314
FROM baseballStats
@@ -20,7 +21,7 @@ async def run_pinot_async_example():
2021
session = httpx.AsyncClient(verify=False)
2122
conn = connect_async(
2223
host='localhost', port=8000, path='/query/sql', scheme='http',
23-
verify_ssl=False, session=session)
24+
verify_ssl=False, session=session, extra_request_headers="Database=default")
2425

2526
# launch 10 requests in parallel spanning a limit/offset range
2627
reqs = []

examples/pinot_live.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
def run_pinot_live_example() -> None:
1010
# Query pinot.live with pinotdb connect
11-
conn = connect(host="pinot-broker.pinot.live", port=443, path="/query/sql", scheme="https")
11+
conn = connect(host="pinot-broker.pinot.live", port=443, path="/query/sql", scheme="https",
12+
extra_request_headers="Database=default")
1213
curs = conn.cursor()
1314
sql = "SELECT * FROM airlineStats LIMIT 5"
1415
print(f"Sending SQL to Pinot: {sql}")
@@ -21,7 +22,7 @@ def run_pinot_live_example() -> None:
2122
"pinot+https://pinot-broker.pinot.live:443/query/sql?controller=https://pinot-controller.pinot.live/"
2223
) # uses HTTP by default :(
2324

24-
airlineStats = Table("airlineStats", MetaData(bind=engine), autoload=True)
25+
airlineStats = Table("airlineStats", MetaData(bind=engine), autoload=True, schema="default")
2526
print(f"\nSending Count(*) SQL to Pinot")
2627
query=select([func.count("*")], from_obj=airlineStats)
2728
print(engine.execute(query).scalar())

examples/pinot_quickstart_auth_zk.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def run_pinot_quickstart_batch_example() -> None:
2020
scheme="http",
2121
username="admin",
2222
password="verysecret",
23+
extra_request_headers="Database=default",
2324
)
2425
curs = conn.cursor()
2526
tables = [
@@ -65,7 +66,7 @@ def run_pinot_quickstart_batch_sqlalchemy_example() -> None:
6566
# engine = create_engine('pinot+http://localhost:8000/query/sql?controller=http://localhost:9000/')
6667
# engine = create_engine('pinot+https://localhost:8000/query/sql?controller=http://localhost:9000/')
6768

68-
baseballStats = Table("baseballStats", MetaData(bind=engine), autoload=True)
69+
baseballStats = Table("baseballStats", MetaData(bind=engine), autoload=True, schema="default")
6970
print(f"\nSending Count(*) SQL to Pinot")
7071
query = select([func.count("*")], from_obj=baseballStats)
7172
print(engine.execute(query).scalar())

examples/pinot_quickstart_batch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212

1313

1414
def run_pinot_quickstart_batch_example() -> None:
15-
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http")
15+
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http",
16+
extra_request_headers="Database=default")
1617
curs = conn.cursor()
1718

1819
tables = [
@@ -52,7 +53,7 @@ def run_pinot_quickstart_batch_sqlalchemy_example() -> None:
5253
# engine = create_engine('pinot+http://localhost:8000/query/sql?controller=http://localhost:9000/')
5354
# engine = create_engine('pinot+https://localhost:8000/query/sql?controller=http://localhost:9000/')
5455

55-
baseballStats = Table("baseballStats", MetaData(bind=engine), autoload=True)
56+
baseballStats = Table("baseballStats", MetaData(bind=engine), autoload=True, schema="default")
5657
print(f"\nSending Count(*) SQL to Pinot")
5758
query = select([func.count("*")], from_obj=baseballStats)
5859
print(engine.execute(query).scalar())

examples/pinot_quickstart_hybrid.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
## -d apachepinot/pinot:latest QuickStart -type hybrid
1212

1313
def run_pinot_quickstart_hybrid_example() -> None:
14-
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http")
14+
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http",
15+
extra_request_headers="Database=default")
1516
curs = conn.cursor()
1617
sql = "SELECT * FROM airlineStats LIMIT 5"
1718
print(f"Sending SQL to Pinot: {sql}")
@@ -53,7 +54,7 @@ def run_pinot_quickstart_hybrid_sqlalchemy_example() -> None:
5354
# engine = create_engine('pinot+http://localhost:8000/query/sql?controller=http://localhost:9000/')
5455
# engine = create_engine('pinot+https://localhost:8000/query/sql?controller=http://localhost:9000/')
5556

56-
airlineStats = Table("airlineStats", MetaData(bind=engine), autoload=True)
57+
airlineStats = Table("airlineStats", MetaData(bind=engine), autoload=True, schema="default")
5758
print(f"\nSending Count(*) SQL to Pinot")
5859
query=select([func.count("*")], from_obj=airlineStats)
5960
print(engine.execute(query).scalar())

examples/pinot_quickstart_json_batch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313

1414

1515
def run_quickstart_json_batch_example() -> None:
16-
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http")
16+
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http",
17+
extra_request_headers="Database=default")
1718
curs = conn.cursor()
1819
sql = "SELECT * FROM githubEvents LIMIT 5"
1920
print(f"Sending SQL to Pinot: {sql}")
@@ -43,7 +44,7 @@ def run_quickstart_json_batch_sqlalchemy_example() -> None:
4344
# engine = create_engine('pinot+http://localhost:8000/query/sql?controller=http://localhost:9000/')
4445
# engine = create_engine('pinot+https://localhost:8000/query/sql?controller=http://localhost:9000/')
4546

46-
githubEvents = Table("githubEvents", MetaData(bind=engine), autoload=True)
47+
githubEvents = Table("githubEvents", MetaData(bind=engine), autoload=True, schema="default")
4748
print(f"\nSending Count(*) SQL to Pinot\nResults:")
4849
query=select([func.count("*")], from_obj=githubEvents)
4950
print(engine.execute(query).scalar())

examples/pinot_quickstart_multi_stage.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
## -d apachepinot/pinot:latest QuickStart -type MULTI_STAGE
1212

1313
def run_pinot_quickstart_multi_stage_example() -> None:
14-
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", use_multistage_engine=True)
14+
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", use_multistage_engine=True,
15+
extra_request_headers="Database=default")
1516
curs = conn.cursor()
1617

1718
sql = "SELECT a.playerID, a.runs, a.yearID, b.runs, b.yearID FROM baseballStats_OFFLINE AS a JOIN baseballStats_OFFLINE AS b ON a.playerID = b.playerID WHERE a.runs > 160 AND b.runs < 2 LIMIT 10"

examples/pinot_quickstart_timeout.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ def run_pinot_quickstart_timeout_example() -> None:
1111

1212
#Test 1 : Try without timeout. The request should succeed.
1313

14-
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http")
14+
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http",
15+
extra_request_headers="Database=default")
1516
curs = conn.cursor()
1617
sql = "SELECT * FROM airlineStats LIMIT 5"
1718
print(f"Sending SQL to Pinot: {sql}")
@@ -20,7 +21,8 @@ def run_pinot_quickstart_timeout_example() -> None:
2021

2122
#Test 2 : Try with timeout=None. The request should succeed.
2223

23-
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", timeout=None)
24+
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", timeout=None,
25+
extra_request_headers="Database=default")
2426
curs = conn.cursor()
2527
sql = "SELECT count(*) FROM airlineStats LIMIT 5"
2628
print(f"Sending SQL to Pinot: {sql}")
@@ -29,7 +31,8 @@ def run_pinot_quickstart_timeout_example() -> None:
2931

3032
#Test 3 : Try with a really small timeout. The query should raise an exception.
3133

32-
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", timeout=0.001)
34+
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", timeout=0.001,
35+
extra_request_headers="Database=default")
3336
curs = conn.cursor()
3437
sql = "SELECT AirlineID, sum(Cancelled) FROM airlineStats WHERE Year > 2010 GROUP BY AirlineID LIMIT 5"
3538
print(f"Sending SQL to Pinot: {sql}")

pinotdb/db.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def close(self):
162162
except exceptions.Error:
163163
pass # already closed
164164
# if we're managing the httpx session, attempt to close it
165-
if not self.is_session_external:
165+
if not self.is_session_external and self.session:
166166
self.session.close()
167167

168168
@check_closed
@@ -334,7 +334,8 @@ def __init__(
334334
for header in extra_request_headers.split(","):
335335
k, v = header.split("=", 1)
336336
extra_headers[k] = v
337-
337+
if 'database' in kwargs:
338+
extra_headers['database'] = kwargs['database']
338339
self.session.headers.update(extra_headers)
339340

340341
@check_closed

pinotdb/sqlalchemy.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,11 @@ def __init__(
123123
)
124124

125125

126+
def extract_table_name(fqn):
127+
split = fqn.split(".", 2)
128+
return fqn if len(split) == 1 else split[1]
129+
130+
126131
class PinotDialect(default.DefaultDialect):
127132

128133
name = "pinot"
@@ -132,6 +137,7 @@ class PinotDialect(default.DefaultDialect):
132137
preparer = PinotIdentifierPareparer
133138
statement_compiler = PinotCompiler
134139
type_compiler = PinotTypeCompiler
140+
supports_schemas = False
135141
supports_statement_cache = False
136142
supports_alter = False
137143
supports_pk_autoincrement = False
@@ -154,6 +160,7 @@ def __init__(self, *args, **kwargs):
154160
self._password = None
155161
self._debug = False
156162
self._verify_ssl = True
163+
self._database = None
157164
self.update_from_kwargs(kwargs)
158165

159166
def update_from_kwargs(self, givenkw):
@@ -167,6 +174,8 @@ def update_from_kwargs(self, givenkw):
167174
kwargs["username"] = self._username = kwargs.pop("username")
168175
if "password" in kwargs:
169176
kwargs["password"] = self._password = kwargs.pop("password")
177+
if "database" in kwargs:
178+
kwargs["database"] = self._database = kwargs.pop("database")
170179
kwargs["debug"] = self._debug = bool(kwargs.get("debug", False))
171180
kwargs["verify_ssl"] = self._verify_ssl = (str(kwargs.get("verify_ssl", "true")).lower() in ['true'])
172181
logger.info(
@@ -206,7 +215,7 @@ def create_connect_args(self, url):
206215

207216
def get_metadata_from_controller(self, path):
208217
url = parse.urljoin(self._controller, path)
209-
r = requests.get(url, headers={"Accept": "application/json"}, verify=self._verify_ssl, auth= HTTPBasicAuth(self._username, self._password))
218+
r = requests.get(url, headers={"Accept": "application/json", "Database": self._database}, verify=self._verify_ssl, auth= HTTPBasicAuth(self._username, self._password))
210219
try:
211220
result = r.json()
212221
except ValueError as e:
@@ -221,13 +230,20 @@ def get_metadata_from_controller(self, path):
221230
return result
222231

223232
def get_schema_names(self, connection, **kwargs):
224-
return ["default"]
233+
if self._database:
234+
return [self._database]
235+
else:
236+
return ['default']
225237

226238
def has_table(self, connection, table_name, schema=None):
227239
return table_name in self.get_table_names(connection, schema)
228240

229241
def get_table_names(self, connection, schema=None, **kwargs):
230-
return self.get_metadata_from_controller("/tables")["tables"]
242+
resp = self.get_metadata_from_controller("/tables")
243+
if 'tables' in resp:
244+
return list(map(extract_table_name, resp["tables"]))
245+
else:
246+
return []
231247

232248
def get_view_names(self, connection, schema=None, **kwargs):
233249
return []
@@ -296,7 +312,7 @@ def _check_unicode_returns(self, connection, additional_tests=None):
296312

297313
def _check_unicode_description(self, connection):
298314
return True
299-
315+
300316
# Fix for SQL Alchemy error
301317
def _json_deserializer(self, content: any):
302318
"""

0 commit comments

Comments
 (0)