Skip to content

Commit 6bc7246

Browse files
committed
proc-macro added for pyo3-stub-gen to generate stub pyi
1 parent e6d7fa7 commit 6bc7246

File tree

9 files changed

+51
-1
lines changed

9 files changed

+51
-1
lines changed

python/Cargo.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ readme = "README.md"
1111

1212
[lib]
1313
name = "egobox"
14-
crate-type = ["cdylib"]
14+
crate-type = ["cdylib", "rlib"]
1515

1616
[features]
1717
default = []
@@ -48,3 +48,8 @@ serde_json.workspace = true
4848
ctrlc.workspace = true
4949

5050
argmin_testfunctions.workspace = true
51+
pyo3-stub-gen = { path = "../../pyo3-stub-gen/pyo3-stub-gen", features = ["numpy"] }
52+
53+
[[bin]]
54+
name = "stub_gen"
55+
doc = false

python/py.typed

Whitespace-only changes.

python/src/bin/stub_gen.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
use pyo3_stub_gen::Result;
2+
3+
fn main() -> Result<()> {
4+
env_logger::Builder::from_env(env_logger::Env::default().filter_or("RUST_LOG", "info")).init();
5+
let stub = egobox::stub_info()?;
6+
stub.generate()?;
7+
Ok(())
8+
}

python/src/egor.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use ndarray::{concatenate, Array1, Array2, ArrayView2, Axis};
1616
use numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray2, ToPyArray};
1717
use pyo3::exceptions::PyValueError;
1818
use pyo3::prelude::*;
19+
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyfunction, gen_stub_pymethods};
1920

2021
/// Utility function converting `xlimits` float data list specifying bounds of x components
2122
/// to x specified as a list of XType.Float types [egobox.XType]
@@ -25,6 +26,7 @@ use pyo3::prelude::*;
2526
///
2627
/// # Returns
2728
/// xtypes: nx-size list of XSpec(XType(FLOAT), [lower_bound, upper_bounds]) where `nx` is the dimension of x
29+
#[gen_stub_pyfunction]
2830
#[pyfunction]
2931
pub(crate) fn to_specs(py: Python, xlimits: Vec<Vec<f64>>) -> PyResult<PyObject> {
3032
if xlimits.is_empty() || xlimits[0].is_empty() {
@@ -149,6 +151,7 @@ pub(crate) fn to_specs(py: Python, xlimits: Vec<Vec<f64>>) -> PyResult<PyObject>
149151
/// seed (int >= 0)
150152
/// Random generator seed to allow computation reproducibility.
151153
///
154+
#[gen_stub_pyclass]
152155
#[pyclass]
153156
pub(crate) struct Egor {
154157
pub xspecs: PyObject,
@@ -174,6 +177,7 @@ pub(crate) struct Egor {
174177
pub seed: Option<u64>,
175178
}
176179

180+
#[gen_stub_pyclass]
177181
#[pyclass]
178182
pub(crate) struct OptimResult {
179183
#[pyo3(get)]
@@ -186,6 +190,7 @@ pub(crate) struct OptimResult {
186190
y_doe: Py<PyArray2<f64>>,
187191
}
188192

193+
#[gen_stub_pymethods]
189194
#[pymethods]
190195
impl Egor {
191196
#[new]

python/src/gp_mix.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use ndarray::{Array1, Array2, Axis, Ix1, Ix2, Zip};
2222
use ndarray_rand::rand::SeedableRng;
2323
use numpy::{IntoPyArray, PyArray1, PyArray2, PyReadonlyArray2, PyReadonlyArrayDyn};
2424
use pyo3::prelude::*;
25+
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
2526
use rand_xoshiro::Xoshiro256Plus;
2627

2728
/// Gaussian processes mixture builder
@@ -69,6 +70,7 @@ use rand_xoshiro::Xoshiro256Plus;
6970
/// seed (int >= 0)
7071
/// Random generator seed to allow computation reproducibility.
7172
///
73+
#[gen_stub_pyclass]
7274
#[pyclass]
7375
pub(crate) struct GpMix {
7476
pub n_clusters: usize,
@@ -82,6 +84,7 @@ pub(crate) struct GpMix {
8284
pub seed: Option<u64>,
8385
}
8486

87+
#[gen_stub_pymethods]
8588
#[pymethods]
8689
impl GpMix {
8790
#[new]
@@ -218,9 +221,11 @@ impl GpMix {
218221
}
219222

220223
/// A trained Gaussian processes mixture
224+
#[gen_stub_pyclass]
221225
#[pyclass]
222226
pub(crate) struct Gpx(Box<GpMixture>);
223227

228+
#[gen_stub_pymethods]
224229
#[pymethods]
225230
impl Gpx {
226231
/// Get Gaussian processes mixture builder aka `GpMix`

python/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use types::*;
1414

1515
use env_logger::{Builder, Env};
1616
use pyo3::prelude::*;
17+
use pyo3_stub_gen::define_stub_info_gatherer;
1718

1819
#[doc(hidden)]
1920
#[pymodule]
@@ -55,3 +56,6 @@ fn egobox(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
5556

5657
Ok(())
5758
}
59+
60+
// Define a function to gather stub information.
61+
define_stub_info_gatherer!(stub_info);

python/src/sampling.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ use egobox_doe::{LhsKind, SamplingMethod};
33
use egobox_ego::gpmix::mixint::MixintContext;
44
use numpy::{IntoPyArray, PyArray2};
55
use pyo3::prelude::*;
6+
use pyo3_stub_gen::derive::{gen_stub_pyclass_enum, gen_stub_pyfunction};
67

8+
#[gen_stub_pyclass_enum]
79
#[pyclass(eq, eq_int, rename_all = "SCREAMING_SNAKE_CASE")]
810
#[derive(Debug, Clone, Copy, PartialEq)]
911
pub enum Sampling {
@@ -27,6 +29,7 @@ pub enum Sampling {
2729
/// # Returns
2830
/// ndarray of shape (n_samples, n_variables)
2931
///
32+
#[gen_stub_pyfunction]
3033
#[pyfunction]
3134
#[pyo3(signature = (method, xspecs, n_samples, seed=None))]
3235
pub fn sampling(
@@ -89,6 +92,7 @@ pub fn sampling(
8992
/// # Returns
9093
/// ndarray of shape (n_samples, n_variables)
9194
///
95+
#[gen_stub_pyfunction]
9296
#[pyfunction]
9397
#[pyo3(signature = (xspecs, n_samples, seed=None))]
9498
pub(crate) fn lhs(

python/src/sparse_gp_mix.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use ndarray::{Array1, Array2, Axis, Ix1, Ix2, Zip};
2121
use ndarray_rand::rand::SeedableRng;
2222
use numpy::{IntoPyArray, PyArray1, PyArray2, PyReadonlyArray2};
2323
use pyo3::prelude::*;
24+
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
2425
use rand_xoshiro::Xoshiro256Plus;
2526

2627
/// Sparse Gaussian processes mixture builder
@@ -58,6 +59,7 @@ use rand_xoshiro::Xoshiro256Plus;
5859
/// seed (int >= 0)
5960
/// Random generator seed to allow computation reproducibility.
6061
///
62+
#[gen_stub_pyclass]
6163
#[pyclass]
6264
pub(crate) struct SparseGpMix {
6365
pub correlation_spec: CorrelationSpec,
@@ -71,6 +73,7 @@ pub(crate) struct SparseGpMix {
7173
pub seed: Option<u64>,
7274
}
7375

76+
#[gen_stub_pymethods]
7477
#[pymethods]
7578
impl SparseGpMix {
7679
#[new]
@@ -216,9 +219,11 @@ impl SparseGpMix {
216219
}
217220

218221
/// A trained Gaussian processes mixture
222+
#[gen_stub_pyclass]
219223
#[pyclass]
220224
pub(crate) struct SparseGpx(Box<GpMixture>);
221225

226+
#[gen_stub_pymethods]
222227
#[pymethods]
223228
impl SparseGpx {
224229
/// Get Gaussian processes mixture builder aka `GpSparse`

python/src/types.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
use pyo3::prelude::*;
2+
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods};
23

4+
#[gen_stub_pyclass_enum]
35
#[pyclass(eq, eq_int, rename_all = "UPPERCASE")]
46
#[derive(Debug, Clone, PartialEq)]
57
pub enum Recombination {
@@ -12,10 +14,12 @@ pub enum Recombination {
1214
Smooth = 1,
1315
}
1416

17+
#[gen_stub_pyclass]
1518
#[pyclass]
1619
#[derive(Clone)]
1720
pub(crate) struct RegressionSpec(pub(crate) u8);
1821

22+
#[gen_stub_pymethods]
1923
#[pymethods]
2024
impl RegressionSpec {
2125
#[classattr]
@@ -28,10 +32,12 @@ impl RegressionSpec {
2832
pub(crate) const QUADRATIC: u8 = egobox_moe::RegressionSpec::QUADRATIC.bits();
2933
}
3034

35+
#[gen_stub_pyclass]
3136
#[pyclass]
3237
#[derive(Clone)]
3338
pub(crate) struct CorrelationSpec(pub(crate) u8);
3439

40+
#[gen_stub_pymethods]
3541
#[pymethods]
3642
impl CorrelationSpec {
3743
#[classattr]
@@ -48,6 +54,7 @@ impl CorrelationSpec {
4854
pub(crate) const MATERN52: u8 = egobox_moe::CorrelationSpec::MATERN52.bits();
4955
}
5056

57+
#[gen_stub_pyclass_enum]
5158
#[pyclass(eq, eq_int, rename_all = "UPPERCASE")]
5259
#[derive(Debug, Clone, Copy, PartialEq)]
5360
pub(crate) enum InfillStrategy {
@@ -56,6 +63,7 @@ pub(crate) enum InfillStrategy {
5663
Wb2s = 3,
5764
}
5865

66+
#[gen_stub_pyclass_enum]
5967
#[pyclass(eq, eq_int, rename_all = "UPPERCASE")]
6068
#[derive(Debug, Clone, Copy, PartialEq)]
6169
pub(crate) enum ParInfillStrategy {
@@ -65,13 +73,15 @@ pub(crate) enum ParInfillStrategy {
6573
Clmin = 4,
6674
}
6775

76+
#[gen_stub_pyclass_enum]
6877
#[pyclass(eq, eq_int, rename_all = "UPPERCASE")]
6978
#[derive(Debug, Clone, Copy, PartialEq)]
7079
pub(crate) enum InfillOptimizer {
7180
Cobyla = 1,
7281
Slsqp = 2,
7382
}
7483

84+
#[gen_stub_pyclass]
7585
#[pyclass]
7686
#[derive(Clone, Copy)]
7787
pub(crate) struct ExpectedOptimum {
@@ -93,6 +103,7 @@ impl ExpectedOptimum {
93103
}
94104
}
95105

106+
#[gen_stub_pyclass_enum]
96107
#[pyclass(eq, eq_int, rename_all = "UPPERCASE")]
97108
#[derive(Clone, Copy, Debug, PartialEq)]
98109
pub(crate) enum XType {
@@ -102,6 +113,7 @@ pub(crate) enum XType {
102113
Enum = 4,
103114
}
104115

116+
#[gen_stub_pyclass]
105117
#[pyclass]
106118
#[derive(FromPyObject, Debug)]
107119
pub(crate) struct XSpec {
@@ -113,6 +125,7 @@ pub(crate) struct XSpec {
113125
pub(crate) tags: Vec<String>,
114126
}
115127

128+
#[gen_stub_pymethods]
116129
#[pymethods]
117130
impl XSpec {
118131
#[new]
@@ -126,6 +139,7 @@ impl XSpec {
126139
}
127140
}
128141

142+
#[gen_stub_pyclass_enum]
129143
#[pyclass(eq, eq_int, rename_all = "UPPERCASE")]
130144
#[derive(Debug, Clone, Copy, PartialEq)]
131145
pub(crate) enum SparseMethod {

0 commit comments

Comments
 (0)