Skip to content

Commit c3097a7

Browse files
authored
fix(postgres): clean up leaky cursor coming from raw_sql (ibis-project#11001)
## Description of changes This PR cleans up a leaky cursor coming from a `raw_sql` call that never closed the cursor it returns. I've also added some SQL that should make it easier to debug the problem of idle transactions in the future.
1 parent a065ad3 commit c3097a7

File tree

4 files changed

+114
-68
lines changed

4 files changed

+114
-68
lines changed

docker/postgres/debug.sql

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
SELECT
2+
"pid",
3+
"application_name",
4+
"backend_start",
5+
"state",
6+
"wait_event",
7+
"wait_event_type",
8+
"query"
9+
FROM "pg_stat_activity"
10+
WHERE
11+
"backend_type" = 'client backend' AND "state" = 'idle in transaction'
12+
ORDER BY
13+
"backend_start"

ibis/backends/postgres/__init__.py

+83-52
Original file line numberDiff line numberDiff line change
@@ -132,16 +132,14 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
132132
name, schema=schema, columns=True, placeholder="%s"
133133
)
134134

135-
with self.begin() as cursor:
135+
con = self.con
136+
with con.cursor() as cursor, con.transaction():
136137
cursor.execute(create_stmt_sql)
137138
cursor.executemany(sql, data)
138139

139140
@contextlib.contextmanager
140-
def begin(self, *, name: str = "", withhold: bool = False):
141-
with (
142-
(con := self.con).transaction(),
143-
con.cursor(name=name, withhold=withhold) as cursor,
144-
):
141+
def begin(self):
142+
with (con := self.con).cursor() as cursor, con.transaction():
145143
yield cursor
146144

147145
def _fetch_from_cursor(
@@ -278,9 +276,11 @@ def _post_connect(self) -> None:
278276
import psycopg.types
279277
import psycopg.types.hstore
280278

279+
con = self.con
280+
281281
try:
282282
# try to load hstore
283-
with self.begin() as cursor:
283+
with con.cursor() as cursor, con.transaction():
284284
cursor.execute("CREATE EXTENSION IF NOT EXISTS hstore")
285285
psycopg.types.hstore.register_hstore(
286286
psycopg.types.TypeInfo.fetch(self.con, "hstore"), self.con
@@ -290,7 +290,7 @@ def _post_connect(self) -> None:
290290
except TypeError:
291291
pass
292292

293-
with self.begin() as cursor:
293+
with con.cursor() as cursor, con.transaction():
294294
cursor.execute("SET TIMEZONE = UTC")
295295

296296
@property
@@ -300,9 +300,11 @@ def _session_temp_db(self) -> str | None:
300300
# Before that temp table is created, this will return `None`
301301
# After a temp table is created, it will return `pg_temp_N` where N is
302302
# some integer
303-
res = self.raw_sql(
304-
"select nspname from pg_namespace where oid = pg_my_temp_schema()"
305-
).fetchone()
303+
con = self.con
304+
with con.cursor() as cursor, con.transaction():
305+
res = cursor.execute(
306+
"SELECT nspname FROM pg_namespace WHERE oid = pg_my_temp_schema()"
307+
).fetchone()
306308
if res is not None:
307309
return res[0]
308310
return res
@@ -358,8 +360,9 @@ def list_tables(
358360
.sql(self.dialect)
359361
)
360362

361-
with self._safe_raw_sql(sql) as cur:
362-
out = cur.fetchall()
363+
con = self.con
364+
with con.cursor() as cursor, con.transaction():
365+
out = cursor.execute(sql).fetchall()
363366

364367
# Include temporary tables only if no database has been explicitly specified
365368
# to avoid temp tables showing up in all calls to `list_tables`
@@ -381,8 +384,9 @@ def _fetch_temp_tables(self):
381384
.sql(self.dialect)
382385
)
383386

384-
with self._safe_raw_sql(sql) as cur:
385-
out = cur.fetchall()
387+
con = self.con
388+
with con.cursor() as cursor, con.transaction():
389+
out = cursor.execute(sql).fetchall()
386390

387391
return out
388392

@@ -392,33 +396,42 @@ def list_catalogs(self, *, like: str | None = None) -> list[str]:
392396
sg.select(C.datname)
393397
.from_(sg.table("pg_database", db="pg_catalog"))
394398
.where(sg.not_(C.datistemplate))
399+
.sql(self.dialect)
395400
)
396-
with self._safe_raw_sql(cats) as cur:
397-
catalogs = list(map(itemgetter(0), cur))
401+
con = self.con
402+
with con.cursor() as cursor, con.transaction():
403+
catalogs = list(map(itemgetter(0), cursor.execute(cats)))
398404

399405
return self._filter_with_like(catalogs, like)
400406

401407
def list_databases(
402408
self, *, like: str | None = None, catalog: str | None = None
403409
) -> list[str]:
404-
dbs = sg.select(C.schema_name).from_(
405-
sg.table("schemata", db="information_schema")
410+
dbs = (
411+
sg.select(C.schema_name)
412+
.from_(sg.table("schemata", db="information_schema"))
413+
.sql(self.dialect)
406414
)
407-
with self._safe_raw_sql(dbs) as cur:
408-
databases = list(map(itemgetter(0), cur))
415+
con = self.con
416+
with con.cursor() as cursor, con.transaction():
417+
databases = list(map(itemgetter(0), cursor.execute(dbs)))
409418

410419
return self._filter_with_like(databases, like)
411420

412421
@property
413422
def current_catalog(self) -> str:
414-
with self._safe_raw_sql(sg.select(sg.func("current_database"))) as cur:
415-
(db,) = cur.fetchone()
423+
sql = sg.select(sg.func("current_database")).sql(self.dialect)
424+
con = self.con
425+
with con.cursor() as cursor, con.transaction():
426+
(db,) = cursor.execute(sql).fetchone()
416427
return db
417428

418429
@property
419430
def current_database(self) -> str:
420-
with self._safe_raw_sql(sg.select(sg.func("current_schema"))) as cur:
421-
(schema,) = cur.fetchone()
431+
sql = sg.select(sg.func("current_schema")).sql(self.dialect)
432+
con = self.con
433+
with con.cursor() as cursor, con.transaction():
434+
(schema,) = cursor.execute(sql).fetchone()
422435
return schema
423436

424437
def function(self, name: str, *, database: str | None = None) -> Callable:
@@ -445,14 +458,16 @@ def function(self, name: str, *, database: str | None = None) -> Callable:
445458
join_type="LEFT",
446459
)
447460
.where(*predicates)
461+
.sql(self.dialect)
448462
)
449463

450464
def split_name_type(arg: str) -> tuple[str, dt.DataType]:
451465
name, typ = arg.split(" ", 1)
452466
return name, self.compiler.type_mapper.from_string(typ)
453467

454-
with self._safe_raw_sql(query) as cur:
455-
rows = cur.fetchall()
468+
con = self.con
469+
with con.cursor() as cursor, con.transaction():
470+
rows = cursor.execute(query).fetchall()
456471

457472
if not rows:
458473
name = f"{database}.{name}" if database else name
@@ -528,12 +543,14 @@ def get_schema(
528543
c.relname.eq(sge.convert(name)),
529544
)
530545
.order_by(a.attnum)
546+
.sql(self.dialect)
531547
)
532548

533549
type_mapper = self.compiler.type_mapper
534550

535-
with self._safe_raw_sql(type_info) as cur:
536-
rows = cur.fetchall()
551+
con = self.con
552+
with con.cursor() as cursor, con.transaction():
553+
rows = cursor.execute(type_info).fetchall()
537554

538555
if not rows:
539556
raise com.TableNotFound(name)
@@ -553,18 +570,21 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
553570
this=sg.table(name),
554571
expression=sg.parse_one(query, read=self.dialect),
555572
properties=sge.Properties(expressions=[sge.TemporaryProperty()]),
556-
)
573+
).sql(self.dialect)
574+
557575
drop_stmt = sge.Drop(kind="VIEW", this=sg.table(name), exists=True).sql(
558576
self.dialect
559577
)
560578

561-
with self._safe_raw_sql(create_stmt):
562-
pass
579+
con = self.con
580+
with con.cursor() as cursor, con.transaction():
581+
cursor.execute(create_stmt)
582+
563583
try:
564584
return self.get_schema(name)
565585
finally:
566-
with self._safe_raw_sql(drop_stmt):
567-
pass
586+
with con.cursor() as cursor, con.transaction():
587+
cursor.execute(drop_stmt)
568588

569589
def create_database(
570590
self, name: str, /, *, catalog: str | None = None, force: bool = False
@@ -575,9 +595,10 @@ def create_database(
575595
)
576596
sql = sge.Create(
577597
kind="SCHEMA", this=sg.table(name, catalog=catalog), exists=force
578-
)
579-
with self._safe_raw_sql(sql):
580-
pass
598+
).sql(self.dialect)
599+
con = self.con
600+
with con.cursor() as cursor, con.transaction():
601+
cursor.execute(sql)
581602

582603
def drop_database(
583604
self,
@@ -598,9 +619,11 @@ def drop_database(
598619
this=sg.table(name, catalog=catalog),
599620
exists=force,
600621
cascade=cascade,
601-
)
602-
with self._safe_raw_sql(sql):
603-
pass
622+
).sql(self.dialect)
623+
624+
con = self.con
625+
with con.cursor() as cursor, con.transaction():
626+
cursor.execute(sql)
604627

605628
def create_table(
606629
self,
@@ -679,19 +702,24 @@ def create_table(
679702
kind="TABLE",
680703
this=target,
681704
properties=sge.Properties(expressions=properties),
682-
)
705+
).sql(dialect)
683706

684707
this = sg.table(name, catalog=database, quoted=quoted)
685708
this_no_catalog = sg.table(name, quoted=quoted)
686709

687-
with self._safe_raw_sql(create_stmt) as cur:
710+
con = self.con
711+
with con.cursor() as cursor, con.transaction():
712+
cursor.execute(create_stmt)
713+
688714
if query is not None:
689715
insert_stmt = sge.Insert(this=table_expr, expression=query).sql(dialect)
690-
cur.execute(insert_stmt)
716+
cursor.execute(insert_stmt)
691717

692718
if overwrite:
693-
cur.execute(sge.Drop(kind="TABLE", this=this, exists=True).sql(dialect))
694-
cur.execute(
719+
cursor.execute(
720+
sge.Drop(kind="TABLE", this=this, exists=True).sql(dialect)
721+
)
722+
cursor.execute(
695723
f"ALTER TABLE IF EXISTS {table_expr.sql(dialect)} RENAME TO {this_no_catalog.sql(dialect)}"
696724
)
697725

@@ -715,16 +743,17 @@ def drop_table(
715743
kind="TABLE",
716744
this=sg.table(name, db=database, quoted=self.compiler.quoted),
717745
exists=force,
718-
)
719-
with self._safe_raw_sql(drop_stmt):
720-
pass
746+
).sql(self.dialect)
747+
con = self.con
748+
with con.cursor() as cursor, con.transaction():
749+
cursor.execute(drop_stmt)
721750

722751
@contextlib.contextmanager
723752
def _safe_raw_sql(self, query: str | sg.Expression, **kwargs: Any):
724753
with contextlib.suppress(AttributeError):
725754
query = query.sql(dialect=self.dialect)
726755

727-
with self.begin() as cursor:
756+
with (con := self.con).cursor() as cursor, con.transaction():
728757
yield cursor.execute(query, **kwargs)
729758

730759
def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any:
@@ -757,11 +786,13 @@ def to_pyarrow_batches(
757786
import pyarrow as pa
758787

759788
def _batches(self: Self, *, schema: pa.Schema, query: str):
789+
con = self.con
760790
columns = schema.names
761791
# server-side cursors need to be uniquely named
762-
with self.begin(
763-
name=util.gen_name("postgres_cursor"), withhold=True
764-
) as cursor:
792+
with (
793+
con.cursor(name=util.gen_name("postgres_cursor")) as cursor,
794+
con.transaction(),
795+
):
765796
cursor.execute(query)
766797
while batch := cursor.fetchmany(chunk_size):
767798
yield pa.RecordBatch.from_pandas(

ibis/backends/postgres/tests/test_client.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def test_pgvector_type_load(con, vector_size):
290290
CREATE TABLE itemsvrandom (id bigserial PRIMARY KEY, embedding vector({vector_size}));
291291
"""
292292

293-
with con.raw_sql(query):
293+
with con._safe_raw_sql(query):
294294
pass
295295

296296
t = con.table("itemsvrandom")
@@ -412,7 +412,9 @@ def test_create_geospatial_table_with_srid(con):
412412
f"{column} geometry({dtype}, 4326)"
413413
for column, dtype in zip(column_names, column_types)
414414
)
415-
con.raw_sql(f"CREATE TEMP TABLE {name} ({schema_string})")
415+
with con._safe_raw_sql(f"CREATE TEMP TABLE {name} ({schema_string})"):
416+
pass
417+
416418
schema = con.get_schema(name)
417419
assert schema == ibis.schema(
418420
{
@@ -425,11 +427,11 @@ def test_create_geospatial_table_with_srid(con):
425427
@pytest.fixture(scope="module")
426428
def enum_table(con):
427429
name = gen_name("enum_table")
428-
con.raw_sql("CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy')")
429-
con.raw_sql(f"CREATE TEMP TABLE {name} (mood mood)")
430-
yield name
431-
con.raw_sql(f"DROP TABLE {name}")
432-
con.raw_sql("DROP TYPE mood")
430+
with con._safe_raw_sql("CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy')") as cur:
431+
cur.execute(f"CREATE TEMP TABLE {name} (mood mood)")
432+
yield name
433+
cur.execute(f"DROP TABLE {name}")
434+
cur.execute("DROP TYPE mood")
433435

434436

435437
def test_enum_table(con, enum_table):

ibis/backends/tests/test_client.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
PsycoPgUndefinedObject,
3535
Py4JJavaError,
3636
PyAthenaDatabaseError,
37+
PyDruidProgrammingError,
3738
PyODBCProgrammingError,
3839
SnowflakeProgrammingError,
3940
)
@@ -1343,12 +1344,11 @@ def test_create_table_timestamp(con, temp_table):
13431344
schema = ibis.schema(
13441345
dict(zip(string.ascii_letters, map("timestamp({:d})".format, range(10))))
13451346
)
1346-
con.create_table(
1347-
temp_table,
1348-
schema=schema,
1349-
overwrite=True,
1350-
)
1351-
rows = con.raw_sql(f"DESCRIBE {temp_table}").fetchall()
1347+
con.create_table(temp_table, schema=schema, overwrite=True)
1348+
1349+
with con._safe_raw_sql(f"DESCRIBE {temp_table}") as cur:
1350+
rows = cur.fetchall()
1351+
13521352
result = ibis.schema((name, typ) for name, typ, *_ in rows)
13531353
assert result == schema
13541354

@@ -1723,12 +1723,11 @@ def test_cross_database_join(con_create_database, monkeypatch):
17231723

17241724

17251725
@pytest.mark.notimpl(
1726-
["druid"], raises=AttributeError, reason="doesn't implement `raw_sql`"
1726+
["druid"], raises=PyDruidProgrammingError, reason="doesn't implement CREATE syntax"
17271727
)
17281728
@pytest.mark.notimpl(["clickhouse"], reason="create table isn't implemented")
17291729
@pytest.mark.notyet(["flink"], raises=Py4JJavaError)
17301730
@pytest.mark.notyet(["polars"], reason="Doesn't support insert")
1731-
@pytest.mark.notyet(["exasol"], reason="Backend does not support raw_sql")
17321731
@pytest.mark.notimpl(
17331732
["impala", "pyspark", "trino"], reason="Default constraints are not supported"
17341733
)
@@ -1750,7 +1749,8 @@ def test_insert_into_table_missing_columns(con, temp_table):
17501749

17511750
ct_sql = f'CREATE TABLE {raw_ident} ("a" INT DEFAULT 1, "b" INT)'
17521751
sg_expr = sg.parse_one(ct_sql, read="duckdb")
1753-
con.raw_sql(sg_expr.sql(dialect=con.dialect))
1752+
with con._safe_raw_sql(sg_expr.sql(dialect=con.dialect)):
1753+
pass
17541754
con.insert(temp_table, [{"b": 1}])
17551755

17561756
result = con.table(temp_table).to_pyarrow().to_pydict()

0 commit comments

Comments
 (0)