1
1
"""Handles Postgres interactions."""
2
2
3
+
3
4
from __future__ import annotations
4
5
5
6
import atexit
6
7
import io
8
+ import itertools
7
9
import signal
10
+ import sys
8
11
import typing as t
9
12
from contextlib import contextmanager
10
13
from os import chmod , path
@@ -79,7 +82,7 @@ def __init__(self, config: dict) -> None:
79
82
sqlalchemy_url = url .render_as_string (hide_password = False ),
80
83
)
81
84
82
- def prepare_table ( # type: ignore[override]
85
+ def prepare_table ( # type: ignore[override] # noqa: PLR0913
83
86
self ,
84
87
full_table_name : str ,
85
88
schema : dict ,
@@ -105,7 +108,7 @@ def prepare_table( # type: ignore[override]
105
108
meta = sa .MetaData (schema = schema_name )
106
109
table : sa .Table
107
110
if not self .table_exists (full_table_name = full_table_name ):
108
- table = self .create_empty_table (
111
+ return self .create_empty_table (
109
112
table_name = table_name ,
110
113
meta = meta ,
111
114
schema = schema ,
@@ -114,7 +117,6 @@ def prepare_table( # type: ignore[override]
114
117
as_temp_table = as_temp_table ,
115
118
connection = connection ,
116
119
)
117
- return table
118
120
meta .reflect (connection , only = [table_name ])
119
121
table = meta .tables [
120
122
full_table_name
@@ -161,19 +163,19 @@ def copy_table_structure(
161
163
_ , schema_name , table_name = self .parse_full_table_name (full_table_name )
162
164
meta = sa .MetaData (schema = schema_name )
163
165
new_table : sa .Table
164
- columns = []
165
166
if self .table_exists (full_table_name = full_table_name ):
166
167
raise RuntimeError ("Table already exists" )
167
- for column in from_table .columns :
168
- columns .append (column ._copy ())
168
+
169
+ columns = [column ._copy () for column in from_table .columns ]
170
+
169
171
if as_temp_table :
170
172
new_table = sa .Table (table_name , meta , * columns , prefixes = ["TEMPORARY" ])
171
173
new_table .create (bind = connection )
172
174
return new_table
173
- else :
174
- new_table = sa .Table (table_name , meta , * columns )
175
- new_table .create (bind = connection )
176
- return new_table
175
+
176
+ new_table = sa .Table (table_name , meta , * columns )
177
+ new_table .create (bind = connection )
178
+ return new_table
177
179
178
180
@contextmanager
179
181
def _connect (self ) -> t .Iterator [sa .engine .Connection ]:
@@ -184,18 +186,17 @@ def drop_table(self, table: sa.Table, connection: sa.engine.Connection):
184
186
"""Drop table data."""
185
187
table .drop (bind = connection )
186
188
187
- def clone_table (
189
+ def clone_table ( # noqa: PLR0913
188
190
self , new_table_name , table , metadata , connection , temp_table
189
191
) -> sa .Table :
190
192
"""Clone a table."""
191
- new_columns = []
192
- for column in table .columns :
193
- new_columns .append (
194
- sa .Column (
195
- column .name ,
196
- column .type ,
197
- )
193
+ new_columns = [
194
+ sa .Column (
195
+ column .name ,
196
+ column .type ,
198
197
)
198
+ for column in table .columns
199
+ ]
199
200
if temp_table is True :
200
201
new_table = sa .Table (
201
202
new_table_name , metadata , * new_columns , prefixes = ["TEMPORARY" ]
@@ -275,9 +276,8 @@ def pick_individual_type(jsonschema_type: dict):
275
276
if jsonschema_type .get ("format" ) == "date-time" :
276
277
return TIMESTAMP ()
277
278
individual_type = th .to_sql_type (jsonschema_type )
278
- if isinstance (individual_type , VARCHAR ):
279
- return TEXT ()
280
- return individual_type
279
+
280
+ return TEXT () if isinstance (individual_type , VARCHAR ) else individual_type
281
281
282
282
@staticmethod
283
283
def pick_best_sql_type (sql_type_array : list ):
@@ -304,13 +304,12 @@ def pick_best_sql_type(sql_type_array: list):
304
304
NOTYPE ,
305
305
]
306
306
307
- for sql_type in precedence_order :
308
- for obj in sql_type_array :
309
- if isinstance (obj , sql_type ):
310
- return obj
307
+ for sql_type , obj in itertools .product (precedence_order , sql_type_array ):
308
+ if isinstance (obj , sql_type ):
309
+ return obj
311
310
return TEXT ()
312
311
313
- def create_empty_table ( # type: ignore[override]
312
+ def create_empty_table ( # type: ignore[override] # noqa: PLR0913
314
313
self ,
315
314
table_name : str ,
316
315
meta : sa .MetaData ,
@@ -324,7 +323,7 @@ def create_empty_table( # type: ignore[override]
324
323
325
324
Args:
326
325
table_name: the target table name.
327
- meta: the SQLAchemy metadata object.
326
+ meta: the SQLAlchemy metadata object.
328
327
schema: the JSON schema for the new table.
329
328
connection: the database connection.
330
329
primary_keys: list of key properties.
@@ -367,7 +366,7 @@ def create_empty_table( # type: ignore[override]
367
366
new_table .create (bind = connection )
368
367
return new_table
369
368
370
- def prepare_column (
369
+ def prepare_column ( # noqa: PLR0913
371
370
self ,
372
371
full_table_name : str ,
373
372
column_name : str ,
@@ -415,7 +414,7 @@ def prepare_column(
415
414
column_object = column_object ,
416
415
)
417
416
418
- def _create_empty_column ( # type: ignore[override]
417
+ def _create_empty_column ( # type: ignore[override] # noqa: PLR0913
419
418
self ,
420
419
schema_name : str ,
421
420
table_name : str ,
@@ -480,7 +479,7 @@ def get_column_add_ddl( # type: ignore[override]
480
479
},
481
480
)
482
481
483
- def _adapt_column_type ( # type: ignore[override]
482
+ def _adapt_column_type ( # type: ignore[override] # noqa: PLR0913
484
483
self ,
485
484
schema_name : str ,
486
485
table_name : str ,
@@ -523,7 +522,7 @@ def _adapt_column_type( # type: ignore[override]
523
522
return
524
523
525
524
# Not the same type, generic type or compatible types
526
- # calling merge_sql_types for assistnace
525
+ # calling merge_sql_types for assistance
527
526
compatible_sql_type = self .merge_sql_types ([current_type , sql_type ])
528
527
529
528
if str (compatible_sql_type ) == str (current_type ):
@@ -593,17 +592,16 @@ def get_sqlalchemy_url(self, config: dict) -> str:
593
592
if config .get ("sqlalchemy_url" ):
594
593
return cast (str , config ["sqlalchemy_url" ])
595
594
596
- else :
597
- sqlalchemy_url = URL .create (
598
- drivername = config ["dialect+driver" ],
599
- username = config ["user" ],
600
- password = config ["password" ],
601
- host = config ["host" ],
602
- port = config ["port" ],
603
- database = config ["database" ],
604
- query = self .get_sqlalchemy_query (config ),
605
- )
606
- return cast (str , sqlalchemy_url )
595
+ sqlalchemy_url = URL .create (
596
+ drivername = config ["dialect+driver" ],
597
+ username = config ["user" ],
598
+ password = config ["password" ],
599
+ host = config ["host" ],
600
+ port = config ["port" ],
601
+ database = config ["database" ],
602
+ query = self .get_sqlalchemy_query (config ),
603
+ )
604
+ return cast (str , sqlalchemy_url )
607
605
608
606
def get_sqlalchemy_query (self , config : dict ) -> dict :
609
607
"""Get query values to be used for sqlalchemy URL creation.
@@ -619,7 +617,7 @@ def get_sqlalchemy_query(self, config: dict) -> dict:
619
617
# ssl_enable is for verifying the server's identity to the client.
620
618
if config ["ssl_enable" ]:
621
619
ssl_mode = config ["ssl_mode" ]
622
- query . update ({ "sslmode" : ssl_mode })
620
+ query [ "sslmode" ] = ssl_mode
623
621
query ["sslrootcert" ] = self .filepath_or_certificate (
624
622
value = config ["ssl_certificate_authority" ],
625
623
alternative_name = config ["ssl_storage_directory" ] + "/root.crt" ,
@@ -665,12 +663,11 @@ def filepath_or_certificate(
665
663
"""
666
664
if path .isfile (value ):
667
665
return value
668
- else :
669
- with open (alternative_name , "wb" ) as alternative_file :
670
- alternative_file .write (value .encode ("utf-8" ))
671
- if restrict_permissions :
672
- chmod (alternative_name , 0o600 )
673
- return alternative_name
666
+ with open (alternative_name , "wb" ) as alternative_file :
667
+ alternative_file .write (value .encode ("utf-8" ))
668
+ if restrict_permissions :
669
+ chmod (alternative_name , 0o600 )
670
+ return alternative_name
674
671
675
672
def guess_key_type (self , key_data : str ) -> paramiko .PKey :
676
673
"""Guess the type of the private key.
@@ -695,7 +692,7 @@ def guess_key_type(self, key_data: str) -> paramiko.PKey:
695
692
):
696
693
try :
697
694
key = key_class .from_private_key (io .StringIO (key_data )) # type: ignore[attr-defined]
698
- except paramiko .SSHException :
695
+ except paramiko .SSHException : # noqa: PERF203
699
696
continue
700
697
else :
701
698
return key
@@ -715,7 +712,7 @@ def catch_signal(self, signum, frame) -> None:
715
712
signum: The signal number
716
713
frame: The current stack frame
717
714
"""
718
- exit (1 ) # Calling this to be sure atexit is called, so clean_up gets called
715
+ sys . exit (1 ) # Calling this to be sure atexit is called, so clean_up gets called
719
716
720
717
def _get_column_type ( # type: ignore[override]
721
718
self ,
0 commit comments