@@ -123,6 +123,11 @@ def __init__(
123123 )
124124
125125
126+ def extract_table_name (fqn ):
127+ split = fqn .split ("." , 2 )
128+ return fqn if len (split ) == 1 else split [1 ]
129+
130+
126131class PinotDialect (default .DefaultDialect ):
127132
128133 name = "pinot"
@@ -132,6 +137,7 @@ class PinotDialect(default.DefaultDialect):
132137 preparer = PinotIdentifierPareparer
133138 statement_compiler = PinotCompiler
134139 type_compiler = PinotTypeCompiler
140+ supports_schemas = False
135141 supports_statement_cache = False
136142 supports_alter = False
137143 supports_pk_autoincrement = False
@@ -154,6 +160,7 @@ def __init__(self, *args, **kwargs):
154160 self ._password = None
155161 self ._debug = False
156162 self ._verify_ssl = True
163+ self ._database = None
157164 self .update_from_kwargs (kwargs )
158165
159166 def update_from_kwargs (self , givenkw ):
@@ -167,6 +174,8 @@ def update_from_kwargs(self, givenkw):
167174 kwargs ["username" ] = self ._username = kwargs .pop ("username" )
168175 if "password" in kwargs :
169176 kwargs ["password" ] = self ._password = kwargs .pop ("password" )
177+ if "database" in kwargs :
178+ kwargs ["database" ] = self ._database = kwargs .pop ("database" )
170179 kwargs ["debug" ] = self ._debug = bool (kwargs .get ("debug" , False ))
171180 kwargs ["verify_ssl" ] = self ._verify_ssl = (str (kwargs .get ("verify_ssl" , "true" )).lower () in ['true' ])
172181 logger .info (
@@ -206,7 +215,7 @@ def create_connect_args(self, url):
206215
207216 def get_metadata_from_controller (self , path ):
208217 url = parse .urljoin (self ._controller , path )
209- r = requests .get (url , headers = {"Accept" : "application/json" }, verify = self ._verify_ssl , auth = HTTPBasicAuth (self ._username , self ._password ))
218+ r = requests .get (url , headers = {"Accept" : "application/json" , "Database" : self . _database }, verify = self ._verify_ssl , auth = HTTPBasicAuth (self ._username , self ._password ))
210219 try :
211220 result = r .json ()
212221 except ValueError as e :
@@ -221,13 +230,20 @@ def get_metadata_from_controller(self, path):
221230 return result
222231
223232 def get_schema_names (self , connection , ** kwargs ):
224- return ["default" ]
233+ if self ._database :
234+ return [self ._database ]
235+ else :
236+ return ['default' ]
225237
226238 def has_table (self , connection , table_name , schema = None ):
227239 return table_name in self .get_table_names (connection , schema )
228240
229241 def get_table_names (self , connection , schema = None , ** kwargs ):
230- return self .get_metadata_from_controller ("/tables" )["tables" ]
242+ resp = self .get_metadata_from_controller ("/tables" )
243+ if 'tables' in resp :
244+ return list (map (extract_table_name , resp ["tables" ]))
245+ else :
246+ return []
231247
232248 def get_view_names (self , connection , schema = None , ** kwargs ):
233249 return []
@@ -296,7 +312,7 @@ def _check_unicode_returns(self, connection, additional_tests=None):
296312
297313 def _check_unicode_description (self , connection ):
298314 return True
299-
315+
300316 # Fix for SQL Alchemy error
301317 def _json_deserializer (self , content : any ):
302318 """
0 commit comments