@@ -278,7 +278,7 @@ def __init__(self, query_string: str, name: str, columns: list[str] = None):
278
278
self .clause = subquery
279
279
else :
280
280
wrapped_query = f"({ query_string } ) as t"
281
- self .clause = sa .select ([ "*" ] ).select_from (sa .text (wrapped_query )).alias ()
281
+ self .clause = sa .select ("*" ).select_from (sa .text (wrapped_query )).alias ()
282
282
283
283
def __str__ (self ) -> str :
284
284
return self .name
@@ -308,10 +308,10 @@ def get_selection(self, engine: sa.engine.Engine):
308
308
clause = self .data_source .get_clause (engine )
309
309
if self .columns :
310
310
selection = sa .select (
311
- [clause .c [column_name ] for column_name in self .get_columns (engine )]
311
+ * [clause .c [column_name ] for column_name in self .get_columns (engine )]
312
312
)
313
313
else :
314
- selection = sa .select ([ clause ] )
314
+ selection = sa .select (clause )
315
315
if self .condition is not None :
316
316
text = str (self .condition )
317
317
if is_snowflake (engine ):
@@ -386,7 +386,7 @@ def get_date_span(engine, ref, date_column_name):
386
386
column = subquery .c [date_column_name ]
387
387
if is_postgresql (engine ):
388
388
selection = sa .select (
389
- [
389
+ * [
390
390
sa .sql .extract (
391
391
"day" ,
392
392
(
@@ -398,7 +398,7 @@ def get_date_span(engine, ref, date_column_name):
398
398
)
399
399
elif is_mssql (engine ) or is_snowflake (engine ):
400
400
selection = sa .select (
401
- [
401
+ * [
402
402
sa .func .datediff (
403
403
sa .text ("day" ),
404
404
sa .func .min (column ),
@@ -408,7 +408,7 @@ def get_date_span(engine, ref, date_column_name):
408
408
)
409
409
elif is_bigquery (engine ):
410
410
selection = sa .select (
411
- [
411
+ * [
412
412
sa .func .date_diff (
413
413
sa .func .max (column ),
414
414
sa .func .min (column ),
@@ -418,7 +418,7 @@ def get_date_span(engine, ref, date_column_name):
418
418
)
419
419
elif is_impala (engine ):
420
420
selection = sa .select (
421
- [
421
+ * [
422
422
sa .func .datediff (
423
423
sa .func .to_date (sa .func .max (column )),
424
424
sa .func .to_date (sa .func .min (column )),
@@ -427,7 +427,7 @@ def get_date_span(engine, ref, date_column_name):
427
427
)
428
428
elif is_db2 (engine ):
429
429
selection = sa .select (
430
- [
430
+ * [
431
431
sa .func .days_between (
432
432
sa .func .max (column ),
433
433
sa .func .min (column ),
@@ -504,17 +504,17 @@ def get_date_overlaps_nd(
504
504
505
505
join_condition = sa .and_ (* key_conditions , violation_condition )
506
506
violation_selection = sa .select (
507
- table_key_columns
508
- + [
507
+ * table_key_columns ,
508
+ * [
509
509
table .c [start_column ]
510
510
for table in [table1 , table2 ]
511
511
for start_column in start_columns
512
- ]
513
- + [
512
+ ],
513
+ * [
514
514
table .c [end_column ]
515
515
for table in [table1 , table2 ]
516
516
for end_column in end_columns
517
- ]
517
+ ],
518
518
).select_from (table1 .join (table2 , join_condition ))
519
519
520
520
# Note, Kevin, 21/12/09
@@ -539,7 +539,7 @@ def get_date_overlaps_nd(
539
539
if key_columns
540
540
else violation_subquery .columns
541
541
)
542
- violation_subquery = sa .select (keys , group_by = keys ).subquery ()
542
+ violation_subquery = sa .select (* keys ). group_by ( * keys ).subquery ()
543
543
544
544
n_violations_selection = sa .select (sa .func .count ()).select_from (violation_subquery )
545
545
@@ -624,13 +624,13 @@ def get_date_gaps(
624
624
)
625
625
626
626
start_table = (
627
- sa .select ([ * raw_start_table .columns , start_rank_column ] )
627
+ sa .select (* raw_start_table .columns , start_rank_column )
628
628
.where (start_not_in_other_interval_condition )
629
629
.subquery ()
630
630
)
631
631
632
632
end_table = (
633
- sa .select ([ * raw_end_table .columns , end_rank_column ] )
633
+ sa .select (* raw_end_table .columns , end_rank_column )
634
634
.where (end_not_in_other_interval_condition )
635
635
.subquery ()
636
636
)
@@ -697,20 +697,18 @@ def get_date_gaps(
697
697
)
698
698
699
699
violation_selection = sa .select (
700
- [
701
- * get_table_columns (start_table , key_columns ),
702
- start_table .c [start_column ],
703
- end_table .c [end_column ],
704
- ]
700
+ * get_table_columns (start_table , key_columns ),
701
+ start_table .c [start_column ],
702
+ end_table .c [end_column ],
705
703
).select_from (start_table .join (end_table , join_condition ))
706
704
707
705
violation_subquery = violation_selection .subquery ()
708
706
709
707
keys = get_table_columns (violation_subquery , key_columns )
710
708
711
- grouped_violation_subquery = sa .select (keys , group_by = keys ).subquery ()
709
+ grouped_violation_subquery = sa .select (* keys ). group_by ( * keys ).subquery ()
712
710
713
- n_violations_selection = sa .select ([ sa .func .count ()] ).select_from (
711
+ n_violations_selection = sa .select (sa .func .count ()).select_from (
714
712
grouped_violation_subquery
715
713
)
716
714
@@ -726,9 +724,7 @@ def get_row_count(engine, ref, row_limit: int = None):
726
724
if row_limit :
727
725
subquery = subquery .limit (row_limit )
728
726
subquery = subquery .alias ()
729
- selection = sa .select ([sa .cast (sa .func .count (), sa .BigInteger )]).select_from (
730
- subquery
731
- )
727
+ selection = sa .select (sa .cast (sa .func .count (), sa .BigInteger )).select_from (subquery )
732
728
result = engine .connect ().execute (selection ).scalar ()
733
729
return result , [selection ]
734
730
@@ -748,11 +744,11 @@ def get_column(
748
744
column = subquery .c [ref .get_column (engine )]
749
745
750
746
if not aggregate_operator :
751
- selection = sa .select ([ column ] )
747
+ selection = sa .select (column )
752
748
result = engine .connect ().execute (selection ).scalars ().all ()
753
749
754
750
else :
755
- selection = sa .select ([ aggregate_operator (column )] )
751
+ selection = sa .select (aggregate_operator (column ))
756
752
result = engine .connect ().execute (selection ).scalar ()
757
753
758
754
return result , [selection ]
@@ -784,18 +780,16 @@ def get_percentile(engine, ref, percentage):
784
780
column = ref .get_selection (engine ).subquery ().c [column_name ]
785
781
subquery = (
786
782
sa .select (
787
- [
788
- column ,
789
- sa .func .row_number ().over (order_by = column ).label (row_num ),
790
- sa .func .count ().over (partition_by = None ).label (row_count ),
791
- ]
783
+ column ,
784
+ sa .func .row_number ().over (order_by = column ).label (row_num ),
785
+ sa .func .count ().over (partition_by = None ).label (row_count ),
792
786
)
793
787
.where (column .is_not (None ))
794
788
.subquery ()
795
789
)
796
790
797
791
constrained_selection = (
798
- sa .select (subquery .columns )
792
+ sa .select (* subquery .columns )
799
793
.where (subquery .c [row_num ] * 100.0 / subquery .c [row_count ] <= percentage )
800
794
.subquery ()
801
795
)
@@ -850,7 +844,7 @@ def get_uniques(
850
844
return Counter ({}), []
851
845
selection = ref .get_selection (engine ).alias ()
852
846
columns = [selection .c [column_name ] for column_name in ref .get_columns (engine )]
853
- selection = sa .select ([ * columns , sa .func .count ()], group_by = columns )
847
+ selection = sa .select (* columns , sa .func .count ()). group_by ( * columns )
854
848
855
849
def _scalar_accessor (row ):
856
850
return row [0 ]
@@ -875,7 +869,7 @@ def _tuple_accessor(row):
875
869
def get_unique_count (engine , ref ):
876
870
selection = ref .get_selection (engine )
877
871
subquery = selection .distinct ().alias ()
878
- selection = sa .select ([ sa .func .count ()] ).select_from (subquery )
872
+ selection = sa .select (sa .func .count ()).select_from (subquery )
879
873
result = engine .connect ().execute (selection ).scalar ()
880
874
return result , [selection ]
881
875
@@ -884,16 +878,16 @@ def get_unique_count_union(engine, ref, ref2):
884
878
selection1 = ref .get_selection (engine )
885
879
selection2 = ref2 .get_selection (engine )
886
880
subquery = sa .sql .union (selection1 , selection2 ).alias ().select ().distinct ().alias ()
887
- selection = sa .select ([ sa .func .count ()] ).select_from (subquery )
881
+ selection = sa .select (sa .func .count ()).select_from (subquery )
888
882
result = engine .connect ().execute (selection ).scalar ()
889
883
return result , [selection ]
890
884
891
885
892
886
def get_missing_fraction (engine , ref ):
893
887
selection = ref .get_selection (engine ).subquery ()
894
- n_rows_total_selection = sa .select ([ sa .func .count ()] ).select_from (selection )
888
+ n_rows_total_selection = sa .select (sa .func .count ()).select_from (selection )
895
889
n_rows_missing_selection = (
896
- sa .select ([ sa .func .count ()] )
890
+ sa .select (sa .func .count ())
897
891
.select_from (selection )
898
892
.where (selection .c [ref .get_column (engine )].is_ (None ))
899
893
)
@@ -947,7 +941,7 @@ def get_row_difference_count(engine, ref, ref2):
947
941
subquery = (
948
942
sa .sql .except_ (selection1 , selection2 ).alias ().select ().distinct ().alias ()
949
943
)
950
- selection = sa .select ([ sa .func .count ()] ).select_from (subquery )
944
+ selection = sa .select (sa .func .count ()).select_from (subquery )
951
945
result = engine .connect ().execute (selection ).scalar ()
952
946
return result , [selection ]
953
947
@@ -982,9 +976,9 @@ def get_row_mismatch(engine, ref, ref2, match_and_compare):
982
976
]
983
977
)
984
978
985
- avg_match_column = sa .func .avg (sa .case ([ (compare , 0.0 )] , else_ = 1.0 ))
979
+ avg_match_column = sa .func .avg (sa .case ((compare , 0.0 ), else_ = 1.0 ))
986
980
987
- selection_difference = sa .select ([ avg_match_column ] ).select_from (
981
+ selection_difference = sa .select (avg_match_column ).select_from (
988
982
subselection1 .join (subselection2 , match )
989
983
)
990
984
selection_n_rows = sa .select (sa .func .count ()).select_from (
@@ -998,14 +992,14 @@ def get_row_mismatch(engine, ref, ref2, match_and_compare):
998
992
def get_duplicate_sample (engine , ref ):
999
993
initial_selection = ref .get_selection (engine ).alias ()
1000
994
aggregate_subquery = (
1001
- sa .select ([ initial_selection , sa .func .count ().label ("n_copies" )] )
995
+ sa .select (initial_selection , sa .func .count ().label ("n_copies" ))
1002
996
.select_from (initial_selection )
1003
997
.group_by (* initial_selection .columns )
1004
998
.alias ()
1005
999
)
1006
1000
duplicate_selection = (
1007
1001
sa .select (
1008
- [
1002
+ * [
1009
1003
column
1010
1004
for column in aggregate_subquery .columns
1011
1005
if column .key != "n_copies"
@@ -1027,8 +1021,8 @@ def column_array_agg_query(
1027
1021
raise ValueError ("There must be a column to group by" )
1028
1022
group_columns = [clause .c [column ] for column in ref .get_columns (engine )]
1029
1023
agg_column = clause .c [aggregation_column ]
1030
- selection = sa .select (
1031
- [ * group_columns , sa . func . array_agg ( agg_column )], group_by = [ * group_columns ]
1024
+ selection = sa .select (* group_columns , sa . func . array_agg ( agg_column )). group_by (
1025
+ * group_columns
1032
1026
)
1033
1027
return [selection ]
1034
1028
@@ -1064,19 +1058,15 @@ def _cdf_selection(engine, ref: DataReference, cdf_label: str, value_label: str)
1064
1058
1065
1059
# Step 1: Calculate the CDF over the value column.
1066
1060
cdf_selection = sa .select (
1067
- [
1068
- selection .c [col ].label (value_label ),
1069
- sa .func .cume_dist ().over (order_by = col ).label (cdf_label ),
1070
- ]
1061
+ selection .c [col ].label (value_label ),
1062
+ sa .func .cume_dist ().over (order_by = col ).label (cdf_label ),
1071
1063
).subquery ()
1072
1064
1073
1065
# Step 2: Aggregate rows s.t. every value occurs only once.
1074
1066
grouped_cdf_selection = (
1075
1067
sa .select (
1076
- [
1077
- cdf_selection .c [value_label ],
1078
- sa .func .max (cdf_selection .c [cdf_label ]).label (cdf_label ),
1079
- ]
1068
+ cdf_selection .c [value_label ],
1069
+ sa .func .max (cdf_selection .c [cdf_label ]).label (cdf_label ),
1080
1070
)
1081
1071
.group_by (cdf_selection .c [value_label ])
1082
1072
.subquery ()
@@ -1136,13 +1126,11 @@ def _cdf_index_column(table, value_label, cdf_label, group_label):
1136
1126
# In other words, we point rows to their most recent present value - per sample. This is necessary
1137
1127
# Due to the nature of the full outer join.
1138
1128
indexed_cross_cdf = sa .select (
1139
- [
1140
- cross_cdf .c [value_label ],
1141
- _cdf_index_column (cross_cdf , value_label , cdf_label1 , group_label1 ),
1142
- cross_cdf .c [cdf_label1 ],
1143
- _cdf_index_column (cross_cdf , value_label , cdf_label2 , group_label2 ),
1144
- cross_cdf .c [cdf_label2 ],
1145
- ]
1129
+ cross_cdf .c [value_label ],
1130
+ _cdf_index_column (cross_cdf , value_label , cdf_label1 , group_label1 ),
1131
+ cross_cdf .c [cdf_label1 ],
1132
+ _cdf_index_column (cross_cdf , value_label , cdf_label2 , group_label2 ),
1133
+ cross_cdf .c [cdf_label2 ],
1146
1134
).subquery ()
1147
1135
1148
1136
def _forward_filled_cdf_column (table , cdf_label , value_label , group_label ):
@@ -1160,15 +1148,13 @@ def _forward_filled_cdf_column(table, cdf_label, value_label, group_label):
1160
1148
)
1161
1149
1162
1150
filled_cross_cdf = sa .select (
1163
- [
1164
- indexed_cross_cdf .c [value_label ],
1165
- _forward_filled_cdf_column (
1166
- indexed_cross_cdf , cdf_label1 , value_label , group_label1
1167
- ),
1168
- _forward_filled_cdf_column (
1169
- indexed_cross_cdf , cdf_label2 , value_label , group_label2
1170
- ),
1171
- ]
1151
+ indexed_cross_cdf .c [value_label ],
1152
+ _forward_filled_cdf_column (
1153
+ indexed_cross_cdf , cdf_label1 , value_label , group_label1
1154
+ ),
1155
+ _forward_filled_cdf_column (
1156
+ indexed_cross_cdf , cdf_label2 , value_label , group_label2
1157
+ ),
1172
1158
)
1173
1159
return filled_cross_cdf , cdf_label1 , cdf_label2
1174
1160
@@ -1219,7 +1205,7 @@ def get_regex_violations(engine, ref, aggregated, regex, n_counterexamples):
1219
1205
violation_selection = sa .select (subquery .c [column ]).where (
1220
1206
sa .not_ (subquery .c [column ].regexp_match (regex ))
1221
1207
)
1222
- n_violations_selection = sa .select ([ sa .func .count ()] ).select_from (
1208
+ n_violations_selection = sa .select (sa .func .count ()).select_from (
1223
1209
violation_selection .subquery ()
1224
1210
)
1225
1211
0 commit comments