Skip to content

Commit 5641180

Browse files
authored
Allow pinot table aliases (#97)
1 parent bfdd768 commit 5641180

File tree

2 files changed

+16
-19
lines changed

2 files changed

+16
-19
lines changed

pinotdb/sqlalchemy.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@ def visit_select(self, select, **kwargs):
2121
return super().visit_select(select, **kwargs)
2222

2323
def visit_column(self, column, result_map=None, **kwargs):
24-
# Pinot does not support table aliases
25-
if column.table is not None:
26-
column.table.named_with_column = False
2724
result_map = result_map or kwargs.pop("add_to_result_map", None)
2825
# This is a hack to modify the original column, but how do I clone it ?
2926
column.is_literal = True
@@ -61,7 +58,7 @@ def visit_REAL(self, type_, **kwargs):
6158
return "DOUBLE"
6259

6360
def visit_NUMERIC(self, type_, **kwargs):
64-
return "LONG"
61+
return "NUMERIC"
6562

6663
visit_DECIMAL = visit_NUMERIC
6764
visit_INTEGER = visit_NUMERIC
@@ -72,7 +69,7 @@ def visit_NUMERIC(self, type_, **kwargs):
7269
visit_DATE = visit_NUMERIC
7370

7471
def visit_CHAR(self, type_, **kwargs):
75-
return "STRING"
72+
return "VARCHAR"
7673

7774
visit_NCHAR = visit_CHAR
7875
visit_VARCHAR = visit_CHAR

tests/unit/test_sqlalchemy.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def test_can_select_table_directly(self):
344344

345345
self.assertEqual(
346346
str(compiler),
347-
'SELECT some_column \nFROM some_table',
347+
'SELECT some_table.some_column \nFROM some_table',
348348
)
349349

350350

@@ -357,44 +357,44 @@ def test_compiles_real(self):
357357
self.assertEqual(self.compiler.visit_REAL(None), 'DOUBLE')
358358

359359
def test_compiles_numeric(self):
360-
self.assertEqual(self.compiler.visit_NUMERIC(None), 'LONG')
360+
self.assertEqual(self.compiler.visit_NUMERIC(None), 'NUMERIC')
361361

362362
def test_compiles_decimal(self):
363-
self.assertEqual(self.compiler.visit_DECIMAL(None), 'LONG')
363+
self.assertEqual(self.compiler.visit_DECIMAL(None), 'NUMERIC')
364364

365365
def test_compiles_integer(self):
366-
self.assertEqual(self.compiler.visit_INTEGER(None), 'LONG')
366+
self.assertEqual(self.compiler.visit_INTEGER(None), 'NUMERIC')
367367

368368
def test_compiles_smallint(self):
369-
self.assertEqual(self.compiler.visit_SMALLINT(None), 'LONG')
369+
self.assertEqual(self.compiler.visit_SMALLINT(None), 'NUMERIC')
370370

371371
def test_compiles_bigint(self):
372-
self.assertEqual(self.compiler.visit_BIGINT(None), 'LONG')
372+
self.assertEqual(self.compiler.visit_BIGINT(None), 'NUMERIC')
373373

374374
# TODO: Check if this is correct (seems strange to have boolean as long).
375375
def test_compiles_boolean(self):
376-
self.assertEqual(self.compiler.visit_BOOLEAN(None), 'LONG')
376+
self.assertEqual(self.compiler.visit_BOOLEAN(None), 'NUMERIC')
377377

378378
def test_compiles_timestamp(self):
379-
self.assertEqual(self.compiler.visit_TIMESTAMP(None), 'LONG')
379+
self.assertEqual(self.compiler.visit_TIMESTAMP(None), 'NUMERIC')
380380

381381
def test_compiles_date(self):
382-
self.assertEqual(self.compiler.visit_DATE(None), 'LONG')
382+
self.assertEqual(self.compiler.visit_DATE(None), 'NUMERIC')
383383

384384
def test_compiles_char(self):
385-
self.assertEqual(self.compiler.visit_CHAR(None), 'STRING')
385+
self.assertEqual(self.compiler.visit_CHAR(None), 'VARCHAR')
386386

387387
def test_compiles_nchar(self):
388-
self.assertEqual(self.compiler.visit_NCHAR(None), 'STRING')
388+
self.assertEqual(self.compiler.visit_NCHAR(None), 'VARCHAR')
389389

390390
def test_compiles_varchar(self):
391-
self.assertEqual(self.compiler.visit_VARCHAR(None), 'STRING')
391+
self.assertEqual(self.compiler.visit_VARCHAR(None), 'VARCHAR')
392392

393393
def test_compiles_nvarchar(self):
394-
self.assertEqual(self.compiler.visit_NVARCHAR(None), 'STRING')
394+
self.assertEqual(self.compiler.visit_NVARCHAR(None), 'VARCHAR')
395395

396396
def test_compiles_text(self):
397-
self.assertEqual(self.compiler.visit_TEXT(None), 'STRING')
397+
self.assertEqual(self.compiler.visit_TEXT(None), 'VARCHAR')
398398

399399
def test_compiles_binary(self):
400400
self.assertEqual(self.compiler.visit_BINARY(None), 'BYTES')

0 commit comments

Comments
 (0)