Skip to content

Commit c007370

Browse files
committed
fix: Use a thread to commit and close the sync session.
1 parent 07dfaf3 commit c007370

File tree

1 file changed

+29
-21
lines changed

1 file changed

+29
-21
lines changed

sqlalchemy_database/database.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
Generator,
88
Mapping,
99
Optional,
10-
Sequence,
1110
TypeVar,
1211
Union,
1312
)
@@ -29,15 +28,12 @@
2928
from sqlalchemy.orm import Session
3029
from sqlalchemy.ext.asyncio import AsyncSession
3130

32-
from sqlalchemy_database._abc_async_database import AbcAsyncDatabase
31+
from sqlalchemy_database._abc_async_database import AbcAsyncDatabase, to_thread
3332

3433
_P = ParamSpec("_P")
3534
_T = TypeVar("_T")
3635
_R = TypeVar("_R")
3736

38-
_ExecuteParams = Union[Mapping[Any, Any], Sequence[Mapping[Any, Any]]]
39-
_ExecuteOptions = Mapping[Any, Any]
40-
4137

4238
class AsyncDatabase(AbcAsyncDatabase):
4339
"""`sqlalchemy` asynchronous database client"""
@@ -258,7 +254,13 @@ def run_sync(
258254

259255
def asyncify(self, db: Union[AsyncSession, AsyncDatabase], fn: Callable[_P, _T]) -> Callable[_P, Awaitable[_T]]:
260256
"""Convert the given sync function that runs in the context of a sync session to
261-
an async function that runs in the context of an async session."""
257+
an async function that runs in the context of an async session.
258+
Args:
259+
db: Async database client or session
260+
fn: sync function
261+
Returns:
262+
Returns the async function.
263+
"""
262264
session = db if isinstance(db, AsyncSession) else db.session
263265

264266
@functools.wraps(fn)
@@ -301,23 +303,29 @@ def __enter__(self):
301303
self._token = self.db._session_scope.set(self._scope)
302304
return self.db.session
303305

306+
def _close_session(self, session: Session, exc_type):
307+
if exc_type is not None:
308+
session.rollback()
309+
elif self.db.commit_on_exit:
310+
session.commit()
311+
session.close()
312+
304313
def __exit__(self, exc_type, exc_value, traceback):
305-
if isinstance(self._scope, self._SessionCls):
314+
if not (self._scope and isinstance(self._scope, self._SessionCls)):
306315
"""If the scope is a session, it will not be closed."""
307-
self.db.scoped_session.registry.clear()
308-
else:
309-
if exc_type is not None:
310-
self.db.session.rollback()
311-
elif self.db.commit_on_exit:
312-
self.db.session.commit()
313-
self.db.scoped_session.remove()
316+
self._close_session(self.db.session, exc_type)
317+
self.db.scoped_session.registry.clear()
314318
self.db._session_scope.reset(self._token)
315319

316320
async def __aenter__(self):
317321
return self.__enter__()
318322

319323
async def __aexit__(self, exc_type, exc_value, traceback):
320-
self.__exit__(exc_type, exc_value, traceback)
324+
if not (self._scope and isinstance(self._scope, self._SessionCls)):
325+
"""If the scope is a session, it will not be closed."""
326+
await to_thread(self._close_session, self.db.session, exc_type)
327+
self.db.scoped_session.registry.clear()
328+
self.db._session_scope.reset(self._token)
321329

322330

323331
class AsyncSessionContextVarManager(SessionContextVarManager):
@@ -331,13 +339,13 @@ def __exit__(self, exc_type, exc_val, exc_tb):
331339

332340
async def __aexit__(self, exc_type, exc_value, traceback):
333341
self.db: AsyncDatabase
334-
if isinstance(self._scope, AsyncSession):
342+
if not (self._scope and isinstance(self._scope, self._SessionCls)):
335343
"""If the scope is a session, it will not be closed."""
336-
self.db.scoped_session.registry.clear()
337-
else:
344+
session = self.db.session
338345
if exc_type is not None:
339-
await self.db.session.rollback()
346+
await session.rollback()
340347
elif self.db.commit_on_exit:
341-
await self.db.session.commit()
342-
await self.db.scoped_session.remove()
348+
await session.commit()
349+
await session.close()
350+
self.db.scoped_session.registry.clear()
343351
self.db._session_scope.reset(self._token)

0 commit comments

Comments
 (0)