|
14 | 14 | # See the License for the specific language governing permissions and |
15 | 15 | # limitations under the License. |
16 | 16 |
|
| 17 | +import warnings |
17 | 18 | from multiprocessing import Manager |
18 | 19 | from typing import List, Optional, Tuple |
19 | | -import warnings |
20 | 20 |
|
21 | 21 | import numpy as np |
22 | 22 | import numpy.typing as npt |
|
26 | 26 | from pymc.logprob.abstract import _logprob |
27 | 27 | from pytensor.tensor.random.op import RandomVariable |
28 | 28 |
|
| 29 | +from .split_rules import SplitRule |
29 | 30 | from .tree import Tree |
30 | 31 | from .utils import TensorLike, _sample_posterior |
31 | | -from .split_rules import SplitRule |
32 | 32 |
|
33 | 33 | __all__ = ["BART"] |
34 | 34 |
|
@@ -93,7 +93,7 @@ class BART(Distribution): |
93 | 93 | Each element of split_prior should be in the [0, 1] interval and the elements should sum to |
94 | 94 | 1. Otherwise they will be normalized. |
95 | 95 | Defaults to 0, i.e. all covariates have the same prior probability to be selected. |
96 | | - split_rules : Optional[SplitRule], default None |
| 96 | + split_rules : Optional[List[SplitRule]], default None |
97 | 97 | List of SplitRule objects, one per column in input data. |
98 | 98 | Allows using different split rules for different columns. Default is ContinuousSplitRule. |
99 | 99 | Other options are OneHotSplitRule and SubsetSplitRule, both meant for categorical variables. |
@@ -127,7 +127,7 @@ def __new__( |
127 | 127 | beta: float = 2.0, |
128 | 128 | response: str = "constant", |
129 | 129 | split_prior: Optional[List[float]] = None, |
130 | | - split_rules: Optional[SplitRule] = None, |
| 130 | + split_rules: Optional[List[SplitRule]] = None, |
131 | 131 | separate_trees: Optional[bool] = False, |
132 | 132 | **kwargs, |
133 | 133 | ): |
|
0 commit comments