Skip to content

Commit 82647e2

Browse files
committed
chore: Add types from mongo-types
1 parent 4d3ab60 commit 82647e2

40 files changed

+1941
-253
lines changed

.github/workflows/github-actions.yml

+12
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,18 @@ jobs:
4949
- run: bash .github/workflows/install_ci_python_dep.sh
5050
- run: pre-commit run -a
5151

52+
type-check:
53+
# Can be moved to pre-commit, separate step for now.
54+
runs-on: ubuntu-latest
55+
steps:
56+
- uses: actions/checkout@v3
57+
- uses: actions/setup-python@v4
58+
with:
59+
python-version: '3.9'
60+
check-latest: true
61+
- run: bash .github/workflows/install_ci_python_typing_deps.sh
62+
- run: mypy mongoengine tests
63+
5264
test:
5365
# Test suite run against recent python versions
5466
# and against a few combination of MongoDB and pymongo
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/bin/bash
2+
pip install --upgrade pip
3+
pip install mypy==1.13.0 typing-extensions mongomock types-Pygments types-cffi types-colorama types-pyOpenSSL types-python-dateutil types-requests types-setuptools
4+
pip install -e '.[test]'

mongoengine/__init__.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@
2020
from mongoengine.signals import * # noqa: F401
2121

2222
__all__ = (
23-
list(document.__all__)
24-
+ list(fields.__all__)
25-
+ list(connection.__all__)
26-
+ list(queryset.__all__)
27-
+ list(signals.__all__)
28-
+ list(errors.__all__)
23+
document.__all__
24+
+ fields.__all__
25+
+ connection.__all__
26+
+ queryset.__all__
27+
+ signals.__all__
28+
+ errors.__all__
2929
)
3030

3131

mongoengine/_typing.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from typing import TYPE_CHECKING, TypeVar
2+
3+
if TYPE_CHECKING:
4+
from mongoengine.queryset.queryset import QuerySet
5+
6+
QS = TypeVar("QS", bound="QuerySet")

mongoengine/base/common.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1+
from __future__ import annotations
2+
13
import warnings
4+
from typing import TYPE_CHECKING
25

36
from mongoengine.errors import NotRegistered
47

5-
__all__ = ("UPDATE_OPERATORS", "_DocumentRegistry")
8+
if TYPE_CHECKING:
9+
from mongoengine.document import Document
610

11+
__all__ = ("UPDATE_OPERATORS", "_DocumentRegistry")
712

813
UPDATE_OPERATORS = {
914
"set",
@@ -24,7 +29,7 @@
2429
}
2530

2631

27-
_document_registry = {}
32+
_document_registry: dict[str, type[Document]] = {}
2833

2934

3035
class _DocumentRegistry:
@@ -33,7 +38,7 @@ class _DocumentRegistry:
3338
"""
3439

3540
@staticmethod
36-
def get(name):
41+
def get(name: str) -> type[Document]:
3742
doc = _document_registry.get(name, None)
3843
if not doc:
3944
# Possible old style name
@@ -58,7 +63,7 @@ def get(name):
5863
return doc
5964

6065
@staticmethod
61-
def register(DocCls):
66+
def register(DocCls: type[Document]) -> None:
6267
ExistingDocCls = _document_registry.get(DocCls._class_name)
6368
if (
6469
ExistingDocCls is not None
@@ -76,7 +81,7 @@ def register(DocCls):
7681
_document_registry[DocCls._class_name] = DocCls
7782

7883
@staticmethod
79-
def unregister(doc_cls_name):
84+
def unregister(doc_cls_name: str):
8085
_document_registry.pop(doc_cls_name)
8186

8287

mongoengine/base/datastructures.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
1+
from __future__ import annotations
2+
13
import weakref
4+
from typing import TYPE_CHECKING, Any, Generic, TypeVar
25

3-
from bson import DBRef
6+
from bson import DBRef, ObjectId
47

58
from mongoengine.common import _import_class
69
from mongoengine.errors import DoesNotExist, MultipleObjectsReturned
710

11+
if TYPE_CHECKING:
12+
from mongoengine import Document
13+
14+
_T = TypeVar("_T", bound="Document")
15+
816
__all__ = (
917
"BaseDict",
1018
"StrictDict",
@@ -356,7 +364,7 @@ def update(self, **update):
356364
class StrictDict:
357365
__slots__ = ()
358366
_special_fields = {"get", "pop", "iteritems", "items", "keys", "create"}
359-
_classes = {}
367+
_classes: dict[str, Any] = {}
360368

361369
def __init__(self, **kwargs):
362370
for k, v in kwargs.items():
@@ -435,7 +443,7 @@ def __repr__(self):
435443
return cls._classes[allowed_keys]
436444

437445

438-
class LazyReference(DBRef):
446+
class LazyReference(Generic[_T], DBRef):
439447
__slots__ = ("_cached_doc", "passthrough", "document_type")
440448

441449
def fetch(self, force=False):
@@ -449,19 +457,21 @@ def fetch(self, force=False):
449457
def pk(self):
450458
return self.id
451459

452-
def __init__(self, document_type, pk, cached_doc=None, passthrough=False):
460+
def __init__(
461+
self, document_type: type[_T], pk: ObjectId, cached_doc=None, passthrough=False
462+
):
453463
self.document_type = document_type
454464
self._cached_doc = cached_doc
455465
self.passthrough = passthrough
456-
super().__init__(self.document_type._get_collection_name(), pk)
466+
super().__init__(self.document_type._get_collection_name(), pk) # type: ignore[arg-type]
457467

458-
def __getitem__(self, name):
468+
def __getitem__(self, name: str) -> Any:
459469
if not self.passthrough:
460470
raise KeyError()
461471
document = self.fetch()
462472
return document[name]
463473

464-
def __getattr__(self, name):
474+
def __getattr__(self, name: str) -> Any:
465475
if not object.__getattribute__(self, "passthrough"):
466476
raise AttributeError()
467477
document = self.fetch()

mongoengine/base/document.py

+31-14
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1+
# mypy: disable-error-code="attr-defined,union-attr,assignment"
2+
from __future__ import annotations
3+
14
import copy
25
import numbers
36
import warnings
47
from functools import partial
8+
from typing import TYPE_CHECKING, Any
59

610
import pymongo
711
from bson import SON, DBRef, ObjectId, json_util
12+
from typing_extensions import Self
813

914
from mongoengine import signals
1015
from mongoengine.base.common import _DocumentRegistry
@@ -15,7 +20,7 @@
1520
LazyReference,
1621
StrictDict,
1722
)
18-
from mongoengine.base.fields import ComplexBaseField
23+
from mongoengine.base.fields import BaseField, ComplexBaseField
1924
from mongoengine.common import _import_class
2025
from mongoengine.errors import (
2126
FieldDoesNotExist,
@@ -26,12 +31,15 @@
2631
)
2732
from mongoengine.pymongo_support import LEGACY_JSON_OPTIONS
2833

34+
if TYPE_CHECKING:
35+
from mongoengine.fields import DynamicField
36+
2937
__all__ = ("BaseDocument", "NON_FIELD_ERRORS")
3038

3139
NON_FIELD_ERRORS = "__all__"
3240

3341
try:
34-
GEOHAYSTACK = pymongo.GEOHAYSTACK
42+
GEOHAYSTACK = pymongo.GEOHAYSTACK # type: ignore[attr-defined]
3543
except AttributeError:
3644
GEOHAYSTACK = None
3745

@@ -62,7 +70,12 @@ class BaseDocument:
6270
_dynamic_lock = True
6371
STRICT = False
6472

65-
def __init__(self, *args, **values):
73+
# Fields, added by metaclass
74+
_class_name: str
75+
_fields: dict[str, BaseField]
76+
_meta: dict[str, Any]
77+
78+
def __init__(self, *args, **values) -> None:
6679
"""
6780
Initialise a document or an embedded document.
6881
@@ -103,7 +116,7 @@ def __init__(self, *args, **values):
103116
else:
104117
self._data = {}
105118

106-
self._dynamic_fields = SON()
119+
self._dynamic_fields: SON[str, DynamicField] = SON()
107120

108121
# Assign default values for fields
109122
# not set in the constructor
@@ -329,13 +342,15 @@ def get_text_score(self):
329342

330343
return self._data["_text_score"]
331344

332-
def to_mongo(self, use_db_field=True, fields=None):
345+
def to_mongo(
346+
self, use_db_field: bool = True, fields: list[str] | None = None
347+
) -> SON[Any, Any]:
333348
"""
334349
Return as SON data ready for use with MongoDB.
335350
"""
336351
fields = fields or []
337352

338-
data = SON()
353+
data: SON[str, Any] = SON()
339354
data["_id"] = None
340355
data["_cls"] = self._class_name
341356

@@ -354,7 +369,7 @@ def to_mongo(self, use_db_field=True, fields=None):
354369

355370
if value is not None:
356371
f_inputs = field.to_mongo.__code__.co_varnames
357-
ex_vars = {}
372+
ex_vars: dict[str, Any] = {}
358373
if fields and "fields" in f_inputs:
359374
key = "%s." % field_name
360375
embedded_fields = [
@@ -370,7 +385,7 @@ def to_mongo(self, use_db_field=True, fields=None):
370385

371386
# Handle self generating fields
372387
if value is None and field._auto_gen:
373-
value = field.generate()
388+
value = field.generate() # type: ignore[attr-defined]
374389
self._data[field_name] = value
375390

376391
if value is not None or field.null:
@@ -385,7 +400,7 @@ def to_mongo(self, use_db_field=True, fields=None):
385400

386401
return data
387402

388-
def validate(self, clean=True):
403+
def validate(self, clean: bool = True) -> None:
389404
"""Ensure that all fields' values are valid and that required fields
390405
are present.
391406
@@ -439,7 +454,7 @@ def validate(self, clean=True):
439454
message = f"ValidationError ({self._class_name}:{pk}) "
440455
raise ValidationError(message, errors=errors)
441456

442-
def to_json(self, *args, **kwargs):
457+
def to_json(self, *args: Any, **kwargs: Any) -> str:
443458
"""Convert this document to JSON.
444459
445460
:param use_db_field: Serialize field names as they appear in
@@ -461,7 +476,7 @@ def to_json(self, *args, **kwargs):
461476
return json_util.dumps(self.to_mongo(use_db_field), *args, **kwargs)
462477

463478
@classmethod
464-
def from_json(cls, json_data, created=False, **kwargs):
479+
def from_json(cls, json_data: str, created: bool = False, **kwargs: Any) -> Self:
465480
"""Converts json data to a Document instance.
466481
467482
:param str json_data: The json data to load into the Document.
@@ -687,7 +702,7 @@ def _get_changed_fields(self):
687702
self._nestable_types_changed_fields(changed_fields, key, data)
688703
return changed_fields
689704

690-
def _delta(self):
705+
def _delta(self) -> tuple[dict[str, Any], dict[str, Any]]:
691706
"""Returns the delta (set, unset) of the changes for a document.
692707
Gets any values that have been explicitly changed.
693708
"""
@@ -771,14 +786,16 @@ def _delta(self):
771786
return set_data, unset_data
772787

773788
@classmethod
774-
def _get_collection_name(cls):
789+
def _get_collection_name(cls) -> str | None:
775790
"""Return the collection name for this class. None for abstract
776791
class.
777792
"""
778793
return cls._meta.get("collection", None)
779794

780795
@classmethod
781-
def _from_son(cls, son, _auto_dereference=True, created=False):
796+
def _from_son(
797+
cls, son: dict[str, Any], _auto_dereference: bool = True, created: bool = False
798+
) -> Self:
782799
"""Create an instance of a Document (subclass) from a PyMongo SON (dict)"""
783800
if son and not isinstance(son, dict):
784801
raise ValueError(

0 commit comments

Comments
 (0)