Skip to content

Commit 5d0723c

Browse files
authored
fix mypy issues (#238)
s
1 parent d87eb24 commit 5d0723c

File tree

2 files changed

+47
-47
lines changed

2 files changed

+47
-47
lines changed

pymc_bart/utils.py

Lines changed: 47 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
"""Utility function for variable selection and bart interpretability."""
33

44
import warnings
5-
from typing import Any, Callable, Optional, Union
5+
from collections.abc import Callable
6+
from typing import Any, TypeVar
67

78
import matplotlib.pyplot as plt
89
import numpy as np
@@ -18,15 +19,15 @@
1819

1920
from .tree import Tree
2021

21-
TensorLike = Union[npt.NDArray, pt.TensorVariable]
22+
TensorLike = TypeVar("TensorLike", npt.NDArray, pt.TensorVariable)
2223

2324

2425
def _sample_posterior(
2526
all_trees: list[list[Tree]],
2627
X: TensorLike,
2728
rng: np.random.Generator,
28-
size: Optional[Union[int, tuple[int, ...]]] = None,
29-
excluded: Optional[list[int]] = None,
29+
size: int | tuple[int, ...] | None = None,
30+
excluded: list[int] | None = None,
3031
shape: int = 1,
3132
) -> npt.NDArray:
3233
"""
@@ -51,7 +52,7 @@ def _sample_posterior(
5152
X = X.eval()
5253

5354
if size is None:
54-
size_iter: Union[list, tuple] = (1,)
55+
size_iter: list | tuple = (1,)
5556
elif isinstance(size, int):
5657
size_iter = [size]
5758
else:
@@ -78,9 +79,9 @@ def _sample_posterior(
7879

7980
def plot_convergence(
8081
idata: Any,
81-
var_name: Optional[str] = None,
82+
var_name: str | None = None,
8283
kind: str = "ecdf",
83-
figsize: Optional[tuple[float, float]] = None,
84+
figsize: tuple[float, float] | None = None,
8485
ax=None,
8586
) -> None:
8687
"""
@@ -114,23 +115,23 @@ def plot_convergence(
114115
def plot_ice(
115116
bartrv: Variable,
116117
X: npt.NDArray,
117-
Y: Optional[npt.NDArray] = None,
118-
var_idx: Optional[list[int]] = None,
119-
var_discrete: Optional[list[int]] = None,
120-
func: Optional[Callable] = None,
121-
centered: Optional[bool] = True,
118+
Y: npt.NDArray | None = None,
119+
var_idx: list[int] | None = None,
120+
var_discrete: list[int] | None = None,
121+
func: Callable | None = None,
122+
centered: bool | None = True,
122123
samples: int = 100,
123124
instances: int = 30,
124-
random_seed: Optional[int] = None,
125+
random_seed: int | None = None,
125126
sharey: bool = True,
126127
smooth: bool = True,
127128
grid: str = "long",
128129
color="C0",
129130
color_mean: str = "C0",
130131
alpha: float = 0.1,
131-
figsize: Optional[tuple[float, float]] = None,
132-
smooth_kwargs: Optional[dict[str, Any]] = None,
133-
ax: Optional[plt.Axes] = None,
132+
figsize: tuple[float, float] | None = None,
133+
smooth_kwargs: dict[str, Any] | None = None,
134+
ax: plt.Axes | None = None,
134135
) -> list[plt.Axes]:
135136
"""
136137
Individual conditional expectation plot.
@@ -258,24 +259,24 @@ def identity(x):
258259
def plot_pdp(
259260
bartrv: Variable,
260261
X: npt.NDArray,
261-
Y: Optional[npt.NDArray] = None,
262+
Y: npt.NDArray | None = None,
262263
xs_interval: str = "quantiles",
263-
xs_values: Optional[Union[int, list[float]]] = None,
264-
var_idx: Optional[list[int]] = None,
265-
var_discrete: Optional[list[int]] = None,
266-
func: Optional[Callable] = None,
264+
xs_values: int | list[float] | None = None,
265+
var_idx: list[int] | None = None,
266+
var_discrete: list[int] | None = None,
267+
func: Callable | None = None,
267268
samples: int = 200,
268269
ref_line: bool = True,
269-
random_seed: Optional[int] = None,
270+
random_seed: int | None = None,
270271
sharey: bool = True,
271272
smooth: bool = True,
272273
grid: str = "long",
273274
color="C0",
274275
color_mean: str = "C0",
275276
alpha: float = 0.1,
276-
figsize: Optional[tuple[float, float]] = None,
277-
smooth_kwargs: Optional[dict[str, Any]] = None,
278-
ax: Optional[plt.Axes] = None,
277+
figsize: tuple[float, float] | None = None,
278+
smooth_kwargs: dict[str, Any] | None = None,
279+
ax: plt.Axes = None,
279280
) -> list[plt.Axes]:
280281
"""
281282
Partial dependence plot.
@@ -425,8 +426,8 @@ def _create_figure_axes(
425426
var_idx: list[int],
426427
grid: str = "long",
427428
sharey: bool = True,
428-
figsize: Optional[tuple[float, float]] = None,
429-
ax: Optional[plt.Axes] = None,
429+
figsize: tuple[float, float] | None = None,
430+
ax: plt.Axes | None = None,
430431
) -> tuple[plt.Figure, list[plt.Axes], int]:
431432
"""
432433
Create and return the figure and axes objects for plotting the variables.
@@ -506,11 +507,11 @@ def _get_axes(grid, n_plots, sharex, sharey, figsize):
506507

507508
def _prepare_plot_data(
508509
X: npt.NDArray,
509-
Y: Optional[npt.NDArray] = None,
510+
Y: npt.NDArray | None = None,
510511
xs_interval: str = "quantiles",
511-
xs_values: Optional[Union[int, list[float]]] = None,
512-
var_idx: Optional[list[int]] = None,
513-
var_discrete: Optional[list[int]] = None,
512+
xs_values: int | list[float] | None = None,
513+
var_idx: list[int] | None = None,
514+
var_discrete: list[int] | None = None,
514515
) -> tuple[
515516
npt.NDArray,
516517
list[str],
@@ -519,7 +520,7 @@ def _prepare_plot_data(
519520
list[int],
520521
list[int],
521522
str,
522-
Union[int, None, list[float]],
523+
int | None | list[float],
523524
]:
524525
"""
525526
Prepare data for plotting.
@@ -600,7 +601,7 @@ def _prepare_plot_data(
600601
def _create_pdp_data(
601602
X: npt.NDArray,
602603
xs_interval: str,
603-
xs_values: Optional[Union[int, list[float]]] = None,
604+
xs_values: int | list[float] | None = None,
604605
) -> npt.NDArray:
605606
"""
606607
Create data for partial dependence plot.
@@ -636,7 +637,7 @@ def _smooth_mean(
636637
new_x: npt.NDArray,
637638
p_di: npt.NDArray,
638639
kind: str = "neutral",
639-
smooth_kwargs: Optional[dict[str, Any]] = None,
640+
smooth_kwargs: dict[str, Any] | None = None,
640641
) -> tuple[np.ndarray, np.ndarray]:
641642
"""
642643
Smooth the mean data for plotting.
@@ -805,7 +806,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
805806
fixed: int = 0,
806807
samples: int = 50,
807808
random_seed: int | None = None,
808-
) -> dict[str, object]:
809+
) -> dict[str, npt.NDArray]:
809810
"""
810811
Estimates variable importance from the BART-posterior.
811812
@@ -1026,11 +1027,11 @@ def vi_to_kulprit(vi_results: dict) -> list[list[str]]:
10261027

10271028
def plot_variable_importance(
10281029
vi_results: dict,
1029-
submodels: Optional[Union[list[int], np.ndarray, tuple[int, ...]]] = None,
1030-
labels: Optional[list[str]] = None,
1031-
figsize: Optional[tuple[float, float]] = None,
1032-
plot_kwargs: Optional[dict[str, Any]] = None,
1033-
ax: Optional[plt.Axes] = None,
1030+
submodels: list[int] | np.ndarray | tuple[int, ...] | None = None,
1031+
labels: list[str] | None = None,
1032+
figsize: tuple[float, float] | None = None,
1033+
plot_kwargs: dict[str, Any] | None = None,
1034+
ax: plt.Axes | None = None,
10341035
):
10351036
"""
10361037
Estimates variable importance from the BART-posterior.
@@ -1128,13 +1129,13 @@ def plot_variable_importance(
11281129

11291130
def plot_scatter_submodels(
11301131
vi_results: dict,
1131-
func: Optional[Callable] = None,
1132-
submodels: Optional[Union[list[int], np.ndarray]] = None,
1132+
func: Callable | None = None,
1133+
submodels: list[int] | np.ndarray | None = None,
11331134
grid: str = "long",
1134-
labels: Optional[list[str]] = None,
1135-
figsize: Optional[tuple[float, float]] = None,
1136-
plot_kwargs: Optional[dict[str, Any]] = None,
1137-
ax: Optional[plt.Axes] = None,
1135+
labels: list[str] | None = None,
1136+
figsize: tuple[float, float] | None = None,
1137+
plot_kwargs: dict[str, Any] | None = None,
1138+
ax: plt.Axes | None = None,
11381139
) -> list[plt.Axes]:
11391140
"""
11401141
Plot submodel's predictions against reference-model's predictions.

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ pyupgrade = 1
3737

3838
[tool.mypy]
3939
files = "pymc_bart/*.py"
40-
plugins = "numpy.typing.mypy_plugin"
4140

4241
[tool.mypy-matplotlib]
4342
ignore_missing_imports = true

0 commit comments

Comments
 (0)