Skip to content

Commit

Permalink
feat: Allow to compare queries in compare_table function (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
borchero authored Jan 16, 2025
1 parent 412c1b7 commit 8c32375
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 33 deletions.
68 changes: 42 additions & 26 deletions sqlcompyre/analysis/table_comparison.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
# Copyright (c) QuantCo 2024-2024
# Copyright (c) QuantCo 2024-2025
# SPDX-License-Identifier: BSD-3-Clause

import functools
import logging
from functools import cached_property
from typing import cast

import sqlalchemy as sa
import sqlalchemy.sql.functions as func
from sqlalchemy.engine import Engine
from sqlalchemy.sql import elements, expression, false, schema, select, selectable, true
from sqlalchemy.sql import elements, expression, false, select, selectable, true

from sqlcompyre.report import Report
from sqlcompyre.results import ColumnMatches, Counts, Names, RowMatches
Expand All @@ -26,8 +25,8 @@ class TableComparison:
def __init__(
self,
engine: Engine,
left_table: schema.Table,
right_table: schema.Table,
left_table: sa.FromClause,
right_table: sa.FromClause,
join_columns: list[str] | None,
column_name_mapping: dict[str, str] | None,
ignore_columns: list[str] | None,
Expand All @@ -54,8 +53,8 @@ def __init__(
infer_primary_keys: Whether to infer primary keys if none are available.
"""
self.engine = engine
self.left_table = cast(expression.Alias, left_table.alias("left"))
self.right_table = cast(expression.Alias, right_table.alias("right"))
self.left_table = left_table.alias("left")
self.right_table = right_table.alias("right")
self.column_name_mapping = _identity_column_mapping_if_needed(
left_table,
right_table,
Expand All @@ -79,8 +78,8 @@ def join_columns(self) -> list[str]:
"""The columns used for joining the two tables."""
pks = _join_columns_from_pk_if_needed(
self.engine,
cast(sa.Table, self.left_table.element),
cast(sa.Table, self.right_table.element),
self.left_table,
self.right_table,
self._user_join_columns,
ignore_casing=self.ignore_casing,
column_name_mapping=self.column_name_mapping,
Expand Down Expand Up @@ -323,9 +322,6 @@ def summary_report(self) -> Report:
Returns:
A report summarizing the comparison of the two tables.
"""
left_name = str(self.left_table.original)
right_name = str(self.right_table.original)

description = None
sections = {
"Column Names": self.column_names,
Expand All @@ -349,17 +345,35 @@ def summary_report(self) -> Report:
logging.warning(
"'%s' and '%s' cannot be matched (%s): dropping row and column matches "
"from the report",
left_name,
right_name,
self._left_table_name,
self._right_table_name,
exc,
)

return Report("tables", left_name, right_name, description, sections)
return Report(
"tables",
self._left_table_name,
self._right_table_name,
description,
sections,
)

# ---------------------------------------------------------------------------------------------
# UTILITY METHODS
# ---------------------------------------------------------------------------------------------

@property
def _left_table_name(self) -> str:
if isinstance(self.left_table, sa.Alias):
return str(self.left_table.element)
return "<left query>"

@property
def _right_table_name(self) -> str:
if isinstance(self.right_table, sa.Alias):
return str(self.right_table.element)
return "<right query>"

def _is_equal(
self, left_column: str, right_column: str
) -> elements.ColumnElement[bool]:
Expand Down Expand Up @@ -477,8 +491,8 @@ def __repr__(self):
def __str__(self):
return (
f"{self.__class__.__name__}("
f'left_table="{self.left_table.original}", '
f'right_table="{self.right_table.original}")'
f'left_table="{self._left_table_name}", '
f'right_table="{self._right_table_name}")'
)


Expand All @@ -489,8 +503,8 @@ def __str__(self):

def _join_columns_from_pk_if_needed(
engine: Engine,
left: sa.Table,
right: sa.Table,
left: sa.FromClause,
right: sa.FromClause,
join_columns: list[str],
ignore_casing: bool,
column_name_mapping: dict[str, str],
Expand All @@ -503,8 +517,8 @@ def _join_columns_from_pk_if_needed(
join_columns = [lowercase_map[c.lower()] for c in join_columns]

if not join_columns:
left_pks = {pk.name for pk in sa.inspect(left).primary_key}
right_pks = {pk.name for pk in sa.inspect(right).primary_key}
left_pks = {col.name for col in left.columns if col.primary_key}
right_pks = {col.name for col in right.columns if col.primary_key}
reverse_mapping = {v: k for k, v in column_name_mapping.items()}
if not (left_pks - set(column_name_mapping) | right_pks - set(reverse_mapping)):
# All primary keys can be matched
Expand Down Expand Up @@ -551,8 +565,8 @@ def _join_columns_from_pk_if_needed(

def _is_valid_primary_key_column(
engine: Engine,
left_table: sa.Table,
right_table: sa.Table,
left_table: sa.FromClause,
right_table: sa.FromClause,
left_column: str,
right_column: str,
) -> bool:
Expand All @@ -578,7 +592,9 @@ def _is_valid_primary_key_column(
return left_nulls == 0 and right_nulls == 0


def _is_valid_primary_key(engine: Engine, table: sa.Table, columns: list[str]) -> bool:
def _is_valid_primary_key(
engine: Engine, table: sa.FromClause, columns: list[str]
) -> bool:
with engine.connect() as conn:
result = conn.execute(
sa.select(*[table.c[c] for c in columns])
Expand All @@ -590,8 +606,8 @@ def _is_valid_primary_key(engine: Engine, table: sa.Table, columns: list[str]) -


def _identity_column_mapping_if_needed(
left: sa.schema.Table,
right: sa.schema.Table,
left: sa.FromClause,
right: sa.FromClause,
mapping: dict[str, str],
ignore_columns: list[str],
ignore_casing: bool,
Expand Down
14 changes: 7 additions & 7 deletions sqlcompyre/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2024-2024
# Copyright (c) QuantCo 2024-2025
# SPDX-License-Identifier: BSD-3-Clause

import sys
Expand Down Expand Up @@ -74,8 +74,8 @@ def inspect_table(engine: sa.Engine, table: sa.Table | str) -> QueryInspection:

def compare_tables(
engine: sa.Engine,
left: sa.Table | str,
right: sa.Table | str,
left: sa.Select | sa.FromClause | str,
right: sa.Select | sa.FromClause | str,
join_columns: list[str] | None = None,
ignore_columns: list[str] | None = None,
column_name_mapping: dict[str, str] | None = None,
Expand Down Expand Up @@ -118,8 +118,8 @@ def compare_tables(
A table comparison object that can be used to explore the differences in the tables.
"""
# Get the SQLAlchemy representation of the tables in the database
left_table: sa.Table
right_table: sa.Table
left_table: sa.FromClause
right_table: sa.FromClause
if isinstance(left, str) or isinstance(right, str):
meta = sa.MetaData()

Expand All @@ -134,9 +134,9 @@ def compare_tables(
right_table = meta.tables[right]

if not isinstance(left, str):
left_table = left
left_table = left.subquery() if isinstance(left, sa.Select) else left
if not isinstance(right, str):
right_table = right
right_table = right.subquery() if isinstance(right, sa.Select) else right

# Create a table comparison object
return TableComparison(
Expand Down
35 changes: 35 additions & 0 deletions tests/analysis/table_comparison/test_queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) QuantCo 2024-2025
# SPDX-License-Identifier: BSD-3-Clause

import sqlalchemy as sa

import sqlcompyre as sc


def test_compare_queries_join_columns_inferred(
engine: sa.Engine, table_students: sa.Table
):
comparison = sc.compare_tables(
engine, sa.select(table_students), sa.select(table_students)
)
assert comparison.join_columns == ["id"]


def test_compare_queries_select(engine: sa.Engine, table_students: sa.Table):
comparison = sc.compare_tables(
engine,
sa.select(table_students).where(table_students.c["age"] >= 30),
sa.select(table_students).where(table_students.c["age"] >= 20),
)
assert comparison.row_counts.diff == 2
assert comparison.row_matches.n_joined_total == 2


def test_compare_queries_subquery(engine: sa.Engine, table_students: sa.Table):
comparison = sc.compare_tables(
engine,
sa.select(table_students).where(table_students.c["age"] >= 30).subquery(),
sa.select(table_students).where(table_students.c["age"] >= 20).subquery(),
)
assert comparison.row_counts.diff == 2
assert comparison.row_matches.n_joined_total == 2

0 comments on commit 8c32375

Please sign in to comment.