Skip to content

Commit faa22a8

Browse files
Avoid IntoPyObjectExt::into_bound_py_any for None value (#188)
Modifications: + Signature-generator can avoid `IntoPyObjectExt::into_bound_py_any` for `None` value. It's useful for the external types that only impl FromPyObject, such as cases in relf/egobox#237 ```rust #[gen_stub_pyfunction] #[pyfunction(signature = (array=None) )] fn my_func(array: Option<numpy::PyReadonlyArray2<f64>>) {} ``` + Add missed newline before function's doc + [support Enum's variant document](cd10de3) + [signature support external repr](a31991a) When `eval(repr(py_value))` == `py_value`, we can use repr for external types Before: ```python def default_value(num:Number=...) -> Number: ... ``` After ```python def default_value(num:Number=Number.FLOAT) -> Number: ... ``` Thx! --------- Co-authored-by: Toshiki Teramura <[email protected]>
1 parent cac5a60 commit faa22a8

File tree

8 files changed

+121
-49
lines changed

8 files changed

+121
-49
lines changed

examples/pure/pure.pyi

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ class Number(Enum):
4040

4141
class NumberRenameAll(Enum):
4242
FLOAT = ...
43+
r"""
44+
Float variant
45+
"""
4346
INTEGER = ...
4447

4548
def ahash_dict() -> builtins.dict[builtins.str, builtins.int]: ...
@@ -48,17 +51,21 @@ def create_a(x:builtins.int=2) -> A: ...
4851

4952
def create_dict(n:builtins.int) -> builtins.dict[builtins.int, builtins.list[builtins.int]]: ...
5053

51-
def default_value(num:Number=...) -> Number: ...
54+
def default_value(num:Number=Number.FLOAT) -> Number: ...
5255

5356
def echo_path(path:builtins.str | os.PathLike | pathlib.Path) -> pathlib.Path: ...
5457

58+
def print_c(c:typing.Optional[builtins.int]=None) -> None: ...
59+
5560
def read_dict(dict:typing.Mapping[builtins.int, typing.Mapping[builtins.int, builtins.int]]) -> None: ...
5661

57-
def str_len(x:builtins.str) -> builtins.int: r"""
62+
def str_len(x:builtins.str) -> builtins.int:
63+
r"""
5864
Returns the length of the string.
5965
"""
6066

61-
def sum(v:typing.Sequence[builtins.int]) -> builtins.int: r"""
67+
def sum(v:typing.Sequence[builtins.int]) -> builtins.int:
68+
r"""
6269
Returns the sum of two numbers as a string.
6370
"""
6471

examples/pure/src/lib.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,31 @@ fn create_a(x: usize) -> A {
7575
#[derive(Debug)]
7676
struct B;
7777

78+
/// `C` only impl `FromPyObject`
79+
#[derive(Debug)]
80+
struct C {
81+
x: usize,
82+
}
83+
#[gen_stub_pyfunction]
84+
#[pyfunction(signature = (c=None))]
85+
fn print_c(c: Option<C>) {
86+
if let Some(c) = c {
87+
println!("{}", c.x);
88+
} else {
89+
println!("None");
90+
}
91+
}
92+
impl FromPyObject<'_> for C {
93+
fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
94+
Ok(C { x: ob.extract()? })
95+
}
96+
}
97+
impl pyo3_stub_gen::PyStubType for C {
98+
fn type_output() -> pyo3_stub_gen::TypeInfo {
99+
usize::type_output()
100+
}
101+
}
102+
78103
create_exception!(pure, MyError, PyRuntimeError);
79104

80105
/// Returns the length of the string.
@@ -115,6 +140,7 @@ pub enum Number {
115140
#[pyo3(rename_all = "UPPERCASE")]
116141
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
117142
pub enum NumberRenameAll {
143+
/// Float variant
118144
Float,
119145
Integer,
120146
}
@@ -159,6 +185,7 @@ fn pure(m: &Bound<PyModule>) -> PyResult<()> {
159185
m.add_function(wrap_pyfunction!(create_dict, m)?)?;
160186
m.add_function(wrap_pyfunction!(read_dict, m)?)?;
161187
m.add_function(wrap_pyfunction!(create_a, m)?)?;
188+
m.add_function(wrap_pyfunction!(print_c, m)?)?;
162189
m.add_function(wrap_pyfunction!(str_len, m)?)?;
163190
m.add_function(wrap_pyfunction!(echo_path, m)?)?;
164191
m.add_function(wrap_pyfunction!(ahash_dict, m)?)?;

pyo3-stub-gen-derive/src/gen_stub/pyclass_enum.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ pub struct PyEnumInfo {
88
pyclass_name: String,
99
enum_type: Type,
1010
module: Option<String>,
11-
variants: Vec<String>,
11+
variants: Vec<(String, String)>,
1212
doc: String,
1313
}
1414

@@ -54,7 +54,7 @@ impl TryFrom<ItemEnum> for PyEnumInfo {
5454
let pyclass_name = pyclass_name.unwrap_or_else(|| ident.to_string());
5555
let variants = variants
5656
.into_iter()
57-
.map(|var| -> Result<String> {
57+
.map(|var| -> Result<(String, String)> {
5858
let mut var_name = None;
5959
for attr in parse_pyo3_attrs(&var.attrs)? {
6060
if let Attr::Name(name) = attr {
@@ -65,9 +65,10 @@ impl TryFrom<ItemEnum> for PyEnumInfo {
6565
if let Some(renaming_rule) = renaming_rule {
6666
var_name = renaming_rule.apply(&var_name);
6767
}
68-
Ok(var_name)
68+
let var_doc = extract_documents(&var.attrs).join("\n");
69+
Ok((var_name, var_doc))
6970
})
70-
.collect::<Result<Vec<String>>>()?;
71+
.collect::<Result<Vec<(String, String)>>>()?;
7172
Ok(Self {
7273
doc,
7374
enum_type: struct_type,
@@ -88,6 +89,10 @@ impl ToTokens for PyEnumInfo {
8889
module,
8990
} = self;
9091
let module = quote_option(module);
92+
let variants: Vec<_> = variants
93+
.iter()
94+
.map(|(name, doc)| quote! {(#name,#doc)})
95+
.collect();
9196
tokens.append_all(quote! {
9297
::pyo3_stub_gen::type_info::PyEnumInfo {
9398
pyclass_name: #pyclass_name,

pyo3-stub-gen-derive/src/gen_stub/signature.rs

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -95,22 +95,27 @@ impl ToTokens for ArgsWithSignature<'_> {
9595
SignatureArg::Assign(ident, _eq, value) => {
9696
let name = ident.to_string();
9797
let ty = args_map.get(&name).unwrap();
98+
let default = if value.to_token_stream().to_string() == "None" {
99+
quote! {
100+
"None".to_string()
101+
}
102+
} else {
103+
quote! {
104+
::pyo3::prepare_freethreaded_python();
105+
::pyo3::Python::with_gil(|py| -> String {
106+
let v: #ty = #value;
107+
::pyo3_stub_gen::util::fmt_py_obj(py, v)
108+
})
109+
}
110+
};
98111
quote! {
99112
::pyo3_stub_gen::type_info::ArgInfo {
100113
name: #name,
101114
r#type: <#ty as ::pyo3_stub_gen::PyStubType>::type_input,
102115
signature: Some(pyo3_stub_gen::type_info::SignatureArg::Assign{
103116
default: {
104117
static DEFAULT: std::sync::LazyLock<String> = std::sync::LazyLock::new(|| {
105-
::pyo3::prepare_freethreaded_python();
106-
::pyo3::Python::with_gil(|py| -> String {
107-
let v: #ty = #value;
108-
if let Ok(py_obj) = <#ty as ::pyo3::IntoPyObjectExt>::into_bound_py_any(v, py) {
109-
::pyo3_stub_gen::util::fmt_py_obj(&py_obj)
110-
} else {
111-
"...".to_owned()
112-
}
113-
})
118+
#default
114119
});
115120
&DEFAULT
116121
}

pyo3-stub-gen/src/generate/enum_.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use std::fmt;
66
pub struct EnumDef {
77
pub name: &'static str,
88
pub doc: &'static str,
9-
pub variants: &'static [&'static str],
9+
pub variants: &'static [(&'static str, &'static str)],
1010
pub methods: Vec<MethodDef>,
1111
pub members: Vec<MemberDef>,
1212
}
@@ -28,8 +28,9 @@ impl fmt::Display for EnumDef {
2828
writeln!(f, "class {}(Enum):", self.name)?;
2929
let indent = indent();
3030
docstring::write_docstring(f, self.doc, indent)?;
31-
for variants in self.variants {
32-
writeln!(f, "{indent}{} = ...", variants)?;
31+
for (variant, variant_doc) in self.variants {
32+
writeln!(f, "{indent}{} = ...", variant)?;
33+
docstring::write_docstring(f, variant_doc, indent)?;
3334
}
3435
for member in &self.members {
3536
writeln!(f)?;

pyo3-stub-gen/src/generate/function.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ impl fmt::Display for FunctionDef {
4444

4545
let doc = self.doc;
4646
if !doc.is_empty() {
47+
writeln!(f)?;
4748
docstring::write_docstring(f, self.doc, indent())?;
4849
} else {
4950
writeln!(f, " ...")?;

pyo3-stub-gen/src/type_info.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,8 @@ pub struct PyEnumInfo {
136136
pub module: Option<&'static str>,
137137
/// Docstring
138138
pub doc: &'static str,
139-
/// Variants of enum
140-
pub variants: &'static [&'static str],
139+
/// Variants of enum (name, doc)
140+
pub variants: &'static [(&'static str, &'static str)],
141141
}
142142

143143
inventory::collect!(PyEnumInfo);

pyo3-stub-gen/src/util.rs

Lines changed: 54 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use pyo3::{prelude::*, types::*};
2+
use std::ffi::CString;
23

34
pub fn all_builtin_types(any: &Bound<'_, PyAny>) -> bool {
45
if any.is_instance_of::<PyString>()
@@ -33,19 +34,37 @@ pub fn all_builtin_types(any: &Bound<'_, PyAny>) -> bool {
3334
false
3435
}
3536

36-
pub fn fmt_py_obj(any: &Bound<'_, PyAny>) -> String {
37-
if all_builtin_types(any) {
38-
if let Ok(py_str) = any.repr() {
39-
return py_str.to_string();
37+
/// whether eval(repr(any)) == any
38+
pub fn valid_external_repr(any: &Bound<'_, PyAny>) -> Option<bool> {
39+
let globals = get_globals(any).ok()?;
40+
let fmt_str = any.repr().ok()?.to_string();
41+
let fmt_cstr = CString::new(fmt_str.clone()).ok()?;
42+
let new_any = any.py().eval(&fmt_cstr, Some(&globals), None).ok()?;
43+
new_any.eq(any).ok()
44+
}
45+
46+
fn get_globals<'py>(any: &Bound<'py, PyAny>) -> PyResult<Bound<'py, PyDict>> {
47+
let type_object = any.get_type();
48+
let type_name = type_object.getattr("__name__")?;
49+
let type_name: &str = type_name.extract()?;
50+
let globals = PyDict::new(any.py());
51+
globals.set_item(type_name, type_object)?;
52+
Ok(globals)
53+
}
54+
55+
pub fn fmt_py_obj<'py, T: pyo3::IntoPyObjectExt<'py>>(py: Python<'py>, obj: T) -> String {
56+
if let Ok(any) = obj.into_bound_py_any(py) {
57+
if all_builtin_types(&any) || valid_external_repr(&any).is_some_and(|valid| valid) {
58+
if let Ok(py_str) = any.repr() {
59+
return py_str.to_string();
60+
}
4061
}
4162
}
4263
"...".to_owned()
4364
}
4465

4566
#[cfg(test)]
4667
mod test {
47-
use pyo3::IntoPyObjectExt;
48-
4968
use super::*;
5069
#[pyclass]
5170
#[derive(Debug)]
@@ -57,62 +76,69 @@ mod test {
5776
let dict = PyDict::new(py);
5877
_ = dict.set_item("k1", "v1");
5978
_ = dict.set_item("k2", 2);
60-
assert_eq!("{'k1': 'v1', 'k2': 2}", fmt_py_obj(&dict));
79+
assert_eq!("{'k1': 'v1', 'k2': 2}", fmt_py_obj(py, &dict));
6180
// class A variable can not be formatted
6281
_ = dict.set_item("k3", A {});
63-
assert_eq!("...", fmt_py_obj(&dict));
82+
assert_eq!("...", fmt_py_obj(py, &dict));
6483
})
6584
}
6685
#[test]
6786
fn test_fmt_list() {
6887
pyo3::prepare_freethreaded_python();
6988
Python::with_gil(|py| {
7089
let list = PyList::new(py, [1, 2]).unwrap();
71-
assert_eq!("[1, 2]", fmt_py_obj(&list));
90+
assert_eq!("[1, 2]", fmt_py_obj(py, &list));
7291
// class A variable can not be formatted
7392
let list = PyList::new(py, [A {}, A {}]).unwrap();
74-
assert_eq!("...", fmt_py_obj(&list));
93+
assert_eq!("...", fmt_py_obj(py, &list));
7594
})
7695
}
7796
#[test]
7897
fn test_fmt_tuple() {
7998
pyo3::prepare_freethreaded_python();
8099
Python::with_gil(|py| {
81100
let tuple = PyTuple::new(py, [1, 2]).unwrap();
82-
assert_eq!("(1, 2)", fmt_py_obj(&tuple));
101+
assert_eq!("(1, 2)", fmt_py_obj(py, tuple));
83102
let tuple = PyTuple::new(py, [1]).unwrap();
84-
assert_eq!("(1,)", fmt_py_obj(&tuple));
103+
assert_eq!("(1,)", fmt_py_obj(py, tuple));
85104
// class A variable can not be formatted
86105
let tuple = PyTuple::new(py, [A {}]).unwrap();
87-
assert_eq!("...", fmt_py_obj(&tuple));
106+
assert_eq!("...", fmt_py_obj(py, tuple));
88107
})
89108
}
90109
#[test]
91110
fn test_fmt_other() {
92111
pyo3::prepare_freethreaded_python();
93112
Python::with_gil(|py| {
94113
// str
95-
assert_eq!("'123'", fmt_py_obj(&"123".into_bound_py_any(py).unwrap()));
96-
assert_eq!(
97-
"\"don't\"",
98-
fmt_py_obj(&"don't".into_bound_py_any(py).unwrap())
99-
);
100-
assert_eq!(
101-
"'str\\\\'",
102-
fmt_py_obj(&"str\\".into_bound_py_any(py).unwrap())
103-
);
114+
assert_eq!("'123'", fmt_py_obj(py, &"123"));
115+
assert_eq!("\"don't\"", fmt_py_obj(py, &"don't"));
116+
assert_eq!("'str\\\\'", fmt_py_obj(py, &"str\\"));
104117
// bool
105-
assert_eq!("True", fmt_py_obj(&true.into_bound_py_any(py).unwrap()));
106-
assert_eq!("False", fmt_py_obj(&false.into_bound_py_any(py).unwrap()));
118+
assert_eq!("True", fmt_py_obj(py, true));
119+
assert_eq!("False", fmt_py_obj(py, false));
107120
// int
108-
assert_eq!("123", fmt_py_obj(&123.into_bound_py_any(py).unwrap()));
121+
assert_eq!("123", fmt_py_obj(py, 123));
109122
// float
110-
assert_eq!("1.23", fmt_py_obj(&1.23.into_bound_py_any(py).unwrap()));
123+
assert_eq!("1.23", fmt_py_obj(py, 1.23));
111124
// None
112125
let none: Option<usize> = None;
113-
assert_eq!("None", fmt_py_obj(&none.into_bound_py_any(py).unwrap()));
126+
assert_eq!("None", fmt_py_obj(py, none));
114127
// class A variable can not be formatted
115-
assert_eq!("...", fmt_py_obj(&A {}.into_bound_py_any(py).unwrap()));
128+
assert_eq!("...", fmt_py_obj(py, A {}));
116129
})
117130
}
131+
#[test]
132+
fn test_fmt_enum() {
133+
#[pyclass(eq, eq_int)]
134+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
135+
pub enum Number {
136+
Float,
137+
Integer,
138+
}
139+
pyo3::prepare_freethreaded_python();
140+
Python::with_gil(|py| {
141+
assert_eq!("Number.Float", fmt_py_obj(py, Number::Float));
142+
});
143+
}
118144
}

0 commit comments

Comments
 (0)