1+ from contextvars import ContextVar
2+ from threading import Lock
13from typing import (
24 Any ,
35 AsyncGenerator ,
@@ -67,6 +69,50 @@ def __init__(self, engine: AsyncEngine, **session_options):
6769 await session.commit()
6870 ```
6971 """
72+ self ._session_lock = Lock ()
73+ self ._session_enter_count = 0
74+ self ._session_context_var : ContextVar [Optional [AsyncSession ]] = ContextVar ("_session_context_var" , default = None )
75+
76+ @property
77+ def session (self ) -> Optional [AsyncSession ]:
78+ """Return an instance of Session local to the current async context.
79+
80+ Note: Must register middleware in fastapi application to get session in request.
81+
82+ Example:
83+ ```Python
84+ app = FastAPI()
85+ app.add_middleware(BaseHTTPMiddleware, db.asgi_dispatch)
86+
87+ @app.get('/get_user')
88+ async def get_user(id:int):
89+ return await db.session.get(User,id)
90+ ```
91+ In ordinary methods, session will return None. You can get it through:
92+ ```Python
93+ async with db:
94+ db.session.get(User,id)
95+ ```
96+ """
97+ return self ._session_context_var .get () if self ._session_enter_count > 0 else None
98+
99+ async def __aenter__ (self ):
100+ with self ._session_lock :
101+ session = self .session
102+ if session is None :
103+ session = self .session_maker ()
104+ self ._session_context_var_token = self ._session_context_var .set (session )
105+ self ._session_enter_count += 1
106+ return session
107+
108+ async def __aexit__ (self , exc_type , exc_value , traceback ):
109+ with self ._session_lock :
110+ self ._session_enter_count -= 1
111+ if self ._session_enter_count <= 0 :
112+ session = self ._session_context_var .get ()
113+ if session is not None :
114+ await session .close ()
115+ self ._session_context_var .reset (self ._session_context_var_token )
70116
71117 @classmethod
72118 def create (cls , url : str , * , session_options : Mapping [str , Any ] = None , ** kwargs ) -> "AsyncDatabase" :
@@ -159,7 +205,11 @@ async def execute(
159205 if executor is None or not isinstance (executor , (AsyncSession , AsyncConnection )):
160206 need_close = True
161207 if is_session :
162- executor = self .session_maker ()
208+ executor = self .session
209+ if executor is None :
210+ executor = self .session_maker ()
211+ else :
212+ need_close = False
163213 kw ["bind_arguments" ] = bind_arguments
164214 else :
165215 executor = await self .engine .connect ()
@@ -189,12 +239,12 @@ async def scalar(
189239 """
190240 need_close = False
191241 if session is None or not isinstance (session , AsyncSession ):
192- need_close = True
193- executor = self . session_maker ()
194- else :
195- executor = session
196- async with ExecutorContextManager (executor , need_close = need_close ) as executor :
197- result = await executor .scalar (
242+ session = self . session
243+ if session is None :
244+ need_close = True
245+ session = self . session_maker ()
246+ async with ExecutorContextManager (session , need_close = need_close ) as session :
247+ result = await session .scalar (
198248 statement ,
199249 params ,
200250 execution_options = execution_options ,
@@ -220,13 +270,13 @@ async def scalars_all(
220270 """
221271 need_close = False
222272 if session is None or not isinstance (session , AsyncSession ):
223- need_close = True
224- executor = self . session_maker ()
225- else :
226- executor = session
227- async with ExecutorContextManager (executor , need_close = need_close ) as executor :
273+ session = self . session
274+ if session is None :
275+ need_close = True
276+ session = self . session_maker ()
277+ async with ExecutorContextManager (session , need_close = need_close ) as session :
228278 result = (
229- await executor .scalars (
279+ await session .scalars (
230280 statement ,
231281 params ,
232282 execution_options = execution_options ,
@@ -289,12 +339,12 @@ async def get(
289339 """
290340 need_close = False
291341 if session is None or not isinstance (session , AsyncSession ):
292- need_close = True
293- executor = self . session_maker ()
294- else :
295- executor = session
296- async with ExecutorContextManager (executor , need_close = need_close ) as executor :
297- result = await executor .get (
342+ session = self . session
343+ if session is None :
344+ need_close = True
345+ session = self . session_maker ()
346+ async with ExecutorContextManager (session , need_close = need_close ) as session :
347+ result = await session .get (
298348 entity ,
299349 ident ,
300350 options = options ,
@@ -306,9 +356,13 @@ async def get(
306356
307357 async def delete (self , instance : Any ) -> None :
308358 """Deletes an instance object."""
309- async with self .session_maker () as session :
310- async with session .begin ():
311- await session .delete (instance )
359+ if self .session is not None :
360+ await self .session .delete (instance )
361+ await self .session .commit ()
362+ else :
363+ async with self .session_maker () as session :
364+ async with session .begin ():
365+ await session .delete (instance )
312366
313367 async def save (self , * instances : Any , refresh : bool = False , session : Optional [AsyncSession ] = None ) -> None :
314368 """
@@ -320,15 +374,15 @@ async def save(self, *instances: Any, refresh: bool = False, session: Optional[A
320374 """
321375 need_close = False
322376 if session is None or not isinstance (session , AsyncSession ):
323- need_close = True
324- executor = self . session_maker ()
325- else :
326- executor = session
327- async with ExecutorContextManager (executor , need_close = need_close ) as executor :
328- executor .add_all (instances )
329- await executor .commit ()
377+ session = self . session
378+ if session is None :
379+ need_close = True
380+ session = self . session_maker ()
381+ async with ExecutorContextManager (session , need_close = need_close ) as session :
382+ session .add_all (instances )
383+ await session .commit ()
330384 if refresh :
331- [await executor .refresh (instance ) for instance in instances ]
385+ [await session .refresh (instance ) for instance in instances ]
332386
333387 async def run_sync (
334388 self ,
@@ -374,9 +428,17 @@ def get_user(session:Session,id:int):
374428 APIs which will be properly adapted to the greenlet context.
375429 """
376430 need_close = False
431+ if executor is None and is_session :
432+ executor = self .session
377433 if executor is None or not isinstance (executor , (AsyncSession , AsyncConnection )):
378- need_close = True
379- executor = self .session_maker () if is_session else await self .engine .connect ()
434+ if is_session :
435+ executor = self .session
436+ if executor is None :
437+ executor = self .session_maker ()
438+ need_close = True
439+ else :
440+ executor = await self .engine .connect ()
441+ need_close = True
380442 async with ExecutorContextManager (executor , need_close = need_close ) as executor :
381443 result = await executor .run_sync (fn , * args , ** kwargs )
382444 if on_close_pre :
@@ -394,6 +456,38 @@ def __init__(self, engine: Engine, **session_options):
394456 self .engine : Engine = engine
395457 session_options .setdefault ("class_" , Session )
396458 self .session_maker : Callable [..., Session ] = sessionmaker (self .engine , ** session_options )
459+ self ._session_lock = Lock ()
460+ self ._session_enter_count = 0
461+ self ._session_context_var : ContextVar [Optional [Session ]] = ContextVar ("_session_context_var" , default = None )
462+
463+ @property
464+ def session (self ) -> Optional [Session ]:
465+ """Return an instance of Session local to the current context."""
466+ return self ._session_context_var .get () if self ._session_enter_count > 0 else None
467+
468+ def __enter__ (self ):
469+ with self ._session_lock :
470+ session = self .session
471+ if session is None :
472+ session = self .session_maker ()
473+ self ._session_context_var_token = self ._session_context_var .set (session )
474+ self ._session_enter_count += 1
475+ return session
476+
477+ def __exit__ (self , exc_type , exc_value , traceback ):
478+ with self ._session_lock :
479+ self ._session_enter_count -= 1
480+ if self ._session_enter_count <= 0 :
481+ session = self ._session_context_var .get ()
482+ if session is not None :
483+ session .close ()
484+ self ._session_context_var .reset (self ._session_context_var_token )
485+
486+ async def __aenter__ (self ):
487+ return self .__enter__ ()
488+
489+ async def __aexit__ (self , exc_type , exc_value , traceback ):
490+ return self .__exit__ (exc_type , exc_value , traceback )
397491
398492 @classmethod
399493 def create (cls , url : str , * , session_options : Optional [Mapping [str , Any ]] = None , ** kwargs ) -> "Database" :
@@ -420,10 +514,16 @@ def execute(
420514 ** kw : Any ,
421515 ) -> Union [Result , _T ]:
422516 need_close = False
517+ if executor is None and is_session :
518+ executor = self .session
423519 if executor is None or not isinstance (executor , (Session , Connection )):
424520 need_close = True
425521 if is_session :
426- executor = self .session_maker ()
522+ executor = self .session
523+ if executor is None :
524+ executor = self .session_maker ()
525+ else :
526+ need_close = False
427527 kw ["bind_arguments" ] = bind_arguments
428528 else :
429529 executor = self .engine .connect ()
@@ -447,12 +547,12 @@ def scalar(
447547 ) -> Any :
448548 need_close = False
449549 if session is None or not isinstance (session , Session ):
450- need_close = True
451- executor = self . session_maker ()
452- else :
453- executor = session
454- with ExecutorContextManager (executor , need_close = need_close ) as executor :
455- result = executor .scalar (
550+ session = self . session
551+ if session is None :
552+ need_close = True
553+ session = self . session_maker ()
554+ with ExecutorContextManager (session , need_close = need_close ) as session :
555+ result = session .scalar (
456556 statement ,
457557 params ,
458558 execution_options = execution_options ,
@@ -473,12 +573,12 @@ def scalars_all(
473573 ) -> List [Any ]:
474574 need_close = False
475575 if session is None or not isinstance (session , Session ):
476- need_close = True
477- executor = self . session_maker ()
478- else :
479- executor = session
480- with ExecutorContextManager (executor , need_close = need_close ) as executor :
481- result = executor .scalars (
576+ session = self . session
577+ if session is None :
578+ need_close = True
579+ session = self . session_maker ()
580+ with ExecutorContextManager (session , need_close = need_close ) as session :
581+ result = session .scalars (
482582 statement ,
483583 params ,
484584 execution_options = execution_options ,
@@ -501,12 +601,12 @@ def get(
501601 ) -> Optional [_T ]:
502602 need_close = False
503603 if session is None or not isinstance (session , Session ):
504- need_close = True
505- executor = self . session_maker ()
506- else :
507- executor = session
508- with ExecutorContextManager (executor , need_close = need_close ) as executor :
509- result = executor .get (
604+ session = self . session
605+ if session is None :
606+ need_close = True
607+ session = self . session_maker ()
608+ with ExecutorContextManager (session , need_close = need_close ) as session :
609+ result = session .get (
510610 entity ,
511611 ident ,
512612 options = options ,
@@ -517,22 +617,26 @@ def get(
517617 return result
518618
519619 def delete (self , instance : Any ) -> None :
520- with self .session_maker () as session :
521- with session .begin ():
522- session .delete (instance )
620+ if self .session is not None :
621+ self .session .delete (instance )
622+ self .session .commit ()
623+ else :
624+ with self .session_maker () as session :
625+ with session .begin ():
626+ session .delete (instance )
523627
524628 def save (self , * instances : Any , refresh : bool = False , session : Optional [Session ] = None ) -> None :
525629 need_close = False
526630 if session is None or not isinstance (session , Session ):
527- need_close = True
528- executor = self . session_maker ()
529- else :
530- executor = session
531- with ExecutorContextManager (executor , need_close = need_close ) as executor :
532- executor .add_all (instances )
533- executor .commit ()
631+ session = self . session
632+ if session is None :
633+ need_close = True
634+ session = self . session_maker ()
635+ with ExecutorContextManager (session , need_close = need_close ) as session :
636+ session .add_all (instances )
637+ session .commit ()
534638 if refresh :
535- [executor .refresh (instance ) for instance in instances ]
639+ [session .refresh (instance ) for instance in instances ]
536640
537641 def run_sync (
538642 self ,
@@ -547,7 +651,14 @@ def run_sync(
547651 need_close = False
548652 if executor is None or not isinstance (executor , (Session , Connection )):
549653 need_close = True
550- executor = self .session_maker () if is_session else self .engine .connect ()
654+ if is_session :
655+ executor = self .session
656+ if executor is None :
657+ executor = self .session_maker ()
658+ else :
659+ need_close = False
660+ else :
661+ executor = self .engine .connect ()
551662 with ExecutorContextManager (executor , need_close = need_close ) as executor :
552663 result = fn (executor , * args , ** kwargs )
553664 if on_close_pre :
0 commit comments