77 Generator ,
88 Mapping ,
99 Optional ,
10- Sequence ,
1110 TypeVar ,
1211 Union ,
1312)
2928 from sqlalchemy .orm import Session
3029 from sqlalchemy .ext .asyncio import AsyncSession
3130
32- from sqlalchemy_database ._abc_async_database import AbcAsyncDatabase
31+ from sqlalchemy_database ._abc_async_database import AbcAsyncDatabase , to_thread
3332
3433_P = ParamSpec ("_P" )
3534_T = TypeVar ("_T" )
3635_R = TypeVar ("_R" )
3736
38- _ExecuteParams = Union [Mapping [Any , Any ], Sequence [Mapping [Any , Any ]]]
39- _ExecuteOptions = Mapping [Any , Any ]
40-
4137
4238class AsyncDatabase (AbcAsyncDatabase ):
4339 """`sqlalchemy` asynchronous database client"""
@@ -258,7 +254,13 @@ def run_sync(
258254
259255 def asyncify (self , db : Union [AsyncSession , AsyncDatabase ], fn : Callable [_P , _T ]) -> Callable [_P , Awaitable [_T ]]:
260256 """Convert the given sync function that runs in the context of a sync session to
261- an async function that runs in the context of an async session."""
257+ an async function that runs in the context of an async session.
258+ Args:
259+ db: Async database client or session
260+ fn: sync function
261+ Returns:
262+ Returns the async function.
263+ """
262264 session = db if isinstance (db , AsyncSession ) else db .session
263265
264266 @functools .wraps (fn )
@@ -301,23 +303,29 @@ def __enter__(self):
301303 self ._token = self .db ._session_scope .set (self ._scope )
302304 return self .db .session
303305
306+ def _close_session (self , session : Session , exc_type ):
307+ if exc_type is not None :
308+ session .rollback ()
309+ elif self .db .commit_on_exit :
310+ session .commit ()
311+ session .close ()
312+
304313 def __exit__ (self , exc_type , exc_value , traceback ):
305- if isinstance (self ._scope , self ._SessionCls ):
314+ if not ( self . _scope and isinstance (self ._scope , self ._SessionCls ) ):
306315 """If the scope is a session, it will not be closed."""
307- self .db .scoped_session .registry .clear ()
308- else :
309- if exc_type is not None :
310- self .db .session .rollback ()
311- elif self .db .commit_on_exit :
312- self .db .session .commit ()
313- self .db .scoped_session .remove ()
316+ self ._close_session (self .db .session , exc_type )
317+ self .db .scoped_session .registry .clear ()
314318 self .db ._session_scope .reset (self ._token )
315319
316320 async def __aenter__ (self ):
317321 return self .__enter__ ()
318322
319323 async def __aexit__ (self , exc_type , exc_value , traceback ):
320- self .__exit__ (exc_type , exc_value , traceback )
324+ if not (self ._scope and isinstance (self ._scope , self ._SessionCls )):
325+ """If the scope is a session, it will not be closed."""
326+ await to_thread (self ._close_session , self .db .session , exc_type )
327+ self .db .scoped_session .registry .clear ()
328+ self .db ._session_scope .reset (self ._token )
321329
322330
323331class AsyncSessionContextVarManager (SessionContextVarManager ):
@@ -331,13 +339,13 @@ def __exit__(self, exc_type, exc_val, exc_tb):
331339
332340 async def __aexit__ (self , exc_type , exc_value , traceback ):
333341 self .db : AsyncDatabase
334- if isinstance (self ._scope , AsyncSession ):
342+ if not ( self . _scope and isinstance (self ._scope , self . _SessionCls ) ):
335343 """If the scope is a session, it will not be closed."""
336- self .db .scoped_session .registry .clear ()
337- else :
344+ session = self .db .session
338345 if exc_type is not None :
339- await self . db . session .rollback ()
346+ await session .rollback ()
340347 elif self .db .commit_on_exit :
341- await self .db .session .commit ()
342- await self .db .scoped_session .remove ()
348+ await session .commit ()
349+ await session .close ()
350+ self .db .scoped_session .registry .clear ()
343351 self .db ._session_scope .reset (self ._token )
0 commit comments