Skip to content

Commit 9710d7f

Browse files
authored
add function to export the vi results to kulprit (#237)
1 parent 941d258 commit 9710d7f

File tree

3 files changed

+24
-0
lines changed

3 files changed

+24
-0
lines changed

pymc_bart/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
plot_scatter_submodels,
2626
plot_variable_importance,
2727
plot_variable_inclusion,
28+
vi_to_kulprit,
2829
)
2930

3031
__all__ = [
@@ -41,6 +42,7 @@
4142
"plot_scatter_submodels",
4243
"plot_variable_importance",
4344
"plot_variable_inclusion",
45+
"vi_to_kulprit",
4446
]
4547
__version__ = "0.10.0"
4648

pymc_bart/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,6 +1006,24 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
10061006
return vi_results
10071007

10081008

1009+
def vi_to_kulprit(vi_results: dict) -> list[list[str]]:
1010+
"""
1011+
Export variable importance results to Kulprit format.
1012+
1013+
Parameters
1014+
----------
1015+
vi_results : dict
1016+
Dictionary computed with `compute_variable_importance`
1017+
1018+
Returns
1019+
-------
1020+
list[list[str]]
1021+
A list of lists containing variable names for each submodel.
1022+
"""
1023+
clean_labels = [label.strip("+ ") for label in vi_results["labels"]]
1024+
return [clean_labels[:idx] for idx in range(len(clean_labels))]
1025+
1026+
10091027
def plot_variable_importance(
10101028
vi_results: dict,
10111029
submodels: Optional[Union[list[int], np.ndarray, tuple[int, ...]]] = None,

tests/test_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ def test_vi(self, kwargs):
7979
pmb.plot_variable_importance(vi_results, **kwargs)
8080
pmb.plot_scatter_submodels(vi_results, **kwargs)
8181

82+
user_terms = pmb.vi_to_kulprit(vi_results)
83+
assert len(user_terms) == 3
84+
assert all("+" not in term for terms in user_terms[1:] for term in terms)
85+
8286
def test_pdp_pandas_labels(self):
8387
pd = pytest.importorskip("pandas")
8488

0 commit comments

Comments
 (0)