3
3
from django .core .exceptions import FieldDoesNotExist
4
4
from django .db import models
5
5
from django .db .models .expressions import Col
6
- from django .db .models .lookups import Transform
6
+ from django .db .models .lookups import Lookup , Transform
7
7
8
8
from ..forms import EmbeddedModelArrayFormField
9
9
from ..query_utils import process_lhs , process_rhs
10
10
from . import EmbeddedModelField
11
11
from .array import ArrayField
12
- from .embedded_model import EMFExact
12
+ from .embedded_model import EMFExact , EMFMixin
13
13
14
14
15
15
class EmbeddedModelArrayField (ArrayField ):
@@ -52,17 +52,8 @@ def get_transform(self, name):
52
52
return KeyTransformFactory (name , self )
53
53
54
54
55
- class ProcessRHSMixin :
56
- def process_rhs (self , compiler , connection ):
57
- if isinstance (self .lhs , KeyTransform ):
58
- get_db_prep_value = self .lhs ._lhs .output_field .get_db_prep_value
59
- else :
60
- get_db_prep_value = self .lhs .output_field .get_db_prep_value
61
- return None , [get_db_prep_value (v , connection , prepared = True ) for v in self .rhs ]
62
-
63
-
64
55
@EmbeddedModelArrayField .register_lookup
65
- class EMFArrayExact (EMFExact , ProcessRHSMixin ):
56
+ class EMFArrayExact (EMFExact ):
66
57
def as_mql (self , compiler , connection ):
67
58
lhs_mql = process_lhs (self , compiler , connection )
68
59
value = process_rhs (self , compiler , connection )
@@ -105,12 +96,29 @@ def as_mql(self, compiler, connection):
105
96
106
97
107
98
@EmbeddedModelArrayField .register_lookup
108
- class ArrayOverlap (EMFExact , ProcessRHSMixin ):
99
+ class ArrayOverlap (EMFMixin , Lookup ):
109
100
lookup_name = "overlap"
101
+ get_db_prep_lookup_value_is_iterable = True
102
+
103
+ def process_rhs (self , compiler , connection ):
104
+ values = self .rhs
105
+ if self .get_db_prep_lookup_value_is_iterable :
106
+ values = [values ]
107
+ # Compute how to serialize each value based on the query target.
108
+ # If querying a subfield inside the array (i.e., a nested KeyTransform), use the output
109
+ # field of the subfield. Otherwise, use the base field of the array itself.
110
+ if isinstance (self .lhs , KeyTransform ):
111
+ get_db_prep_value = self .lhs ._lhs .output_field .get_db_prep_value
112
+ else :
113
+ get_db_prep_value = self .lhs .output_field .base_field .get_db_prep_value
114
+ return None , [get_db_prep_value (v , connection , prepared = True ) for v in values ]
110
115
111
116
def as_mql (self , compiler , connection ):
112
117
lhs_mql = process_lhs (self , compiler , connection )
113
118
values = process_rhs (self , compiler , connection )
119
+ # Querying a subfield within the array elements (via nested KeyTransform).
120
+ # Replicates MongoDB's implicit ANY-match by mapping over the array and applying
121
+ # `$in` on the subfield.
114
122
if isinstance (self .lhs , KeyTransform ):
115
123
lhs_mql , inner_lhs_mql = lhs_mql
116
124
return {
@@ -129,11 +137,12 @@ def as_mql(self, compiler, connection):
129
137
}
130
138
conditions = []
131
139
inner_lhs_mql = "$$item"
140
+ # Querying full embedded documents in the array.
141
+ # Builds `$or` conditions and maps them over the array to match any full document.
132
142
for value in values :
133
- if isinstance (value , models .Model ):
134
- value , emf_data = self .model_to_dict (value )
135
- # Get conditions for any nested EmbeddedModelFields.
136
- conditions .append ({"$and" : self .get_conditions ({inner_lhs_mql : (value , emf_data )})})
143
+ value , emf_data = self .model_to_dict (value )
144
+ # Get conditions for any nested EmbeddedModelFields.
145
+ conditions .append ({"$and" : self .get_conditions ({inner_lhs_mql : (value , emf_data )})})
137
146
return {
138
147
"$anyElementTrue" : {
139
148
"$ifNull" : [
0 commit comments