diff --git a/mongoengine/__init__.py b/mongoengine/__init__.py index f8969592b..b7cb5c33f 100644 --- a/mongoengine/__init__.py +++ b/mongoengine/__init__.py @@ -1,5 +1,6 @@ # Import submodules so that we can expose their __all__ from mongoengine import connection +from mongoengine import connections_manager from mongoengine import document from mongoengine import errors from mongoengine import fields @@ -11,6 +12,7 @@ # users can simply use `from mongoengine import connect`, or even # `from mongoengine import *` and then `connect('testdb')`. from mongoengine.connection import * +from mongoengine.connections_manager import * from mongoengine.document import * from mongoengine.errors import * from mongoengine.fields import * @@ -20,7 +22,8 @@ __all__ = (list(document.__all__) + list(fields.__all__) + list(connection.__all__) + list(queryset.__all__) + - list(signals.__all__) + list(errors.__all__)) + list(signals.__all__) + list(errors.__all__) + + list(connections_manager.__all__)) VERSION = (0, 11, 0) diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index 1667215d4..910be4ea7 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -16,6 +16,7 @@ SemiStrictDict, StrictDict) from mongoengine.base.fields import ComplexBaseField from mongoengine.common import _import_class +from mongoengine.connection import DEFAULT_CONNECTION_NAME from mongoengine.errors import (FieldDoesNotExist, InvalidDocumentError, LookUpError, OperationError, ValidationError) @@ -667,6 +668,10 @@ def _get_collection_name(cls): """ return cls._meta.get('collection', None) + @classmethod + def _get_db_alias(cls): + return cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME) + @classmethod def _from_son(cls, son, _auto_dereference=True, only_fields=None, created=False): """Create an instance of a Document (subclass) from a PyMongo diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py index 481408bf0..bd528b4ba 100644 --- a/mongoengine/base/metaclasses.py +++ b/mongoengine/base/metaclasses.py @@ -149,9 +149,6 @@ def __new__(cls, name, bases, attrs): (Document, EmbeddedDocument, DictField, CachedReferenceField) = cls._import_classes() - if issubclass(new_class, Document): - new_class._collection = None - # Add class to the _document_registry _document_registry[new_class._class_name] = new_class diff --git a/mongoengine/connections_manager.py b/mongoengine/connections_manager.py new file mode 100644 index 000000000..578a862aa --- /dev/null +++ b/mongoengine/connections_manager.py @@ -0,0 +1,110 @@ +from collections import defaultdict + +from mongoengine.connection import get_connection, get_db + +__all__ = ['InvalidCollectionError', 'connection_manager'] + + +class InvalidCollectionError(Exception): + pass + + +class ConnectionManager(object): + connections_registry = defaultdict(dict) + + def get_and_setup(self, doc_cls, alias=None, collection_name=None): + if alias is None: + alias = doc_cls._get_db_alias() + + if collection_name is None: + collection_name = doc_cls._get_collection_name() + + _collection = self.connections_registry[alias].get(collection_name) + if not _collection: + _collection = self.get_collection(doc_cls, alias, collection_name) + if doc_cls._meta.get('auto_create_index', True): + doc_cls.ensure_indexes(_collection) + self.connections_registry[alias][collection_name] = _collection + return self.connections_registry[alias][collection_name] + + @classmethod + def _get_db(cls, alias): + """Some Model using other db_alias""" + return get_db(alias) + + @classmethod + def get_collection(cls, doc_cls, alias=None, collection_name=None): + """Returns the collection for the document.""" + + if alias is None: + alias = doc_cls._get_db_alias() + + if collection_name is None: + collection_name = doc_cls._get_collection_name() + + db = cls._get_db(alias=alias) + + # Create collection as a capped collection if specified + if doc_cls._meta.get('max_size') or doc_cls._meta.get('max_documents'): + # Get max document limit and max byte size from meta + max_size = doc_cls._meta.get('max_size') or 10 * 2 ** 20 # 10MB default + max_documents = doc_cls._meta.get('max_documents') + # Round up to next 256 bytes as MongoDB would do it to avoid exception + if max_size % 256: + max_size = (max_size // 256 + 1) * 256 + + if collection_name in db.collection_names(): + _collection = db[collection_name] + # The collection already exists, check if its capped + # options match the specified capped options + options = _collection.options() + if ( + options.get('max') != max_documents or + options.get('size') != max_size + ): + msg = (('Cannot create collection "%s" as a capped ' + 'collection as it already exists') + % _collection) + raise InvalidCollectionError(msg) + else: + # Create the collection as a capped collection + opts = {'capped': True, 'size': max_size} + if max_documents: + opts['max'] = max_documents + _collection = db.create_collection( + collection_name, **opts + ) + else: + _collection = db[collection_name] + return _collection + + def drop_collection(self, doc_cls, alias, collection_name): + if alias is None: + alias = doc_cls._get_db_alias() + + if collection_name is None: + collection_name = doc_cls._get_collection_name() + + if not collection_name: + from mongoengine.queryset import OperationError + raise OperationError('Document %s has no collection defined ' + '(is it abstract ?)' % doc_cls) + + self.connections_registry[alias][collection_name] = None + db = self._get_db(alias=alias) + db.drop_collection(collection_name) + + def drop_database(self, doc_cls, alias=None): + if alias is None: + alias = doc_cls._get_db_alias() + + self.connections_registry[alias] = {} + db = self._get_db(alias=alias) + conn = get_connection(alias) + conn.drop_database(db) + + def reset(self): + self.connections_registry = defaultdict(dict) + + +connection_manager = ConnectionManager() diff --git a/mongoengine/context_managers.py b/mongoengine/context_managers.py index c477575e8..71d151aa2 100644 --- a/mongoengine/context_managers.py +++ b/mongoengine/context_managers.py @@ -2,93 +2,10 @@ from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db -__all__ = ('switch_db', 'switch_collection', 'no_dereference', +__all__ = ('no_dereference', 'no_sub_classes', 'query_counter') -class switch_db(object): - """switch_db alias context manager. - - Example :: - - # Register connections - register_connection('default', 'mongoenginetest') - register_connection('testdb-1', 'mongoenginetest2') - - class Group(Document): - name = StringField() - - Group(name='test').save() # Saves in the default db - - with switch_db(Group, 'testdb-1') as Group: - Group(name='hello testdb!').save() # Saves in testdb-1 - """ - - def __init__(self, cls, db_alias): - """Construct the switch_db context manager - - :param cls: the class to change the registered db - :param db_alias: the name of the specific database to use - """ - self.cls = cls - self.collection = cls._get_collection() - self.db_alias = db_alias - self.ori_db_alias = cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME) - - def __enter__(self): - """Change the db_alias and clear the cached collection.""" - self.cls._meta['db_alias'] = self.db_alias - self.cls._collection = None - return self.cls - - def __exit__(self, t, value, traceback): - """Reset the db_alias and collection.""" - self.cls._meta['db_alias'] = self.ori_db_alias - self.cls._collection = self.collection - - -class switch_collection(object): - """switch_collection alias context manager. - - Example :: - - class Group(Document): - name = StringField() - - Group(name='test').save() # Saves in the default db - - with switch_collection(Group, 'group1') as Group: - Group(name='hello testdb!').save() # Saves in group1 collection - """ - - def __init__(self, cls, collection_name): - """Construct the switch_collection context manager. - - :param cls: the class to change the registered db - :param collection_name: the name of the collection to use - """ - self.cls = cls - self.ori_collection = cls._get_collection() - self.ori_get_collection_name = cls._get_collection_name - self.collection_name = collection_name - - def __enter__(self): - """Change the _get_collection_name and clear the cached collection.""" - - @classmethod - def _get_collection_name(cls): - return self.collection_name - - self.cls._get_collection_name = _get_collection_name - self.cls._collection = None - return self.cls - - def __exit__(self, t, value, traceback): - """Reset the collection.""" - self.cls._collection = self.ori_collection - self.cls._get_collection_name = self.ori_get_collection_name - - class no_dereference(object): """no_dereference context manager. diff --git a/mongoengine/document.py b/mongoengine/document.py index e86a45d9b..74e20f2af 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -11,8 +11,7 @@ DocumentMetaclass, EmbeddedDocumentList, TopLevelDocumentMetaclass, get_document) from mongoengine.common import _import_class -from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db -from mongoengine.context_managers import switch_collection, switch_db +from mongoengine.connections_manager import connection_manager from mongoengine.errors import (InvalidDocumentError, InvalidQueryError, SaveConditionError) from mongoengine.python_support import IS_PYMONGO_3 @@ -21,7 +20,7 @@ __all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument', 'DynamicEmbeddedDocument', 'OperationError', - 'InvalidCollectionError', 'NotUniqueError', 'MapReduceDocument') + 'NotUniqueError', 'MapReduceDocument') def includes_cls(fields): @@ -35,10 +34,6 @@ def includes_cls(fields): return first_field == '_cls' -class InvalidCollectionError(Exception): - pass - - class EmbeddedDocument(BaseDocument): """A :class:`~mongoengine.Document` that isn't stored in its own collection. :class:`~mongoengine.EmbeddedDocument`\ s should be used as @@ -160,52 +155,6 @@ def pk(self, value): """Set the primary key.""" return setattr(self, self._meta['id_field'], value) - @classmethod - def _get_db(cls): - """Some Model using other db_alias""" - return get_db(cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME)) - - @classmethod - def _get_collection(cls): - """Returns the collection for the document.""" - # TODO: use new get_collection() with PyMongo3 ? - if not hasattr(cls, '_collection') or cls._collection is None: - db = cls._get_db() - collection_name = cls._get_collection_name() - # Create collection as a capped collection if specified - if cls._meta.get('max_size') or cls._meta.get('max_documents'): - # Get max document limit and max byte size from meta - max_size = cls._meta.get('max_size') or 10 * 2 ** 20 # 10MB default - max_documents = cls._meta.get('max_documents') - # Round up to next 256 bytes as MongoDB would do it to avoid exception - if max_size % 256: - max_size = (max_size // 256 + 1) * 256 - - if collection_name in db.collection_names(): - cls._collection = db[collection_name] - # The collection already exists, check if its capped - # options match the specified capped options - options = cls._collection.options() - if options.get('max') != max_documents or \ - options.get('size') != max_size: - msg = (('Cannot create collection "%s" as a capped ' - 'collection as it already exists') - % cls._collection) - raise InvalidCollectionError(msg) - else: - # Create the collection as a capped collection - opts = {'capped': True, 'size': max_size} - if max_documents: - opts['max'] = max_documents - cls._collection = db.create_collection( - collection_name, **opts - ) - else: - cls._collection = db[collection_name] - if cls._meta.get('auto_create_index', True): - cls.ensure_indexes() - return cls._collection - def to_mongo(self, *args, **kwargs): data = super(Document, self).to_mongo(*args, **kwargs) @@ -261,7 +210,9 @@ def modify(self, query=None, **update): def save(self, force_insert=False, validate=True, clean=True, write_concern=None, cascade=None, cascade_kwargs=None, - _refs=None, save_condition=None, signal_kwargs=None, **kwargs): + _refs=None, save_condition=None, signal_kwargs=None, + alias=None, collection_name=None, + **kwargs): """Save the :class:`~mongoengine.Document` to the database. If the document already exists, it will be updated, otherwise it will be created. @@ -330,9 +281,7 @@ def save(self, force_insert=False, validate=True, clean=True, created=created, **signal_kwargs) try: - collection = self._get_collection() - if self._meta.get('auto_create_index', True): - self.ensure_indexes() + collection = connection_manager.get_and_setup(self.__class__, alias=alias, collection_name=collection_name) if created: if force_insert: object_id = collection.insert(doc, **write_concern) @@ -457,7 +406,7 @@ def cascade_save(self, **kwargs): def _qs(self): """Return the queryset to use for updating / reloading / deletions.""" if not hasattr(self, '__objects'): - self.__objects = QuerySet(self, self._get_collection()) + self.__objects = QuerySet(self, connection_manager.get_and_setup(self.__class__)) return self.__objects @property @@ -531,65 +480,6 @@ def delete(self, signal_kwargs=None, **write_concern): raise OperationError(message) signals.post_delete.send(self.__class__, document=self, **signal_kwargs) - def switch_db(self, db_alias, keep_created=True): - """ - Temporarily switch the database for a document instance. - - Only really useful for archiving off data and calling `save()`:: - - user = User.objects.get(id=user_id) - user.switch_db('archive-db') - user.save() - - :param str db_alias: The database alias to use for saving the document - - :param bool keep_created: keep self._created value after switching db, else is reset to True - - - .. seealso:: - Use :class:`~mongoengine.context_managers.switch_collection` - if you need to read from another collection - """ - with switch_db(self.__class__, db_alias) as cls: - collection = cls._get_collection() - db = cls._get_db() - self._get_collection = lambda: collection - self._get_db = lambda: db - self._collection = collection - self._created = True if not keep_created else self._created - self.__objects = self._qs - self.__objects._collection_obj = collection - return self - - def switch_collection(self, collection_name, keep_created=True): - """ - Temporarily switch the collection for a document instance. - - Only really useful for archiving off data and calling `save()`:: - - user = User.objects.get(id=user_id) - user.switch_collection('old-users') - user.save() - - :param str collection_name: The database alias to use for saving the - document - - :param bool keep_created: keep self._created value after switching collection, else is reset to True - - - .. seealso:: - Use :class:`~mongoengine.context_managers.switch_db` - if you need to read from another database - """ - with switch_collection(self.__class__, collection_name) as cls: - collection = cls._get_collection() - self._get_collection = lambda: collection - self._collection = collection - self._created = True if not keep_created else self._created - self.__objects = self._qs - self.__objects._collection_obj = collection - return self - def select_related(self, max_depth=1): """Handles dereferencing of :class:`~bson.dbref.DBRef` objects to a maximum depth in order to cut down the number queries to mongodb. @@ -693,7 +583,7 @@ def register_delete_rule(cls, document_cls, field_name, rule): klass._meta['delete_rules'] = delete_rules @classmethod - def drop_collection(cls): + def drop_collection(cls, alias=None, collection_name=None): """Drops the entire collection associated with this :class:`~mongoengine.Document` type from the database. @@ -703,13 +593,17 @@ def drop_collection(cls): .. versionchanged:: 0.10.7 :class:`OperationError` exception raised if no collection available """ - col_name = cls._get_collection_name() - if not col_name: - raise OperationError('Document %s has no collection defined ' - '(is it abstract ?)' % cls) - cls._collection = None - db = cls._get_db() - db.drop_collection(col_name) + connection_manager.drop_collection(cls, alias=alias, collection_name=collection_name) + + @classmethod + def _get_collection(cls, alias=None, collection_name=None): + return connection_manager.get_and_setup(cls, alias=alias, collection_name=collection_name) + + @classmethod + def _get_db(cls, alias=None): + if alias is None: + alias = cls._get_db_alias() + return connection_manager._get_db(alias) @classmethod def create_index(cls, keys, background=False, **kwargs): @@ -733,9 +627,9 @@ def create_index(cls, keys, background=False, **kwargs): index_spec.update(kwargs) if IS_PYMONGO_3: - return cls._get_collection().create_index(fields, **index_spec) + return connection_manager.get_collection(cls).create_index(fields, **index_spec) else: - return cls._get_collection().ensure_index(fields, **index_spec) + return connection_manager.get_collection(cls).ensure_index(fields, **index_spec) @classmethod def ensure_index(cls, key_or_list, drop_dups=False, background=False, @@ -758,7 +652,7 @@ def ensure_index(cls, key_or_list, drop_dups=False, background=False, return cls.create_index(key_or_list, background=background, **kwargs) @classmethod - def ensure_indexes(cls): + def ensure_indexes(cls, collection): """Checks the document meta data and ensures all the indexes exist. Global defaults can be set in the meta - see :doc:`guide/defining-documents` @@ -774,7 +668,6 @@ def ensure_indexes(cls): msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.' warnings.warn(msg, DeprecationWarning) - collection = cls._get_collection() # 746: when connection is via mongos, the read preference is not necessarily an indication that # this code runs on a secondary if not collection.is_mongos and collection.read_preference > 1: @@ -845,13 +738,13 @@ def get_classes(cls): if (isinstance(base_cls, TopLevelDocumentMetaclass) and base_cls != Document and not base_cls._meta.get('abstract') and - base_cls._get_collection().full_name == cls._get_collection().full_name and + connection_manager.get_collection(base_cls).full_name == connection_manager.get_collection(cls).full_name and base_cls not in classes): classes.append(base_cls) get_classes(base_cls) for subclass in cls.__subclasses__(): if (isinstance(base_cls, TopLevelDocumentMetaclass) and - subclass._get_collection().full_name == cls._get_collection().full_name and + connection_manager.get_collection(subclass).full_name == connection_manager.get_collection(cls).full_name and subclass not in classes): classes.append(subclass) get_classes(subclass) @@ -892,7 +785,7 @@ def compare_indexes(cls): required = cls.list_indexes() existing = [info['key'] - for info in cls._get_collection().index_information().values()] + for info in connection_manager.get_collection(cls).index_information().values()] missing = [index for index in required if index not in existing] extra = [index for index in existing if index not in required] diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 3ee978b8b..3df70567d 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -18,7 +18,6 @@ from mongoengine.base import get_document from mongoengine.common import _import_class from mongoengine.connection import get_db -from mongoengine.context_managers import switch_db from mongoengine.errors import (InvalidQueryError, LookUpError, NotUniqueError, OperationError) from mongoengine.python_support import IS_PYMONGO_3 @@ -448,7 +447,7 @@ def delete(self, write_concern=None, _from_doc_delete=False, if rule == CASCADE: cascade_refs = set() if cascade_refs is None else cascade_refs # Handle recursive reference - if doc._collection == document_cls._collection: + if doc._get_collection() == document_cls._get_collection(): for ref in queryset: cascade_refs.add(ref.id) refs = document_cls.objects(**{field_name + '__in': self, @@ -703,8 +702,7 @@ def using(self, alias): .. versionadded:: 0.9 """ - with switch_db(self._document, alias) as cls: - collection = cls._get_collection() + collection = self._document._get_collection(alias=alias) return self.clone_into(self.__class__(self._document, collection)) diff --git a/tests/document/class_methods.py b/tests/document/class_methods.py index dd3addb76..3531d9add 100644 --- a/tests/document/class_methods.py +++ b/tests/document/class_methods.py @@ -100,7 +100,7 @@ class BlogPost(Document): BlogPost.drop_collection() - BlogPost.ensure_indexes() + BlogPost.ensure_indexes(BlogPost._get_collection()) self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [] }) BlogPost.ensure_index(['author', 'description']) @@ -137,8 +137,8 @@ class BlogPostWithTags(BlogPost): BlogPost.drop_collection() - BlogPost.ensure_indexes() - BlogPostWithTags.ensure_indexes() + BlogPost.ensure_indexes(BlogPost._get_collection()) + BlogPostWithTags.ensure_indexes(BlogPost._get_collection()) self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [] }) BlogPostWithTags.ensure_index(['author', 'tag_list']) @@ -179,9 +179,9 @@ class BlogPostWithCustomField(BlogPost): 'indexes': [('author', 'custom')] } - BlogPost.ensure_indexes() - BlogPostWithTags.ensure_indexes() - BlogPostWithCustomField.ensure_indexes() + BlogPost.ensure_indexes(BlogPost._get_collection()) + BlogPostWithTags.ensure_indexes(BlogPost._get_collection()) + BlogPostWithCustomField.ensure_indexes(BlogPost._get_collection()) self.assertEqual(BlogPost.compare_indexes(), { 'missing': [], 'extra': [] }) self.assertEqual(BlogPostWithTags.compare_indexes(), { 'missing': [], 'extra': [] }) @@ -217,9 +217,9 @@ class BlogPostWithTagsAndExtraText(BlogPostWithTags): BlogPost.drop_collection() - BlogPost.ensure_indexes() - BlogPostWithTags.ensure_indexes() - BlogPostWithTagsAndExtraText.ensure_indexes() + BlogPost.ensure_indexes(BlogPost._get_collection()) + BlogPostWithTags.ensure_indexes(BlogPost._get_collection()) + BlogPostWithTagsAndExtraText.ensure_indexes(BlogPost._get_collection()) self.assertEqual(BlogPost.list_indexes(), BlogPostWithTags.list_indexes()) diff --git a/tests/document/indexes.py b/tests/document/indexes.py index af93e7db2..89d3b11ea 100644 --- a/tests/document/indexes.py +++ b/tests/document/indexes.py @@ -31,7 +31,7 @@ class Person(Document): self.Person = Person def tearDown(self): - self.connection.drop_database(self.db) + connection_manager.drop_database(doc_cls=self.Person) def test_indexes_document(self): """Ensure that indexes are used when meta[indexes] is specified for @@ -64,7 +64,7 @@ class BlogPost(InheritFrom): {'fields': [('category', 1), ('addDate', -1)]}] self.assertEqual(expected_specs, BlogPost._meta['index_specs']) - BlogPost.ensure_indexes() + BlogPost.ensure_indexes(BlogPost._get_collection()) info = BlogPost.objects._collection.index_information() # _id, '-date', 'tags', ('cat', 'date') self.assertEqual(len(info), 4) @@ -93,7 +93,7 @@ class BlogPost(InheritFrom): ('addDate', -1)]}] self.assertEqual(expected_specs, BlogPost._meta['index_specs']) - BlogPost.ensure_indexes() + BlogPost.ensure_indexes(BlogPost._get_collection()) info = BlogPost.objects._collection.index_information() # _id, '-date', 'tags', ('cat', 'date') # NB: there is no index on _cls by itself, since @@ -113,7 +113,7 @@ class ExtendedBlogPost(BlogPost): BlogPost.drop_collection() - ExtendedBlogPost.ensure_indexes() + ExtendedBlogPost.ensure_indexes(ExtendedBlogPost._get_collection()) info = ExtendedBlogPost.objects._collection.index_information() info = [value['key'] for key, value in info.iteritems()] for expected in expected_specs: @@ -167,7 +167,7 @@ class A(Document): self.assertEqual([('title', 1)], A._meta['index_specs'][0]['fields']) A._get_collection().drop_indexes() - A.ensure_indexes() + A.ensure_indexes(A._get_collection()) info = A._get_collection().index_information() self.assertEqual(len(info.keys()), 2) @@ -195,7 +195,7 @@ class MyDoc(Document): [{'fields': [('keywords', 1)]}]) # Force index creation - MyDoc.ensure_indexes() + MyDoc.ensure_indexes(MyDoc._get_collection()) self.assertEqual(MyDoc._meta['index_specs'], [{'fields': [('keywords', 1)]}]) @@ -243,7 +243,7 @@ class Place(Document): self.assertEqual([{'fields': [('location.point', '2d')]}], Place._meta['index_specs']) - Place.ensure_indexes() + Place.ensure_indexes(Place._get_collection()) info = Place._get_collection().index_information() info = [value['key'] for key, value in info.iteritems()] self.assertTrue([('location.point', '2d')] in info) @@ -266,7 +266,7 @@ class Place(Document): self.assertEqual([{'fields': [('current.location.point', '2d')]}], Place._meta['index_specs']) - Place.ensure_indexes() + Place.ensure_indexes(Place._get_collection()) info = Place._get_collection().index_information() info = [value['key'] for key, value in info.iteritems()] self.assertTrue([('current.location.point', '2d')] in info) @@ -286,7 +286,7 @@ class Place(Document): self.assertEqual([{'fields': [('location.point', '2dsphere')]}], Place._meta['index_specs']) - Place.ensure_indexes() + Place.ensure_indexes(Place._get_collection()) info = Place._get_collection().index_information() info = [value['key'] for key, value in info.iteritems()] self.assertTrue([('location.point', '2dsphere')] in info) @@ -308,7 +308,7 @@ class Place(Document): self.assertEqual([{'fields': [('location.point', 'geoHaystack'), ('name', 1)]}], Place._meta['index_specs']) - Place.ensure_indexes() + Place.ensure_indexes(Place._get_collection()) info = Place._get_collection().index_information() info = [value['key'] for key, value in info.iteritems()] self.assertTrue([('location.point', 'geoHaystack')] in info) @@ -409,7 +409,7 @@ class MongoUser(User): info = User.objects._collection.index_information() self.assertEqual(info.keys(), ['_id_']) - User.ensure_indexes() + User.ensure_indexes(User._get_collection()) info = User.objects._collection.index_information() self.assertEqual(sorted(info.keys()), ['_cls_1_user_guid_1', '_id_']) User.drop_collection() @@ -472,7 +472,7 @@ class RecursiveDocument(Document): recursive_obj = EmbeddedDocumentField(RecursiveObject) meta = {'allow_inheritance': True} - RecursiveDocument.ensure_indexes() + RecursiveDocument.ensure_indexes(RecursiveDocument._get_collection()) info = RecursiveDocument._get_collection().index_information() self.assertEqual(sorted(info.keys()), ['_cls_1', '_id_']) @@ -806,7 +806,6 @@ class BlogPost(Document): 'unique': True}]} except UnboundLocalError: self.fail('Unbound local error at index + pk definition') - info = BlogPost.objects._collection.index_information() info = [value['key'] for key, value in info.iteritems()] index_item = [('_id', 1), ('comments.comment_id', 1)] @@ -924,8 +923,8 @@ class BlogPost(Document): post1 = BlogPost(title='test1', slug='test') post1.save() - # Drop the Database - connection.drop_database('tempdatabase') + # Drop the collection + BlogPost.drop_collection() # Re-create Post #1 post1 = BlogPost(title='test1', slug='test') @@ -969,8 +968,8 @@ class TestChildDoc(TestDoc): } TestDoc.drop_collection() - TestDoc.ensure_indexes() - TestChildDoc.ensure_indexes() + TestDoc.ensure_indexes(TestDoc._get_collection()) + TestChildDoc.ensure_indexes(TestChildDoc._get_collection()) index_info = TestDoc._get_collection().index_information() for key in index_info: @@ -1017,7 +1016,7 @@ class TestDoc(Document): } TestDoc.drop_collection() - TestDoc.ensure_indexes() + TestDoc.ensure_indexes(TestDoc._get_collection()) index_info = TestDoc._get_collection().index_information() self.assertTrue('shard_1_1__cls_1_txt_1_1' in index_info) diff --git a/tests/document/inheritance.py b/tests/document/inheritance.py index 2897e1d15..fe2f23795 100644 --- a/tests/document/inheritance.py +++ b/tests/document/inheritance.py @@ -215,7 +215,7 @@ class C(A, B): B.drop_collection() C.drop_collection() - C.ensure_indexes() + C.ensure_indexes(C._get_collection()) self.assertEqual( sorted([idx['key'] for idx in C._get_collection().index_information().values()]), diff --git a/tests/document/instance.py b/tests/document/instance.py index d961f034c..fbfb07f3f 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -19,7 +19,7 @@ InvalidQueryError, NotUniqueError, FieldDoesNotExist, SaveConditionError) from mongoengine.queryset import NULLIFY, Q -from mongoengine.context_managers import switch_db, query_counter +from mongoengine.context_managers import query_counter from mongoengine import signals TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), @@ -105,6 +105,8 @@ class Log(Document): 'max_documents': 11, } + connection_manager.reset() + # Accessing Document.objects creates the collection with self.assertRaises(InvalidCollectionError): Log.objects @@ -1312,6 +1314,8 @@ def test_update_unique_field(self): class Doc(Document): name = StringField(unique=True) + connection_manager.reset() + doc1 = Doc(name="first").save() doc2 = Doc(name="second").save() @@ -2603,56 +2607,6 @@ def __str__(self): })]), "1,2") - def test_switch_db_instance(self): - register_connection('testdb-1', 'mongoenginetest2') - - class Group(Document): - name = StringField() - - Group.drop_collection() - with switch_db(Group, 'testdb-1') as Group: - Group.drop_collection() - - Group(name="hello - default").save() - self.assertEqual(1, Group.objects.count()) - - group = Group.objects.first() - group.switch_db('testdb-1') - group.name = "hello - testdb!" - group.save() - - with switch_db(Group, 'testdb-1') as Group: - group = Group.objects.first() - self.assertEqual("hello - testdb!", group.name) - - group = Group.objects.first() - self.assertEqual("hello - default", group.name) - - # Slightly contrived now - perform an update - # Only works as they have the same object_id - group.switch_db('testdb-1') - group.update(set__name="hello - update") - - with switch_db(Group, 'testdb-1') as Group: - group = Group.objects.first() - self.assertEqual("hello - update", group.name) - Group.drop_collection() - self.assertEqual(0, Group.objects.count()) - - group = Group.objects.first() - self.assertEqual("hello - default", group.name) - - # Totally contrived now - perform a delete - # Only works as they have the same object_id - group.switch_db('testdb-1') - group.delete() - - with switch_db(Group, 'testdb-1') as Group: - self.assertEqual(0, Group.objects.count()) - - group = Group.objects.first() - self.assertEqual("hello - default", group.name) - def test_load_undefined_fields(self): class User(Document): name = StringField() diff --git a/tests/fields/geo.py b/tests/fields/geo.py index 1c5bccc0b..f2414de10 100644 --- a/tests/fields/geo.py +++ b/tests/fields/geo.py @@ -356,7 +356,7 @@ class Log(Document): self.assertEqual([], Log._geo_indices()) Log.drop_collection() - Log.ensure_indexes() + Log.ensure_indexes(Log._get_collection()) info = Log._get_collection().index_information() self.assertEqual(info["location_2dsphere_datetime_1"]["key"], @@ -376,7 +376,7 @@ class Log(Document): self.assertEqual([], Log._geo_indices()) Log.drop_collection() - Log.ensure_indexes() + Log.ensure_indexes(Log._get_collection()) info = Log._get_collection().index_information() self.assertEqual(info["location_2dsphere_datetime_1"]["key"], diff --git a/tests/queryset/modify.py b/tests/queryset/modify.py index 607937f68..f7547cbfc 100644 --- a/tests/queryset/modify.py +++ b/tests/queryset/modify.py @@ -17,7 +17,7 @@ def setUp(self): Doc.drop_collection() def assertDbEqual(self, docs): - self.assertEqual(list(Doc._collection.find().sort("id")), docs) + self.assertEqual(list(Doc._get_collection().find().sort("id")), docs) def test_modify(self): Doc(id=0, value=0).save() diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index a15807a5b..a4499915a 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -13,7 +13,7 @@ from mongoengine import * from mongoengine.connection import get_connection, get_db -from mongoengine.context_managers import query_counter, switch_db +from mongoengine.context_managers import query_counter from mongoengine.errors import InvalidQueryError from mongoengine.python_support import IS_PYMONGO_3 from mongoengine.queryset import (DoesNotExist, MultipleObjectsReturned, @@ -810,7 +810,7 @@ class Blog(Document): self.assertEqual(q, 99) Blog.drop_collection() - Blog.ensure_indexes() + Blog.ensure_indexes(Blog._get_collection()) with query_counter() as q: self.assertEqual(q, 0) @@ -3709,15 +3709,14 @@ class Number2(Document): n = IntField() Number2.drop_collection() - with switch_db(Number2, 'test2') as Number2: - Number2.drop_collection() + Number2.drop_collection(alias='test2') for i in range(1, 10): t = Number2(n=i) - t.switch_db('test2') - t.save() + t.save(alias='test2') self.assertEqual(len(Number2.objects.using('test2')), 9) + self.assertEqual(len(Number2.objects()), 0) def test_unset_reference(self): class Comment(Document): @@ -4179,7 +4178,7 @@ class Test(Document): Test.drop_collection() Test.objects(test='foo').update_one(upsert=True, set__test='foo') - self.assertFalse('_cls' in Test._collection.find_one()) + self.assertFalse('_cls' in Test._get_collection().find_one()) class Test(Document): meta = {'allow_inheritance': True} @@ -4188,7 +4187,7 @@ class Test(Document): Test.drop_collection() Test.objects(test='foo').update_one(upsert=True, set__test='foo') - self.assertTrue('_cls' in Test._collection.find_one()) + self.assertTrue('_cls' in Test._get_collection().find_one()) def test_update_upsert_looks_like_a_digit(self): class MyDoc(DynamicDocument): diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py index 0f6bf815e..2540ec1a0 100644 --- a/tests/test_context_managers.py +++ b/tests/test_context_managers.py @@ -2,65 +2,12 @@ from mongoengine import * from mongoengine.connection import get_db -from mongoengine.context_managers import (switch_db, switch_collection, - no_sub_classes, no_dereference, - query_counter) +from mongoengine.context_managers import ( + no_sub_classes, no_dereference, query_counter) class ContextManagersTest(unittest.TestCase): - def test_switch_db_context_manager(self): - connect('mongoenginetest') - register_connection('testdb-1', 'mongoenginetest2') - - class Group(Document): - name = StringField() - - Group.drop_collection() - - Group(name="hello - default").save() - self.assertEqual(1, Group.objects.count()) - - with switch_db(Group, 'testdb-1') as Group: - - self.assertEqual(0, Group.objects.count()) - - Group(name="hello").save() - - self.assertEqual(1, Group.objects.count()) - - Group.drop_collection() - self.assertEqual(0, Group.objects.count()) - - self.assertEqual(1, Group.objects.count()) - - def test_switch_collection_context_manager(self): - connect('mongoenginetest') - register_connection('testdb-1', 'mongoenginetest2') - - class Group(Document): - name = StringField() - - Group.drop_collection() - with switch_collection(Group, 'group1') as Group: - Group.drop_collection() - - Group(name="hello - group").save() - self.assertEqual(1, Group.objects.count()) - - with switch_collection(Group, 'group1') as Group: - - self.assertEqual(0, Group.objects.count()) - - Group(name="hello - group1").save() - - self.assertEqual(1, Group.objects.count()) - - Group.drop_collection() - self.assertEqual(0, Group.objects.count()) - - self.assertEqual(1, Group.objects.count()) - def test_no_dereference_context_manager_object_id(self): """Ensure that DBRef items in ListFields aren't dereferenced. """ diff --git a/tests/test_signals.py b/tests/test_signals.py index df687d0ee..17dbc6e98 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -388,30 +388,17 @@ def test_signals_with_explicit_doc_ids(self): def test_signals_with_switch_collection(self): ei = self.ExplicitId(id=123) - ei.switch_collection("explicit__1") - self.assertEqual(self.get_signal_output(ei.save), ['Is created']) - ei.switch_collection("explicit__1") - self.assertEqual(self.get_signal_output(ei.save), ['Is updated']) - - ei.switch_collection("explicit__1", keep_created=False) - self.assertEqual(self.get_signal_output(ei.save), ['Is created']) - ei.switch_collection("explicit__1", keep_created=False) - self.assertEqual(self.get_signal_output(ei.save), ['Is created']) + self.assertEqual(self.get_signal_output(ei.save, collection_name='explicit__1'), ['Is created']) + self.assertEqual(self.get_signal_output(ei.save, collection_name='explicit__1'), ['Is updated']) def test_signals_with_switch_db(self): connect('mongoenginetest') register_connection('testdb-1', 'mongoenginetest2') ei = self.ExplicitId(id=123) - ei.switch_db("testdb-1") - self.assertEqual(self.get_signal_output(ei.save), ['Is created']) - ei.switch_db("testdb-1") - self.assertEqual(self.get_signal_output(ei.save), ['Is updated']) + self.assertEqual(self.get_signal_output(ei.save, alias='testdb-1'), ['Is created']) + self.assertEqual(self.get_signal_output(ei.save, alias='testdb-1'), ['Is updated']) - ei.switch_db("testdb-1", keep_created=False) - self.assertEqual(self.get_signal_output(ei.save), ['Is created']) - ei.switch_db("testdb-1", keep_created=False) - self.assertEqual(self.get_signal_output(ei.save), ['Is created']) def test_signals_bulk_insert(self): def bulk_set_active_post():