Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/1.4.latest' into 1.5.latest
Browse files Browse the repository at this point in the history
  • Loading branch information
mwallace582 committed Jan 16, 2024
2 parents 49c9e85 + a74fe82 commit dcf64e4
Show file tree
Hide file tree
Showing 19 changed files with 166 additions and 152 deletions.
6 changes: 6 additions & 0 deletions .git-blame-ignore-revs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Ran `black` on all files
99a125c82b846fd25aed432ed67a9ad982bbe0ad
# Ran `black` on all files
f0507baaa4279575831ad06452757666d25bd90b
# Ran `black` on all files
870bbbf3c55081c436da7acb67a28e05c7484a44
1 change: 1 addition & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ jobs:
python -m pip install -r dev-requirements.txt
python -m pip --version
pre-commit --version
mypy --version
dbt --version
- name: Run pre-commit hooks
Expand Down
45 changes: 43 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,22 +1,63 @@
# For more on configuring pre-commit hooks (see https://pre-commit.com/)

# Force all unspecified python hooks to run python 3.8
default_language_version:
python: python3

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
rev: v4.4.0
hooks:
- id: check-yaml
args: [--unsafe]
- id: check-json
- id: end-of-file-fixer
- id: trailing-whitespace
- id: check-case-conflict
- repo: https://github.com/psf/black
rev: 23.1.0
hooks:
- id: black
additional_dependencies: ['click~=8.1']
args:
- "--line-length=99"
- "--target-version=py38"
- id: black
alias: black-check
stages: [manual]
additional_dependencies: ['click~=8.1']
args:
- "--line-length=99"
- "--target-version=py38"
- "--check"
- "--diff"
- repo: https://github.com/pycqa/flake8
rev: 4.0.1
rev: 6.0.0
hooks:
- id: flake8
- id: flake8
alias: flake8-check
stages: [manual]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.1.1
hooks:
- id: mypy
# N.B.: Mypy is... a bit fragile.
#
# By using `language: system` we run this hook in the local
# environment instead of a pre-commit isolated one. This is needed
# to ensure mypy correctly parses the project.

# It may cause trouble in that it adds environmental variables out
# of our control to the mix. Unfortunately, there's nothing we can
# do about per pre-commit's author.
# See https://github.com/pre-commit/pre-commit/issues/730 for details.
args: [--show-error-codes, --ignore-missing-imports, --explicit-package-bases]
files: ^dbt/adapters/.*
language: system
- id: mypy
alias: mypy-check
stages: [manual]
args: [--show-error-codes, --pretty, --ignore-missing-imports, --explicit-package-bases]
files: ^dbt/adapters
language: system
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- Support dbt v1.4 ([#146](https://github.com/dbeatty10/dbt-mysql/pull/146))
- Support dbt v1.5 ([#145](https://github.com/dbeatty10/dbt-mysql/issues/145))
- Support connecting via UNIX sockets ([#164](https://github.com/dbeatty10/dbt-mysql/issues/164))
- Support Black & MyPy pre-commit hooks ([#138](https://github.com/dbeatty10/dbt-mysql/issues/138))

### Fixes
- Fix incremental composite keys ([#144](https://github.com/dbeatty10/dbt-mysql/issues/144))
Expand All @@ -13,7 +14,7 @@
- [@lpezet](https://github.com/lpezet) ([#146](https://github.com/dbeatty10/dbt-mysql/pull/146))
- [@moszutij](https://github.com/moszutij) ([#146](https://github.com/dbeatty10/dbt-mysql/pull/146), [#144](https://github.com/dbeatty10/dbt-mysql/issues/144))
- [@wesen](https://github.com/wesen) ([#146](https://github.com/dbeatty10/dbt-mysql/pull/146))
- [@mwallace582](https://github.com/mwallace582) ([#162](https://github.com/dbeatty10/dbt-mysql/pull/162), [#163](https://github.com/dbeatty10/dbt-mysql/pull/163), [#164](https://github.com/dbeatty10/dbt-mysql/issues/164))
- [@mwallace582](https://github.com/mwallace582) ([#162](https://github.com/dbeatty10/dbt-mysql/pull/162), [#163](https://github.com/dbeatty10/dbt-mysql/pull/163), [#164](https://github.com/dbeatty10/dbt-mysql/issues/164), [#138](https://github.com/dbeatty10/dbt-mysql/issues/138))


## dbt-mysql 1.1.0 (Feb 5, 2023)
Expand Down
19 changes: 16 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,26 @@ dev-uninstall: ## Uninstalls all packages while maintaining the virtual environm
pip freeze | grep -v "^-e" | cut -d "@" -f1 | xargs pip uninstall -y
pip uninstall -y dbt-mysql

.PHONY: mypy
mypy: ## Runs mypy against staged changes for static type checking.
@\
pre-commit run --hook-stage manual mypy-check | grep -v "INFO"

.PHONY: flake8
flake8: ## Runs flake8 against staged changes to enforce style guide.
@\
pre-commit run --hook-stage manual flake8-check | grep -v "INFO"

.PHONY: black
black: ## Runs black against staged changes to enforce style guide.
@\
pre-commit run --hook-stage manual black-check -v | grep -v "INFO"

.PHONY: lint
lint: ## Runs flake8 code checks against staged changes.
lint: ## Runs flake8 and mypy code checks against staged changes.
@\
pre-commit run flake8-check --hook-stage manual | grep -v "INFO";
pre-commit run flake8-check --hook-stage manual | grep -v "INFO"; \
pre-commit run mypy-check --hook-stage manual | grep -v "INFO"

.PHONY: linecheck
linecheck: ## Checks for all Python lines 100 characters or more
Expand All @@ -35,7 +46,9 @@ unit: ## Runs unit tests with py38.
test: ## Runs unit tests with py38 and code checks against staged changes.
@\
tox -p -e py38; \
pre-commit run flake8-check --hook-stage manual | grep -v "INFO";
pre-commit run black-check --hook-stage manual | grep -v "INFO"; \
pre-commit run flake8-check --hook-stage manual | grep -v "INFO"; \
pre-commit run mypy-check --hook-stage manual | grep -v "INFO"

.PHONY: integration
integration: ## Runs mysql integration tests with py38.
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/mariadb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


Plugin = AdapterPlugin(
adapter=MariaDBAdapter,
adapter=MariaDBAdapter, # type: ignore[arg-type]
credentials=MariaDBCredentials,
include_path=mariadb.PACKAGE_PATH,
)
15 changes: 5 additions & 10 deletions dbt/adapters/mariadb/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

@dataclass(init=False)
class MariaDBCredentials(Credentials):
server: Optional[str] = None
server: str = ""
unix_socket: Optional[str] = None
port: Optional[int] = None
database: Optional[str] = None
schema: str
database: str = ""
schema: str = ""
username: Optional[str] = None
password: Optional[str] = None
charset: Optional[str] = None
Expand Down Expand Up @@ -102,7 +102,6 @@ def open(cls, connection):
connection.handle = mysql.connector.connect(**kwargs)
connection.state = "open"
except mysql.connector.Error:

try:
logger.debug(
"Failed connection without supplying the `database`. "
Expand All @@ -115,10 +114,8 @@ def open(cls, connection):
connection.handle = mysql.connector.connect(**kwargs)
connection.state = "open"
except mysql.connector.Error as e:

logger.debug(
"Got an error when attempting to open a MariaDB "
"connection: '{}'".format(e)
"Got an error when attempting to open a MariaDB " "connection: '{}'".format(e)
)

connection.handle = None
Expand Down Expand Up @@ -175,9 +172,7 @@ def get_response(cls, cursor) -> AdapterResponse:
# the mysql-connector-python driver.
# So just return the default value.
return AdapterResponse(
_message="{} {}".format(code, num_rows),
rows_affected=num_rows,
code=code
_message="{} {}".format(code, num_rows), rows_affected=num_rows, code=code
)

@classmethod
Expand Down
58 changes: 23 additions & 35 deletions dbt/adapters/mariadb/impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from concurrent.futures import Future
from dataclasses import asdict
from typing import Optional, List, Dict, Any, Iterable
from typing import Optional, List, Dict, Any, Iterable, Tuple
import agate

import dbt
Expand All @@ -14,6 +14,7 @@
from dbt.adapters.base import BaseRelation
from dbt.contracts.graph.nodes import ConstraintType
from dbt.adapters.base.impl import ConstraintSupport
from dbt.contracts.graph.manifest import Manifest
from dbt.clients.agate_helper import DEFAULT_TYPE_TESTER
from dbt.events import AdapterLogger
from dbt.utils import executor
Expand Down Expand Up @@ -47,21 +48,19 @@ def date_function(cls):
return "current_date()"

@classmethod
def convert_datetime_type(cls, agate_table: agate.Table,
col_idx: int) -> str:
def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
return "timestamp"

@classmethod
def quote(cls, identifier: str) -> str:
return "`{}`".format(identifier)

def list_relations_without_caching(
self, schema_relation: MariaDBRelation
def list_relations_without_caching( # type: ignore[override]
self, schema_relation: MariaDBRelation # type: ignore[override]
) -> List[MariaDBRelation]:
kwargs = {"schema_relation": schema_relation}
try:
results = self.execute_macro(LIST_RELATIONS_MACRO_NAME,
kwargs=kwargs)
results = self.execute_macro(LIST_RELATIONS_MACRO_NAME, kwargs=kwargs)
except dbt.exceptions.DbtRuntimeError as e:
errmsg = getattr(e, "msg", "")
if f"MariaDB database '{schema_relation}' not found" in errmsg:
Expand All @@ -80,21 +79,16 @@ def list_relations_without_caching(
f"got {len(row)} values, expected 4"
)
_, name, _schema, relation_type = row
relation = self.Relation.create(
schema=_schema, identifier=name, type=relation_type
)
relation = self.Relation.create(schema=_schema, identifier=name, type=relation_type)
relations.append(relation)

return relations

def get_columns_in_relation(self,
relation: Relation) -> List[MariaDBColumn]:
def get_columns_in_relation(self, relation: MariaDBRelation) -> List[MariaDBColumn]:
rows: List[agate.Row] = super().get_columns_in_relation(relation)
return self.parse_show_columns(relation, rows)

def _get_columns_for_catalog(
self, relation: MariaDBRelation
) -> Iterable[Dict[str, Any]]:
def _get_columns_for_catalog(self, relation: MariaDBRelation) -> Iterable[Dict[str, Any]]:
columns = self.get_columns_in_relation(relation)

for column in columns:
Expand All @@ -106,15 +100,15 @@ def _get_columns_for_catalog(
yield as_dict

def get_relation(
self, database: str, schema: str, identifier: str
self, database: Optional[str], schema: str, identifier: str
) -> Optional[BaseRelation]:
if not self.Relation.get_default_include_policy().database:
database = None

return super().get_relation(database, schema, identifier)

def parse_show_columns(
self, relation: Relation, raw_rows: List[agate.Row]
self, relation: MariaDBRelation, raw_rows: List[agate.Row]
) -> List[MariaDBColumn]:
return [
MariaDBColumn(
Expand All @@ -131,12 +125,12 @@ def parse_show_columns(
for idx, column in enumerate(raw_rows)
]

def get_catalog(self, manifest):
def get_catalog(self, manifest: Manifest) -> Tuple[agate.Table, List[Exception]]:
schema_map = self._get_catalog_schemas(manifest)

if len(schema_map) > 1:
raise dbt.exceptions.CompilationError(
f"Expected only one database in get_catalog, found "
f"{list(schema_map)}"
f"Expected only one database in get_catalog, found " f"{list(schema_map)}"
)

with executor(self.config) as tpe:
Expand Down Expand Up @@ -164,8 +158,7 @@ def _get_one_catalog(
) -> agate.Table:
if len(schemas) != 1:
raise dbt.exceptions.CompilationError(
f"Expected only one schema in mariadb _get_one_catalog, found "
f"{schemas}"
f"Expected only one schema in mariadb _get_one_catalog, found " f"{schemas}"
)

database = information_schema.database
Expand All @@ -174,14 +167,12 @@ def _get_one_catalog(
columns: List[Dict[str, Any]] = []
for relation in self.list_relations(database, schema):
logger.debug("Getting table schema for relation {}", str(relation))
columns.extend(self._get_columns_for_catalog(relation))
columns.extend(self._get_columns_for_catalog(relation)) # type: ignore[arg-type]
return agate.Table.from_object(columns,
column_types=DEFAULT_TYPE_TESTER)

def check_schema_exists(self, database, schema):
results = self.execute_macro(
LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database}
)
results = self.execute_macro(LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database})

exists = True if schema in [row[0] for row in results] else False
return exists
Expand All @@ -199,9 +190,7 @@ def update_column_sql(
clause += f" where {where_clause}"
return clause

def timestamp_add_sql(
self, add_to: str, number: int = 1, interval: str = "hour"
) -> str:
def timestamp_add_sql(self, add_to: str, number: int = 1, interval: str = "hour") -> str:
# for backwards compatibility, we're compelled to set some sort of
# default. A lot of searching has lead me to believe that the
# '+ interval' syntax used in postgres/redshift is relatively common
Expand All @@ -225,9 +214,10 @@ def string_add_sql(

def get_rows_different_sql(
self,
relation_a: MariaDBRelation,
relation_b: MariaDBRelation,
relation_a: MariaDBRelation, # type: ignore[override]
relation_b: MariaDBRelation, # type: ignore[override]
column_names: Optional[List[str]] = None,
except_operator: str = "", # Required to match BaseRelation.get_rows_different_sql()
) -> str:
# This method only really exists for test reasons
names: List[str]
Expand All @@ -241,12 +231,10 @@ def get_rows_different_sql(
alias_b = "B"
columns_csv_a = ", ".join([f"{alias_a}.{name}" for name in names])
columns_csv_b = ", ".join([f"{alias_b}.{name}" for name in names])
join_condition = " AND ".join(
[f"{alias_a}.{name} = {alias_b}.{name}" for name in names]
)
join_condition = " AND ".join([f"{alias_a}.{name} = {alias_b}.{name}" for name in names])
first_column = names[0]

# There is no EXCEPT or MINUS operator, so we need to simulate it
# MariaDB doesn't have an EXCEPT or MINUS operator, so we need to simulate it
COLUMNS_EQUAL_SQL = """
SELECT
row_count_diff.difference as row_count_difference,
Expand Down
6 changes: 2 additions & 4 deletions dbt/adapters/mariadb/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@ class MariaDBIncludePolicy(Policy):

@dataclass(frozen=True, eq=False, repr=False)
class MariaDBRelation(BaseRelation):
quote_policy: MariaDBQuotePolicy = field(
default_factory=lambda: MariaDBQuotePolicy())
include_policy: MariaDBIncludePolicy = field(
default_factory=lambda: MariaDBIncludePolicy())
quote_policy: MariaDBQuotePolicy = field(default_factory=lambda: MariaDBQuotePolicy())
include_policy: MariaDBIncludePolicy = field(default_factory=lambda: MariaDBIncludePolicy())
quote_character: str = "`"

def __post_init__(self):
Expand Down
4 changes: 3 additions & 1 deletion dbt/adapters/mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,7 @@


Plugin = AdapterPlugin(
adapter=MySQLAdapter, credentials=MySQLCredentials, include_path=mysql.PACKAGE_PATH
adapter=MySQLAdapter, # type: ignore[arg-type]
credentials=MySQLCredentials,
include_path=mysql.PACKAGE_PATH,
)
Loading

0 comments on commit dcf64e4

Please sign in to comment.