@@ -61,8 +61,25 @@ def get_transform(self, name):
61
61
return KeyTransformFactory (name , self )
62
62
63
63
64
+ class EMFArrayRHSMixin :
65
+ def process_rhs (self , compiler , connection ):
66
+ values = self .rhs
67
+ if not self .get_db_prep_lookup_value_is_iterable :
68
+ values = [values ]
69
+ # Compute how to serialize each value based on the query target.
70
+ # If querying a subfield inside the array (i.e., a nested KeyTransform), use the output
71
+ # field of the subfield. Otherwise, use the base field of the array itself.
72
+ if isinstance (self .lhs , KeyTransform ):
73
+ get_db_prep_value = self .lhs ._lhs .output_field .get_db_prep_value
74
+ else :
75
+ get_db_prep_value = self .lhs .output_field .base_field .get_db_prep_value
76
+ return None , [get_db_prep_value (v , connection , prepared = True ) for v in values ]
77
+
78
+
64
79
@EmbeddedModelArrayField .register_lookup
65
- class EMFArrayExact (lookups .Exact ):
80
+ class EMFArrayExact (EMFArrayRHSMixin , lookups .Exact ):
81
+ get_db_prep_lookup_value_is_iterable = False
82
+
66
83
def as_mql (self , compiler , connection ):
67
84
if not isinstance (self .lhs , KeyTransform ):
68
85
raise ValueError ("error" )
@@ -85,22 +102,9 @@ def as_mql(self, compiler, connection):
85
102
86
103
87
104
@EmbeddedModelArrayField .register_lookup
88
- class ArrayOverlap (Lookup ):
105
+ class ArrayOverlap (EMFArrayRHSMixin , Lookup ):
89
106
lookup_name = "overlap"
90
- get_db_prep_lookup_value_is_iterable = True
91
-
92
- def process_rhs (self , compiler , connection ):
93
- values = self .rhs
94
- if self .get_db_prep_lookup_value_is_iterable :
95
- values = [values ]
96
- # Compute how to serialize each value based on the query target.
97
- # If querying a subfield inside the array (i.e., a nested KeyTransform), use the output
98
- # field of the subfield. Otherwise, use the base field of the array itself.
99
- if isinstance (self .lhs , KeyTransform ):
100
- get_db_prep_value = self .lhs ._lhs .output_field .get_db_prep_value
101
- else :
102
- get_db_prep_value = self .lhs .output_field .base_field .get_db_prep_value
103
- return None , [get_db_prep_value (v , connection , prepared = True ) for v in values ]
107
+ get_db_prep_lookup_value_is_iterable = False
104
108
105
109
def as_mql (self , compiler , connection ):
106
110
# Querying a subfield within the array elements (via nested KeyTransform).
0 commit comments