Skip to content

Commit 8b57307

Browse files
committed
some issues triaged
1 parent e26ee4f commit 8b57307

File tree

4 files changed

+208
-15
lines changed

4 files changed

+208
-15
lines changed
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
import difflib
2+
3+
from django.core.exceptions import FieldDoesNotExist
4+
from django.db.models import Field, lookups
5+
from django.db.models.expressions import Col
6+
from django.db.models.lookups import Lookup, Transform
7+
8+
from .. import forms
9+
from ..query_utils import process_lhs, process_rhs
10+
from . import EmbeddedModelField
11+
from .array import ArrayField
12+
13+
14+
class EmbeddedModelArrayField(ArrayField):
15+
ALLOWED_LOOKUPS = {"exact", "len", "overlap"}
16+
17+
def __init__(self, embedded_model, **kwargs):
18+
if "size" in kwargs:
19+
raise ValueError("EmbeddedModelArrayField does not support size.")
20+
super().__init__(EmbeddedModelField(embedded_model), **kwargs)
21+
self.embedded_model = embedded_model
22+
23+
def deconstruct(self):
24+
name, path, args, kwargs = super().deconstruct()
25+
if path == "django_mongodb_backend.fields.embedded_model_array.EmbeddedModelArrayField":
26+
path = "django_mongodb_backend.fields.EmbeddedModelArrayField"
27+
kwargs["embedded_model"] = self.embedded_model
28+
del kwargs["base_field"]
29+
return name, path, args, kwargs
30+
31+
def get_db_prep_value(self, value, connection, prepared=False):
32+
if isinstance(value, list | tuple):
33+
# Must call get_db_prep_save() rather than get_db_prep_value()
34+
# to transform model instances to dicts.
35+
return [self.base_field.get_db_prep_save(i, connection) for i in value]
36+
if value is not None:
37+
raise TypeError(
38+
f"Expected list of {self.embedded_model!r} instances, not {type(value)!r}."
39+
)
40+
return value
41+
42+
def formfield(self, **kwargs):
43+
# Skip ArrayField.formfield() which has some differeences, including
44+
# unneeded "base_field" and "max_length" instead of "max_num".
45+
return Field.formfield(
46+
self,
47+
**{
48+
"form_class": forms.EmbeddedModelArrayField,
49+
"model": self.base_field.embedded_model,
50+
"max_num": self.max_size,
51+
"prefix": self.name,
52+
**kwargs,
53+
},
54+
)
55+
56+
def get_transform(self, name):
57+
transform = super().get_transform(name)
58+
if transform:
59+
return transform
60+
return KeyTransformFactory(name, self)
61+
62+
def get_lookup(self, name):
63+
return super().get_lookup(name) if name in self.ALLOWED_LOOKUPS else None
64+
65+
66+
class EMFArrayRHSMixin:
67+
def check_lhs(self):
68+
if not isinstance(self.lhs, KeyTransform):
69+
raise ValueError(
70+
"Cannot apply this lookup directly to EmbeddedModelArrayField. "
71+
"Try querying one of its embedded fields instead."
72+
)
73+
74+
def process_rhs(self, compiler, connection):
75+
values = self.rhs
76+
# Value must be serealized based on the query target.
77+
# If querying a subfield inside the array (i.e., a nested KeyTransform), use the output
78+
# field of the subfield. Otherwise, use the base field of the array itself.
79+
get_db_prep_value = self.lhs._lhs.output_field.get_db_prep_value
80+
return None, [get_db_prep_value(values, connection, prepared=True)]
81+
82+
83+
@EmbeddedModelArrayField.register_lookup
84+
class EMFArrayExact(EMFArrayRHSMixin, lookups.Exact):
85+
def as_mql(self, compiler, connection):
86+
self.check_lhs()
87+
lhs_mql, inner_lhs_mql = process_lhs(self, compiler, connection)
88+
value = process_rhs(self, compiler, connection)
89+
return {
90+
"$anyElementTrue": {
91+
"$ifNull": [
92+
{
93+
"$map": {
94+
"input": lhs_mql,
95+
"as": "item",
96+
"in": {"$eq": [inner_lhs_mql, value]},
97+
}
98+
},
99+
[],
100+
]
101+
}
102+
}
103+
104+
105+
@EmbeddedModelArrayField.register_lookup
106+
class ArrayOverlap(EMFArrayRHSMixin, Lookup):
107+
lookup_name = "overlap"
108+
109+
def as_mql(self, compiler, connection):
110+
self.check_lhs()
111+
# Querying a subfield within the array elements (via nested KeyTransform).
112+
# Replicates MongoDB's implicit ANY-match by mapping over the array and applying
113+
# `$in` on the subfield.
114+
lhs_mql = process_lhs(self, compiler, connection)
115+
values = process_rhs(self, compiler, connection)
116+
lhs_mql, inner_lhs_mql = lhs_mql
117+
return {
118+
"$anyElementTrue": {
119+
"$ifNull": [
120+
{
121+
"$map": {
122+
"input": lhs_mql,
123+
"as": "item",
124+
"in": {"$in": [inner_lhs_mql, values]},
125+
}
126+
},
127+
[],
128+
]
129+
}
130+
}
131+
132+
133+
class KeyTransform(Transform):
134+
def __init__(self, key_name, array_field, *args, **kwargs):
135+
super().__init__(*args, **kwargs)
136+
self.array_field = array_field
137+
self.key_name = key_name
138+
# The iteration items begins from the base_field, a virtual column with
139+
# base field output type is created.
140+
column_target = array_field.embedded_model._meta.get_field(key_name).clone()
141+
column_name = f"$item.{key_name}"
142+
column_target.db_column = column_name
143+
column_target.set_attributes_from_name(column_name)
144+
self._lhs = Col(None, column_target)
145+
self._sub_transform = None
146+
147+
def __call__(self, this, *args, **kwargs):
148+
self._lhs = self._sub_transform(self._lhs, *args, **kwargs)
149+
return self
150+
151+
def get_lookup(self, name):
152+
return self.output_field.get_lookup(name)
153+
154+
def get_transform(self, name):
155+
"""
156+
Validate that `name` is either a field of an embedded model or a
157+
lookup on an embedded model's field.
158+
"""
159+
# Once the sub lhs is a transform, all the filter are applied over it.
160+
# Otherwise get transform from EMF.
161+
if transform := self._lhs.get_transform(name):
162+
self._sub_transform = transform
163+
return self
164+
output_field = self._lhs.output_field
165+
suggested_lookups = difflib.get_close_matches(name, output_field.get_lookups())
166+
if suggested_lookups:
167+
suggested_lookups = " or ".join(suggested_lookups)
168+
suggestion = f", perhaps you meant {suggested_lookups}?"
169+
else:
170+
suggestion = ""
171+
raise FieldDoesNotExist(
172+
f"Unsupported lookup '{name}' for "
173+
f"EmbeddedModelArrayField of '{output_field.__class__.__name__}'"
174+
f"{suggestion}"
175+
)
176+
177+
def as_mql(self, compiler, connection):
178+
inner_lhs_mql = self._lhs.as_mql(compiler, connection)
179+
lhs_mql = process_lhs(self, compiler, connection)
180+
return lhs_mql, inner_lhs_mql
181+
182+
@property
183+
def output_field(self):
184+
return self.array_field
185+
186+
187+
class KeyTransformFactory:
188+
def __init__(self, key_name, base_field):
189+
self.key_name = key_name
190+
self.base_field = base_field
191+
192+
def __call__(self, *args, **kwargs):
193+
return KeyTransform(self.key_name, self.base_field, *args, **kwargs)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
from .functions import register_functions
2+
from .lookups import register_lookups
23

34
register_functions()
5+
register_lookups()

django_mongodb_backend_gis/features.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,10 @@ def django_test_expected_failures(self):
1616
# SRIDs aren't supported.
1717
"gis_tests.geogapp.tests.GeographyTest.test05_geography_layermapping",
1818
"gis_tests.geoapp.tests.GeoModelTest.test_proxy",
19-
# 'WithinLookup' object has no attribute 'as_mql'
20-
# "gis_tests.relatedapp.tests.RelatedGeoModelTest.test10_combine",
2119
# GEOSException: Calling transform() with no SRID set is not supported
2220
"gis_tests.relatedapp.tests.RelatedGeoModelTest.test06_f_expressions",
2321
# 'Adapter' object has no attribute 'srid'
2422
"gis_tests.geoapp.test_expressions.GeoExpressionsTests.test_geometry_value_annotation",
25-
# To triage:
26-
"gis_tests.geoapp.test_expressions.GeoExpressionsTests.test_multiple_annotation",
27-
# 'ContainsLookup' object has no attribute 'as_mql'.
28-
"gis_tests.geoapp.test_regress.GeoRegressionTests.test_empty_count",
29-
"gis_tests.geoapp.tests.GeoLookupTest.test_contains_contained_lookups",
3023
# Object of type ObjectId is not JSON serializable
3124
"gis_tests.geoapp.test_serializers.GeoJSONSerializerTests.test_fields_option",
3225
# LinearRing requires at least 4 points, got 1.
@@ -36,14 +29,9 @@ def django_test_expected_failures(self):
3629
"gis_tests.geoapp.test_serializers.GeoJSONSerializerTests.test_id_field_option",
3730
"gis_tests.geoapp.test_serializers.GeoJSONSerializerTests.test_serialization_base",
3831
"gis_tests.geoapp.test_serializers.GeoJSONSerializerTests.test_srid_option",
39-
# 'DisjointLookup' object has no attribute 'as_mql'
40-
"gis_tests.geoapp.tests.GeoLookupTest.test_disjoint_lookup",
41-
# 'SameAsLookup' object has no attribute 'as_mql'
42-
"gis_tests.geoapp.tests.GeoLookupTest.test_equals_lookups",
43-
# 'WithinLookup' object has no attribute 'as_mql'
44-
"gis_tests.geoapp.tests.GeoLookupTest.test_subquery_annotation",
45-
"gis_tests.geoapp.tests.GeoQuerySetTest.test_within_subquery",
46-
# issubclass() arg 1 must be a class
32+
# GeometryField is not supported (the type of Geometry isn't
33+
# stored so that it can be initialized by the database converter).
34+
# Error in database converter: issubclass() arg 1 must be a class
4735
"gis_tests.geoapp.tests.GeoModelTest.test_geometryfield",
4836
# KeyError: 'within' connection.ops.gis_operators[self.lookup_name]
4937
"gis_tests.geoapp.tests.GeoModelTest.test_gis_query_as_string",

django_mongodb_backend_gis/lookups.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from django.contrib.gis.db.models.lookups import GISLookup
2+
from django.db import NotSupportedError
3+
4+
5+
def gis_lookup(self, compiler, connection): # noqa: ARG001
6+
raise NotSupportedError(f"MongoDB does not support the {self.lookup_name} lookup.")
7+
8+
9+
def register_lookups():
10+
GISLookup.as_mql = gis_lookup

0 commit comments

Comments
 (0)