Skip to content

Extract ConnectionManager to handle setup connection before using it … #1457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mongoengine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Import submodules so that we can expose their __all__
from mongoengine import connection
from mongoengine import connections_manager
from mongoengine import document
from mongoengine import errors
from mongoengine import fields
Expand All @@ -11,6 +12,7 @@
# users can simply use `from mongoengine import connect`, or even
# `from mongoengine import *` and then `connect('testdb')`.
from mongoengine.connection import *
from mongoengine.connections_manager import *
from mongoengine.document import *
from mongoengine.errors import *
from mongoengine.fields import *
Expand All @@ -20,7 +22,8 @@

__all__ = (list(document.__all__) + list(fields.__all__) +
list(connection.__all__) + list(queryset.__all__) +
list(signals.__all__) + list(errors.__all__))
list(signals.__all__) + list(errors.__all__) +
list(connections_manager.__all__))


VERSION = (0, 11, 0)
Expand Down
5 changes: 5 additions & 0 deletions mongoengine/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
SemiStrictDict, StrictDict)
from mongoengine.base.fields import ComplexBaseField
from mongoengine.common import _import_class
from mongoengine.connection import DEFAULT_CONNECTION_NAME
from mongoengine.errors import (FieldDoesNotExist, InvalidDocumentError,
LookUpError, OperationError, ValidationError)

Expand Down Expand Up @@ -667,6 +668,10 @@ def _get_collection_name(cls):
"""
return cls._meta.get('collection', None)

@classmethod
def _get_db_alias(cls):
return cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME)

@classmethod
def _from_son(cls, son, _auto_dereference=True, only_fields=None, created=False):
"""Create an instance of a Document (subclass) from a PyMongo
Expand Down
3 changes: 0 additions & 3 deletions mongoengine/base/metaclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,6 @@ def __new__(cls, name, bases, attrs):
(Document, EmbeddedDocument, DictField,
CachedReferenceField) = cls._import_classes()

if issubclass(new_class, Document):
new_class._collection = None

# Add class to the _document_registry
_document_registry[new_class._class_name] = new_class

Expand Down
110 changes: 110 additions & 0 deletions mongoengine/connections_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from collections import defaultdict

from mongoengine.connection import get_connection, get_db

__all__ = ['InvalidCollectionError', 'connection_manager']


class InvalidCollectionError(Exception):
pass


class ConnectionManager(object):
connections_registry = defaultdict(dict)

def get_and_setup(self, doc_cls, alias=None, collection_name=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could use a docstring, especially as a public method. What it doesn, when it should be used, etc.

if alias is None:
alias = doc_cls._get_db_alias()

if collection_name is None:
collection_name = doc_cls._get_collection_name()

_collection = self.connections_registry[alias].get(collection_name)
if not _collection:
_collection = self.get_collection(doc_cls, alias, collection_name)
if doc_cls._meta.get('auto_create_index', True):
doc_cls.ensure_indexes(_collection)
self.connections_registry[alias][collection_name] = _collection
return self.connections_registry[alias][collection_name]

@classmethod
def _get_db(cls, alias):
"""Some Model using other db_alias"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring here is pretty confusing. What is "Some Model"? Why is it using "other db_alias"? Why do we need a private method that only proxies a call to a public get_db?

return get_db(alias)

@classmethod
def get_collection(cls, doc_cls, alias=None, collection_name=None):
"""Returns the collection for the document."""

if alias is None:
alias = doc_cls._get_db_alias()

if collection_name is None:
collection_name = doc_cls._get_collection_name()

db = cls._get_db(alias=alias)

# Create collection as a capped collection if specified
if doc_cls._meta.get('max_size') or doc_cls._meta.get('max_documents'):
# Get max document limit and max byte size from meta
max_size = doc_cls._meta.get('max_size') or 10 * 2 ** 20 # 10MB default
max_documents = doc_cls._meta.get('max_documents')
# Round up to next 256 bytes as MongoDB would do it to avoid exception
if max_size % 256:
max_size = (max_size // 256 + 1) * 256

if collection_name in db.collection_names():
_collection = db[collection_name]
# The collection already exists, check if its capped
# options match the specified capped options
options = _collection.options()
if (
options.get('max') != max_documents or
options.get('size') != max_size
):
msg = (('Cannot create collection "%s" as a capped '
'collection as it already exists')
% _collection)
raise InvalidCollectionError(msg)
else:
# Create the collection as a capped collection
opts = {'capped': True, 'size': max_size}
if max_documents:
opts['max'] = max_documents
_collection = db.create_collection(
collection_name, **opts
)
else:
_collection = db[collection_name]
return _collection

def drop_collection(self, doc_cls, alias, collection_name):
if alias is None:
alias = doc_cls._get_db_alias()

if collection_name is None:
collection_name = doc_cls._get_collection_name()

if not collection_name:
from mongoengine.queryset import OperationError
raise OperationError('Document %s has no collection defined '
'(is it abstract ?)' % doc_cls)

self.connections_registry[alias][collection_name] = None
db = self._get_db(alias=alias)
db.drop_collection(collection_name)

def drop_database(self, doc_cls, alias=None):
if alias is None:
alias = doc_cls._get_db_alias()

self.connections_registry[alias] = {}
db = self._get_db(alias=alias)
conn = get_connection(alias)
conn.drop_database(db)

def reset(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could use a docstring, too. Additionally, should this method perform any cleanup of the connections before dropping references to them? Will that leave lingering unclosed connections to the database?

self.connections_registry = defaultdict(dict)


connection_manager = ConnectionManager()
85 changes: 1 addition & 84 deletions mongoengine/context_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,93 +2,10 @@
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db


__all__ = ('switch_db', 'switch_collection', 'no_dereference',
__all__ = ('no_dereference',
'no_sub_classes', 'query_counter')


class switch_db(object):
"""switch_db alias context manager.

Example ::

# Register connections
register_connection('default', 'mongoenginetest')
register_connection('testdb-1', 'mongoenginetest2')

class Group(Document):
name = StringField()

Group(name='test').save() # Saves in the default db

with switch_db(Group, 'testdb-1') as Group:
Group(name='hello testdb!').save() # Saves in testdb-1
"""

def __init__(self, cls, db_alias):
"""Construct the switch_db context manager

:param cls: the class to change the registered db
:param db_alias: the name of the specific database to use
"""
self.cls = cls
self.collection = cls._get_collection()
self.db_alias = db_alias
self.ori_db_alias = cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME)

def __enter__(self):
"""Change the db_alias and clear the cached collection."""
self.cls._meta['db_alias'] = self.db_alias
self.cls._collection = None
return self.cls

def __exit__(self, t, value, traceback):
"""Reset the db_alias and collection."""
self.cls._meta['db_alias'] = self.ori_db_alias
self.cls._collection = self.collection


class switch_collection(object):
"""switch_collection alias context manager.

Example ::

class Group(Document):
name = StringField()

Group(name='test').save() # Saves in the default db

with switch_collection(Group, 'group1') as Group:
Group(name='hello testdb!').save() # Saves in group1 collection
"""

def __init__(self, cls, collection_name):
"""Construct the switch_collection context manager.

:param cls: the class to change the registered db
:param collection_name: the name of the collection to use
"""
self.cls = cls
self.ori_collection = cls._get_collection()
self.ori_get_collection_name = cls._get_collection_name
self.collection_name = collection_name

def __enter__(self):
"""Change the _get_collection_name and clear the cached collection."""

@classmethod
def _get_collection_name(cls):
return self.collection_name

self.cls._get_collection_name = _get_collection_name
self.cls._collection = None
return self.cls

def __exit__(self, t, value, traceback):
"""Reset the collection."""
self.cls._collection = self.ori_collection
self.cls._get_collection_name = self.ori_get_collection_name


class no_dereference(object):
"""no_dereference context manager.

Expand Down
Loading