Skip to content

Commit f32ee9e

Browse files
committed
Fix rhs serialization
1 parent c1f8ae5 commit f32ee9e

File tree

1 file changed

+20
-16
lines changed

1 file changed

+20
-16
lines changed

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,25 @@ def get_transform(self, name):
6161
return KeyTransformFactory(name, self)
6262

6363

64+
class EMFArrayRHSMixin:
65+
def process_rhs(self, compiler, connection):
66+
values = self.rhs
67+
if not self.get_db_prep_lookup_value_is_iterable:
68+
values = [values]
69+
# Compute how to serialize each value based on the query target.
70+
# If querying a subfield inside the array (i.e., a nested KeyTransform), use the output
71+
# field of the subfield. Otherwise, use the base field of the array itself.
72+
if isinstance(self.lhs, KeyTransform):
73+
get_db_prep_value = self.lhs._lhs.output_field.get_db_prep_value
74+
else:
75+
get_db_prep_value = self.lhs.output_field.base_field.get_db_prep_value
76+
return None, [get_db_prep_value(v, connection, prepared=True) for v in values]
77+
78+
6479
@EmbeddedModelArrayField.register_lookup
65-
class EMFArrayExact(lookups.Exact):
80+
class EMFArrayExact(EMFArrayRHSMixin, lookups.Exact):
81+
get_db_prep_lookup_value_is_iterable = False
82+
6683
def as_mql(self, compiler, connection):
6784
if not isinstance(self.lhs, KeyTransform):
6885
raise ValueError("error")
@@ -85,22 +102,9 @@ def as_mql(self, compiler, connection):
85102

86103

87104
@EmbeddedModelArrayField.register_lookup
88-
class ArrayOverlap(Lookup):
105+
class ArrayOverlap(EMFArrayRHSMixin, Lookup):
89106
lookup_name = "overlap"
90-
get_db_prep_lookup_value_is_iterable = True
91-
92-
def process_rhs(self, compiler, connection):
93-
values = self.rhs
94-
if self.get_db_prep_lookup_value_is_iterable:
95-
values = [values]
96-
# Compute how to serialize each value based on the query target.
97-
# If querying a subfield inside the array (i.e., a nested KeyTransform), use the output
98-
# field of the subfield. Otherwise, use the base field of the array itself.
99-
if isinstance(self.lhs, KeyTransform):
100-
get_db_prep_value = self.lhs._lhs.output_field.get_db_prep_value
101-
else:
102-
get_db_prep_value = self.lhs.output_field.base_field.get_db_prep_value
103-
return None, [get_db_prep_value(v, connection, prepared=True) for v in values]
107+
get_db_prep_lookup_value_is_iterable = False
104108

105109
def as_mql(self, compiler, connection):
106110
# Querying a subfield within the array elements (via nested KeyTransform).

0 commit comments

Comments
 (0)