Skip to content

Commit a2ff536

Browse files
chore: Apply non-functional refactoring and fix typos
1 parent 4483fa8 commit a2ff536

File tree

5 files changed

+120
-134
lines changed

5 files changed

+120
-134
lines changed

pyproject.toml

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,22 @@ target-version = "py38"
8080

8181
[tool.ruff.lint]
8282
select = [
83-
"F", # Pyflakes
84-
"W", # pycodestyle warnings
85-
"E", # pycodestyle errors
86-
"I", # isort
87-
"N", # pep8-naming
88-
"D", # pydocsyle
89-
"ICN", # flake8-import-conventions
90-
"RUF", # ruff
83+
"F", # Pyflakes
84+
"W", # pycodestyle warnings
85+
"E", # pycodestyle errors
86+
"I", # isort
87+
"N", # pep8-naming
88+
"D", # pydocsyle
89+
"UP", # pyupgrade
90+
"ICN", # flake8-import-conventions
91+
"RET", # flake8-return
92+
"SIM", # flake8-simplify
93+
"TCH", # flake8-type-checking
94+
"ERA", # eradicate
95+
"PGH", # pygrep-hooks
96+
"PL", # Pylint
97+
"PERF", # Perflint
98+
"RUF", # ruff
9199
]
92100

93101
[tool.ruff.lint.flake8-import-conventions]

target_postgres/connector.py

Lines changed: 48 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
"""Handles Postgres interactions."""
22

3+
34
from __future__ import annotations
45

56
import atexit
67
import io
8+
import itertools
79
import signal
10+
import sys
811
import typing as t
912
from contextlib import contextmanager
1013
from os import chmod, path
@@ -79,7 +82,7 @@ def __init__(self, config: dict) -> None:
7982
sqlalchemy_url=url.render_as_string(hide_password=False),
8083
)
8184

82-
def prepare_table( # type: ignore[override]
85+
def prepare_table( # type: ignore[override] # noqa: PLR0913
8386
self,
8487
full_table_name: str,
8588
schema: dict,
@@ -105,7 +108,7 @@ def prepare_table( # type: ignore[override]
105108
meta = sa.MetaData(schema=schema_name)
106109
table: sa.Table
107110
if not self.table_exists(full_table_name=full_table_name):
108-
table = self.create_empty_table(
111+
return self.create_empty_table(
109112
table_name=table_name,
110113
meta=meta,
111114
schema=schema,
@@ -114,7 +117,6 @@ def prepare_table( # type: ignore[override]
114117
as_temp_table=as_temp_table,
115118
connection=connection,
116119
)
117-
return table
118120
meta.reflect(connection, only=[table_name])
119121
table = meta.tables[
120122
full_table_name
@@ -161,19 +163,19 @@ def copy_table_structure(
161163
_, schema_name, table_name = self.parse_full_table_name(full_table_name)
162164
meta = sa.MetaData(schema=schema_name)
163165
new_table: sa.Table
164-
columns = []
165166
if self.table_exists(full_table_name=full_table_name):
166167
raise RuntimeError("Table already exists")
167-
for column in from_table.columns:
168-
columns.append(column._copy())
168+
169+
columns = [column._copy() for column in from_table.columns]
170+
169171
if as_temp_table:
170172
new_table = sa.Table(table_name, meta, *columns, prefixes=["TEMPORARY"])
171173
new_table.create(bind=connection)
172174
return new_table
173-
else:
174-
new_table = sa.Table(table_name, meta, *columns)
175-
new_table.create(bind=connection)
176-
return new_table
175+
176+
new_table = sa.Table(table_name, meta, *columns)
177+
new_table.create(bind=connection)
178+
return new_table
177179

178180
@contextmanager
179181
def _connect(self) -> t.Iterator[sa.engine.Connection]:
@@ -184,18 +186,17 @@ def drop_table(self, table: sa.Table, connection: sa.engine.Connection):
184186
"""Drop table data."""
185187
table.drop(bind=connection)
186188

187-
def clone_table(
189+
def clone_table( # noqa: PLR0913
188190
self, new_table_name, table, metadata, connection, temp_table
189191
) -> sa.Table:
190192
"""Clone a table."""
191-
new_columns = []
192-
for column in table.columns:
193-
new_columns.append(
194-
sa.Column(
195-
column.name,
196-
column.type,
197-
)
193+
new_columns = [
194+
sa.Column(
195+
column.name,
196+
column.type,
198197
)
198+
for column in table.columns
199+
]
199200
if temp_table is True:
200201
new_table = sa.Table(
201202
new_table_name, metadata, *new_columns, prefixes=["TEMPORARY"]
@@ -275,9 +276,8 @@ def pick_individual_type(jsonschema_type: dict):
275276
if jsonschema_type.get("format") == "date-time":
276277
return TIMESTAMP()
277278
individual_type = th.to_sql_type(jsonschema_type)
278-
if isinstance(individual_type, VARCHAR):
279-
return TEXT()
280-
return individual_type
279+
280+
return TEXT() if isinstance(individual_type, VARCHAR) else individual_type
281281

282282
@staticmethod
283283
def pick_best_sql_type(sql_type_array: list):
@@ -304,13 +304,12 @@ def pick_best_sql_type(sql_type_array: list):
304304
NOTYPE,
305305
]
306306

307-
for sql_type in precedence_order:
308-
for obj in sql_type_array:
309-
if isinstance(obj, sql_type):
310-
return obj
307+
for sql_type, obj in itertools.product(precedence_order, sql_type_array):
308+
if isinstance(obj, sql_type):
309+
return obj
311310
return TEXT()
312311

313-
def create_empty_table( # type: ignore[override]
312+
def create_empty_table( # type: ignore[override] # noqa: PLR0913
314313
self,
315314
table_name: str,
316315
meta: sa.MetaData,
@@ -324,7 +323,7 @@ def create_empty_table( # type: ignore[override]
324323
325324
Args:
326325
table_name: the target table name.
327-
meta: the SQLAchemy metadata object.
326+
meta: the SQLAlchemy metadata object.
328327
schema: the JSON schema for the new table.
329328
connection: the database connection.
330329
primary_keys: list of key properties.
@@ -367,7 +366,7 @@ def create_empty_table( # type: ignore[override]
367366
new_table.create(bind=connection)
368367
return new_table
369368

370-
def prepare_column(
369+
def prepare_column( # noqa: PLR0913
371370
self,
372371
full_table_name: str,
373372
column_name: str,
@@ -415,7 +414,7 @@ def prepare_column(
415414
column_object=column_object,
416415
)
417416

418-
def _create_empty_column( # type: ignore[override]
417+
def _create_empty_column( # type: ignore[override] # noqa: PLR0913
419418
self,
420419
schema_name: str,
421420
table_name: str,
@@ -480,7 +479,7 @@ def get_column_add_ddl( # type: ignore[override]
480479
},
481480
)
482481

483-
def _adapt_column_type( # type: ignore[override]
482+
def _adapt_column_type( # type: ignore[override] # noqa: PLR0913
484483
self,
485484
schema_name: str,
486485
table_name: str,
@@ -523,7 +522,7 @@ def _adapt_column_type( # type: ignore[override]
523522
return
524523

525524
# Not the same type, generic type or compatible types
526-
# calling merge_sql_types for assistnace
525+
# calling merge_sql_types for assistance
527526
compatible_sql_type = self.merge_sql_types([current_type, sql_type])
528527

529528
if str(compatible_sql_type) == str(current_type):
@@ -593,17 +592,16 @@ def get_sqlalchemy_url(self, config: dict) -> str:
593592
if config.get("sqlalchemy_url"):
594593
return cast(str, config["sqlalchemy_url"])
595594

596-
else:
597-
sqlalchemy_url = URL.create(
598-
drivername=config["dialect+driver"],
599-
username=config["user"],
600-
password=config["password"],
601-
host=config["host"],
602-
port=config["port"],
603-
database=config["database"],
604-
query=self.get_sqlalchemy_query(config),
605-
)
606-
return cast(str, sqlalchemy_url)
595+
sqlalchemy_url = URL.create(
596+
drivername=config["dialect+driver"],
597+
username=config["user"],
598+
password=config["password"],
599+
host=config["host"],
600+
port=config["port"],
601+
database=config["database"],
602+
query=self.get_sqlalchemy_query(config),
603+
)
604+
return cast(str, sqlalchemy_url)
607605

608606
def get_sqlalchemy_query(self, config: dict) -> dict:
609607
"""Get query values to be used for sqlalchemy URL creation.
@@ -619,7 +617,7 @@ def get_sqlalchemy_query(self, config: dict) -> dict:
619617
# ssl_enable is for verifying the server's identity to the client.
620618
if config["ssl_enable"]:
621619
ssl_mode = config["ssl_mode"]
622-
query.update({"sslmode": ssl_mode})
620+
query["sslmode"] = ssl_mode
623621
query["sslrootcert"] = self.filepath_or_certificate(
624622
value=config["ssl_certificate_authority"],
625623
alternative_name=config["ssl_storage_directory"] + "/root.crt",
@@ -665,12 +663,11 @@ def filepath_or_certificate(
665663
"""
666664
if path.isfile(value):
667665
return value
668-
else:
669-
with open(alternative_name, "wb") as alternative_file:
670-
alternative_file.write(value.encode("utf-8"))
671-
if restrict_permissions:
672-
chmod(alternative_name, 0o600)
673-
return alternative_name
666+
with open(alternative_name, "wb") as alternative_file:
667+
alternative_file.write(value.encode("utf-8"))
668+
if restrict_permissions:
669+
chmod(alternative_name, 0o600)
670+
return alternative_name
674671

675672
def guess_key_type(self, key_data: str) -> paramiko.PKey:
676673
"""Guess the type of the private key.
@@ -695,7 +692,7 @@ def guess_key_type(self, key_data: str) -> paramiko.PKey:
695692
):
696693
try:
697694
key = key_class.from_private_key(io.StringIO(key_data)) # type: ignore[attr-defined]
698-
except paramiko.SSHException:
695+
except paramiko.SSHException: # noqa: PERF203
699696
continue
700697
else:
701698
return key
@@ -715,7 +712,7 @@ def catch_signal(self, signum, frame) -> None:
715712
signum: The signal number
716713
frame: The current stack frame
717714
"""
718-
exit(1) # Calling this to be sure atexit is called, so clean_up gets called
715+
sys.exit(1) # Calling this to be sure atexit is called, so clean_up gets called
719716

720717
def _get_column_type( # type: ignore[override]
721718
self,

target_postgres/sinks.py

Lines changed: 21 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,7 @@ def setup(self) -> None:
4747
This method is called on Sink creation, and creates the required Schema and
4848
Table entities in the target database.
4949
"""
50-
if self.key_properties is None or self.key_properties == []:
51-
self.append_only = True
52-
else:
53-
self.append_only = False
50+
self.append_only = self.key_properties is None or self.key_properties == []
5451
if self.schema_name:
5552
self.connector.prepare_schema(self.schema_name)
5653
with self.connector._connect() as connection, connection.begin():
@@ -109,14 +106,14 @@ def process_batch(self, context: dict) -> None:
109106

110107
def generate_temp_table_name(self):
111108
"""Uuid temp table name."""
112-
# sa.exc.IdentifierError: Identifier
109+
# sa.exc.IdentifierError: Identifier # noqa: ERA001
113110
# 'temp_test_optional_attributes_388470e9_fbd0_47b7_a52f_d32a2ee3f5f6'
114111
# exceeds maximum length of 63 characters
115112
# Is hit if we have a long table name, there is no limit on Temporary tables
116113
# in postgres, used a guid just in case we are using the same session
117114
return f"{str(uuid.uuid4()).replace('-', '_')}"
118115

119-
def bulk_insert_records( # type: ignore[override]
116+
def bulk_insert_records( # type: ignore[override] # noqa: PLR0913
120117
self,
121118
table: sa.Table,
122119
schema: dict,
@@ -156,24 +153,24 @@ def bulk_insert_records( # type: ignore[override]
156153
if self.append_only is False:
157154
insert_records: Dict[str, Dict] = {} # pk : record
158155
for record in records:
159-
insert_record = {}
160-
for column in columns:
161-
insert_record[column.name] = record.get(column.name)
156+
insert_record = {
157+
column.name: record.get(column.name) for column in columns
158+
}
162159
# No need to check for a KeyError here because the SDK already
163-
# guaruntees that all key properties exist in the record.
160+
# guarantees that all key properties exist in the record.
164161
primary_key_value = "".join([str(record[key]) for key in primary_keys])
165162
insert_records[primary_key_value] = insert_record
166163
data_to_insert = list(insert_records.values())
167164
else:
168165
for record in records:
169-
insert_record = {}
170-
for column in columns:
171-
insert_record[column.name] = record.get(column.name)
166+
insert_record = {
167+
column.name: record.get(column.name) for column in columns
168+
}
172169
data_to_insert.append(insert_record)
173170
connection.execute(insert, data_to_insert)
174171
return True
175172

176-
def upsert(
173+
def upsert( # noqa: PLR0913
177174
self,
178175
from_table: sa.Table,
179176
to_table: sa.Table,
@@ -232,7 +229,7 @@ def upsert(
232229
# Update
233230
where_condition = join_condition
234231
update_columns = {}
235-
for column_name in self.schema["properties"].keys():
232+
for column_name in self.schema["properties"]:
236233
from_table_column: sa.Column = from_table.columns[column_name]
237234
to_table_column: sa.Column = to_table.columns[column_name]
238235
update_columns[to_table_column] = from_table_column
@@ -249,14 +246,13 @@ def column_representation(
249246
schema: dict,
250247
) -> List[sa.Column]:
251248
"""Return a sqlalchemy table representation for the current schema."""
252-
columns: list[sa.Column] = []
253-
for property_name, property_jsonschema in schema["properties"].items():
254-
columns.append(
255-
sa.Column(
256-
property_name,
257-
self.connector.to_sql_type(property_jsonschema),
258-
)
249+
columns: list[sa.Column] = [
250+
sa.Column(
251+
property_name,
252+
self.connector.to_sql_type(property_jsonschema),
259253
)
254+
for property_name, property_jsonschema in schema["properties"].items()
255+
]
260256
return columns
261257

262258
def generate_insert_statement(
@@ -286,12 +282,12 @@ def schema_name(self) -> Optional[str]:
286282
"""Return the schema name or `None` if using names with no schema part.
287283
288284
Note that after the next SDK release (after 0.14.0) we can remove this
289-
as it's already upstreamed.
285+
as it's already up-streamed.
290286
291287
Returns:
292288
The target schema name.
293289
"""
294-
# Look for a default_target_scheme in the configuraion fle
290+
# Look for a default_target_scheme in the configuration file
295291
default_target_schema: str = self.config.get("default_target_schema", None)
296292
parts = self.stream_name.split("-")
297293

@@ -302,14 +298,7 @@ def schema_name(self) -> Optional[str]:
302298
if default_target_schema:
303299
return default_target_schema
304300

305-
if len(parts) in {2, 3}:
306-
# Stream name is a two-part or three-part identifier.
307-
# Use the second-to-last part as the schema name.
308-
stream_schema = self.conform_name(parts[-2], "schema")
309-
return stream_schema
310-
311-
# Schema name not detected.
312-
return None
301+
return self.conform_name(parts[-2], "schema") if len(parts) in {2, 3} else None
313302

314303
def activate_version(self, new_version: int) -> None:
315304
"""Bump the active version of the target table.

0 commit comments

Comments
 (0)