2
2
"""Utility function for variable selection and bart interpretability."""
3
3
4
4
import warnings
5
- from typing import Any , Callable , Optional , Union
5
+ from collections .abc import Callable
6
+ from typing import Any , TypeVar
6
7
7
8
import matplotlib .pyplot as plt
8
9
import numpy as np
18
19
19
20
from .tree import Tree
20
21
21
- TensorLike = Union [ npt .NDArray , pt .TensorVariable ]
22
+ TensorLike = TypeVar ( "TensorLike" , npt .NDArray , pt .TensorVariable )
22
23
23
24
24
25
def _sample_posterior (
25
26
all_trees : list [list [Tree ]],
26
27
X : TensorLike ,
27
28
rng : np .random .Generator ,
28
- size : Optional [ Union [ int , tuple [int , ...]]] = None ,
29
- excluded : Optional [ list [int ]] = None ,
29
+ size : int | tuple [int , ...] | None = None ,
30
+ excluded : list [int ] | None = None ,
30
31
shape : int = 1 ,
31
32
) -> npt .NDArray :
32
33
"""
@@ -51,7 +52,7 @@ def _sample_posterior(
51
52
X = X .eval ()
52
53
53
54
if size is None :
54
- size_iter : Union [ list , tuple ] = (1 ,)
55
+ size_iter : list | tuple = (1 ,)
55
56
elif isinstance (size , int ):
56
57
size_iter = [size ]
57
58
else :
@@ -78,9 +79,9 @@ def _sample_posterior(
78
79
79
80
def plot_convergence (
80
81
idata : Any ,
81
- var_name : Optional [ str ] = None ,
82
+ var_name : str | None = None ,
82
83
kind : str = "ecdf" ,
83
- figsize : Optional [ tuple [float , float ]] = None ,
84
+ figsize : tuple [float , float ] | None = None ,
84
85
ax = None ,
85
86
) -> None :
86
87
"""
@@ -114,23 +115,23 @@ def plot_convergence(
114
115
def plot_ice (
115
116
bartrv : Variable ,
116
117
X : npt .NDArray ,
117
- Y : Optional [ npt .NDArray ] = None ,
118
- var_idx : Optional [ list [int ]] = None ,
119
- var_discrete : Optional [ list [int ]] = None ,
120
- func : Optional [ Callable ] = None ,
121
- centered : Optional [ bool ] = True ,
118
+ Y : npt .NDArray | None = None ,
119
+ var_idx : list [int ] | None = None ,
120
+ var_discrete : list [int ] | None = None ,
121
+ func : Callable | None = None ,
122
+ centered : bool | None = True ,
122
123
samples : int = 100 ,
123
124
instances : int = 30 ,
124
- random_seed : Optional [ int ] = None ,
125
+ random_seed : int | None = None ,
125
126
sharey : bool = True ,
126
127
smooth : bool = True ,
127
128
grid : str = "long" ,
128
129
color = "C0" ,
129
130
color_mean : str = "C0" ,
130
131
alpha : float = 0.1 ,
131
- figsize : Optional [ tuple [float , float ]] = None ,
132
- smooth_kwargs : Optional [ dict [str , Any ]] = None ,
133
- ax : Optional [ plt .Axes ] = None ,
132
+ figsize : tuple [float , float ] | None = None ,
133
+ smooth_kwargs : dict [str , Any ] | None = None ,
134
+ ax : plt .Axes | None = None ,
134
135
) -> list [plt .Axes ]:
135
136
"""
136
137
Individual conditional expectation plot.
@@ -258,24 +259,24 @@ def identity(x):
258
259
def plot_pdp (
259
260
bartrv : Variable ,
260
261
X : npt .NDArray ,
261
- Y : Optional [ npt .NDArray ] = None ,
262
+ Y : npt .NDArray | None = None ,
262
263
xs_interval : str = "quantiles" ,
263
- xs_values : Optional [ Union [ int , list [float ]]] = None ,
264
- var_idx : Optional [ list [int ]] = None ,
265
- var_discrete : Optional [ list [int ]] = None ,
266
- func : Optional [ Callable ] = None ,
264
+ xs_values : int | list [float ] | None = None ,
265
+ var_idx : list [int ] | None = None ,
266
+ var_discrete : list [int ] | None = None ,
267
+ func : Callable | None = None ,
267
268
samples : int = 200 ,
268
269
ref_line : bool = True ,
269
- random_seed : Optional [ int ] = None ,
270
+ random_seed : int | None = None ,
270
271
sharey : bool = True ,
271
272
smooth : bool = True ,
272
273
grid : str = "long" ,
273
274
color = "C0" ,
274
275
color_mean : str = "C0" ,
275
276
alpha : float = 0.1 ,
276
- figsize : Optional [ tuple [float , float ]] = None ,
277
- smooth_kwargs : Optional [ dict [str , Any ]] = None ,
278
- ax : Optional [ plt .Axes ] = None ,
277
+ figsize : tuple [float , float ] | None = None ,
278
+ smooth_kwargs : dict [str , Any ] | None = None ,
279
+ ax : plt .Axes = None ,
279
280
) -> list [plt .Axes ]:
280
281
"""
281
282
Partial dependence plot.
@@ -425,8 +426,8 @@ def _create_figure_axes(
425
426
var_idx : list [int ],
426
427
grid : str = "long" ,
427
428
sharey : bool = True ,
428
- figsize : Optional [ tuple [float , float ]] = None ,
429
- ax : Optional [ plt .Axes ] = None ,
429
+ figsize : tuple [float , float ] | None = None ,
430
+ ax : plt .Axes | None = None ,
430
431
) -> tuple [plt .Figure , list [plt .Axes ], int ]:
431
432
"""
432
433
Create and return the figure and axes objects for plotting the variables.
@@ -506,11 +507,11 @@ def _get_axes(grid, n_plots, sharex, sharey, figsize):
506
507
507
508
def _prepare_plot_data (
508
509
X : npt .NDArray ,
509
- Y : Optional [ npt .NDArray ] = None ,
510
+ Y : npt .NDArray | None = None ,
510
511
xs_interval : str = "quantiles" ,
511
- xs_values : Optional [ Union [ int , list [float ]]] = None ,
512
- var_idx : Optional [ list [int ]] = None ,
513
- var_discrete : Optional [ list [int ]] = None ,
512
+ xs_values : int | list [float ] | None = None ,
513
+ var_idx : list [int ] | None = None ,
514
+ var_discrete : list [int ] | None = None ,
514
515
) -> tuple [
515
516
npt .NDArray ,
516
517
list [str ],
@@ -519,7 +520,7 @@ def _prepare_plot_data(
519
520
list [int ],
520
521
list [int ],
521
522
str ,
522
- Union [ int , None , list [float ] ],
523
+ int | None | list [float ],
523
524
]:
524
525
"""
525
526
Prepare data for plotting.
@@ -600,7 +601,7 @@ def _prepare_plot_data(
600
601
def _create_pdp_data (
601
602
X : npt .NDArray ,
602
603
xs_interval : str ,
603
- xs_values : Optional [ Union [ int , list [float ]]] = None ,
604
+ xs_values : int | list [float ] | None = None ,
604
605
) -> npt .NDArray :
605
606
"""
606
607
Create data for partial dependence plot.
@@ -636,7 +637,7 @@ def _smooth_mean(
636
637
new_x : npt .NDArray ,
637
638
p_di : npt .NDArray ,
638
639
kind : str = "neutral" ,
639
- smooth_kwargs : Optional [ dict [str , Any ]] = None ,
640
+ smooth_kwargs : dict [str , Any ] | None = None ,
640
641
) -> tuple [np .ndarray , np .ndarray ]:
641
642
"""
642
643
Smooth the mean data for plotting.
@@ -805,7 +806,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
805
806
fixed : int = 0 ,
806
807
samples : int = 50 ,
807
808
random_seed : int | None = None ,
808
- ) -> dict [str , object ]:
809
+ ) -> dict [str , npt . NDArray ]:
809
810
"""
810
811
Estimates variable importance from the BART-posterior.
811
812
@@ -1026,11 +1027,11 @@ def vi_to_kulprit(vi_results: dict) -> list[list[str]]:
1026
1027
1027
1028
def plot_variable_importance (
1028
1029
vi_results : dict ,
1029
- submodels : Optional [ Union [ list [int ], np .ndarray , tuple [int , ...]]] = None ,
1030
- labels : Optional [ list [str ]] = None ,
1031
- figsize : Optional [ tuple [float , float ]] = None ,
1032
- plot_kwargs : Optional [ dict [str , Any ]] = None ,
1033
- ax : Optional [ plt .Axes ] = None ,
1030
+ submodels : list [int ] | np .ndarray | tuple [int , ...] | None = None ,
1031
+ labels : list [str ] | None = None ,
1032
+ figsize : tuple [float , float ] | None = None ,
1033
+ plot_kwargs : dict [str , Any ] | None = None ,
1034
+ ax : plt .Axes | None = None ,
1034
1035
):
1035
1036
"""
1036
1037
Estimates variable importance from the BART-posterior.
@@ -1128,13 +1129,13 @@ def plot_variable_importance(
1128
1129
1129
1130
def plot_scatter_submodels (
1130
1131
vi_results : dict ,
1131
- func : Optional [ Callable ] = None ,
1132
- submodels : Optional [ Union [ list [int ], np .ndarray ]] = None ,
1132
+ func : Callable | None = None ,
1133
+ submodels : list [int ] | np .ndarray | None = None ,
1133
1134
grid : str = "long" ,
1134
- labels : Optional [ list [str ]] = None ,
1135
- figsize : Optional [ tuple [float , float ]] = None ,
1136
- plot_kwargs : Optional [ dict [str , Any ]] = None ,
1137
- ax : Optional [ plt .Axes ] = None ,
1135
+ labels : list [str ] | None = None ,
1136
+ figsize : tuple [float , float ] | None = None ,
1137
+ plot_kwargs : dict [str , Any ] | None = None ,
1138
+ ax : plt .Axes | None = None ,
1138
1139
) -> list [plt .Axes ]:
1139
1140
"""
1140
1141
Plot submodel's predictions against reference-model's predictions.
0 commit comments