Skip to content

Commit

Permalink
[sqlserver] Support for SQLServer Dialect (sodadata#564)
Browse files Browse the repository at this point in the history
Co-authored-by: Vijay Kiran <[email protected]>
  • Loading branch information
fakirAyoub and vijaykiran authored Dec 10, 2021
1 parent f30f3a8 commit 5b40926
Show file tree
Hide file tree
Showing 21 changed files with 243 additions and 120 deletions.
17 changes: 14 additions & 3 deletions core/sodasql/dataset_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.
from dataclasses import dataclass
from typing import List

import re
from sodasql.scan.validity import Validity
from sodasql.scan.warehouse import Warehouse
from deprecated import deprecated
Expand All @@ -36,9 +36,17 @@ def to_dict(self):
def to_json(self):
return self.to_dict()


class DatasetAnalyzer:

def _wrap_sqlserver_column_name(self, column_name):
special_characters_check = re.compile('[@_!#$%^&*()<>?/\|}{~:]')

if re.search(r"\s", column_name) or special_characters_check.search(column_name) is not None:
return f"[{column_name}]"
else:
return column_name


def analyze(self, warehouse: Warehouse, table_name: str):
dialect = warehouse.dialect
qualified_table_name = dialect.qualify_table_name(table_name)
Expand All @@ -50,7 +58,10 @@ def analyze(self, warehouse: Warehouse, table_name: str):
column_tuples = warehouse.sql_fetchall(sql) if len(
column_tuple_list) == 0 else column_tuple_list
for column_tuple in column_tuples:
column_name = column_tuple[0]
if warehouse.dialect.type == "sqlserver":
column_name = self._wrap_sqlserver_column_name(column_tuple[0])
else:
column_name = column_tuple[0]
source_type = column_tuple[1]

column_analysis_result = ColumnAnalysisResult(
Expand Down
17 changes: 9 additions & 8 deletions core/sodasql/scan/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def sql_drop_table(self, table_name):
def sql_expr_count_all(self) -> str:
return 'COUNT(*)'

def sql_expr_count_conditional(self, condition: str):
def sql_expr_count_conditional(self, condition: str, column: str):
return f'COUNT(CASE WHEN {condition} THEN 1 END)'

def sql_expr_conditional(self, condition: str, expr: str):
Expand All @@ -233,25 +233,25 @@ def sql_expr_count(self, expr):
def sql_expr_distinct(self, expr):
return f'DISTINCT({expr})'

def sql_expr_min(self, expr):
def sql_expr_min(self, expr, column: str):
return f'MIN({expr})'

def sql_expr_length(self, expr):
def sql_expr_length(self, expr, column: str):
return f'LENGTH({expr})'

def sql_expr_max(self, expr: str):
def sql_expr_max(self, expr: str, column: str):
return f'MAX({expr})'

def sql_expr_avg(self, expr: str):
def sql_expr_avg(self, expr: str, column: str):
return f'AVG({expr})'

def sql_expr_sum(self, expr: str):
def sql_expr_sum(self, expr: str, column: str):
return f'SUM({expr})'

def sql_expr_variance(self, expr: str):
def sql_expr_variance(self, expr: str, column: str):
return f'VARIANCE({expr})'

def sql_expr_stddev(self, expr: str):
def sql_expr_stddev(self, expr: str, column: str):
return f'STDDEV({expr})'

def sql_expr_regexp_like(self, expr: str, pattern: str):
Expand All @@ -264,6 +264,7 @@ def sql_select_with_limit(self, table_name, count):
return f'SELECT * FROM {table_name} LIMIT {count}'

def sql_expr_list(self, column: ColumnMetadata, values: List[str]) -> str:

if self.is_text(column.data_type):
sql_values = [self.literal_string(value) for value in values]
elif self.is_number(column.data_type):
Expand Down
5 changes: 4 additions & 1 deletion core/sodasql/scan/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ def is_invalid_test(test):
if tablesample is not None:
sql += f" \nTABLESAMPLE {tablesample};"
else:
sql += f" \nLIMIT {limit};"
if self.scan.warehouse.dialect.type == 'sqlserver':
sql += f" ORDER BY 1 OFFSET 0 ROWS FETCH NEXT {limit} ROWS ONLY"
else:
sql += f" \nLIMIT {limit};"

sample_description = (f'{self.scan.scan_yml.table_name}.sample'
if sample_name == 'dataset' else f'{self.scan.scan_yml.table_name}.{column_name}.{sample_name}')
Expand Down
33 changes: 19 additions & 14 deletions core/sodasql/scan/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def __init__(self,

self.scan_result = ScanResult()
self.dialect = warehouse.dialect
self.qualified_table_name = self.dialect.qualify_table_name(scan_yml.table_name)
self.qualified_table_name = self.dialect.qualify_table_name(scan_yml.table_name) \
if self.dialect.type != 'sqlserver' else scan_yml.table_name
self.scan_reference = None
self.disable_sample_collection = False
self.column_metadatas: List[ColumnMetadata] = []
Expand Down Expand Up @@ -131,6 +132,7 @@ def execute(self) -> ScanResult:
return self.scan_result

def _process_cloud_custom_metrics(self):

if self.soda_server_client:
from sodasql.soda_server_client.monitor_metric_parser import MonitorMetricParser
from sodasql.soda_server_client.monitor_metric import MonitorMetricType
Expand Down Expand Up @@ -250,61 +252,61 @@ def _query_aggregations(self):
if scan_column.is_missing_enabled:
metric_indices['non_missing'] = len(measurements)
if scan_column.non_missing_condition:
fields.append(dialect.sql_expr_count_conditional(scan_column.non_missing_condition))
fields.append(dialect.sql_expr_count_conditional(scan_column.non_missing_condition, column_name, ))
else:
fields.append(dialect.sql_expr_count(scan_column.qualified_column_name))
measurements.append(Measurement(Metric.VALUES_COUNT, column_name))

if scan_column.is_valid_enabled:
metric_indices['valid'] = len(measurements)
if scan_column.non_missing_and_valid_condition:
fields.append(dialect.sql_expr_count_conditional(scan_column.non_missing_and_valid_condition))
fields.append(dialect.sql_expr_count_conditional(scan_column.non_missing_and_valid_condition, column_name))
else:
fields.append(dialect.sql_expr_count(scan_column.qualified_column_name))
measurements.append(Measurement(Metric.VALID_COUNT, column_name))

if scan_column.is_text:
length_expr = dialect.sql_expr_conditional(
scan_column.non_missing_and_valid_condition,
dialect.sql_expr_length(scan_column.qualified_column_name)) \
dialect.sql_expr_length(scan_column.qualified_column_name, column_name)) \
if scan_column.non_missing_and_valid_condition \
else dialect.sql_expr_length(scan_column.qualified_column_name)

if self.scan_yml.is_metric_enabled(Metric.AVG_LENGTH, column_name):
fields.append(dialect.sql_expr_avg(length_expr))
fields.append(dialect.sql_expr_avg(length_expr, column_name))
measurements.append(Measurement(Metric.AVG_LENGTH, column_name))

if self.scan_yml.is_metric_enabled(Metric.MIN_LENGTH, column_name):
fields.append(dialect.sql_expr_min(length_expr))
fields.append(dialect.sql_expr_min(length_expr, column_name))
measurements.append(Measurement(Metric.MIN_LENGTH, column_name))

if self.scan_yml.is_metric_enabled(Metric.MAX_LENGTH, column_name):
fields.append(dialect.sql_expr_max(length_expr))
fields.append(dialect.sql_expr_max(length_expr, column_name))
measurements.append(Measurement(Metric.MAX_LENGTH, column_name))

if scan_column.is_numeric:
if scan_column.is_metric_enabled(Metric.MIN):
fields.append(dialect.sql_expr_min(scan_column.numeric_expr))
fields.append(dialect.sql_expr_min(scan_column.numeric_expr, column_name))
measurements.append(Measurement(Metric.MIN, column_name))

if scan_column.is_metric_enabled(Metric.MAX):
fields.append(dialect.sql_expr_max(scan_column.numeric_expr))
fields.append(dialect.sql_expr_max(scan_column.numeric_expr, column_name))
measurements.append(Measurement(Metric.MAX, column_name))

if scan_column.is_metric_enabled(Metric.AVG):
fields.append(dialect.sql_expr_avg(scan_column.numeric_expr))
fields.append(dialect.sql_expr_avg(scan_column.numeric_expr, column_name))
measurements.append(Measurement(Metric.AVG, column_name))

if scan_column.is_metric_enabled(Metric.SUM):
fields.append(dialect.sql_expr_sum(scan_column.numeric_expr))
fields.append(dialect.sql_expr_sum(scan_column.numeric_expr, column_name))
measurements.append(Measurement(Metric.SUM, column_name))

if scan_column.is_metric_enabled(Metric.VARIANCE):
fields.append(dialect.sql_expr_variance(scan_column.numeric_expr))
fields.append(dialect.sql_expr_variance(scan_column.numeric_expr, column_name))
measurements.append(Measurement(Metric.VARIANCE, column_name))

if scan_column.is_metric_enabled(Metric.STDDEV):
fields.append(dialect.sql_expr_stddev(scan_column.numeric_expr))
fields.append(dialect.sql_expr_stddev(scan_column.numeric_expr, column_name))
measurements.append(Measurement(Metric.STDDEV, column_name))

if len(fields) > 0:
Expand Down Expand Up @@ -641,7 +643,10 @@ def _send_failed_rows_custom_metric(self,
failed_limit = 5
if self.scan_yml.samples_yml is not None:
failed_limit = self.scan_yml.samples_yml.failed_limit or 5
sql += f'\nLIMIT {failed_limit}'
if self.warehouse.dialect.type == 'sqlserver':
sql += f' ORDER BY 1 OFFSET 0 ROWS FETCH NEXT {failed_limit} ROWS ONLY'
else:
sql += f'\nLIMIT {failed_limit}'

stored_failed_rows, sample_columns, total_failed_rows = \
self.sampler.save_sample_to_local_file_with_limit(sql, temp_file, failed_limit)
Expand Down
12 changes: 6 additions & 6 deletions core/sodasql/soda_server_client/monitor_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,22 +77,22 @@ def build_sql(self,
elif scan_column:
select_fields.append(scan_column.qualified_column_name)
if self.metric_type == MonitorMetricType.MISSING_VALUES_COUNT:
metric_select_fields.append(dialect.sql_expr_count_conditional(scan_column.missing_condition))
metric_select_fields.append(dialect.sql_expr_count_conditional(scan_column.missing_condition, scan_column.column_name))
elif self.metric_type == MonitorMetricType.MISSING_VALUES_PERCENTAGE:
metric_select_fields.append(dialect.sql_expr_count_all())
metric_select_fields.append(dialect.sql_expr_count_conditional(scan_column.missing_condition))
metric_select_fields.append(dialect.sql_expr_count_conditional(scan_column.missing_condition, scan_column.column_name))
elif self.metric_type == MonitorMetricType.VALID_VALUES_COUNT:
metric_select_fields.append(
dialect.sql_expr_count_conditional(scan_column.non_missing_and_valid_condition))
dialect.sql_expr_count_conditional(scan_column.non_missing_and_valid_condition, scan_column.column_name))
elif self.metric_type in [MonitorMetricType.VALID_VALUES_PERCENTAGE,
MonitorMetricType.INVALID_VALUES_COUNT,
MonitorMetricType.INVALID_VALUES_PERCENTAGE]:
metric_select_fields.append(dialect.sql_expr_count_conditional(scan_column.non_missing_condition))
metric_select_fields.append(dialect.sql_expr_count_conditional(scan_column.non_missing_condition, scan_column.column_name))
metric_select_fields.append(
dialect.sql_expr_count_conditional(scan_column.non_missing_and_valid_condition))
dialect.sql_expr_count_conditional(scan_column.non_missing_and_valid_condition, scan_column.column_name))
elif self.metric_type == MonitorMetricType.UNIQUENESS_PERCENTAGE:
metric_select_fields.append(
dialect.sql_expr_count_conditional(scan_column.non_missing_and_valid_condition))
dialect.sql_expr_count_conditional(scan_column.non_missing_and_valid_condition, scan_column.column_name))
metric_select_fields.append(dialect.sql_expr_count(dialect.sql_expr_distinct(
dialect.sql_expr_conditional(scan_column.non_missing_and_valid_condition,
scan_column.qualified_column_name))))
Expand Down
4 changes: 2 additions & 2 deletions packages/athena/sodasql/dialects/athena_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,10 @@ def qualify_table_name(self, table_name: str) -> str:
def qualify_writable_table_name(self, table_name: str) -> str:
return f'`{self.database}`.`{table_name}`'

def sql_expr_avg(self, expr: str):
def sql_expr_avg(self, expr: str, column_name):
return f"AVG(CAST({expr} as DECIMAL(38, 0)))"

def sql_expr_sum(self, expr: str):
def sql_expr_sum(self, expr: str, column_name):
return f"SUM(CAST({expr} as DECIMAL(38, 0)))"

def literal_date(self, date: date):
Expand Down
2 changes: 1 addition & 1 deletion packages/hive/sodasql/dialects/hive_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def qualify_writable_table_name(self, table_name: str) -> str:
def sql_expr_regexp_like(self, expr: str, pattern: str):
return f"cast({expr} as string) rlike '{self.qualify_regex(pattern)}'"

def sql_expr_stddev(self, expr: str):
def sql_expr_stddev(self, expr: str, column_name):
return f'STDDEV_POP({expr})'

def qualify_regex(self, regex) -> str:
Expand Down
4 changes: 2 additions & 2 deletions packages/redshift/sodasql/dialects/redshift_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,10 @@ def is_time(self, column_type: str):
def qualify_regex(self, regex):
return self.escape_metacharacters(regex)

def sql_expr_avg(self, expr: str):
def sql_expr_avg(self, expr: str, column_name):
return f"AVG({expr})"

def sql_expr_sum(self, expr: str):
def sql_expr_sum(self, expr: str, column_name):
return f"SUM({expr})"

def sql_expr_cast_text_to_number(self, quoted_column_name, validity_format):
Expand Down
2 changes: 1 addition & 1 deletion packages/spark/sodasql/dialects/spark_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def qualify_writable_table_name(self, table_name: str) -> str:
def sql_expr_regexp_like(self, expr: str, pattern: str):
return f"cast({expr} as string) rlike '{self.qualify_regex(pattern)}'"

def sql_expr_stddev(self, expr: str):
def sql_expr_stddev(self, expr: str, column_name):
return f'STDDEV_POP({expr})'

def qualify_regex(self, regex) -> str:
Expand Down
Loading

0 comments on commit 5b40926

Please sign in to comment.