8
8
import sqlalchemy as sa
9
9
from impala .dbapi import connect
10
10
11
- from datajudge .db_access import apply_patches , is_bigquery , is_impala , is_mssql
11
+ from datajudge .db_access import apply_patches , is_bigquery , is_db2 , is_impala , is_mssql
12
12
13
13
TEST_DB_NAME = "tempdb"
14
14
SCHEMA = "dbo" # 'dbo' is the standard schema in mssql
@@ -30,6 +30,8 @@ def conn_creator():
30
30
31
31
if backend == "postgres" :
32
32
connection_string = f"postgresql://datajudge:datajudge@{ address } :5432/datajudge"
33
+ if backend == "db2" :
34
+ connection_string = f"db2+ibm_db://db2inst1:password@{ address } :50000/testdb"
33
35
elif "mssql" in backend :
34
36
connection_string = (
35
37
f"mssql+pyodbc://sa:datajudge-123@{ address } :1433/{ TEST_DB_NAME } "
@@ -56,6 +58,12 @@ def conn_creator():
56
58
return engine
57
59
58
60
61
+ def _string_column (engine ):
62
+ if is_db2 (engine ):
63
+ return sa .String (40 )
64
+ return sa .String ()
65
+
66
+
59
67
@pytest .fixture (scope = "module" )
60
68
def engine (backend ):
61
69
engine = get_engine (backend )
@@ -111,7 +119,7 @@ def mix_table1(engine, metadata):
111
119
table_name = "mix_table1"
112
120
columns = [
113
121
sa .Column ("col_int" , sa .Integer ()),
114
- sa .Column ("col_varchar" , sa . String ( )),
122
+ sa .Column ("col_varchar" , _string_column ( engine )),
115
123
sa .Column ("col_date" , sa .DateTime ()),
116
124
]
117
125
data = [
@@ -131,7 +139,7 @@ def mix_table2(engine, metadata):
131
139
table_name = "mix_table2"
132
140
columns = [
133
141
sa .Column ("col_int" , sa .Integer ()),
134
- sa .Column ("col_varchar" , sa . String ( )),
142
+ sa .Column ("col_varchar" , _string_column ( engine )),
135
143
sa .Column ("col_date" , sa .DateTime ()),
136
144
]
137
145
data = [
@@ -152,7 +160,7 @@ def mix_table2_pk(engine, metadata):
152
160
table_name = "mix_table2_pk"
153
161
columns = [
154
162
sa .Column ("col_int" , sa .Integer (), primary_key = True ),
155
- sa .Column ("col_varchar" , sa . String ( )),
163
+ sa .Column ("col_varchar" , _string_column ( engine )),
156
164
sa .Column ("col_date" , sa .DateTime ()),
157
165
]
158
166
data = [
@@ -477,7 +485,7 @@ def unique_table1(engine, metadata):
477
485
table_name = "unique_table1"
478
486
columns = [
479
487
sa .Column ("col_int" , sa .Integer ()),
480
- sa .Column ("col_varchar" , sa . String ( )),
488
+ sa .Column ("col_varchar" , _string_column ( engine )),
481
489
]
482
490
data = [{"col_int" : i // 2 , "col_varchar" : f"hi{ i // 3 } " } for i in range (60 )]
483
491
data += [
@@ -493,7 +501,7 @@ def unique_table2(engine, metadata):
493
501
table_name = "unique_table2"
494
502
columns = [
495
503
sa .Column ("col_int" , sa .Integer ()),
496
- sa .Column ("col_varchar" , sa . String ( )),
504
+ sa .Column ("col_varchar" , _string_column ( engine )),
497
505
]
498
506
data = [{"col_int" : i // 2 , "col_varchar" : f"hi{ i // 3 } " } for i in range (40 )]
499
507
_handle_table (engine , metadata , table_name , columns , data )
@@ -503,7 +511,7 @@ def unique_table2(engine, metadata):
503
511
@pytest .fixture (scope = "module" )
504
512
def nested_table (engine , metadata ):
505
513
table_name = "nested_table"
506
- columns = [sa .Column ("nested_varchar" , sa . String ( ))]
514
+ columns = [sa .Column ("nested_varchar" , _string_column ( engine ))]
507
515
data = [
508
516
{"nested_varchar" : "ABC#1," },
509
517
{"nested_varchar" : "ABC#1,DEF#2," },
@@ -517,7 +525,7 @@ def nested_table(engine, metadata):
517
525
def varchar_table1 (engine , metadata ):
518
526
table_name = "varchar_table1"
519
527
columns = [
520
- sa .Column ("col_varchar" , sa . String ( )),
528
+ sa .Column ("col_varchar" , _string_column ( engine )),
521
529
]
522
530
data = [{"col_varchar" : "qq" * i } for i in range (1 , 10 )]
523
531
data .append ({"col_varchar" : None })
@@ -529,7 +537,7 @@ def varchar_table1(engine, metadata):
529
537
def varchar_table2 (engine , metadata ):
530
538
table_name = "varchar_table2"
531
539
columns = [
532
- sa .Column ("col_varchar" , sa . String ( )),
540
+ sa .Column ("col_varchar" , _string_column ( engine )),
533
541
]
534
542
data = [{"col_varchar" : "qq" * i } for i in range (2 , 11 )]
535
543
_handle_table (engine , metadata , table_name , columns , data )
@@ -540,7 +548,7 @@ def varchar_table2(engine, metadata):
540
548
def varchar_table_real (engine , metadata ):
541
549
table_name = "varchar_table_real"
542
550
columns = [
543
- sa .Column ("col_varchar" , sa . String ( )),
551
+ sa .Column ("col_varchar" , _string_column ( engine )),
544
552
]
545
553
data = [
546
554
{"col_varchar" : val }
@@ -754,6 +762,10 @@ def capitalization_table(engine, metadata):
754
762
str_datatype = "STRING"
755
763
# Impala supports primary keys but uses a different grammar.
756
764
primary_key = ""
765
+ elif is_db2 (engine ):
766
+ str_datatype = "VARCHAR(20)"
767
+ # Primary key needs to be non-nullable.
768
+ primary_key = ""
757
769
else :
758
770
str_datatype = "TEXT"
759
771
with engine .connect () as connection :
@@ -796,7 +808,15 @@ def pytest_addoption(parser):
796
808
parser .addoption (
797
809
"--backend" ,
798
810
choices = (
799
- ("mssql" , "mssql-freetds" , "postgres" , "snowflake" , "bigquery" , "impala" )
811
+ (
812
+ "mssql" ,
813
+ "mssql-freetds" ,
814
+ "postgres" ,
815
+ "snowflake" ,
816
+ "bigquery" ,
817
+ "impala" ,
818
+ "db2" ,
819
+ )
800
820
),
801
821
help = "which database backend to use to run the integration tests" ,
802
822
)
0 commit comments