@@ -155,10 +155,9 @@ def __init__(
155155 else :
156156 self .split_rules = [ContinuousSplitRule ] * self .X .shape [1 ]
157157
158- jittered = np .random .normal (self .X , self .X .std (axis = 0 ) / 12 )
159- min_values = np .min (self .X , axis = 0 )
160- max_values = np .max (self .X , axis = 0 )
161- self .X = np .clip (jittered , min_values , max_values )
158+ for idx , rule in enumerate (self .split_rules ):
159+ if rule is ContinuousSplitRule :
160+ self .X [:, idx ] = jitter_duplicated (self .X [:, idx ], np .std (self .X [:, idx ]))
162161
163162 init_mean = self .bart .Y .mean ()
164163 self .num_observations = self .X .shape [0 ]
@@ -693,6 +692,21 @@ def inverse_cdf(
693692 return new_indices
694693
695694
695+ @njit
696+ def jitter_duplicated (array : npt .NDArray [np .float_ ], std : float ) -> npt .NDArray [np .float_ ]:
697+ """
698+ Jitter duplicated values.
699+ """
700+ seen = []
701+ for idx , num in enumerate (array ):
702+ if num in seen :
703+ array [idx ] = num + np .random .normal (0 , std / 12 )
704+ else :
705+ seen .append (num )
706+
707+ return array
708+
709+
696710def logp (point , out_vars , vars , shared ): # pylint: disable=redefined-builtin
697711 """Compile PyTensor function of the model and the input and output variables.
698712
0 commit comments