diff --git a/src/de.rs b/src/de.rs index c6a171f..8a839d2 100644 --- a/src/de.rs +++ b/src/de.rs @@ -1,3 +1,4 @@ +use pyo3::intern; use pyo3::types::*; use serde::de::{self, IntoDeserializer}; use serde::Deserialize; @@ -9,17 +10,30 @@ pub fn depythonize<'de, T>(obj: &'de PyAny) -> Result where T: Deserialize<'de>, { - let mut depythonizer = Depythonizer::from_object(obj); + let mut depythonizer = Depythonizer::from_object(obj, &None); + T::deserialize(&mut depythonizer) +} + +/// Attempt to convert a Python object to an instance of `T` +pub fn depythonize_on_error<'de, T>( + obj: &'de PyAny, + on_error: &'de Option &PyAny>, +) -> Result +where + T: Deserialize<'de>, +{ + let mut depythonizer = Depythonizer::from_object(obj, on_error); T::deserialize(&mut depythonizer) } pub struct Depythonizer<'de> { input: &'de PyAny, + on_error: &'de Option &PyAny>, } impl<'de> Depythonizer<'de> { - pub fn from_object(input: &'de PyAny) -> Self { - Depythonizer { input } + pub fn from_object(input: &'de PyAny, on_error: &'de Option &PyAny>) -> Self { + Depythonizer { input, on_error } } fn sequence_access(&self, expected_len: Option) -> Result> { @@ -30,12 +44,12 @@ impl<'de> Depythonizer<'de> { Some(expected) if expected != len => { Err(PythonizeError::incorrect_sequence_length(expected, len)) } - _ => Ok(PySequenceAccess::new(seq, len)), + _ => Ok(PySequenceAccess::new(seq, len, self.on_error)), } } fn dict_access(&self) -> Result> { - PyMappingAccess::new(self.input.downcast()?) + PyMappingAccess::new(self.input.downcast()?, self.on_error) } } @@ -63,22 +77,16 @@ impl<'a, 'de> de::Deserializer<'de> for &'a mut Depythonizer<'de> { self.deserialize_unit(visitor) } else if obj.is_instance_of::()? { self.deserialize_bool(visitor) - } else if obj.is_instance_of::()? || obj.is_instance_of::()? { - self.deserialize_bytes(visitor) } else if obj.is_instance_of::()? { self.deserialize_map(visitor) } else if obj.is_instance_of::()? { self.deserialize_f64(visitor) - } else if obj.is_instance_of::()? { - self.deserialize_tuple(obj.len()?, visitor) } else if obj.is_instance_of::()? { self.deserialize_i64(visitor) } else if obj.is_instance_of::()? { self.deserialize_tuple(obj.len()?, visitor) } else if obj.is_instance_of::()? { self.deserialize_i64(visitor) - } else if obj.is_instance_of::()? { - self.deserialize_tuple(obj.len()?, visitor) } else if obj.is_instance_of::()? { self.deserialize_str(visitor) } else if obj.is_instance_of::()? { @@ -90,9 +98,17 @@ impl<'a, 'de> de::Deserializer<'de> for &'a mut Depythonizer<'de> { } else if obj.downcast::().is_ok() { self.deserialize_map(visitor) } else { - Err(PythonizeError::unsupported_type( - obj.get_type().name().unwrap_or(""), - )) + match self.on_error { + Some(on_error) => { + // Do we recurse infinitely if on_error returns something we can't + // deserialize? + self.input = on_error(self.input); + self.deserialize_any(visitor) + } + None => Err(PythonizeError::unsupported_type( + obj.get_type().name().unwrap_or(""), + )), + } } } @@ -259,7 +275,7 @@ impl<'a, 'de> de::Deserializer<'de> for &'a mut Depythonizer<'de> { .downcast() .map_err(|_| PythonizeError::dict_key_not_string())?; let value = d.get_item(variant).unwrap(); - let mut de = Depythonizer::from_object(value); + let mut de = Depythonizer::from_object(value, self.on_error); visitor.visit_enum(PyEnumAccess::new(&mut de, variant)) } else if item.is_instance_of::()? { let s: &PyString = self.input.downcast()?; @@ -292,11 +308,17 @@ struct PySequenceAccess<'a> { seq: &'a PySequence, index: usize, len: usize, + on_error: &'a Option &PyAny>, } impl<'a> PySequenceAccess<'a> { - fn new(seq: &'a PySequence, len: usize) -> Self { - Self { seq, index: 0, len } + fn new(seq: &'a PySequence, len: usize, on_error: &'a Option &PyAny>) -> Self { + Self { + seq, + index: 0, + len, + on_error, + } } } @@ -308,7 +330,8 @@ impl<'de> de::SeqAccess<'de> for PySequenceAccess<'de> { T: de::DeserializeSeed<'de>, { if self.index < self.len { - let mut item_de = Depythonizer::from_object(self.seq.get_item(self.index)?); + let mut item_de = + Depythonizer::from_object(self.seq.get_item(self.index)?, self.on_error); self.index += 1; seed.deserialize(&mut item_de).map(Some) } else { @@ -323,10 +346,11 @@ struct PyMappingAccess<'de> { key_idx: usize, val_idx: usize, len: usize, + on_error: &'de Option &PyAny>, } impl<'de> PyMappingAccess<'de> { - fn new(map: &'de PyMapping) -> Result { + fn new(map: &'de PyMapping, on_error: &'de Option &PyAny>) -> Result { let keys = map.keys()?; let values = map.values()?; let len = map.len()?; @@ -336,6 +360,7 @@ impl<'de> PyMappingAccess<'de> { key_idx: 0, val_idx: 0, len, + on_error, }) } } @@ -348,7 +373,18 @@ impl<'de> de::MapAccess<'de> for PyMappingAccess<'de> { K: de::DeserializeSeed<'de>, { if self.key_idx < self.len { - let mut item_de = Depythonizer::from_object(self.keys.get_item(self.key_idx)?); + let key = self.keys.get_item(self.key_idx)?; + let py = key.py(); + let key = match key { + key if key.is_instance_of::()? => key, + key if key.is_none() => intern!(py, "null"), + key if key.is_instance_of::()? => match key.is_true()? { + true => intern!(py, "true"), + false => intern!(py, "false"), + }, + key => key.str()?, + }; + let mut item_de = Depythonizer::from_object(key, self.on_error); self.key_idx += 1; seed.deserialize(&mut item_de).map(Some) } else { @@ -360,7 +396,8 @@ impl<'de> de::MapAccess<'de> for PyMappingAccess<'de> { where V: de::DeserializeSeed<'de>, { - let mut item_de = Depythonizer::from_object(self.values.get_item(self.val_idx)?); + let mut item_de = + Depythonizer::from_object(self.values.get_item(self.val_idx)?, self.on_error); self.val_idx += 1; seed.deserialize(&mut item_de) } diff --git a/src/lib.rs b/src/lib.rs index 139f684..9c7b609 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -40,7 +40,7 @@ mod de; mod error; mod ser; -pub use crate::de::depythonize; +pub use crate::de::{depythonize, depythonize_on_error}; pub use crate::error::{PythonizeError, Result}; pub use crate::ser::{ pythonize, pythonize_custom, PythonizeDictType, PythonizeListType, PythonizeTypes,