From 9af2ec4dcfac555d54b8507ea67c1ff396ec809f Mon Sep 17 00:00:00 2001 From: Shane Date: Wed, 22 Mar 2023 14:18:18 -0700 Subject: [PATCH] Relax input requirements in deserialize_str Instead of attempting to downcast the input to a PyString, rely on the fact that for str objects, calling __str__ is a no-op, while for other objects it may allow us to support neat use cases. An example of one such use case is turning Python's uuid.UUID into Rust's uuid::Uuid. If you call Uuid::depythonize you end up in ::deserialize (if you enabled its serde feature). Since Depythonizer doesn't override Deserialize::is_human_readable, the implementation tries to deserialize with deserialize_str. At that point, the Deserializer has a value of type uuid.UUID (actually PyAny, but it really is uuid.UUID), and the visitor expects a value of type &str. We can make that happen by using __str__. --- Cargo.toml | 1 + src/de.rs | 37 ++++++++++++++++++++++++++++++++++--- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 97816a1..1868743 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,3 +19,4 @@ serde = { version = "1.0", default-features = false, features = ["derive"] } pyo3 = { version = "0.18.0", default-features = false, features = ["auto-initialize", "macros"] } serde_json = "1.0" maplit = "1.0.2" +uuid = { version = "1.3.0", features = ["serde", "v4"]} diff --git a/src/de.rs b/src/de.rs index c6a171f..621b883 100644 --- a/src/de.rs +++ b/src/de.rs @@ -129,7 +129,7 @@ impl<'a, 'de> de::Deserializer<'de> for &'a mut Depythonizer<'de> { where V: de::Visitor<'de>, { - let s: &PyString = self.input.downcast()?; + let s: &PyString = self.input.str()?; visitor.visit_str(s.to_str()?) } @@ -428,15 +428,38 @@ mod test { use maplit::hashmap; use pyo3::Python; use serde_json::{json, Value as JsonValue}; + use uuid::Uuid; fn test_de(code: &str, expected: &T, expected_json: &JsonValue) where T: de::DeserializeOwned + PartialEq + std::fmt::Debug, { + test_de_with_imports(code, expected, expected_json, &[]); + } + + fn test_de_with_imports( + code: &str, + expected: &T, + expected_json: &JsonValue, + imports: &[&str], + ) where + T: de::DeserializeOwned + PartialEq + std::fmt::Debug, + { + let imports: Vec<_> = imports + .iter() + .map(|module| format!("import {}", module)) + .collect(); + + let import_statements = imports.join("\n"); Python::with_gil(|py| { let locals = PyDict::new(py); - py.run(&format!("obj = {}", code), None, Some(locals)) - .unwrap(); + py.run( + &format!("{}\nobj = {}", import_statements, code), + None, + Some(locals), + ) + .unwrap(); + let obj = locals.get_item("obj").unwrap(); let actual: T = depythonize(obj).unwrap(); assert_eq!(&actual, expected); @@ -713,4 +736,12 @@ mod test { let code = "{'name': 'SomeFoo', 'bar': {'value': 13, 'variant': {'Tuple': [-1.5, 8]}}}"; test_de(code, &expected, &expected_json); } + + #[test] + fn test_uuid() { + let expected = Uuid::new_v4(); + let expected_json = json!(expected.to_string().replace("-", "")); + let code = format!("uuid.UUID('{}')", expected.to_string()); + test_de_with_imports(&code, &expected, &expected_json, &["uuid"]); + } }