diff --git a/CHANGELOG.md b/CHANGELOG.md index 7544af4..30ff327 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ ### Fixed - Fix overflow error attempting to depythonize `u64` values greater than `i64::MAX` to types like `serde_json::Value` +- Fix deserializing `set` and `frozenset` into Rust homogeneous containers ## 0.21.1 - 2024-04-02 diff --git a/src/de.rs b/src/de.rs index b1369c6..8b65825 100644 --- a/src/de.rs +++ b/src/de.rs @@ -2,7 +2,7 @@ use pyo3::{types::*, Bound}; use serde::de::{self, DeserializeOwned, IntoDeserializer}; use serde::Deserialize; -use crate::error::{PythonizeError, Result}; +use crate::error::{ErrorImpl, PythonizeError, Result}; /// Attempt to convert a Python object to an instance of `T` pub fn depythonize<'a, 'py, T>(obj: &'a Bound<'py, PyAny>) -> Result @@ -44,6 +44,19 @@ impl<'a, 'py> Depythonizer<'a, 'py> { } } + fn set_access(&self) -> Result> { + match self.input.downcast::() { + Ok(set) => Ok(PySetAsSequence::from_set(&set)), + Err(e) => { + if let Ok(f) = self.input.downcast::() { + Ok(PySetAsSequence::from_frozenset(&f)) + } else { + Err(e.into()) + } + } + } + } + fn dict_access(&self) -> Result> { PyMappingAccess::new(self.input.downcast()?) } @@ -122,10 +135,9 @@ impl<'de> de::Deserializer<'de> for &'_ mut Depythonizer<'_, '_> { self.deserialize_bytes(visitor) } else if obj.is_instance_of::() { self.deserialize_f64(visitor) - } else if obj.is_instance_of::() - || obj.is_instance_of::() - || obj.downcast::().is_ok() - { + } else if obj.is_instance_of::() || obj.is_instance_of::() { + self.deserialize_seq(visitor) + } else if obj.downcast::().is_ok() { self.deserialize_tuple(obj.len()?, visitor) } else if obj.downcast::().is_ok() { self.deserialize_map(visitor) @@ -238,7 +250,18 @@ impl<'de> de::Deserializer<'de> for &'_ mut Depythonizer<'_, '_> { where V: de::Visitor<'de>, { - visitor.visit_seq(self.sequence_access(None)?) + match self.sequence_access(None) { + Ok(seq) => visitor.visit_seq(seq), + Err(e) => { + // we allow sets to be deserialized as sequences, so try that + if matches!(*e.inner, ErrorImpl::UnexpectedType(_)) { + if let Ok(set) = self.set_access() { + return visitor.visit_seq(set); + } + } + Err(e) + } + } } fn deserialize_tuple(self, len: usize, visitor: V) -> Result @@ -357,6 +380,40 @@ impl<'de> de::SeqAccess<'de> for PySequenceAccess<'_, '_> { } } +struct PySetAsSequence<'py> { + iter: Bound<'py, PyIterator>, +} + +impl<'py> PySetAsSequence<'py> { + fn from_set(set: &Bound<'py, PySet>) -> Self { + Self { + iter: PyIterator::from_bound_object(&set).expect("set is always iterable"), + } + } + + fn from_frozenset(set: &Bound<'py, PyFrozenSet>) -> Self { + Self { + iter: PyIterator::from_bound_object(&set).expect("frozenset is always iterable"), + } + } +} + +impl<'de> de::SeqAccess<'de> for PySetAsSequence<'_> { + type Error = PythonizeError; + + fn next_element_seed(&mut self, seed: T) -> Result> + where + T: de::DeserializeSeed<'de>, + { + match self.iter.next() { + Some(item) => seed + .deserialize(&mut Depythonizer::from_object(&item?)) + .map(Some), + None => Ok(None), + } + } +} + struct PyMappingAccess<'py> { keys: Bound<'py, PySequence>, values: Bound<'py, PySequence>, @@ -606,6 +663,22 @@ mod test { test_de(code, &expected, &expected_json); } + #[test] + fn test_vec_from_pyset() { + let expected = vec!["foo".to_string()]; + let expected_json = json!(["foo"]); + let code = "{'foo'}"; + test_de(code, &expected, &expected_json); + } + + #[test] + fn test_vec_from_pyfrozenset() { + let expected = vec!["foo".to_string()]; + let expected_json = json!(["foo"]); + let code = "frozenset({'foo'})"; + test_de(code, &expected, &expected_json); + } + #[test] fn test_vec() { let expected = vec![3, 2, 1]; diff --git a/src/error.rs b/src/error.rs index 4aee7ea..9aa5a87 100644 --- a/src/error.rs +++ b/src/error.rs @@ -32,6 +32,15 @@ impl PythonizeError { } } + pub(crate) fn unexpected_type(t: T) -> Self + where + T: ToString, + { + Self { + inner: Box::new(ErrorImpl::UnexpectedType(t.to_string())), + } + } + pub(crate) fn dict_key_not_string() -> Self { Self { inner: Box::new(ErrorImpl::DictKeyNotString),