Skip to content

Commit a453939

Browse files
committed
feat: Save the session instance to a context variable
1 parent 7ab2104 commit a453939

File tree

4 files changed

+269
-64
lines changed

4 files changed

+269
-64
lines changed

sqlalchemy_database/database.py

Lines changed: 174 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from contextvars import ContextVar
2+
from threading import Lock
13
from typing import (
24
Any,
35
AsyncGenerator,
@@ -67,6 +69,50 @@ def __init__(self, engine: AsyncEngine, **session_options):
6769
await session.commit()
6870
```
6971
"""
72+
self._session_lock = Lock()
73+
self._session_enter_count = 0
74+
self._session_context_var: ContextVar[Optional[AsyncSession]] = ContextVar("_session_context_var", default=None)
75+
76+
@property
77+
def session(self) -> Optional[AsyncSession]:
78+
"""Return an instance of Session local to the current async context.
79+
80+
Note: Must register middleware in fastapi application to get session in request.
81+
82+
Example:
83+
```Python
84+
app = FastAPI()
85+
app.add_middleware(BaseHTTPMiddleware, db.asgi_dispatch)
86+
87+
@app.get('/get_user')
88+
async def get_user(id:int):
89+
return await db.session.get(User,id)
90+
```
91+
In ordinary methods, session will return None. You can get it through:
92+
```Python
93+
async with db:
94+
db.session.get(User,id)
95+
```
96+
"""
97+
return self._session_context_var.get() if self._session_enter_count > 0 else None
98+
99+
async def __aenter__(self):
100+
with self._session_lock:
101+
session = self.session
102+
if session is None:
103+
session = self.session_maker()
104+
self._session_context_var_token = self._session_context_var.set(session)
105+
self._session_enter_count += 1
106+
return session
107+
108+
async def __aexit__(self, exc_type, exc_value, traceback):
109+
with self._session_lock:
110+
self._session_enter_count -= 1
111+
if self._session_enter_count <= 0:
112+
session = self._session_context_var.get()
113+
if session is not None:
114+
await session.close()
115+
self._session_context_var.reset(self._session_context_var_token)
70116

71117
@classmethod
72118
def create(cls, url: str, *, session_options: Mapping[str, Any] = None, **kwargs) -> "AsyncDatabase":
@@ -159,7 +205,11 @@ async def execute(
159205
if executor is None or not isinstance(executor, (AsyncSession, AsyncConnection)):
160206
need_close = True
161207
if is_session:
162-
executor = self.session_maker()
208+
executor = self.session
209+
if executor is None:
210+
executor = self.session_maker()
211+
else:
212+
need_close = False
163213
kw["bind_arguments"] = bind_arguments
164214
else:
165215
executor = await self.engine.connect()
@@ -189,12 +239,12 @@ async def scalar(
189239
"""
190240
need_close = False
191241
if session is None or not isinstance(session, AsyncSession):
192-
need_close = True
193-
executor = self.session_maker()
194-
else:
195-
executor = session
196-
async with ExecutorContextManager(executor, need_close=need_close) as executor:
197-
result = await executor.scalar(
242+
session = self.session
243+
if session is None:
244+
need_close = True
245+
session = self.session_maker()
246+
async with ExecutorContextManager(session, need_close=need_close) as session:
247+
result = await session.scalar(
198248
statement,
199249
params,
200250
execution_options=execution_options,
@@ -220,13 +270,13 @@ async def scalars_all(
220270
"""
221271
need_close = False
222272
if session is None or not isinstance(session, AsyncSession):
223-
need_close = True
224-
executor = self.session_maker()
225-
else:
226-
executor = session
227-
async with ExecutorContextManager(executor, need_close=need_close) as executor:
273+
session = self.session
274+
if session is None:
275+
need_close = True
276+
session = self.session_maker()
277+
async with ExecutorContextManager(session, need_close=need_close) as session:
228278
result = (
229-
await executor.scalars(
279+
await session.scalars(
230280
statement,
231281
params,
232282
execution_options=execution_options,
@@ -289,12 +339,12 @@ async def get(
289339
"""
290340
need_close = False
291341
if session is None or not isinstance(session, AsyncSession):
292-
need_close = True
293-
executor = self.session_maker()
294-
else:
295-
executor = session
296-
async with ExecutorContextManager(executor, need_close=need_close) as executor:
297-
result = await executor.get(
342+
session = self.session
343+
if session is None:
344+
need_close = True
345+
session = self.session_maker()
346+
async with ExecutorContextManager(session, need_close=need_close) as session:
347+
result = await session.get(
298348
entity,
299349
ident,
300350
options=options,
@@ -306,9 +356,13 @@ async def get(
306356

307357
async def delete(self, instance: Any) -> None:
308358
"""Deletes an instance object."""
309-
async with self.session_maker() as session:
310-
async with session.begin():
311-
await session.delete(instance)
359+
if self.session is not None:
360+
await self.session.delete(instance)
361+
await self.session.commit()
362+
else:
363+
async with self.session_maker() as session:
364+
async with session.begin():
365+
await session.delete(instance)
312366

313367
async def save(self, *instances: Any, refresh: bool = False, session: Optional[AsyncSession] = None) -> None:
314368
"""
@@ -320,15 +374,15 @@ async def save(self, *instances: Any, refresh: bool = False, session: Optional[A
320374
"""
321375
need_close = False
322376
if session is None or not isinstance(session, AsyncSession):
323-
need_close = True
324-
executor = self.session_maker()
325-
else:
326-
executor = session
327-
async with ExecutorContextManager(executor, need_close=need_close) as executor:
328-
executor.add_all(instances)
329-
await executor.commit()
377+
session = self.session
378+
if session is None:
379+
need_close = True
380+
session = self.session_maker()
381+
async with ExecutorContextManager(session, need_close=need_close) as session:
382+
session.add_all(instances)
383+
await session.commit()
330384
if refresh:
331-
[await executor.refresh(instance) for instance in instances]
385+
[await session.refresh(instance) for instance in instances]
332386

333387
async def run_sync(
334388
self,
@@ -374,9 +428,17 @@ def get_user(session:Session,id:int):
374428
APIs which will be properly adapted to the greenlet context.
375429
"""
376430
need_close = False
431+
if executor is None and is_session:
432+
executor = self.session
377433
if executor is None or not isinstance(executor, (AsyncSession, AsyncConnection)):
378-
need_close = True
379-
executor = self.session_maker() if is_session else await self.engine.connect()
434+
if is_session:
435+
executor = self.session
436+
if executor is None:
437+
executor = self.session_maker()
438+
need_close = True
439+
else:
440+
executor = await self.engine.connect()
441+
need_close = True
380442
async with ExecutorContextManager(executor, need_close=need_close) as executor:
381443
result = await executor.run_sync(fn, *args, **kwargs)
382444
if on_close_pre:
@@ -394,6 +456,38 @@ def __init__(self, engine: Engine, **session_options):
394456
self.engine: Engine = engine
395457
session_options.setdefault("class_", Session)
396458
self.session_maker: Callable[..., Session] = sessionmaker(self.engine, **session_options)
459+
self._session_lock = Lock()
460+
self._session_enter_count = 0
461+
self._session_context_var: ContextVar[Optional[Session]] = ContextVar("_session_context_var", default=None)
462+
463+
@property
464+
def session(self) -> Optional[Session]:
465+
"""Return an instance of Session local to the current context."""
466+
return self._session_context_var.get() if self._session_enter_count > 0 else None
467+
468+
def __enter__(self):
469+
with self._session_lock:
470+
session = self.session
471+
if session is None:
472+
session = self.session_maker()
473+
self._session_context_var_token = self._session_context_var.set(session)
474+
self._session_enter_count += 1
475+
return session
476+
477+
def __exit__(self, exc_type, exc_value, traceback):
478+
with self._session_lock:
479+
self._session_enter_count -= 1
480+
if self._session_enter_count <= 0:
481+
session = self._session_context_var.get()
482+
if session is not None:
483+
session.close()
484+
self._session_context_var.reset(self._session_context_var_token)
485+
486+
async def __aenter__(self):
487+
return self.__enter__()
488+
489+
async def __aexit__(self, exc_type, exc_value, traceback):
490+
return self.__exit__(exc_type, exc_value, traceback)
397491

398492
@classmethod
399493
def create(cls, url: str, *, session_options: Optional[Mapping[str, Any]] = None, **kwargs) -> "Database":
@@ -420,10 +514,16 @@ def execute(
420514
**kw: Any,
421515
) -> Union[Result, _T]:
422516
need_close = False
517+
if executor is None and is_session:
518+
executor = self.session
423519
if executor is None or not isinstance(executor, (Session, Connection)):
424520
need_close = True
425521
if is_session:
426-
executor = self.session_maker()
522+
executor = self.session
523+
if executor is None:
524+
executor = self.session_maker()
525+
else:
526+
need_close = False
427527
kw["bind_arguments"] = bind_arguments
428528
else:
429529
executor = self.engine.connect()
@@ -447,12 +547,12 @@ def scalar(
447547
) -> Any:
448548
need_close = False
449549
if session is None or not isinstance(session, Session):
450-
need_close = True
451-
executor = self.session_maker()
452-
else:
453-
executor = session
454-
with ExecutorContextManager(executor, need_close=need_close) as executor:
455-
result = executor.scalar(
550+
session = self.session
551+
if session is None:
552+
need_close = True
553+
session = self.session_maker()
554+
with ExecutorContextManager(session, need_close=need_close) as session:
555+
result = session.scalar(
456556
statement,
457557
params,
458558
execution_options=execution_options,
@@ -473,12 +573,12 @@ def scalars_all(
473573
) -> List[Any]:
474574
need_close = False
475575
if session is None or not isinstance(session, Session):
476-
need_close = True
477-
executor = self.session_maker()
478-
else:
479-
executor = session
480-
with ExecutorContextManager(executor, need_close=need_close) as executor:
481-
result = executor.scalars(
576+
session = self.session
577+
if session is None:
578+
need_close = True
579+
session = self.session_maker()
580+
with ExecutorContextManager(session, need_close=need_close) as session:
581+
result = session.scalars(
482582
statement,
483583
params,
484584
execution_options=execution_options,
@@ -501,12 +601,12 @@ def get(
501601
) -> Optional[_T]:
502602
need_close = False
503603
if session is None or not isinstance(session, Session):
504-
need_close = True
505-
executor = self.session_maker()
506-
else:
507-
executor = session
508-
with ExecutorContextManager(executor, need_close=need_close) as executor:
509-
result = executor.get(
604+
session = self.session
605+
if session is None:
606+
need_close = True
607+
session = self.session_maker()
608+
with ExecutorContextManager(session, need_close=need_close) as session:
609+
result = session.get(
510610
entity,
511611
ident,
512612
options=options,
@@ -517,22 +617,26 @@ def get(
517617
return result
518618

519619
def delete(self, instance: Any) -> None:
520-
with self.session_maker() as session:
521-
with session.begin():
522-
session.delete(instance)
620+
if self.session is not None:
621+
self.session.delete(instance)
622+
self.session.commit()
623+
else:
624+
with self.session_maker() as session:
625+
with session.begin():
626+
session.delete(instance)
523627

524628
def save(self, *instances: Any, refresh: bool = False, session: Optional[Session] = None) -> None:
525629
need_close = False
526630
if session is None or not isinstance(session, Session):
527-
need_close = True
528-
executor = self.session_maker()
529-
else:
530-
executor = session
531-
with ExecutorContextManager(executor, need_close=need_close) as executor:
532-
executor.add_all(instances)
533-
executor.commit()
631+
session = self.session
632+
if session is None:
633+
need_close = True
634+
session = self.session_maker()
635+
with ExecutorContextManager(session, need_close=need_close) as session:
636+
session.add_all(instances)
637+
session.commit()
534638
if refresh:
535-
[executor.refresh(instance) for instance in instances]
639+
[session.refresh(instance) for instance in instances]
536640

537641
def run_sync(
538642
self,
@@ -547,7 +651,14 @@ def run_sync(
547651
need_close = False
548652
if executor is None or not isinstance(executor, (Session, Connection)):
549653
need_close = True
550-
executor = self.session_maker() if is_session else self.engine.connect()
654+
if is_session:
655+
executor = self.session
656+
if executor is None:
657+
executor = self.session_maker()
658+
else:
659+
need_close = False
660+
else:
661+
executor = self.engine.connect()
551662
with ExecutorContextManager(executor, need_close=need_close) as executor:
552663
result = fn(executor, *args, **kwargs)
553664
if on_close_pre:

tests/test_AbcAsyncDatabase.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from sqlalchemy.orm import Session
77

88
from sqlalchemy_database import AsyncDatabase, Database
9-
from tests.conftest import Base, User, async_db, sync_db
9+
from tests.conftest import Base, Group, User, async_db, sync_db
1010

1111

1212
@pytest.fixture(params=[async_db, sync_db])
@@ -124,3 +124,16 @@ def get_user(session: Session, user_id: int):
124124

125125
user_id = await db.async_run_sync(get_user, 2, on_close_pre=lambda r: r.id)
126126
assert user_id == 2
127+
128+
129+
async def test_async_session_context_var(db, fake_users):
130+
async with db:
131+
# test db function
132+
user = await db.async_get(User, 1)
133+
assert user.id == 1
134+
group = Group(name="group1")
135+
await db.async_save(group, refresh=True)
136+
assert group.id == 1
137+
user.group_id = group.id
138+
139+
assert db.session is None

0 commit comments

Comments
 (0)