Skip to content

Commit 9cbd9ed

Browse files
committed
Add transaction support
1 parent c8f8881 commit 9cbd9ed

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

django_mongodb_backend/base.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def _isnull_operator(a, b):
139139
introspection_class = DatabaseIntrospection
140140
ops_class = DatabaseOperations
141141
validation_class = DatabaseValidation
142+
session = None
142143

143144
def get_collection(self, name, **kwargs):
144145
collection = Collection(self.database, name, **kwargs)
@@ -190,13 +191,28 @@ def _driver_info(self):
190191
return None
191192

192193
def _commit(self):
193-
pass
194+
if self.session:
195+
self.session.commit_transaction()
196+
self.session.end_session()
197+
self.session = None
194198

195199
def _rollback(self):
196200
pass
197201

198-
def set_autocommit(self, autocommit, force_begin_transaction_with_broken_autocommit=False):
199-
self.autocommit = autocommit
202+
def _start_session(self):
203+
if self.session is None:
204+
self.session = self.connection.start_session()
205+
self.session.start_transaction()
206+
207+
def _start_transaction_under_autocommit(self):
208+
self._start_session()
209+
210+
def _set_autocommit(self, autocommit, force_begin_transaction_with_broken_autocommit=False):
211+
if not autocommit:
212+
self._start_session()
213+
else:
214+
if self.session:
215+
self.commit()
200216

201217
def _close(self):
202218
# Normally called by close(), this method is also called by some tests.

django_mongodb_backend/compiler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,9 @@ def execute_sql(self, returning_fields=None):
685685
@wrap_database_errors
686686
def insert(self, docs, returning_fields=None):
687687
"""Store a list of documents using field columns as element names."""
688-
inserted_ids = self.collection.insert_many(docs).inserted_ids
688+
inserted_ids = self.collection.insert_many(
689+
docs, session=self.connection.session
690+
).inserted_ids
689691
return [(x,) for x in inserted_ids] if returning_fields else []
690692

691693
@cached_property

0 commit comments

Comments
 (0)