Skip to content

Commit b23d20c

Browse files
authored
Merge branch 'master' into master
2 parents 20b2283 + a4616f7 commit b23d20c

File tree

12 files changed

+372
-77
lines changed

12 files changed

+372
-77
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@
1313
dist
1414
build
1515
venv*
16+
docker-compose.yaml

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,19 @@ repos:
1616
- id: trailing-whitespace
1717

1818
- repo: https://github.com/astral-sh/ruff-pre-commit
19-
rev: v0.11.11
19+
rev: v0.12.5
2020
hooks:
2121
- id: ruff
2222
args: [--fix, --show-fixes]
2323
- id: ruff-format
2424

2525
- repo: https://github.com/pre-commit/mirrors-mypy
26-
rev: v1.15.0
26+
rev: v1.17.0
2727
hooks:
2828
- id: mypy
2929
additional_dependencies:
3030
- pytest
31-
- "sqlalchemy[mypy] < 2.0"
31+
- "SQLAlchemy >= 2.0.29"
3232

3333
- repo: https://github.com/pre-commit/pygrep-hooks
3434
rev: v1.10.0

CHANGES.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@ Version history
77
the array elements
88
- Added support for specifying engine arguments via ``--engine-arg``
99
(PR by @LajosCseppento)
10+
- Fixed incorrect package name used in ``importlib.metadata.version`` for
11+
``sqlalchemy-citext``, resolving ``PackageNotFoundError`` (PR by @oaimtiaz)
12+
- Prevent double pluralization (PR by @dkratzert)
13+
- Fixes DOMAIN extending JSON/JSONB data types (PR by @sheinbergon)
14+
- Temporarily restrict SQLAlchemy version to 2.0.41 (PR by @sheinbergon)
15+
- Fixes ``add_import`` behavior when adding imports from sqlalchemy and overall better
16+
alignment of import behavior(s) across generators
1017

1118
**3.0.0**
1219

pyproject.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@ name = "sqlacodegen"
1010
description = "Automatic model code generator for SQLAlchemy"
1111
readme = "README.rst"
1212
authors = [{name = "Alex Grönholm", email = "[email protected]"}]
13+
maintainers = [{name = "Idan Sheinberg", email = "[email protected]"}]
1314
keywords = ["sqlalchemy"]
14-
license = {text = "MIT"}
15+
license = "MIT"
1516
classifiers = [
1617
"Development Status :: 5 - Production/Stable",
1718
"Intended Audience :: Developers",
18-
"License :: OSI Approved :: MIT License",
1919
"Environment :: Console",
2020
"Topic :: Database",
2121
"Topic :: Software Development :: Code Generators",
@@ -29,9 +29,10 @@ classifiers = [
2929
]
3030
requires-python = ">=3.9"
3131
dependencies = [
32-
"SQLAlchemy >= 2.0.23",
32+
"SQLAlchemy >= 2.0.29,<2.0.42",
3333
"inflect >= 4.0.0",
3434
"importlib_metadata; python_version < '3.10'",
35+
"stdlib-list; python_version < '3.10'"
3536
]
3637
dynamic = ["version"]
3738

@@ -80,6 +81,7 @@ extend-select = [
8081

8182
[tool.mypy]
8283
strict = true
84+
disable_error_code = "no-untyped-call"
8385

8486
[tool.pytest.ini_options]
8587
addopts = "-rsfE --tb=short"

src/sqlacodegen/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def main() -> None:
105105
return
106106

107107
if citext:
108-
print(f"Using sqlalchemy-citext {version('citext')}")
108+
print(f"Using sqlalchemy-citext {version('sqlalchemy-citext')}")
109109

110110
if geoalchemy2:
111111
print(f"Using geoalchemy2 {version('geoalchemy2')}")

src/sqlacodegen/generators.py

Lines changed: 64 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from keyword import iskeyword
1414
from pprint import pformat
1515
from textwrap import indent
16-
from typing import Any, ClassVar
16+
from typing import Any, ClassVar, Literal, cast
1717

1818
import inflect
1919
import sqlalchemy
@@ -38,7 +38,7 @@
3838
TypeDecorator,
3939
UniqueConstraint,
4040
)
41-
from sqlalchemy.dialects.postgresql import JSONB
41+
from sqlalchemy.dialects.postgresql import DOMAIN, JSON, JSONB
4242
from sqlalchemy.engine import Connection, Engine
4343
from sqlalchemy.exc import CompileError
4444
from sqlalchemy.sql.elements import TextClause
@@ -59,6 +59,7 @@
5959
get_common_fk_constraints,
6060
get_compiled_expression,
6161
get_constraint_sort_key,
62+
get_stdlib_module_names,
6263
qualified_table_name,
6364
render_callable,
6465
uses_default_name,
@@ -119,9 +120,7 @@ def generate(self) -> str:
119120
@dataclass(eq=False)
120121
class TablesGenerator(CodeGenerator):
121122
valid_options: ClassVar[set[str]] = {"noindexes", "noconstraints", "nocomments"}
122-
builtin_module_names: ClassVar[set[str]] = set(sys.builtin_module_names) | {
123-
"dataclasses"
124-
}
123+
stdlib_module_names: ClassVar[set[str]] = get_stdlib_module_names()
125124

126125
def __init__(
127126
self,
@@ -222,12 +221,14 @@ def collect_imports_for_column(self, column: Column[Any]) -> None:
222221

223222
if isinstance(column.type, ARRAY):
224223
self.add_import(column.type.item_type.__class__)
225-
elif isinstance(column.type, JSONB):
224+
elif isinstance(column.type, (JSONB, JSON)):
226225
if (
227226
not isinstance(column.type.astext_type, Text)
228227
or column.type.astext_type.length is not None
229228
):
230229
self.add_import(column.type.astext_type)
230+
elif isinstance(column.type, DOMAIN):
231+
self.add_import(column.type.data_type.__class__)
231232

232233
if column.default:
233234
self.add_import(column.default)
@@ -274,7 +275,7 @@ def add_import(self, obj: Any) -> None:
274275

275276
if type_.__name__ in dialect_pkg.__all__:
276277
pkgname = dialect_pkgname
277-
elif type_.__name__ in dir(sqlalchemy):
278+
elif type_ is getattr(sqlalchemy, type_.__name__, None):
278279
pkgname = "sqlalchemy"
279280
else:
280281
pkgname = type_.__module__
@@ -298,21 +299,26 @@ def group_imports(self) -> list[list[str]]:
298299
stdlib_imports: list[str] = []
299300
thirdparty_imports: list[str] = []
300301

301-
for package in sorted(self.imports):
302-
imports = ", ".join(sorted(self.imports[package]))
302+
def get_collection(package: str) -> list[str]:
303303
collection = thirdparty_imports
304304
if package == "__future__":
305305
collection = future_imports
306-
elif package in self.builtin_module_names:
306+
elif package in self.stdlib_module_names:
307307
collection = stdlib_imports
308308
elif package in sys.modules:
309309
if "site-packages" not in (sys.modules[package].__file__ or ""):
310310
collection = stdlib_imports
311+
return collection
312+
313+
for package in sorted(self.imports):
314+
imports = ", ".join(sorted(self.imports[package]))
311315

316+
collection = get_collection(package)
312317
collection.append(f"from {package} import {imports}")
313318

314319
for module in sorted(self.module_imports):
315-
thirdparty_imports.append(f"import {module}")
320+
collection = get_collection(module)
321+
collection.append(f"import {module}")
316322

317323
return [
318324
group
@@ -375,7 +381,7 @@ def render_table(self, table: Table) -> str:
375381

376382
args.append(self.render_constraint(constraint))
377383

378-
for index in sorted(table.indexes, key=lambda i: i.name):
384+
for index in sorted(table.indexes, key=lambda i: cast(str, i.name)):
379385
# One-column indexes should be rendered as index=True on columns
380386
if len(index.columns) > 1 or not uses_default_name(index):
381387
args.append(self.render_index(index))
@@ -467,7 +473,7 @@ def render_column(
467473

468474
if isinstance(column.server_default, DefaultClause):
469475
kwargs["server_default"] = render_callable(
470-
"text", repr(column.server_default.arg.text)
476+
"text", repr(cast(TextClause, column.server_default.arg).text)
471477
)
472478
elif isinstance(column.server_default, Computed):
473479
expression = str(column.server_default.sqltext)
@@ -497,7 +503,7 @@ def render_column_callable(self, is_table: bool, *args: Any, **kwargs: Any) -> s
497503
else:
498504
return render_callable("mapped_column", *args, kwargs=kwargs)
499505

500-
def render_column_type(self, coltype: object) -> str:
506+
def render_column_type(self, coltype: TypeEngine[Any]) -> str:
501507
args = []
502508
kwargs: dict[str, Any] = {}
503509
sig = inspect.signature(coltype.__class__.__init__)
@@ -513,13 +519,30 @@ def render_column_type(self, coltype: object) -> str:
513519
continue
514520

515521
value = getattr(coltype, param.name, missing)
522+
523+
if isinstance(value, (JSONB, JSON)):
524+
# Remove astext_type if it's the default
525+
if (
526+
isinstance(value.astext_type, Text)
527+
and value.astext_type.length is None
528+
):
529+
value.astext_type = None # type: ignore[assignment]
530+
else:
531+
self.add_import(Text)
532+
516533
default = defaults.get(param.name, missing)
534+
if isinstance(value, TextClause):
535+
self.add_literal_import("sqlalchemy", "text")
536+
rendered_value = render_callable("text", repr(value.text))
537+
else:
538+
rendered_value = repr(value)
539+
517540
if value is missing or value == default:
518541
use_kwargs = True
519542
elif use_kwargs:
520-
kwargs[param.name] = repr(value)
543+
kwargs[param.name] = rendered_value
521544
else:
522-
args.append(repr(value))
545+
args.append(rendered_value)
523546

524547
vararg = next(
525548
(
@@ -539,7 +562,7 @@ def render_column_type(self, coltype: object) -> str:
539562
if (value := getattr(coltype, colname)) is not None:
540563
kwargs[colname] = repr(value)
541564

542-
if isinstance(coltype, JSONB):
565+
if isinstance(coltype, (JSONB, JSON)):
543566
# Remove astext_type if it's the default
544567
if (
545568
isinstance(coltype.astext_type, Text)
@@ -1073,13 +1096,13 @@ def generate_relationship_name(
10731096
preferred_name = column_names[0][:-3]
10741097

10751098
if "use_inflect" in self.options:
1099+
inflected_name: str | Literal[False]
10761100
if relationship.type in (
10771101
RelationshipType.ONE_TO_MANY,
10781102
RelationshipType.MANY_TO_MANY,
10791103
):
1080-
inflected_name = self.inflect_engine.plural_noun(preferred_name)
1081-
if inflected_name:
1082-
preferred_name = inflected_name
1104+
if not self.inflect_engine.singular_noun(preferred_name):
1105+
preferred_name = self.inflect_engine.plural_noun(preferred_name)
10831106
else:
10841107
inflected_name = self.inflect_engine.singular_noun(preferred_name)
10851108
if inflected_name:
@@ -1168,7 +1191,7 @@ def render_table_args(self, table: Table) -> str:
11681191
args.append(self.render_constraint(constraint))
11691192

11701193
# Render indexes
1171-
for index in sorted(table.indexes, key=lambda i: i.name):
1194+
for index in sorted(table.indexes, key=lambda i: cast(str, i.name)):
11721195
if len(index.columns) > 1 or not uses_default_name(index):
11731196
args.append(self.render_index(index))
11741197

@@ -1194,10 +1217,7 @@ def render_table_args(self, table: Table) -> str:
11941217
else:
11951218
return ""
11961219

1197-
def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
1198-
column = column_attr.column
1199-
rendered_column = self.render_column(column, column_attr.name != column.name)
1200-
1220+
def render_column_python_type(self, column: Column[Any]) -> str:
12011221
def get_type_qualifiers() -> tuple[str, TypeEngine[Any], str]:
12021222
column_type = column.type
12031223
pre: list[str] = []
@@ -1217,7 +1237,11 @@ def get_type_qualifiers() -> tuple[str, TypeEngine[Any], str]:
12171237
return "".join(pre), column_type, "]" * post_size
12181238

12191239
def render_python_type(column_type: TypeEngine[Any]) -> str:
1220-
python_type = column_type.python_type
1240+
if isinstance(column_type, DOMAIN):
1241+
python_type = column_type.data_type.python_type
1242+
else:
1243+
python_type = column_type.python_type
1244+
12211245
python_type_name = python_type.__name__
12221246
python_type_module = python_type.__module__
12231247
if python_type_module == "builtins":
@@ -1232,7 +1256,14 @@ def render_python_type(column_type: TypeEngine[Any]) -> str:
12321256

12331257
pre, col_type, post = get_type_qualifiers()
12341258
column_python_type = f"{pre}{render_python_type(col_type)}{post}"
1235-
return f"{column_attr.name}: Mapped[{column_python_type}] = {rendered_column}"
1259+
return column_python_type
1260+
1261+
def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
1262+
column = column_attr.column
1263+
rendered_column = self.render_column(column, column_attr.name != column.name)
1264+
rendered_column_python_type = self.render_column_python_type(column)
1265+
1266+
return f"{column_attr.name}: Mapped[{rendered_column_python_type}] = {rendered_column}"
12361267

12371268
def render_relationship(self, relationship: RelationshipAttribute) -> str:
12381269
def render_column_attrs(column_attrs: list[ColumnAttribute]) -> str:
@@ -1422,15 +1453,6 @@ def collect_imports_for_model(self, model: Model) -> None:
14221453
if model.relationships:
14231454
self.add_literal_import("sqlmodel", "Relationship")
14241455

1425-
def collect_imports_for_column(self, column: Column[Any]) -> None:
1426-
super().collect_imports_for_column(column)
1427-
try:
1428-
python_type = column.type.python_type
1429-
except NotImplementedError:
1430-
self.add_literal_import("typing", "Any")
1431-
else:
1432-
self.add_import(python_type)
1433-
14341456
def render_module_variables(self, models: list[Model]) -> str:
14351457
declarations: list[str] = []
14361458
if any(not isinstance(model, ModelClass) for model in models):
@@ -1463,25 +1485,17 @@ def render_class_variables(self, model: ModelClass) -> str:
14631485

14641486
def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
14651487
column = column_attr.column
1466-
try:
1467-
python_type = column.type.python_type
1468-
except NotImplementedError:
1469-
python_type_name = "Any"
1470-
else:
1471-
python_type_name = python_type.__name__
1488+
rendered_column = self.render_column(column, True)
1489+
rendered_column_python_type = self.render_column_python_type(column)
14721490

14731491
kwargs: dict[str, Any] = {}
1474-
if (
1475-
column.autoincrement and column.name in column.table.primary_key
1476-
) or column.nullable:
1477-
self.add_literal_import("typing", "Optional")
1492+
if column.nullable:
14781493
kwargs["default"] = None
1479-
python_type_name = f"Optional[{python_type_name}]"
1480-
1481-
rendered_column = self.render_column(column, True)
14821494
kwargs["sa_column"] = f"{rendered_column}"
1495+
14831496
rendered_field = render_callable("Field", kwargs=kwargs)
1484-
return f"{column_attr.name}: {python_type_name} = {rendered_field}"
1497+
1498+
return f"{column_attr.name}: {rendered_column_python_type} = {rendered_field}"
14851499

14861500
def render_relationship(self, relationship: RelationshipAttribute) -> str:
14871501
rendered = super().render_relationship(relationship).partition(" = ")[2]

0 commit comments

Comments
 (0)