@@ -58,55 +58,56 @@ class EMFArrayExact(EMFExact):
58
58
def as_mql (self , compiler , connection ):
59
59
lhs_mql = process_lhs (self , compiler , connection )
60
60
value = process_rhs (self , compiler , connection )
61
- if isinstance (self .lhs , Col | KeyTransform ):
62
- if isinstance (self .lhs , Col ):
63
- inner_lhs_mql = "$$item"
64
- else :
65
- lhs_mql , inner_lhs_mql = lhs_mql
66
- if isinstance (value , models .Model ):
67
- value , emf_data = self .model_to_dict (value )
68
- # Get conditions for any nested EmbeddedModelFields.
69
- conditions = self .get_conditions ({inner_lhs_mql : (value , emf_data )})
70
- return {
71
- "$anyElementTrue" : {
72
- "$ifNull" : [
73
- {
74
- "$map" : {
75
- "input" : lhs_mql ,
76
- "as" : "item" ,
77
- "in" : {"$and" : conditions },
78
- }
79
- },
80
- [],
81
- ]
82
- }
83
- }
61
+ if isinstance (self .lhs , KeyTransform ):
62
+ lhs_mql , inner_lhs_mql = lhs_mql
63
+ else :
64
+ inner_lhs_mql = "$$item"
65
+ if isinstance (value , models .Model ):
66
+ value , emf_data = self .model_to_dict (value )
67
+ # Get conditions for any nested EmbeddedModelFields.
68
+ conditions = self .get_conditions ({inner_lhs_mql : (value , emf_data )})
84
69
return {
85
70
"$anyElementTrue" : {
86
71
"$ifNull" : [
87
72
{
88
73
"$map" : {
89
74
"input" : lhs_mql ,
90
75
"as" : "item" ,
91
- "in" : {"$eq " : [ inner_lhs_mql , value ] },
76
+ "in" : {"$and " : conditions },
92
77
}
93
78
},
94
79
[],
95
80
]
96
81
}
97
82
}
98
- return connection .mongo_operators [self .lookup_name ](lhs_mql , value )
83
+ return {
84
+ "$anyElementTrue" : {
85
+ "$ifNull" : [
86
+ {
87
+ "$map" : {
88
+ "input" : lhs_mql ,
89
+ "as" : "item" ,
90
+ "in" : {"$eq" : [inner_lhs_mql , value ]},
91
+ }
92
+ },
93
+ [],
94
+ ]
95
+ }
96
+ }
99
97
100
98
101
99
class KeyTransform (Transform ):
102
100
# it should be different class than EMF keytransform even most of the methods are equal.
103
101
def __init__ (self , key_name , base_field , * args , ** kwargs ):
104
102
super ().__init__ (* args , ** kwargs )
105
103
self .base_field = base_field
106
- # TODO: Need to create a column, will refactor this thing.
104
+ self .key_name = key_name
105
+ # The iteration items begins from the base_field, a virtual column with
106
+ # base field output type is created.
107
107
column_target = base_field .clone ()
108
- column_target .db_column = f"$item.{ key_name } "
109
- column_target .set_attributes_from_name (f"$item.{ key_name } " )
108
+ column_name = f"$item.{ key_name } "
109
+ column_target .db_column = column_name
110
+ column_target .set_attributes_from_name (column_name )
110
111
self ._lhs = Col (None , column_target )
111
112
self ._sub_transform = None
112
113
@@ -117,19 +118,8 @@ def __call__(self, this, *args, **kwargs):
117
118
def get_lookup (self , name ):
118
119
return self .output_field .get_lookup (name )
119
120
120
- def get_transform (self , name ):
121
- """
122
- Validate that `name` is either a field of an embedded model or a
123
- lookup on an embedded model's field.
124
- """
125
- if isinstance (self ._lhs , Transform ):
126
- transform = self ._lhs .get_transform (name )
127
- else :
128
- transform = self .base_field .get_transform (name )
129
- if transform :
130
- self ._sub_transform = transform
131
- return self
132
- suggested_lookups = difflib .get_close_matches (name , self .base_field .get_lookups ())
121
+ def _get_missing_field_or_lookup_exception (self , lhs , name ):
122
+ suggested_lookups = difflib .get_close_matches (name , lhs .get_lookups ())
133
123
if suggested_lookups :
134
124
suggested_lookups = " or " .join (suggested_lookups )
135
125
suggestion = f", perhaps you meant { suggested_lookups } ?"
@@ -141,6 +131,25 @@ def get_transform(self, name):
141
131
f"{ suggestion } "
142
132
)
143
133
134
+ def get_transform (self , name ):
135
+ """
136
+ Validate that `name` is either a field of an embedded model or a
137
+ lookup on an embedded model's field.
138
+ """
139
+ # Once the sub lhs is a transform, all the filter are applied over it.
140
+
141
+ transform = (
142
+ self ._lhs .get_transform (name )
143
+ if isinstance (self ._lhs , Transform )
144
+ else self .base_field .get_transform (name )
145
+ )
146
+ if transform :
147
+ self ._sub_transform = transform
148
+ return self
149
+ raise self ._get_missing_field_or_lookup_exception (
150
+ self ._lhs if isinstance (self ._lhs , Transform ) else self .base_field , name
151
+ )
152
+
144
153
def as_mql (self , compiler , connection ):
145
154
inner_lhs_mql = self ._lhs .as_mql (compiler , connection )
146
155
lhs_mql = process_lhs (self , compiler , connection )
0 commit comments