Skip to content

chore: Harmonize imports of sqlalchemy module, use sa where applicable #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,12 @@ extend-ignore = [
extend-exclude = [
]

[tool.ruff.lint.flake8-import-conventions]
banned-from = ["typing"]

[tool.ruff.lint.flake8-import-conventions.extend-aliases]
typing = "t"

[tool.ruff.per-file-ignores]
"*/tests/*" = [
"S101", # Allow use of `assert`, and `print`.
Expand Down
79 changes: 32 additions & 47 deletions target_cratedb/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,10 @@
from builtins import issubclass
from datetime import datetime

import sqlalchemy
import sqlalchemy as sa
from crate.client.sqlalchemy.types import ObjectType, ObjectTypeImpl, _ObjectArray
Comment on lines -8 to 9
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That we need to import both ObjectType and ObjectTypeImpl here, and use them correspondingly, clearly indicates something is not optimal, and should be addressed on behalf of improvements to the type definitions of the CrateDB SQLAlchemy dialect in crate-python.

from singer_sdk import typing as th
from singer_sdk.helpers._typing import is_array_type, is_boolean_type, is_integer_type, is_number_type, is_object_type
from sqlalchemy.types import (
ARRAY,
BIGINT,
BOOLEAN,
DATE,
DATETIME,
DECIMAL,
FLOAT,
INTEGER,
TEXT,
TIME,
TIMESTAMP,
VARCHAR,
)
from target_postgres.connector import NOTYPE, PostgresConnector

from target_cratedb.sqlalchemy.patch import polyfill_refresh_after_dml_engine
Expand All @@ -39,7 +24,7 @@ class CrateDBConnector(PostgresConnector):
allow_merge_upsert: bool = False # Whether MERGE UPSERT is supported.
allow_temp_tables: bool = False # Whether temp tables are supported.

def create_engine(self) -> sqlalchemy.Engine:
def create_engine(self) -> sa.Engine:
"""
Create an SQLAlchemy engine object.

Expand All @@ -50,7 +35,7 @@ def create_engine(self) -> sqlalchemy.Engine:
return engine

@staticmethod
def to_sql_type(jsonschema_type: dict) -> sqlalchemy.types.TypeEngine:
def to_sql_type(jsonschema_type: dict) -> sa.types.TypeEngine:
"""Return a JSON Schema representation of the provided type.

Note: Needs to be patched to invoke other static methods on `CrateDBConnector`.
Expand Down Expand Up @@ -112,7 +97,7 @@ def pick_individual_type(jsonschema_type: dict):
if "null" in jsonschema_type["type"]:
return None
if "integer" in jsonschema_type["type"]:
return BIGINT()
return sa.BIGINT()
if "object" in jsonschema_type["type"]:
return ObjectType
if "array" in jsonschema_type["type"]:
Expand Down Expand Up @@ -157,16 +142,16 @@ def pick_individual_type(jsonschema_type: dict):
# Discover/translate inner types.
inner_type = resolve_array_inner_type(jsonschema_type)
if inner_type is not None:
return ARRAY(inner_type)
return sa.ARRAY(inner_type)

# When type discovery fails, assume `TEXT`.
return ARRAY(TEXT())
return sa.ARRAY(sa.TEXT())

if jsonschema_type.get("format") == "date-time":
return TIMESTAMP()
return sa.TIMESTAMP()
individual_type = th.to_sql_type(jsonschema_type)
if isinstance(individual_type, VARCHAR):
return TEXT()
if isinstance(individual_type, sa.VARCHAR):
return sa.TEXT()
return individual_type

@staticmethod
Expand All @@ -182,18 +167,18 @@ def pick_best_sql_type(sql_type_array: list):
An instance of the best SQL type class based on defined precedence order.
"""
precedence_order = [
TEXT,
TIMESTAMP,
DATETIME,
DATE,
TIME,
DECIMAL,
FLOAT,
BIGINT,
INTEGER,
BOOLEAN,
sa.TEXT,
sa.TIMESTAMP,
sa.DATETIME,
sa.DATE,
sa.TIME,
sa.DECIMAL,
sa.FLOAT,
sa.BIGINT,
sa.INTEGER,
sa.BOOLEAN,
NOTYPE,
ARRAY,
sa.ARRAY,
FloatVector,
ObjectTypeImpl,
]
Expand All @@ -202,12 +187,12 @@ def pick_best_sql_type(sql_type_array: list):
for obj in sql_type_array:
if isinstance(obj, sql_type):
return obj
return TEXT()
return sa.TEXT()

def _sort_types(
self,
sql_types: t.Iterable[sqlalchemy.types.TypeEngine],
) -> list[sqlalchemy.types.TypeEngine]:
sql_types: t.Iterable[sa.types.TypeEngine],
) -> list[sa.types.TypeEngine]:
"""Return the input types sorted from most to least compatible.

Note: Needs to be patched to supply handlers for `_ObjectArray` and `NOTYPE`.
Expand All @@ -227,7 +212,7 @@ def _sort_types(
"""

def _get_type_sort_key(
sql_type: sqlalchemy.types.TypeEngine,
sql_type: sa.types.TypeEngine,
) -> tuple[int, int]:
# return rank, with higher numbers ranking first

Expand Down Expand Up @@ -257,10 +242,10 @@ def _get_type_sort_key(
def copy_table_structure(
self,
full_table_name: str,
from_table: sqlalchemy.Table,
connection: sqlalchemy.engine.Connection,
from_table: sa.Table,
connection: sa.engine.Connection,
as_temp_table: bool = False,
) -> sqlalchemy.Table:
) -> sa.Table:
"""Copy table structure.

Note: Needs to be patched to prevent `Primary key columns cannot be nullable` errors.
Expand All @@ -275,17 +260,17 @@ def copy_table_structure(
The new table object.
"""
_, schema_name, table_name = self.parse_full_table_name(full_table_name)
meta = sqlalchemy.MetaData(schema=schema_name)
meta = sa.MetaData(schema=schema_name)
columns = []
if self.table_exists(full_table_name=full_table_name):
raise RuntimeError("Table already exists")
column: sqlalchemy.Column
column: sa.Column
for column in from_table.columns:
# CrateDB: Prevent `Primary key columns cannot be nullable` errors.
if column.primary_key and column.nullable:
column.nullable = False
columns.append(column._copy())
new_table = sqlalchemy.Table(table_name, meta, *columns)
new_table = sa.Table(table_name, meta, *columns)
new_table.create(bind=connection)
return new_table

Expand All @@ -299,11 +284,11 @@ def prepare_schema(self, schema_name: str) -> None:
def resolve_array_inner_type(jsonschema_type: dict) -> t.Union[sa.types.TypeEngine, None]:
if "items" in jsonschema_type:
if is_boolean_type(jsonschema_type["items"]):
return BOOLEAN()
return sa.BOOLEAN()
if is_number_type(jsonschema_type["items"]):
return FLOAT()
return sa.FLOAT()
if is_integer_type(jsonschema_type["items"]):
return BIGINT()
return sa.BIGINT()
if is_object_type(jsonschema_type["items"]):
return ObjectType()
if is_array_type(jsonschema_type["items"]):
Expand Down
55 changes: 27 additions & 28 deletions target_cratedb/sinks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
import time
from typing import List, Optional, Union

import sqlalchemy
import sqlalchemy as sa
from pendulum import now
from sqlalchemy import Column, Executable, MetaData, Table, bindparam, insert, select, update
from target_postgres.sinks import PostgresSink

from target_cratedb.connector import CrateDBConnector
Expand Down Expand Up @@ -116,7 +115,7 @@ def process_batch(self, context: dict) -> None:
# Use one connection so we do this all in a single transaction
with self.connector._connect() as connection, connection.begin():
# Check structure of table
table: sqlalchemy.Table = self.connector.prepare_table(
table: sa.Table = self.connector.prepare_table(
full_table_name=self.full_table_name,
schema=self.schema,
primary_keys=self.key_properties,
Expand All @@ -134,7 +133,7 @@ def process_batch(self, context: dict) -> None:
# FIXME: Upserts do not work yet.
"""
# Create a temp table (Creates from the table above)
temp_table: sqlalchemy.Table = self.connector.copy_table_structure(
temp_table: sa.Table = self.connector.copy_table_structure(
full_table_name=self.temp_table_name,
from_table=table,
as_temp_table=True,
Expand Down Expand Up @@ -162,11 +161,11 @@ def process_batch(self, context: dict) -> None:

def upsertX(
self,
from_table: sqlalchemy.Table,
to_table: sqlalchemy.Table,
from_table: sa.Table,
to_table: sa.Table,
schema: dict,
join_keys: List[Column],
connection: sqlalchemy.engine.Connection,
join_keys: List[sa.Column],
connection: sa.engine.Connection,
) -> Optional[int]:
"""Merge upsert data from one table to another.

Expand All @@ -185,45 +184,45 @@ def upsertX(

if self.append_only is True:
# Insert
select_stmt = select(from_table.columns).select_from(from_table)
select_stmt = sa.select(from_table.columns).select_from(from_table)
insert_stmt = to_table.insert().from_select(names=list(from_table.columns), select=select_stmt)
connection.execute(insert_stmt)
else:
join_predicates = []
for key in join_keys:
from_table_key: sqlalchemy.Column = from_table.columns[key] # type: ignore[call-overload]
to_table_key: sqlalchemy.Column = to_table.columns[key] # type: ignore[call-overload]
from_table_key: sa.Column = from_table.columns[key] # type: ignore[call-overload]
to_table_key: sa.Column = to_table.columns[key] # type: ignore[call-overload]
join_predicates.append(from_table_key == to_table_key) # type: ignore[call-overload]

join_condition = sqlalchemy.and_(*join_predicates)
join_condition = sa.and_(*join_predicates)

where_predicates = []
for key in join_keys:
to_table_key: sqlalchemy.Column = to_table.columns[key] # type: ignore[call-overload,no-redef]
to_table_key: sa.Column = to_table.columns[key] # type: ignore[call-overload,no-redef]
where_predicates.append(to_table_key.is_(None))
where_condition = sqlalchemy.and_(*where_predicates)
where_condition = sa.and_(*where_predicates)

select_stmt = (
select(from_table.columns)
sa.select(from_table.columns)
.select_from(from_table.outerjoin(to_table, join_condition))
.where(where_condition)
)
insert_stmt = insert(to_table).from_select(names=list(from_table.columns), select=select_stmt)
insert_stmt = sa.insert(to_table).from_select(names=list(from_table.columns), select=select_stmt)

connection.execute(insert_stmt)

# Update
where_condition = join_condition
update_columns = {}
for column_name in self.schema["properties"].keys():
from_table_column: sqlalchemy.Column = from_table.columns[column_name]
to_table_column: sqlalchemy.Column = to_table.columns[column_name]
from_table_column: sa.Column = from_table.columns[column_name]
to_table_column: sa.Column = to_table.columns[column_name]
# Prevent: `Updating a primary key is not supported`
if to_table_column.primary_key:
continue
update_columns[to_table_column] = from_table_column

update_stmt = update(to_table).where(where_condition).values(update_columns)
update_stmt = sa.update(to_table).where(where_condition).values(update_columns)
connection.execute(update_stmt)

return None
Expand Down Expand Up @@ -264,7 +263,7 @@ def activate_version(self, new_version: int) -> None:
self.logger.info("Hard delete: %s", self.config.get("hard_delete"))
if self.config["hard_delete"] is True:
connection.execute(
sqlalchemy.text(
sa.text(
f'DELETE FROM "{self.schema_name}"."{self.table_name}" ' # noqa: S608
f'WHERE "{self.version_column_name}" <= {new_version} '
f'OR "{self.version_column_name}" IS NULL'
Expand All @@ -284,24 +283,24 @@ def activate_version(self, new_version: int) -> None:
connection=connection,
)
# Need to deal with the case where data doesn't exist for the version column
query = sqlalchemy.text(
query = sa.text(
f'UPDATE "{self.schema_name}"."{self.table_name}"\n'
f'SET "{self.soft_delete_column_name}" = :deletedate \n'
f'WHERE "{self.version_column_name}" < :version '
f'OR "{self.version_column_name}" IS NULL \n'
f' AND "{self.soft_delete_column_name}" IS NULL\n'
)
query = query.bindparams(
bindparam("deletedate", value=deleted_at, type_=datetime_type),
bindparam("version", value=new_version, type_=integer_type),
sa.bindparam("deletedate", value=deleted_at, type_=datetime_type),
sa.bindparam("version", value=new_version, type_=integer_type),
)
connection.execute(query)

def generate_insert_statement(
self,
full_table_name: str,
columns: List[Column],
) -> Union[str, Executable]:
columns: List[sa.Column],
) -> Union[str, sa.sql.Executable]:
"""Generate an insert statement for the given records.

Args:
Expand All @@ -312,6 +311,6 @@ def generate_insert_statement(
An insert statement.
"""
# FIXME:
metadata = MetaData(schema=self.schema_name)
table = Table(full_table_name, metadata, *columns)
return insert(table)
metadata = sa.MetaData(schema=self.schema_name)
table = sa.Table(full_table_name, metadata, *columns)
return sa.insert(table)
Loading