@@ -49,11 +49,20 @@ def get_transform(self, name):
49
49
transform = super ().get_transform (name )
50
50
if transform :
51
51
return transform
52
- return KeyTransformFactory (name , self .base_field )
52
+ return KeyTransformFactory (name , self )
53
+
54
+
55
+ class ProcessRHSMixin :
56
+ def process_rhs (self , compiler , connection ):
57
+ if isinstance (self .lhs , KeyTransform ):
58
+ get_db_prep_value = self .lhs ._lhs .output_field .get_db_prep_value
59
+ else :
60
+ get_db_prep_value = self .lhs .output_field .get_db_prep_value
61
+ return None , [get_db_prep_value (v , connection , prepared = True ) for v in self .rhs ]
53
62
54
63
55
64
@EmbeddedModelArrayField .register_lookup
56
- class EMFArrayExact (EMFExact ):
65
+ class EMFArrayExact (EMFExact , ProcessRHSMixin ):
57
66
def as_mql (self , compiler , connection ):
58
67
lhs_mql = process_lhs (self , compiler , connection )
59
68
value = process_rhs (self , compiler , connection )
@@ -95,15 +104,61 @@ def as_mql(self, compiler, connection):
95
104
}
96
105
97
106
107
+ @EmbeddedModelArrayField .register_lookup
108
+ class ArrayOverlap (EMFExact , ProcessRHSMixin ):
109
+ lookup_name = "overlap"
110
+
111
+ def as_mql (self , compiler , connection ):
112
+ lhs_mql = process_lhs (self , compiler , connection )
113
+ values = process_rhs (self , compiler , connection )
114
+ if isinstance (self .lhs , KeyTransform ):
115
+ lhs_mql , inner_lhs_mql = lhs_mql
116
+ return {
117
+ "$anyElementTrue" : {
118
+ "$ifNull" : [
119
+ {
120
+ "$map" : {
121
+ "input" : lhs_mql ,
122
+ "as" : "item" ,
123
+ "in" : {"$in" : [inner_lhs_mql , values ]},
124
+ }
125
+ },
126
+ [],
127
+ ]
128
+ }
129
+ }
130
+ conditions = []
131
+ inner_lhs_mql = "$$item"
132
+ for value in values :
133
+ if isinstance (value , models .Model ):
134
+ value , emf_data = self .model_to_dict (value )
135
+ # Get conditions for any nested EmbeddedModelFields.
136
+ conditions .append ({"$and" : self .get_conditions ({inner_lhs_mql : (value , emf_data )})})
137
+ return {
138
+ "$anyElementTrue" : {
139
+ "$ifNull" : [
140
+ {
141
+ "$map" : {
142
+ "input" : lhs_mql ,
143
+ "as" : "item" ,
144
+ "in" : {"$or" : conditions },
145
+ }
146
+ },
147
+ [],
148
+ ]
149
+ }
150
+ }
151
+
152
+
98
153
class KeyTransform (Transform ):
99
154
# it should be different class than EMF keytransform even most of the methods are equal.
100
- def __init__ (self , key_name , base_field , * args , ** kwargs ):
155
+ def __init__ (self , key_name , array_field , * args , ** kwargs ):
101
156
super ().__init__ (* args , ** kwargs )
102
- self .base_field = base_field
157
+ self .array_field = array_field
103
158
self .key_name = key_name
104
159
# The iteration items begins from the base_field, a virtual column with
105
160
# base field output type is created.
106
- column_target = base_field .clone ()
161
+ column_target = array_field . base_field . embedded_model . _meta . get_field ( key_name ) .clone ()
107
162
column_name = f"$item.{ key_name } "
108
163
column_target .db_column = column_name
109
164
column_target .set_attributes_from_name (column_name )
@@ -126,7 +181,7 @@ def _get_missing_field_or_lookup_exception(self, lhs, name):
126
181
suggestion = "."
127
182
raise FieldDoesNotExist (
128
183
f"Unsupported lookup '{ name } ' for "
129
- f"{ self .base_field .__class__ .__name__ } '{ self .base_field .name } '"
184
+ f"{ self .array_field . base_field .__class__ .__name__ } '{ self . array_field .base_field .name } '"
130
185
f"{ suggestion } "
131
186
)
132
187
@@ -139,7 +194,9 @@ def get_transform(self, name):
139
194
transform = (
140
195
self ._lhs .get_transform (name )
141
196
if isinstance (self ._lhs , Transform )
142
- else self .base_field .embedded_model ._meta .get_field (self .key_name ).get_transform (name )
197
+ else self .array_field .base_field .embedded_model ._meta .get_field (
198
+ self .key_name
199
+ ).get_transform (name )
143
200
)
144
201
if transform :
145
202
self ._sub_transform = transform
@@ -155,7 +212,7 @@ def as_mql(self, compiler, connection):
155
212
156
213
@property
157
214
def output_field (self ):
158
- return EmbeddedModelArrayField ( self .base_field )
215
+ return self .array_field
159
216
160
217
161
218
class KeyTransformFactory :
0 commit comments