@@ -132,16 +132,14 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
132
132
name , schema = schema , columns = True , placeholder = "%s"
133
133
)
134
134
135
- with self .begin () as cursor :
135
+ con = self .con
136
+ with con .cursor () as cursor , con .transaction ():
136
137
cursor .execute (create_stmt_sql )
137
138
cursor .executemany (sql , data )
138
139
139
140
@contextlib .contextmanager
140
- def begin (self , * , name : str = "" , withhold : bool = False ):
141
- with (
142
- (con := self .con ).transaction (),
143
- con .cursor (name = name , withhold = withhold ) as cursor ,
144
- ):
141
+ def begin (self ):
142
+ with (con := self .con ).cursor () as cursor , con .transaction ():
145
143
yield cursor
146
144
147
145
def _fetch_from_cursor (
@@ -278,9 +276,11 @@ def _post_connect(self) -> None:
278
276
import psycopg .types
279
277
import psycopg .types .hstore
280
278
279
+ con = self .con
280
+
281
281
try :
282
282
# try to load hstore
283
- with self . begin () as cursor :
283
+ with con . cursor () as cursor , con . transaction () :
284
284
cursor .execute ("CREATE EXTENSION IF NOT EXISTS hstore" )
285
285
psycopg .types .hstore .register_hstore (
286
286
psycopg .types .TypeInfo .fetch (self .con , "hstore" ), self .con
@@ -290,7 +290,7 @@ def _post_connect(self) -> None:
290
290
except TypeError :
291
291
pass
292
292
293
- with self . begin () as cursor :
293
+ with con . cursor () as cursor , con . transaction () :
294
294
cursor .execute ("SET TIMEZONE = UTC" )
295
295
296
296
@property
@@ -300,9 +300,11 @@ def _session_temp_db(self) -> str | None:
300
300
# Before that temp table is created, this will return `None`
301
301
# After a temp table is created, it will return `pg_temp_N` where N is
302
302
# some integer
303
- res = self .raw_sql (
304
- "select nspname from pg_namespace where oid = pg_my_temp_schema()"
305
- ).fetchone ()
303
+ con = self .con
304
+ with con .cursor () as cursor , con .transaction ():
305
+ res = cursor .execute (
306
+ "SELECT nspname FROM pg_namespace WHERE oid = pg_my_temp_schema()"
307
+ ).fetchone ()
306
308
if res is not None :
307
309
return res [0 ]
308
310
return res
@@ -358,8 +360,9 @@ def list_tables(
358
360
.sql (self .dialect )
359
361
)
360
362
361
- with self ._safe_raw_sql (sql ) as cur :
362
- out = cur .fetchall ()
363
+ con = self .con
364
+ with con .cursor () as cursor , con .transaction ():
365
+ out = cursor .execute (sql ).fetchall ()
363
366
364
367
# Include temporary tables only if no database has been explicitly specified
365
368
# to avoid temp tables showing up in all calls to `list_tables`
@@ -381,8 +384,9 @@ def _fetch_temp_tables(self):
381
384
.sql (self .dialect )
382
385
)
383
386
384
- with self ._safe_raw_sql (sql ) as cur :
385
- out = cur .fetchall ()
387
+ con = self .con
388
+ with con .cursor () as cursor , con .transaction ():
389
+ out = cursor .execute (sql ).fetchall ()
386
390
387
391
return out
388
392
@@ -392,33 +396,42 @@ def list_catalogs(self, *, like: str | None = None) -> list[str]:
392
396
sg .select (C .datname )
393
397
.from_ (sg .table ("pg_database" , db = "pg_catalog" ))
394
398
.where (sg .not_ (C .datistemplate ))
399
+ .sql (self .dialect )
395
400
)
396
- with self ._safe_raw_sql (cats ) as cur :
397
- catalogs = list (map (itemgetter (0 ), cur ))
401
+ con = self .con
402
+ with con .cursor () as cursor , con .transaction ():
403
+ catalogs = list (map (itemgetter (0 ), cursor .execute (cats )))
398
404
399
405
return self ._filter_with_like (catalogs , like )
400
406
401
407
def list_databases (
402
408
self , * , like : str | None = None , catalog : str | None = None
403
409
) -> list [str ]:
404
- dbs = sg .select (C .schema_name ).from_ (
405
- sg .table ("schemata" , db = "information_schema" )
410
+ dbs = (
411
+ sg .select (C .schema_name )
412
+ .from_ (sg .table ("schemata" , db = "information_schema" ))
413
+ .sql (self .dialect )
406
414
)
407
- with self ._safe_raw_sql (dbs ) as cur :
408
- databases = list (map (itemgetter (0 ), cur ))
415
+ con = self .con
416
+ with con .cursor () as cursor , con .transaction ():
417
+ databases = list (map (itemgetter (0 ), cursor .execute (dbs )))
409
418
410
419
return self ._filter_with_like (databases , like )
411
420
412
421
@property
413
422
def current_catalog (self ) -> str :
414
- with self ._safe_raw_sql (sg .select (sg .func ("current_database" ))) as cur :
415
- (db ,) = cur .fetchone ()
423
+ sql = sg .select (sg .func ("current_database" )).sql (self .dialect )
424
+ con = self .con
425
+ with con .cursor () as cursor , con .transaction ():
426
+ (db ,) = cursor .execute (sql ).fetchone ()
416
427
return db
417
428
418
429
@property
419
430
def current_database (self ) -> str :
420
- with self ._safe_raw_sql (sg .select (sg .func ("current_schema" ))) as cur :
421
- (schema ,) = cur .fetchone ()
431
+ sql = sg .select (sg .func ("current_schema" )).sql (self .dialect )
432
+ con = self .con
433
+ with con .cursor () as cursor , con .transaction ():
434
+ (schema ,) = cursor .execute (sql ).fetchone ()
422
435
return schema
423
436
424
437
def function (self , name : str , * , database : str | None = None ) -> Callable :
@@ -445,14 +458,16 @@ def function(self, name: str, *, database: str | None = None) -> Callable:
445
458
join_type = "LEFT" ,
446
459
)
447
460
.where (* predicates )
461
+ .sql (self .dialect )
448
462
)
449
463
450
464
def split_name_type (arg : str ) -> tuple [str , dt .DataType ]:
451
465
name , typ = arg .split (" " , 1 )
452
466
return name , self .compiler .type_mapper .from_string (typ )
453
467
454
- with self ._safe_raw_sql (query ) as cur :
455
- rows = cur .fetchall ()
468
+ con = self .con
469
+ with con .cursor () as cursor , con .transaction ():
470
+ rows = cursor .execute (query ).fetchall ()
456
471
457
472
if not rows :
458
473
name = f"{ database } .{ name } " if database else name
@@ -528,12 +543,14 @@ def get_schema(
528
543
c .relname .eq (sge .convert (name )),
529
544
)
530
545
.order_by (a .attnum )
546
+ .sql (self .dialect )
531
547
)
532
548
533
549
type_mapper = self .compiler .type_mapper
534
550
535
- with self ._safe_raw_sql (type_info ) as cur :
536
- rows = cur .fetchall ()
551
+ con = self .con
552
+ with con .cursor () as cursor , con .transaction ():
553
+ rows = cursor .execute (type_info ).fetchall ()
537
554
538
555
if not rows :
539
556
raise com .TableNotFound (name )
@@ -553,18 +570,21 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
553
570
this = sg .table (name ),
554
571
expression = sg .parse_one (query , read = self .dialect ),
555
572
properties = sge .Properties (expressions = [sge .TemporaryProperty ()]),
556
- )
573
+ ).sql (self .dialect )
574
+
557
575
drop_stmt = sge .Drop (kind = "VIEW" , this = sg .table (name ), exists = True ).sql (
558
576
self .dialect
559
577
)
560
578
561
- with self ._safe_raw_sql (create_stmt ):
562
- pass
579
+ con = self .con
580
+ with con .cursor () as cursor , con .transaction ():
581
+ cursor .execute (create_stmt )
582
+
563
583
try :
564
584
return self .get_schema (name )
565
585
finally :
566
- with self . _safe_raw_sql ( drop_stmt ):
567
- pass
586
+ with con . cursor () as cursor , con . transaction ( ):
587
+ cursor . execute ( drop_stmt )
568
588
569
589
def create_database (
570
590
self , name : str , / , * , catalog : str | None = None , force : bool = False
@@ -575,9 +595,10 @@ def create_database(
575
595
)
576
596
sql = sge .Create (
577
597
kind = "SCHEMA" , this = sg .table (name , catalog = catalog ), exists = force
578
- )
579
- with self ._safe_raw_sql (sql ):
580
- pass
598
+ ).sql (self .dialect )
599
+ con = self .con
600
+ with con .cursor () as cursor , con .transaction ():
601
+ cursor .execute (sql )
581
602
582
603
def drop_database (
583
604
self ,
@@ -598,9 +619,11 @@ def drop_database(
598
619
this = sg .table (name , catalog = catalog ),
599
620
exists = force ,
600
621
cascade = cascade ,
601
- )
602
- with self ._safe_raw_sql (sql ):
603
- pass
622
+ ).sql (self .dialect )
623
+
624
+ con = self .con
625
+ with con .cursor () as cursor , con .transaction ():
626
+ cursor .execute (sql )
604
627
605
628
def create_table (
606
629
self ,
@@ -679,19 +702,24 @@ def create_table(
679
702
kind = "TABLE" ,
680
703
this = target ,
681
704
properties = sge .Properties (expressions = properties ),
682
- )
705
+ ). sql ( dialect )
683
706
684
707
this = sg .table (name , catalog = database , quoted = quoted )
685
708
this_no_catalog = sg .table (name , quoted = quoted )
686
709
687
- with self ._safe_raw_sql (create_stmt ) as cur :
710
+ con = self .con
711
+ with con .cursor () as cursor , con .transaction ():
712
+ cursor .execute (create_stmt )
713
+
688
714
if query is not None :
689
715
insert_stmt = sge .Insert (this = table_expr , expression = query ).sql (dialect )
690
- cur .execute (insert_stmt )
716
+ cursor .execute (insert_stmt )
691
717
692
718
if overwrite :
693
- cur .execute (sge .Drop (kind = "TABLE" , this = this , exists = True ).sql (dialect ))
694
- cur .execute (
719
+ cursor .execute (
720
+ sge .Drop (kind = "TABLE" , this = this , exists = True ).sql (dialect )
721
+ )
722
+ cursor .execute (
695
723
f"ALTER TABLE IF EXISTS { table_expr .sql (dialect )} RENAME TO { this_no_catalog .sql (dialect )} "
696
724
)
697
725
@@ -715,16 +743,17 @@ def drop_table(
715
743
kind = "TABLE" ,
716
744
this = sg .table (name , db = database , quoted = self .compiler .quoted ),
717
745
exists = force ,
718
- )
719
- with self ._safe_raw_sql (drop_stmt ):
720
- pass
746
+ ).sql (self .dialect )
747
+ con = self .con
748
+ with con .cursor () as cursor , con .transaction ():
749
+ cursor .execute (drop_stmt )
721
750
722
751
@contextlib .contextmanager
723
752
def _safe_raw_sql (self , query : str | sg .Expression , ** kwargs : Any ):
724
753
with contextlib .suppress (AttributeError ):
725
754
query = query .sql (dialect = self .dialect )
726
755
727
- with self .begin () as cursor :
756
+ with ( con := self .con ). cursor () as cursor , con . transaction () :
728
757
yield cursor .execute (query , ** kwargs )
729
758
730
759
def raw_sql (self , query : str | sg .Expression , ** kwargs : Any ) -> Any :
@@ -757,11 +786,13 @@ def to_pyarrow_batches(
757
786
import pyarrow as pa
758
787
759
788
def _batches (self : Self , * , schema : pa .Schema , query : str ):
789
+ con = self .con
760
790
columns = schema .names
761
791
# server-side cursors need to be uniquely named
762
- with self .begin (
763
- name = util .gen_name ("postgres_cursor" ), withhold = True
764
- ) as cursor :
792
+ with (
793
+ con .cursor (name = util .gen_name ("postgres_cursor" )) as cursor ,
794
+ con .transaction (),
795
+ ):
765
796
cursor .execute (query )
766
797
while batch := cursor .fetchmany (chunk_size ):
767
798
yield pa .RecordBatch .from_pandas (
0 commit comments