Skip to content

Commit 54e4ea0

Browse files
authored
Sqlalchemy 2.0 (#124)
* Wrap text in sa.text. * Adapt to new select. * Change from connect to begin. * Adapt another select statement. * Add changelog entry.
1 parent 1cccec5 commit 54e4ea0

File tree

4 files changed

+85
-87
lines changed

4 files changed

+85
-87
lines changed

CHANGELOG.rst

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,14 @@ Changelog
88
=========
99

1010

11+
1.6.0 - 2022.04.12
12+
------------------
13+
14+
**Other changes**
15+
16+
- Ensure compatibility with ``sqlalchemy`` >= 2.0.
17+
18+
1119
1.5.0 - 2022.03.14
1220
------------------
1321

@@ -51,7 +59,7 @@ Changelog
5159
- Implemented in-database regex matching for some dialects via ``computation_in_db`` parameter in :meth:`~datajudge.WithinRequirement.add_varchar_regex_constraint`.
5260
- Added support for BigQuery backends.
5361

54-
**Bug fix:**
62+
**Bug fix**
5563

5664
- Snowflake-sqlalchemy version 1.4.0 introduced an unexpected change in behaviour. This problem is resolved by pinning it to the previous version, 1.3.4.
5765

src/datajudge/db_access.py

Lines changed: 57 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def __init__(self, query_string: str, name: str, columns: list[str] = None):
278278
self.clause = subquery
279279
else:
280280
wrapped_query = f"({query_string}) as t"
281-
self.clause = sa.select(["*"]).select_from(sa.text(wrapped_query)).alias()
281+
self.clause = sa.select("*").select_from(sa.text(wrapped_query)).alias()
282282

283283
def __str__(self) -> str:
284284
return self.name
@@ -308,10 +308,10 @@ def get_selection(self, engine: sa.engine.Engine):
308308
clause = self.data_source.get_clause(engine)
309309
if self.columns:
310310
selection = sa.select(
311-
[clause.c[column_name] for column_name in self.get_columns(engine)]
311+
*[clause.c[column_name] for column_name in self.get_columns(engine)]
312312
)
313313
else:
314-
selection = sa.select([clause])
314+
selection = sa.select(clause)
315315
if self.condition is not None:
316316
text = str(self.condition)
317317
if is_snowflake(engine):
@@ -386,7 +386,7 @@ def get_date_span(engine, ref, date_column_name):
386386
column = subquery.c[date_column_name]
387387
if is_postgresql(engine):
388388
selection = sa.select(
389-
[
389+
*[
390390
sa.sql.extract(
391391
"day",
392392
(
@@ -398,7 +398,7 @@ def get_date_span(engine, ref, date_column_name):
398398
)
399399
elif is_mssql(engine) or is_snowflake(engine):
400400
selection = sa.select(
401-
[
401+
*[
402402
sa.func.datediff(
403403
sa.text("day"),
404404
sa.func.min(column),
@@ -408,7 +408,7 @@ def get_date_span(engine, ref, date_column_name):
408408
)
409409
elif is_bigquery(engine):
410410
selection = sa.select(
411-
[
411+
*[
412412
sa.func.date_diff(
413413
sa.func.max(column),
414414
sa.func.min(column),
@@ -418,7 +418,7 @@ def get_date_span(engine, ref, date_column_name):
418418
)
419419
elif is_impala(engine):
420420
selection = sa.select(
421-
[
421+
*[
422422
sa.func.datediff(
423423
sa.func.to_date(sa.func.max(column)),
424424
sa.func.to_date(sa.func.min(column)),
@@ -427,7 +427,7 @@ def get_date_span(engine, ref, date_column_name):
427427
)
428428
elif is_db2(engine):
429429
selection = sa.select(
430-
[
430+
*[
431431
sa.func.days_between(
432432
sa.func.max(column),
433433
sa.func.min(column),
@@ -504,17 +504,17 @@ def get_date_overlaps_nd(
504504

505505
join_condition = sa.and_(*key_conditions, violation_condition)
506506
violation_selection = sa.select(
507-
table_key_columns
508-
+ [
507+
*table_key_columns,
508+
*[
509509
table.c[start_column]
510510
for table in [table1, table2]
511511
for start_column in start_columns
512-
]
513-
+ [
512+
],
513+
*[
514514
table.c[end_column]
515515
for table in [table1, table2]
516516
for end_column in end_columns
517-
]
517+
],
518518
).select_from(table1.join(table2, join_condition))
519519

520520
# Note, Kevin, 21/12/09
@@ -539,7 +539,7 @@ def get_date_overlaps_nd(
539539
if key_columns
540540
else violation_subquery.columns
541541
)
542-
violation_subquery = sa.select(keys, group_by=keys).subquery()
542+
violation_subquery = sa.select(*keys).group_by(*keys).subquery()
543543

544544
n_violations_selection = sa.select(sa.func.count()).select_from(violation_subquery)
545545

@@ -624,13 +624,13 @@ def get_date_gaps(
624624
)
625625

626626
start_table = (
627-
sa.select([*raw_start_table.columns, start_rank_column])
627+
sa.select(*raw_start_table.columns, start_rank_column)
628628
.where(start_not_in_other_interval_condition)
629629
.subquery()
630630
)
631631

632632
end_table = (
633-
sa.select([*raw_end_table.columns, end_rank_column])
633+
sa.select(*raw_end_table.columns, end_rank_column)
634634
.where(end_not_in_other_interval_condition)
635635
.subquery()
636636
)
@@ -697,20 +697,18 @@ def get_date_gaps(
697697
)
698698

699699
violation_selection = sa.select(
700-
[
701-
*get_table_columns(start_table, key_columns),
702-
start_table.c[start_column],
703-
end_table.c[end_column],
704-
]
700+
*get_table_columns(start_table, key_columns),
701+
start_table.c[start_column],
702+
end_table.c[end_column],
705703
).select_from(start_table.join(end_table, join_condition))
706704

707705
violation_subquery = violation_selection.subquery()
708706

709707
keys = get_table_columns(violation_subquery, key_columns)
710708

711-
grouped_violation_subquery = sa.select(keys, group_by=keys).subquery()
709+
grouped_violation_subquery = sa.select(*keys).group_by(*keys).subquery()
712710

713-
n_violations_selection = sa.select([sa.func.count()]).select_from(
711+
n_violations_selection = sa.select(sa.func.count()).select_from(
714712
grouped_violation_subquery
715713
)
716714

@@ -726,9 +724,7 @@ def get_row_count(engine, ref, row_limit: int = None):
726724
if row_limit:
727725
subquery = subquery.limit(row_limit)
728726
subquery = subquery.alias()
729-
selection = sa.select([sa.cast(sa.func.count(), sa.BigInteger)]).select_from(
730-
subquery
731-
)
727+
selection = sa.select(sa.cast(sa.func.count(), sa.BigInteger)).select_from(subquery)
732728
result = engine.connect().execute(selection).scalar()
733729
return result, [selection]
734730

@@ -748,11 +744,11 @@ def get_column(
748744
column = subquery.c[ref.get_column(engine)]
749745

750746
if not aggregate_operator:
751-
selection = sa.select([column])
747+
selection = sa.select(column)
752748
result = engine.connect().execute(selection).scalars().all()
753749

754750
else:
755-
selection = sa.select([aggregate_operator(column)])
751+
selection = sa.select(aggregate_operator(column))
756752
result = engine.connect().execute(selection).scalar()
757753

758754
return result, [selection]
@@ -784,18 +780,16 @@ def get_percentile(engine, ref, percentage):
784780
column = ref.get_selection(engine).subquery().c[column_name]
785781
subquery = (
786782
sa.select(
787-
[
788-
column,
789-
sa.func.row_number().over(order_by=column).label(row_num),
790-
sa.func.count().over(partition_by=None).label(row_count),
791-
]
783+
column,
784+
sa.func.row_number().over(order_by=column).label(row_num),
785+
sa.func.count().over(partition_by=None).label(row_count),
792786
)
793787
.where(column.is_not(None))
794788
.subquery()
795789
)
796790

797791
constrained_selection = (
798-
sa.select(subquery.columns)
792+
sa.select(*subquery.columns)
799793
.where(subquery.c[row_num] * 100.0 / subquery.c[row_count] <= percentage)
800794
.subquery()
801795
)
@@ -850,7 +844,7 @@ def get_uniques(
850844
return Counter({}), []
851845
selection = ref.get_selection(engine).alias()
852846
columns = [selection.c[column_name] for column_name in ref.get_columns(engine)]
853-
selection = sa.select([*columns, sa.func.count()], group_by=columns)
847+
selection = sa.select(*columns, sa.func.count()).group_by(*columns)
854848

855849
def _scalar_accessor(row):
856850
return row[0]
@@ -875,7 +869,7 @@ def _tuple_accessor(row):
875869
def get_unique_count(engine, ref):
876870
selection = ref.get_selection(engine)
877871
subquery = selection.distinct().alias()
878-
selection = sa.select([sa.func.count()]).select_from(subquery)
872+
selection = sa.select(sa.func.count()).select_from(subquery)
879873
result = engine.connect().execute(selection).scalar()
880874
return result, [selection]
881875

@@ -884,16 +878,16 @@ def get_unique_count_union(engine, ref, ref2):
884878
selection1 = ref.get_selection(engine)
885879
selection2 = ref2.get_selection(engine)
886880
subquery = sa.sql.union(selection1, selection2).alias().select().distinct().alias()
887-
selection = sa.select([sa.func.count()]).select_from(subquery)
881+
selection = sa.select(sa.func.count()).select_from(subquery)
888882
result = engine.connect().execute(selection).scalar()
889883
return result, [selection]
890884

891885

892886
def get_missing_fraction(engine, ref):
893887
selection = ref.get_selection(engine).subquery()
894-
n_rows_total_selection = sa.select([sa.func.count()]).select_from(selection)
888+
n_rows_total_selection = sa.select(sa.func.count()).select_from(selection)
895889
n_rows_missing_selection = (
896-
sa.select([sa.func.count()])
890+
sa.select(sa.func.count())
897891
.select_from(selection)
898892
.where(selection.c[ref.get_column(engine)].is_(None))
899893
)
@@ -947,7 +941,7 @@ def get_row_difference_count(engine, ref, ref2):
947941
subquery = (
948942
sa.sql.except_(selection1, selection2).alias().select().distinct().alias()
949943
)
950-
selection = sa.select([sa.func.count()]).select_from(subquery)
944+
selection = sa.select(sa.func.count()).select_from(subquery)
951945
result = engine.connect().execute(selection).scalar()
952946
return result, [selection]
953947

@@ -982,9 +976,9 @@ def get_row_mismatch(engine, ref, ref2, match_and_compare):
982976
]
983977
)
984978

985-
avg_match_column = sa.func.avg(sa.case([(compare, 0.0)], else_=1.0))
979+
avg_match_column = sa.func.avg(sa.case((compare, 0.0), else_=1.0))
986980

987-
selection_difference = sa.select([avg_match_column]).select_from(
981+
selection_difference = sa.select(avg_match_column).select_from(
988982
subselection1.join(subselection2, match)
989983
)
990984
selection_n_rows = sa.select(sa.func.count()).select_from(
@@ -998,14 +992,14 @@ def get_row_mismatch(engine, ref, ref2, match_and_compare):
998992
def get_duplicate_sample(engine, ref):
999993
initial_selection = ref.get_selection(engine).alias()
1000994
aggregate_subquery = (
1001-
sa.select([initial_selection, sa.func.count().label("n_copies")])
995+
sa.select(initial_selection, sa.func.count().label("n_copies"))
1002996
.select_from(initial_selection)
1003997
.group_by(*initial_selection.columns)
1004998
.alias()
1005999
)
10061000
duplicate_selection = (
10071001
sa.select(
1008-
[
1002+
*[
10091003
column
10101004
for column in aggregate_subquery.columns
10111005
if column.key != "n_copies"
@@ -1027,8 +1021,8 @@ def column_array_agg_query(
10271021
raise ValueError("There must be a column to group by")
10281022
group_columns = [clause.c[column] for column in ref.get_columns(engine)]
10291023
agg_column = clause.c[aggregation_column]
1030-
selection = sa.select(
1031-
[*group_columns, sa.func.array_agg(agg_column)], group_by=[*group_columns]
1024+
selection = sa.select(*group_columns, sa.func.array_agg(agg_column)).group_by(
1025+
*group_columns
10321026
)
10331027
return [selection]
10341028

@@ -1064,19 +1058,15 @@ def _cdf_selection(engine, ref: DataReference, cdf_label: str, value_label: str)
10641058

10651059
# Step 1: Calculate the CDF over the value column.
10661060
cdf_selection = sa.select(
1067-
[
1068-
selection.c[col].label(value_label),
1069-
sa.func.cume_dist().over(order_by=col).label(cdf_label),
1070-
]
1061+
selection.c[col].label(value_label),
1062+
sa.func.cume_dist().over(order_by=col).label(cdf_label),
10711063
).subquery()
10721064

10731065
# Step 2: Aggregate rows s.t. every value occurs only once.
10741066
grouped_cdf_selection = (
10751067
sa.select(
1076-
[
1077-
cdf_selection.c[value_label],
1078-
sa.func.max(cdf_selection.c[cdf_label]).label(cdf_label),
1079-
]
1068+
cdf_selection.c[value_label],
1069+
sa.func.max(cdf_selection.c[cdf_label]).label(cdf_label),
10801070
)
10811071
.group_by(cdf_selection.c[value_label])
10821072
.subquery()
@@ -1136,13 +1126,11 @@ def _cdf_index_column(table, value_label, cdf_label, group_label):
11361126
# In other words, we point rows to their most recent present value - per sample. This is necessary
11371127
# Due to the nature of the full outer join.
11381128
indexed_cross_cdf = sa.select(
1139-
[
1140-
cross_cdf.c[value_label],
1141-
_cdf_index_column(cross_cdf, value_label, cdf_label1, group_label1),
1142-
cross_cdf.c[cdf_label1],
1143-
_cdf_index_column(cross_cdf, value_label, cdf_label2, group_label2),
1144-
cross_cdf.c[cdf_label2],
1145-
]
1129+
cross_cdf.c[value_label],
1130+
_cdf_index_column(cross_cdf, value_label, cdf_label1, group_label1),
1131+
cross_cdf.c[cdf_label1],
1132+
_cdf_index_column(cross_cdf, value_label, cdf_label2, group_label2),
1133+
cross_cdf.c[cdf_label2],
11461134
).subquery()
11471135

11481136
def _forward_filled_cdf_column(table, cdf_label, value_label, group_label):
@@ -1160,15 +1148,13 @@ def _forward_filled_cdf_column(table, cdf_label, value_label, group_label):
11601148
)
11611149

11621150
filled_cross_cdf = sa.select(
1163-
[
1164-
indexed_cross_cdf.c[value_label],
1165-
_forward_filled_cdf_column(
1166-
indexed_cross_cdf, cdf_label1, value_label, group_label1
1167-
),
1168-
_forward_filled_cdf_column(
1169-
indexed_cross_cdf, cdf_label2, value_label, group_label2
1170-
),
1171-
]
1151+
indexed_cross_cdf.c[value_label],
1152+
_forward_filled_cdf_column(
1153+
indexed_cross_cdf, cdf_label1, value_label, group_label1
1154+
),
1155+
_forward_filled_cdf_column(
1156+
indexed_cross_cdf, cdf_label2, value_label, group_label2
1157+
),
11721158
)
11731159
return filled_cross_cdf, cdf_label1, cdf_label2
11741160

@@ -1219,7 +1205,7 @@ def get_regex_violations(engine, ref, aggregated, regex, n_counterexamples):
12191205
violation_selection = sa.select(subquery.c[column]).where(
12201206
sa.not_(subquery.c[column].regexp_match(regex))
12211207
)
1222-
n_violations_selection = sa.select([sa.func.count()]).select_from(
1208+
n_violations_selection = sa.select(sa.func.count()).select_from(
12231209
violation_selection.subquery()
12241210
)
12251211

0 commit comments

Comments
 (0)