Skip to content

Commit 9acb7a1

Browse files
committed
add EmbeddedModelField support to ArrayField
1 parent f6d7abc commit 9acb7a1

File tree

3 files changed

+80
-4
lines changed

3 files changed

+80
-4
lines changed

django_mongodb_backend/fields/array.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,12 @@ def db_type(self, connection):
126126
return "array"
127127

128128
def get_db_prep_value(self, value, connection, prepared=False):
129+
from ..fields import EmbeddedModelField
130+
129131
if isinstance(value, list | tuple):
130-
# Workaround for https://code.djangoproject.com/ticket/35982
131-
# (fixed in Django 5.2).
132-
if isinstance(self.base_field, DecimalField):
132+
# DecimalField here is a workaround for
133+
# https://code.djangoproject.com/ticket/35982 (fixed in Django 5.2).
134+
if isinstance(self.base_field, (DecimalField | EmbeddedModelField)):
133135
return [self.base_field.get_db_prep_save(i, connection) for i in value]
134136
return [self.base_field.get_db_prep_value(i, connection, prepared=False) for i in value]
135137
return value

tests/model_fields_/models.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,20 @@ class D(EmbeddedModel):
158158
class E(EmbeddedModel):
159159
name = models.CharField(max_length=100)
160160
value = models.IntegerField()
161+
162+
163+
# ArrayField + EmbeddedModelField
164+
class Review(EmbeddedModel):
165+
title = models.CharField(max_length=255)
166+
rating = models.IntegerField()
167+
168+
def __str__(self):
169+
return self.title
170+
171+
172+
class Movie(models.Model):
173+
title = models.CharField(max_length=255)
174+
reviews = ArrayField(EmbeddedModelField(Review), null=True)
175+
176+
def __str__(self):
177+
return self.title

tests/model_fields_/test_embedded_model.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from django_mongodb_backend.fields import EmbeddedModelField
1818
from django_mongodb_backend.models import EmbeddedModel
1919

20-
from .models import A, Address, Author, B, Book, C, D, Data, E, Holder, Library
20+
from .models import A, Address, Author, B, Book, C, D, Data, E, Holder, Library, Movie, Review
2121
from .utils import truncate_ms
2222

2323

@@ -83,6 +83,63 @@ def test_pre_save(self):
8383
self.assertGreater(obj.data.auto_now, auto_now_two)
8484

8585

86+
class EmbeddedArrayTests(TestCase):
87+
def test_save_load(self):
88+
reviews = [
89+
Review(title="The best", rating=10),
90+
Review(title="Mediocre", rating=5),
91+
Review(title="Horrible", rating=1),
92+
]
93+
Movie.objects.create(title="Lion King", reviews=reviews)
94+
movie = Movie.objects.get(title="Lion King")
95+
self.assertEqual(movie.reviews[0].title, "The best")
96+
self.assertEqual(movie.reviews[0].rating, 10)
97+
self.assertEqual(movie.reviews[1].title, "Mediocre")
98+
self.assertEqual(movie.reviews[1].rating, 5)
99+
self.assertEqual(movie.reviews[2].title, "Horrible")
100+
self.assertEqual(movie.reviews[2].rating, 1)
101+
self.assertEqual(len(movie.reviews), 3)
102+
103+
def test_save_load_null(self):
104+
movie = Movie.objects.create(title="Lion King")
105+
movie = Movie.objects.get(title="Lion King")
106+
self.assertIsNone(movie.reviews)
107+
108+
109+
class EmbeddedArrayQueryingTests(TestCase):
110+
@classmethod
111+
def setUpTestData(cls):
112+
reviews = [
113+
Review(title="The best", rating=10),
114+
Review(title="Mediocre", rating=5),
115+
Review(title="Horrible", rating=1),
116+
]
117+
cls.clouds = Movie.objects.create(title="Clouds", reviews=reviews)
118+
reviews = [
119+
Review(title="Super", rating=9),
120+
Review(title="Meh", rating=5),
121+
Review(title="Horrible", rating=2),
122+
]
123+
cls.frozen = Movie.objects.create(title="Frozen", reviews=reviews)
124+
reviews = [
125+
Review(title="Excellent", rating=9),
126+
Review(title="Wow", rating=8),
127+
Review(title="Classic", rating=7),
128+
]
129+
cls.bears = Movie.objects.create(title="Bears", reviews=reviews)
130+
131+
def test_filter_with_field(self):
132+
self.assertCountEqual(
133+
Movie.objects.filter(reviews__title="Horrible"), [self.clouds, self.frozen]
134+
)
135+
136+
def test_filter_with_model(self):
137+
self.assertCountEqual(
138+
Movie.objects.filter(reviews=Review(title="Horrible", rating=2)),
139+
[self.clouds, self.frozen],
140+
)
141+
142+
86143
class QueryingTests(TestCase):
87144
@classmethod
88145
def setUpTestData(cls):

0 commit comments

Comments
 (0)