1313from  keyword  import  iskeyword 
1414from  pprint  import  pformat 
1515from  textwrap  import  indent 
16- from  typing  import  Any , ClassVar 
16+ from  typing  import  Any , ClassVar ,  Literal ,  cast 
1717
1818import  inflect 
1919import  sqlalchemy 
3838    TypeDecorator ,
3939    UniqueConstraint ,
4040)
41- from  sqlalchemy .dialects .postgresql  import  JSONB 
41+ from  sqlalchemy .dialects .postgresql  import  DOMAIN ,  JSON ,  JSONB 
4242from  sqlalchemy .engine  import  Connection , Engine 
4343from  sqlalchemy .exc  import  CompileError 
4444from  sqlalchemy .sql .elements  import  TextClause 
5959    get_common_fk_constraints ,
6060    get_compiled_expression ,
6161    get_constraint_sort_key ,
62+     get_stdlib_module_names ,
6263    qualified_table_name ,
6364    render_callable ,
6465    uses_default_name ,
@@ -119,9 +120,7 @@ def generate(self) -> str:
119120@dataclass (eq = False ) 
120121class  TablesGenerator (CodeGenerator ):
121122    valid_options : ClassVar [set [str ]] =  {"noindexes" , "noconstraints" , "nocomments" }
122-     builtin_module_names : ClassVar [set [str ]] =  set (sys .builtin_module_names ) |  {
123-         "dataclasses" 
124-     }
123+     stdlib_module_names : ClassVar [set [str ]] =  get_stdlib_module_names ()
125124
126125    def  __init__ (
127126        self ,
@@ -222,12 +221,14 @@ def collect_imports_for_column(self, column: Column[Any]) -> None:
222221
223222        if  isinstance (column .type , ARRAY ):
224223            self .add_import (column .type .item_type .__class__ )
225-         elif  isinstance (column .type , JSONB ):
224+         elif  isinstance (column .type , ( JSONB ,  JSON ) ):
226225            if  (
227226                not  isinstance (column .type .astext_type , Text )
228227                or  column .type .astext_type .length  is  not None 
229228            ):
230229                self .add_import (column .type .astext_type )
230+         elif  isinstance (column .type , DOMAIN ):
231+             self .add_import (column .type .data_type .__class__ )
231232
232233        if  column .default :
233234            self .add_import (column .default )
@@ -274,7 +275,7 @@ def add_import(self, obj: Any) -> None:
274275
275276            if  type_ .__name__  in  dialect_pkg .__all__ :
276277                pkgname  =  dialect_pkgname 
277-         elif  type_ . __name__   in   dir (sqlalchemy ):
278+         elif  type_   is   getattr (sqlalchemy ,  type_ . __name__ ,  None ):
278279            pkgname  =  "sqlalchemy" 
279280        else :
280281            pkgname  =  type_ .__module__ 
@@ -298,21 +299,26 @@ def group_imports(self) -> list[list[str]]:
298299        stdlib_imports : list [str ] =  []
299300        thirdparty_imports : list [str ] =  []
300301
301-         for  package  in  sorted (self .imports ):
302-             imports  =  ", " .join (sorted (self .imports [package ]))
302+         def  get_collection (package : str ) ->  list [str ]:
303303            collection  =  thirdparty_imports 
304304            if  package  ==  "__future__" :
305305                collection  =  future_imports 
306-             elif  package  in  self .builtin_module_names :
306+             elif  package  in  self .stdlib_module_names :
307307                collection  =  stdlib_imports 
308308            elif  package  in  sys .modules :
309309                if  "site-packages"  not  in sys .modules [package ].__file__  or  "" ):
310310                    collection  =  stdlib_imports 
311+             return  collection 
312+ 
313+         for  package  in  sorted (self .imports ):
314+             imports  =  ", " .join (sorted (self .imports [package ]))
311315
316+             collection  =  get_collection (package )
312317            collection .append (f"from { package } { imports }  )
313318
314319        for  module  in  sorted (self .module_imports ):
315-             thirdparty_imports .append (f"import { module }  )
320+             collection  =  get_collection (module )
321+             collection .append (f"import { module }  )
316322
317323        return  [
318324            group 
@@ -375,7 +381,7 @@ def render_table(self, table: Table) -> str:
375381
376382            args .append (self .render_constraint (constraint ))
377383
378-         for  index  in  sorted (table .indexes , key = lambda  i : i .name ):
384+         for  index  in  sorted (table .indexes , key = lambda  i : cast ( str ,  i .name ) ):
379385            # One-column indexes should be rendered as index=True on columns 
380386            if  len (index .columns ) >  1  or  not  uses_default_name (index ):
381387                args .append (self .render_index (index ))
@@ -467,7 +473,7 @@ def render_column(
467473
468474        if  isinstance (column .server_default , DefaultClause ):
469475            kwargs ["server_default" ] =  render_callable (
470-                 "text" , repr (column .server_default .arg .text )
476+                 "text" , repr (cast ( TextClause ,  column .server_default .arg ) .text )
471477            )
472478        elif  isinstance (column .server_default , Computed ):
473479            expression  =  str (column .server_default .sqltext )
@@ -497,7 +503,7 @@ def render_column_callable(self, is_table: bool, *args: Any, **kwargs: Any) -> s
497503        else :
498504            return  render_callable ("mapped_column" , * args , kwargs = kwargs )
499505
500-     def  render_column_type (self , coltype : object ) ->  str :
506+     def  render_column_type (self , coltype : TypeEngine [ Any ] ) ->  str :
501507        args  =  []
502508        kwargs : dict [str , Any ] =  {}
503509        sig  =  inspect .signature (coltype .__class__ .__init__ )
@@ -513,13 +519,30 @@ def render_column_type(self, coltype: object) -> str:
513519                continue 
514520
515521            value  =  getattr (coltype , param .name , missing )
522+ 
523+             if  isinstance (value , (JSONB , JSON )):
524+                 # Remove astext_type if it's the default 
525+                 if  (
526+                     isinstance (value .astext_type , Text )
527+                     and  value .astext_type .length  is  None 
528+                 ):
529+                     value .astext_type  =  None   # type: ignore[assignment] 
530+                 else :
531+                     self .add_import (Text )
532+ 
516533            default  =  defaults .get (param .name , missing )
534+             if  isinstance (value , TextClause ):
535+                 self .add_literal_import ("sqlalchemy" , "text" )
536+                 rendered_value  =  render_callable ("text" , repr (value .text ))
537+             else :
538+                 rendered_value  =  repr (value )
539+ 
517540            if  value  is  missing  or  value  ==  default :
518541                use_kwargs  =  True 
519542            elif  use_kwargs :
520-                 kwargs [param .name ] =  repr ( value ) 
543+                 kwargs [param .name ] =  rendered_value 
521544            else :
522-                 args .append (repr ( value ) )
545+                 args .append (rendered_value )
523546
524547        vararg  =  next (
525548            (
@@ -539,7 +562,7 @@ def render_column_type(self, coltype: object) -> str:
539562                if  (value  :=  getattr (coltype , colname )) is  not None :
540563                    kwargs [colname ] =  repr (value )
541564
542-         if  isinstance (coltype , JSONB ):
565+         if  isinstance (coltype , ( JSONB ,  JSON ) ):
543566            # Remove astext_type if it's the default 
544567            if  (
545568                isinstance (coltype .astext_type , Text )
@@ -1073,13 +1096,13 @@ def generate_relationship_name(
10731096                        preferred_name  =  column_names [0 ][:- 3 ]
10741097
10751098            if  "use_inflect"  in  self .options :
1099+                 inflected_name : str  |  Literal [False ]
10761100                if  relationship .type  in  (
10771101                    RelationshipType .ONE_TO_MANY ,
10781102                    RelationshipType .MANY_TO_MANY ,
10791103                ):
1080-                     inflected_name  =  self .inflect_engine .plural_noun (preferred_name )
1081-                     if  inflected_name :
1082-                         preferred_name  =  inflected_name 
1104+                     if  not  self .inflect_engine .singular_noun (preferred_name ):
1105+                         preferred_name  =  self .inflect_engine .plural_noun (preferred_name )
10831106                else :
10841107                    inflected_name  =  self .inflect_engine .singular_noun (preferred_name )
10851108                    if  inflected_name :
@@ -1168,7 +1191,7 @@ def render_table_args(self, table: Table) -> str:
11681191            args .append (self .render_constraint (constraint ))
11691192
11701193        # Render indexes 
1171-         for  index  in  sorted (table .indexes , key = lambda  i : i .name ):
1194+         for  index  in  sorted (table .indexes , key = lambda  i : cast ( str ,  i .name ) ):
11721195            if  len (index .columns ) >  1  or  not  uses_default_name (index ):
11731196                args .append (self .render_index (index ))
11741197
@@ -1194,10 +1217,7 @@ def render_table_args(self, table: Table) -> str:
11941217        else :
11951218            return  "" 
11961219
1197-     def  render_column_attribute (self , column_attr : ColumnAttribute ) ->  str :
1198-         column  =  column_attr .column 
1199-         rendered_column  =  self .render_column (column , column_attr .name  !=  column .name )
1200- 
1220+     def  render_column_python_type (self , column : Column [Any ]) ->  str :
12011221        def  get_type_qualifiers () ->  tuple [str , TypeEngine [Any ], str ]:
12021222            column_type  =  column .type 
12031223            pre : list [str ] =  []
@@ -1217,7 +1237,11 @@ def get_type_qualifiers() -> tuple[str, TypeEngine[Any], str]:
12171237            return  "" .join (pre ), column_type , "]"  *  post_size 
12181238
12191239        def  render_python_type (column_type : TypeEngine [Any ]) ->  str :
1220-             python_type  =  column_type .python_type 
1240+             if  isinstance (column_type , DOMAIN ):
1241+                 python_type  =  column_type .data_type .python_type 
1242+             else :
1243+                 python_type  =  column_type .python_type 
1244+ 
12211245            python_type_name  =  python_type .__name__ 
12221246            python_type_module  =  python_type .__module__ 
12231247            if  python_type_module  ==  "builtins" :
@@ -1232,7 +1256,14 @@ def render_python_type(column_type: TypeEngine[Any]) -> str:
12321256
12331257        pre , col_type , post  =  get_type_qualifiers ()
12341258        column_python_type  =  f"{ pre } { render_python_type (col_type )} { post }  
1235-         return  f"{ column_attr .name } { column_python_type } { rendered_column }  
1259+         return  column_python_type 
1260+ 
1261+     def  render_column_attribute (self , column_attr : ColumnAttribute ) ->  str :
1262+         column  =  column_attr .column 
1263+         rendered_column  =  self .render_column (column , column_attr .name  !=  column .name )
1264+         rendered_column_python_type  =  self .render_column_python_type (column )
1265+ 
1266+         return  f"{ column_attr .name } { rendered_column_python_type } { rendered_column }  
12361267
12371268    def  render_relationship (self , relationship : RelationshipAttribute ) ->  str :
12381269        def  render_column_attrs (column_attrs : list [ColumnAttribute ]) ->  str :
@@ -1422,15 +1453,6 @@ def collect_imports_for_model(self, model: Model) -> None:
14221453            if  model .relationships :
14231454                self .add_literal_import ("sqlmodel" , "Relationship" )
14241455
1425-     def  collect_imports_for_column (self , column : Column [Any ]) ->  None :
1426-         super ().collect_imports_for_column (column )
1427-         try :
1428-             python_type  =  column .type .python_type 
1429-         except  NotImplementedError :
1430-             self .add_literal_import ("typing" , "Any" )
1431-         else :
1432-             self .add_import (python_type )
1433- 
14341456    def  render_module_variables (self , models : list [Model ]) ->  str :
14351457        declarations : list [str ] =  []
14361458        if  any (not  isinstance (model , ModelClass ) for  model  in  models ):
@@ -1463,25 +1485,17 @@ def render_class_variables(self, model: ModelClass) -> str:
14631485
14641486    def  render_column_attribute (self , column_attr : ColumnAttribute ) ->  str :
14651487        column  =  column_attr .column 
1466-         try :
1467-             python_type  =  column .type .python_type 
1468-         except  NotImplementedError :
1469-             python_type_name  =  "Any" 
1470-         else :
1471-             python_type_name  =  python_type .__name__ 
1488+         rendered_column  =  self .render_column (column , True )
1489+         rendered_column_python_type  =  self .render_column_python_type (column )
14721490
14731491        kwargs : dict [str , Any ] =  {}
1474-         if  (
1475-             column .autoincrement  and  column .name  in  column .table .primary_key 
1476-         ) or  column .nullable :
1477-             self .add_literal_import ("typing" , "Optional" )
1492+         if  column .nullable :
14781493            kwargs ["default" ] =  None 
1479-             python_type_name  =  f"Optional[{ python_type_name }  
1480- 
1481-         rendered_column  =  self .render_column (column , True )
14821494        kwargs ["sa_column" ] =  f"{ rendered_column }  
1495+ 
14831496        rendered_field  =  render_callable ("Field" , kwargs = kwargs )
1484-         return  f"{ column_attr .name } { python_type_name } { rendered_field }  
1497+ 
1498+         return  f"{ column_attr .name } { rendered_column_python_type } { rendered_field }  
14851499
14861500    def  render_relationship (self , relationship : RelationshipAttribute ) ->  str :
14871501        rendered  =  super ().render_relationship (relationship ).partition (" = " )[2 ]
0 commit comments