diff --git a/elasticsearch/dsl/field.py b/elasticsearch/dsl/field.py index 1aa7a4bca..d4c5a6e76 100644 --- a/elasticsearch/dsl/field.py +++ b/elasticsearch/dsl/field.py @@ -119,9 +119,16 @@ def __init__( def __getitem__(self, subfield: str) -> "Field": return cast(Field, self._params.get("fields", {})[subfield]) - def _serialize(self, data: Any) -> Any: + def _serialize(self, data: Any, skip_empty: bool) -> Any: return data + def _safe_serialize(self, data: Any, skip_empty: bool) -> Any: + try: + return self._serialize(data, skip_empty) + except TypeError: + # older method signature, without skip_empty + return self._serialize(data) # type: ignore[call-arg] + def _deserialize(self, data: Any) -> Any: return data @@ -133,10 +140,16 @@ def empty(self) -> Optional[Any]: return AttrList([]) return self._empty() - def serialize(self, data: Any) -> Any: + def serialize(self, data: Any, skip_empty: bool = True) -> Any: if isinstance(data, (list, AttrList, tuple)): - return list(map(self._serialize, cast(Iterable[Any], data))) - return self._serialize(data) + return list( + map( + self._safe_serialize, + cast(Iterable[Any], data), + [skip_empty] * len(data), + ) + ) + return self._safe_serialize(data, skip_empty) def deserialize(self, data: Any) -> Any: if isinstance(data, (list, AttrList, tuple)): @@ -186,7 +199,7 @@ def _deserialize(self, data: Any) -> Range["_SupportsComparison"]: data = {k: self._core_field.deserialize(v) for k, v in data.items()} # type: ignore[union-attr] return Range(data) - def _serialize(self, data: Any) -> Optional[Dict[str, Any]]: + def _serialize(self, data: Any, skip_empty: bool) -> Optional[Dict[str, Any]]: if data is None: return None if not isinstance(data, collections.abc.Mapping): @@ -550,7 +563,7 @@ def _deserialize(self, data: Any) -> "InnerDoc": return self._wrap(data) def _serialize( - self, data: Optional[Union[Dict[str, Any], "InnerDoc"]] + self, data: Optional[Union[Dict[str, Any], "InnerDoc"]], skip_empty: bool ) -> Optional[Dict[str, Any]]: if data is None: return None @@ -559,7 +572,7 @@ def _serialize( if isinstance(data, collections.abc.Mapping): return data - return data.to_dict() + return data.to_dict(skip_empty=skip_empty) def clean(self, data: Any) -> Any: data = super().clean(data) @@ -768,7 +781,7 @@ def clean(self, data: str) -> str: def _deserialize(self, data: Any) -> bytes: return base64.b64decode(data) - def _serialize(self, data: Any) -> Optional[str]: + def _serialize(self, data: Any, skip_empty: bool) -> Optional[str]: if data is None: return None return base64.b64encode(data).decode() @@ -2619,7 +2632,7 @@ def _deserialize(self, data: Any) -> Union["IPv4Address", "IPv6Address"]: # the ipaddress library for pypy only accepts unicode. return ipaddress.ip_address(unicode(data)) - def _serialize(self, data: Any) -> Optional[str]: + def _serialize(self, data: Any, skip_empty: bool) -> Optional[str]: if data is None: return None return str(data) @@ -3367,7 +3380,7 @@ def __init__( def _deserialize(self, data: Any) -> "Query": return Q(data) # type: ignore[no-any-return] - def _serialize(self, data: Any) -> Optional[Dict[str, Any]]: + def _serialize(self, data: Any, skip_empty: bool) -> Optional[Dict[str, Any]]: if data is None: return None return data.to_dict() # type: ignore[no-any-return] diff --git a/elasticsearch/dsl/utils.py b/elasticsearch/dsl/utils.py index 127a48cc2..cce3c052c 100644 --- a/elasticsearch/dsl/utils.py +++ b/elasticsearch/dsl/utils.py @@ -603,7 +603,7 @@ def to_dict(self, skip_empty: bool = True) -> Dict[str, Any]: # if this is a mapped field, f = self.__get_field(k) if f and f._coerce: - v = f.serialize(v) + v = f.serialize(v, skip_empty=skip_empty) # if someone assigned AttrList, unwrap it if isinstance(v, AttrList): diff --git a/test_elasticsearch/test_dsl/test_integration/_async/test_document.py b/test_elasticsearch/test_dsl/test_integration/_async/test_document.py index 99f475cf1..3d769c606 100644 --- a/test_elasticsearch/test_dsl/test_integration/_async/test_document.py +++ b/test_elasticsearch/test_dsl/test_integration/_async/test_document.py @@ -630,7 +630,9 @@ async def test_can_save_to_different_index( async def test_save_without_skip_empty_will_include_empty_fields( async_write_client: AsyncElasticsearch, ) -> None: - test_repo = Repository(field_1=[], field_2=None, field_3={}, meta={"id": 42}) + test_repo = Repository( + field_1=[], field_2=None, field_3={}, owner={"name": None}, meta={"id": 42} + ) assert await test_repo.save(index="test-document", skip_empty=False) assert_doc_equals( @@ -638,7 +640,12 @@ async def test_save_without_skip_empty_will_include_empty_fields( "found": True, "_index": "test-document", "_id": "42", - "_source": {"field_1": [], "field_2": None, "field_3": {}}, + "_source": { + "field_1": [], + "field_2": None, + "field_3": {}, + "owner": {"name": None}, + }, }, await async_write_client.get(index="test-document", id=42), ) diff --git a/test_elasticsearch/test_dsl/test_integration/_sync/test_document.py b/test_elasticsearch/test_dsl/test_integration/_sync/test_document.py index 05dd05fd9..a005d45bf 100644 --- a/test_elasticsearch/test_dsl/test_integration/_sync/test_document.py +++ b/test_elasticsearch/test_dsl/test_integration/_sync/test_document.py @@ -624,7 +624,9 @@ def test_can_save_to_different_index( def test_save_without_skip_empty_will_include_empty_fields( write_client: Elasticsearch, ) -> None: - test_repo = Repository(field_1=[], field_2=None, field_3={}, meta={"id": 42}) + test_repo = Repository( + field_1=[], field_2=None, field_3={}, owner={"name": None}, meta={"id": 42} + ) assert test_repo.save(index="test-document", skip_empty=False) assert_doc_equals( @@ -632,7 +634,12 @@ def test_save_without_skip_empty_will_include_empty_fields( "found": True, "_index": "test-document", "_id": "42", - "_source": {"field_1": [], "field_2": None, "field_3": {}}, + "_source": { + "field_1": [], + "field_2": None, + "field_3": {}, + "owner": {"name": None}, + }, }, write_client.get(index="test-document", id=42), ) diff --git a/utils/templates/field.py.tpl b/utils/templates/field.py.tpl index 8a4c73f33..8699d852e 100644 --- a/utils/templates/field.py.tpl +++ b/utils/templates/field.py.tpl @@ -119,9 +119,16 @@ class Field(DslBase): def __getitem__(self, subfield: str) -> "Field": return cast(Field, self._params.get("fields", {})[subfield]) - def _serialize(self, data: Any) -> Any: + def _serialize(self, data: Any, skip_empty: bool) -> Any: return data + def _safe_serialize(self, data: Any, skip_empty: bool) -> Any: + try: + return self._serialize(data, skip_empty) + except TypeError: + # older method signature, without skip_empty + return self._serialize(data) # type: ignore[call-arg] + def _deserialize(self, data: Any) -> Any: return data @@ -133,10 +140,10 @@ class Field(DslBase): return AttrList([]) return self._empty() - def serialize(self, data: Any) -> Any: + def serialize(self, data: Any, skip_empty: bool = True) -> Any: if isinstance(data, (list, AttrList, tuple)): - return list(map(self._serialize, cast(Iterable[Any], data))) - return self._serialize(data) + return list(map(self._safe_serialize, cast(Iterable[Any], data), [skip_empty] * len(data))) + return self._safe_serialize(data, skip_empty) def deserialize(self, data: Any) -> Any: if isinstance(data, (list, AttrList, tuple)): @@ -186,7 +193,7 @@ class RangeField(Field): data = {k: self._core_field.deserialize(v) for k, v in data.items()} # type: ignore[union-attr] return Range(data) - def _serialize(self, data: Any) -> Optional[Dict[str, Any]]: + def _serialize(self, data: Any, skip_empty: bool) -> Optional[Dict[str, Any]]: if data is None: return None if not isinstance(data, collections.abc.Mapping): @@ -318,7 +325,7 @@ class {{ k.name }}({{ k.parent }}): return self._wrap(data) def _serialize( - self, data: Optional[Union[Dict[str, Any], "InnerDoc"]] + self, data: Optional[Union[Dict[str, Any], "InnerDoc"]], skip_empty: bool ) -> Optional[Dict[str, Any]]: if data is None: return None @@ -327,7 +334,7 @@ class {{ k.name }}({{ k.parent }}): if isinstance(data, collections.abc.Mapping): return data - return data.to_dict() + return data.to_dict(skip_empty=skip_empty) def clean(self, data: Any) -> Any: data = super().clean(data) @@ -433,7 +440,7 @@ class {{ k.name }}({{ k.parent }}): # the ipaddress library for pypy only accepts unicode. return ipaddress.ip_address(unicode(data)) - def _serialize(self, data: Any) -> Optional[str]: + def _serialize(self, data: Any, skip_empty: bool) -> Optional[str]: if data is None: return None return str(data) @@ -448,7 +455,7 @@ class {{ k.name }}({{ k.parent }}): def _deserialize(self, data: Any) -> bytes: return base64.b64decode(data) - def _serialize(self, data: Any) -> Optional[str]: + def _serialize(self, data: Any, skip_empty: bool) -> Optional[str]: if data is None: return None return base64.b64encode(data).decode() @@ -458,7 +465,7 @@ class {{ k.name }}({{ k.parent }}): def _deserialize(self, data: Any) -> "Query": return Q(data) # type: ignore[no-any-return] - def _serialize(self, data: Any) -> Optional[Dict[str, Any]]: + def _serialize(self, data: Any, skip_empty: bool) -> Optional[Dict[str, Any]]: if data is None: return None return data.to_dict() # type: ignore[no-any-return]