1
1
import difflib
2
2
3
3
from django .core .exceptions import FieldDoesNotExist
4
+ from django .db import models
5
+ from django .db .models .expressions import Col
4
6
from django .db .models .lookups import Transform
5
7
6
8
from ..forms import EmbeddedModelArrayFormField
7
9
from ..query_utils import process_lhs , process_rhs
8
10
from . import EmbeddedModelField
9
11
from .array import ArrayField
10
12
from .embedded_model import EMFExact
11
- from .json import build_json_mql_path
12
13
13
14
14
15
class EmbeddedModelArrayField (ArrayField ):
@@ -55,18 +56,29 @@ def get_transform(self, name):
55
56
@EmbeddedModelArrayField .register_lookup
56
57
class EMFArrayExact (EMFExact ):
57
58
def as_mql (self , compiler , connection ):
58
- mql , key_transforms , json_key_transforms = self .lhs .preprocess_lhs (compiler , connection )
59
- # TODO, maybe a new flow of transform query must be build
60
- # this part must merge the two part of the transform train.
59
+ lhs_mql = process_lhs (self , compiler , connection )
61
60
value = process_rhs (self , compiler , connection )
62
- transforms = build_json_mql_path ("$$this" , key_transforms )
63
- return {
64
- "$reduce" : {
65
- "input" : mql ,
66
- "initialValue" : False ,
67
- "in" : {"$or" : ["$$value" , {"$eq" : [f"$$this.{ transforms } " , value ]}]},
61
+ if isinstance (self .lhs , Col | KeyTransform ):
62
+ if isinstance (value , models .Model ):
63
+ value , emf_data = self .model_to_dict (value )
64
+ # Get conditions for any nested EmbeddedModelFields.
65
+ conditions = self .get_conditions ({"$$item" : (value , emf_data )})
66
+ return {
67
+ "$anyElementTrue" : {
68
+ "$map" : {"input" : lhs_mql , "as" : "item" , "in" : {"$and" : conditions }}
69
+ }
70
+ }
71
+ lhs_mql = process_lhs (self .lhs , compiler , connection )
72
+ return {
73
+ "$anyElementTrue" : {
74
+ "$map" : {
75
+ "input" : lhs_mql ,
76
+ "as" : "item" ,
77
+ "in" : {"$eq" : [f"$$item.{ self .lhs .key_name } " , value ]},
78
+ }
79
+ }
68
80
}
69
- }
81
+ return connection . mongo_operators [ self . lookup_name ]( lhs_mql , value )
70
82
71
83
72
84
class KeyTransform (Transform ):
@@ -77,7 +89,7 @@ def __init__(self, key_name, ref_field, *args, **kwargs):
77
89
self .ref_field = ref_field
78
90
79
91
def get_lookup (self , name ):
80
- return self .ref_field .get_lookup (name )
92
+ return self .output_field .get_lookup (name )
81
93
82
94
def get_transform (self , name ):
83
95
"""
@@ -99,11 +111,12 @@ def get_transform(self, name):
99
111
)
100
112
101
113
def as_mql (self , compiler , connection ):
102
- return process_lhs (self , compiler , connection )
114
+ lhs_mql = process_lhs (self , compiler , connection )
115
+ return f"{ lhs_mql } .{ self .key_name } "
103
116
104
117
@property
105
118
def output_field (self ):
106
- return self .ref_field
119
+ return EmbeddedModelArrayField ( self .ref_field )
107
120
108
121
109
122
class KeyTransformFactory :
0 commit comments