Skip to content

Commit

Permalink
Trino dialect (sodadata#596)
Browse files Browse the repository at this point in the history
* Adds trino dialect 

Co-authored-by: Bradley Scott Fought <[email protected]>
  • Loading branch information
vijaykiran and scott-fought authored Dec 14, 2021
1 parent 016b3f7 commit 5916406
Show file tree
Hide file tree
Showing 18 changed files with 237 additions and 23 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,6 @@ reports/*
# Spark
.hive-metastore/
.spark-warehouse/

#trino
.trino
2 changes: 1 addition & 1 deletion core/sodasql/dataset_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def analyze(self, warehouse: Warehouse, table_name: str):
column_name=column_name, source_type=source_type)
analyze_results.append(column_analysis_result)

qualified_column_name = dialect.qualify_column_name(column_name)
qualified_column_name = dialect.qualify_column_name(column_name, source_type)
select_with_limit_query = dialect.sql_select_with_limit(qualified_table_name, 1000)

if dialect.is_text(source_type):
Expand Down
27 changes: 15 additions & 12 deletions core/sodasql/scan/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
SNOWFLAKE = 'snowflake'
SQLSERVER = 'sqlserver'
SPARK = 'spark'
TRINO = 'trino'

ALL_WAREHOUSE_TYPES = [ATHENA,
BIGQUERY,
Expand All @@ -46,7 +47,8 @@
REDSHIFT,
SNOWFLAKE,
SQLSERVER,
SPARK]
SPARK,
TRINO]

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -85,22 +87,24 @@ def create(cls, parser: Parser) -> Optional[Dialect]:
else:
if warehouse_type == ATHENA:
_warehouse_class = Dialect._import_class('sodasql.dialects.athena_dialect', 'AthenaDialect')
if warehouse_type == BIGQUERY:
elif warehouse_type == BIGQUERY:
_warehouse_class = Dialect._import_class('sodasql.dialects.bigquery_dialect', 'BigQueryDialect')
if warehouse_type == HIVE:
elif warehouse_type == HIVE:
_warehouse_class = Dialect._import_class('sodasql.dialects.hive_dialect', 'HiveDialect')
if warehouse_type == POSTGRES:
elif warehouse_type == POSTGRES:
_warehouse_class = Dialect._import_class('sodasql.dialects.postgres_dialect', 'PostgresDialect')
if warehouse_type == MYSQL:
elif warehouse_type == MYSQL:
_warehouse_class = Dialect._import_class('sodasql.dialects.mysql_dialect', 'MySQLDialect')
if warehouse_type == REDSHIFT:
elif warehouse_type == REDSHIFT:
_warehouse_class = Dialect._import_class('sodasql.dialects.redshift_dialect', 'RedshiftDialect')
if warehouse_type == SNOWFLAKE:
elif warehouse_type == SNOWFLAKE:
_warehouse_class = Dialect._import_class('sodasql.dialects.snowflake_dialect', 'SnowflakeDialect')
if warehouse_type == SQLSERVER:
elif warehouse_type == SQLSERVER:
_warehouse_class = Dialect._import_class('sodasql.dialects.sqlserver_dialect', 'SQLServerDialect')
if warehouse_type == SPARK:
elif warehouse_type == SPARK:
_warehouse_class = Dialect._import_class('sodasql.dialects.spark_dialect', 'SparkDialect')
elif warehouse_type == TRINO:
_warehouse_class = Dialect._import_class('sodasql.dialects.trino_dialect', 'TrinoDialect')
return _warehouse_class(parser)

@classmethod
Expand Down Expand Up @@ -315,7 +319,7 @@ def literal(self, o: object):
def qualify_table_name(self, table_name: str) -> str:
return table_name

def qualify_column_name(self, column_name: str):
def qualify_column_name(self, column_name: str, source_type: str = None):
return column_name

def qualify_writable_table_name(self, table_name: str) -> str:
Expand Down Expand Up @@ -471,8 +475,7 @@ def is_connection_error(self, exception):
def is_authentication_error(self, exception):
return False

def try_to_raise_soda_sql_exception(
self, exception: Exception) -> Exception:
def try_to_raise_soda_sql_exception(self, exception: Exception) -> Exception:
if self.is_connection_error(exception):
raise WarehouseConnectionError(
warehouse_type=self.type,
Expand Down
6 changes: 3 additions & 3 deletions core/sodasql/scan/scan_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, scan, column_metadata: ColumnMetadata):
self.scan_yml.get_scan_yaml_column(self.column_name)

dialect = self.scan.dialect
self.qualified_column_name = dialect.qualify_column_name(self.column_name)
self.qualified_column_name = dialect.qualify_column_name(self.column_name, column_metadata.data_type)
self.is_text: bool = dialect.is_text(column_metadata.data_type)
self.is_number: bool = dialect.is_number(column_metadata.data_type)
self.is_time: bool = dialect.is_time(column_metadata.data_type)
Expand Down Expand Up @@ -124,7 +124,7 @@ def is_metric_enabled(self, metric: str):

@classmethod
def __get_missing_condition(cls, column_metadata: ColumnMetadata, missing: Missing, dialect: Dialect):
qualified_column_name = dialect.qualify_column_name(column_metadata.name)
qualified_column_name = dialect.qualify_column_name(column_metadata.name, column_metadata.data_type)
validity_clauses = [f'{qualified_column_name} IS NULL']
if missing:
if missing.values:
Expand All @@ -140,7 +140,7 @@ def __get_missing_condition(cls, column_metadata: ColumnMetadata, missing: Missi
return " OR ".join(validity_clauses), len(validity_clauses) == 1

def __get_valid_condition(self, column_metadata: ColumnMetadata, validity: Validity, dialect: Dialect):
qualified_column_name = dialect.qualify_column_name(column_metadata.name)
qualified_column_name = dialect.qualify_column_name(column_metadata.name, column_metadata.data_type)
if validity is None:
return '', True
validity_clauses = []
Expand Down
2 changes: 1 addition & 1 deletion packages/athena/sodasql/dialects/athena_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def sql_columns_metadata_query(self, table_name: str):
f"WHERE table_name = '{table_name.lower()}' \n"
f" AND table_schema = '{self.database.lower()}';")

def qualify_column_name(self, column_name):
def qualify_column_name(self, column_name: str, source_type: str = None):
return f'"{column_name}"'

def qualify_table_name(self, table_name: str) -> str:
Expand Down
2 changes: 1 addition & 1 deletion packages/mysql/sodasql/dialects/mysql_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def is_time(self, column_type: str):
def qualify_table_name(self, table_name: str) -> str:
return f'{table_name}'

def qualify_column_name(self, column_name: str):
def qualify_column_name(self, column_name: str, source_type: str = None):
return f'{column_name}'

def sql_expr_count_conditional(self, condition: str):
Expand Down
2 changes: 1 addition & 1 deletion packages/postgresql/sodasql/dialects/postgres_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def qualify_table_name(self, table_name: str) -> str:
return f'"{self.schema}"."{table_name}"'
return f'"{table_name}"'

def qualify_column_name(self, column_name: str):
def qualify_column_name(self, column_name: str, source_type: str = None):
return f'"{column_name}"'

def sql_expr_regexp_like(self, expr: str, pattern: str):
Expand Down
2 changes: 1 addition & 1 deletion packages/snowflake/sodasql/dialects/snowflake_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def sql_columns_metadata_query(self, table_name: str) -> str:
def qualify_regex(self, regex) -> str:
return self.escape_metacharacters(regex)

def qualify_column_name(self, column_name: str):
def qualify_column_name(self, column_name: str, source_type: str = None):
return f'{column_name}'

def qualify_table_name(self, table_name: str) -> str:
Expand Down
2 changes: 1 addition & 1 deletion packages/spark/sodasql/dialects/spark_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def qualify_table_name(self, table_name: str) -> str:
qualified_table_name = f'{self.database}.{table_name}'
return qualified_table_name

def qualify_column_name(self, column_name: str):
def qualify_column_name(self, column_name: str, source_type: str = None):
return f"`{column_name}`"

def qualify_writable_table_name(self, table_name: str) -> str:
Expand Down
24 changes: 24 additions & 0 deletions packages/trino/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/env python
import sys
from setuptools import setup, find_namespace_packages

if sys.version_info < (3, 7):
print('Error: Soda SQL requires at least Python 3.7')
print('Error: Please upgrade your Python version to 3.7 or later')
sys.exit(1)

package_name = "soda-sql-trino"
package_version = '2.1.0'
description = "Soda SQL Trino"

requires = [
f'soda-sql-core=={package_version}',
'trino>=0.305.0'
]

setup(
name=package_name,
version=package_version,
install_requires=requires,
packages=find_namespace_packages(include=["sodasql*"])
)
122 changes: 122 additions & 0 deletions packages/trino/sodasql/dialects/trino_dialect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# (c) 2021 Walt Disney Parks and Resorts U.S., Inc.
# (c) 2021 Soda Data NV.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import logging
import trino

from sodasql.scan.dialect import Dialect, TRINO, KEY_WAREHOUSE_TYPE
from sodasql.scan.parser import Parser

logger = logging.getLogger(__name__)


class TrinoDialect(Dialect):

def __init__(self, parser: Parser):
super().__init__(TRINO)
if parser:
self.host = parser.get_str_optional_env('host', 'localhost')
self.port = parser.get_str_optional_env('port', '443')
self.http_scheme = parser.get_str_optional_env('http_scheme', 'https')
self.username = parser.get_str_required_env('username')
self.password = parser.get_credential('password')
self.catalog = parser.get_str_required_env('catalog')
self.schema = parser.get_str_required_env('schema')

def default_connection_properties(self, params: dict):
return {
KEY_WAREHOUSE_TYPE: TRINO,
'host': 'localhost',
'port': '443',
'http_scheme': 'https',
'catalog': params.get('catalog', 'YOUR_CATALOG').lower(),
'schema': params.get('schema', 'YOUR_DATABASE').lower(),
'username': 'env_var(TRINO_USERNAME)',
'password': 'env_var(TRINO_PASSWORD)',
}

def default_env_vars(self, params: dict):
return {
'TRINO_USERNAME': params.get('username', 'YOUR_TRINO_USERNAME_GOES_HERE'),
'TRINO_PASSWORD': params.get('password', 'YOUR_TRINO_PASSWORD_GOES_HERE')
}

def create_connection(self):
try:
conn = trino.dbapi.connect(
host=self.host,
port=self.port,
catalog=self.catalog,
schema=self.schema,
http_scheme=self.http_scheme,
auth=trino.auth.BasicAuthentication(self.username, self.password)
)
return conn
except Exception as e:
self.try_to_raise_soda_sql_exception(e)

def sql_tables_metadata_query(self, limit: int = 10, filter: str = None):
sql = (f"SELECT table_name \n"
f"FROM {self.catalog}.information_schema.tables \n"
f"WHERE lower(table_schema) = '{self.schema}' \n"
f" AND lower(table_catalog) = '{self.catalog}'")
return sql

def sql_columns_metadata_query(self, table_name: str) -> str:
sql = (f"SELECT column_name, data_type, is_nullable \n"
f"FROM {self.catalog}.information_schema.columns \n"
f"WHERE lower(table_name) = '{table_name.lower()}' \n"
f" AND lower(table_catalog) = '{self.catalog}' \n"
f" AND lower(table_schema) = '{self.schema}'")
return sql

def is_text(self, column_type: str):
column_type_upper = column_type.upper()
return (column_type_upper in ['VARCHAR', 'CHAR', 'VARBINARY', 'JSON']
or re.match(r'^(VAR)?CHAR\([0-9]+\)$', column_type_upper))

def is_number(self, column_type: str):
column_type_upper = column_type.upper()
return (column_type_upper in ['BOOLEAN',
'INT', 'INTEGER', 'BIGINT', 'SMALLINT', 'TINYINT', 'BYTEINT',
'DOUBLE', 'REAL', 'DECIMAL']
or re.match(r'^DECIMAL\([0-9]+(,[0-9]+)?\)$', column_type_upper))

def is_time(self, column_type: str):
column_type_upper = column_type.upper()
return (column_type_upper in ['DATE', 'TIME', 'TIMESTAMP',
'TIME WITH TIME ZONE', 'INTERVAL YEAR TO MONTH', 'INTERVAL DAY TO SECOND']
or re.match(r'^TIMESTAMP\([0-9]+\)$', column_type_upper))

def qualify_table_name(self, table_name: str) -> str:
return f'"{self.catalog}"."{self.schema}"."{table_name}"'

def qualify_column_name(self, column_name: str, source_type: str = None):
if source_type is not None and re.match(r'^CHAR\([0-9]+\)$', source_type.upper()):
return f'CAST({column_name} AS VARCHAR)'
return f'{column_name}'

def is_connection_error(self, exception):
logger.error(exception)
if exception is None or exception.errno is None:
return False
return isinstance(exception, trino.exceptions.HttpError) or \
isinstance(exception, trino.exceptions.Http503Error) or \
isinstance(exception, trino.exceptions.TrinoError) or \
isinstance(exception, trino.exceptions.TimeoutError)

def is_authentication_error(self, exception):
logger.error(exception)
if exception is None or exception.errno is None:
return False
return True
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
./packages/snowflake
./packages/sqlserver
./packages/spark
./packages/trino
1 change: 1 addition & 0 deletions tests/common/sql_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
TARGET_HIVE = 'hive'
TARGET_MYSQL = 'mysql'
TARGET_SPARK = 'spark'
TARGET_TRINO = 'trino'


def equals_ignore_case(left, right):
Expand Down
7 changes: 5 additions & 2 deletions tests/common/warehouse_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class WarehouseFixture:

@classmethod
def create(cls, target: str):
from tests.common.sql_test_case import TARGET_SNOWFLAKE, TARGET_SQLSERVER, TARGET_POSTGRES, TARGET_REDSHIFT, TARGET_ATHENA, \
TARGET_BIGQUERY, TARGET_HIVE, TARGET_MYSQL, TARGET_SPARK
from tests.common.sql_test_case import TARGET_SNOWFLAKE, TARGET_POSTGRES, TARGET_REDSHIFT, TARGET_ATHENA, \
TARGET_BIGQUERY, TARGET_HIVE, TARGET_MYSQL, TARGET_SPARK, TARGET_SQLSERVER, TARGET_TRINO
if target == TARGET_POSTGRES:
from tests.warehouses.postgres_fixture import PostgresFixture
pf = PostgresFixture(target)
Expand Down Expand Up @@ -65,6 +65,9 @@ def create(cls, target: str):
elif target == TARGET_SPARK:
from tests.warehouses.spark_fixture import SparkFixture
return SparkFixture(target)
elif target == TARGET_TRINO:
from tests.warehouses.trino_fixture import TrinoFixture
return TrinoFixture(target)
raise RuntimeError(f'Invalid target {target}')

def __init__(self, target: str) -> None:
Expand Down
9 changes: 9 additions & 0 deletions tests/trino_contaner/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
version: "3.7"
services:
soda-sql-trino:
image: trinodb/trino
ports:
- "8080:8080"
volumes:
- ./.trino/:/data/trino
command: sh -c "sleep 15 && /usr/bin/trino --execute='create schema memory.sodasql; use memory.sodasql'"
8 changes: 8 additions & 0 deletions tests/warehouses/trino_cfg.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
type: trino
host: localhost
port: 8080
http_scheme: https
catalog: memory
schema: sodasql
#username:
#password:
21 changes: 21 additions & 0 deletions tests/warehouses/trino_fixture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2020 Soda
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from tests.common.warehouse_fixture import WarehouseFixture


class TrinoFixture(WarehouseFixture):

def create_database(self):
pass

def drop_database(self):
pass
Loading

0 comments on commit 5916406

Please sign in to comment.