@@ -157,7 +157,7 @@ def plot_ice(
157157 bartrv : Variable ,
158158 X : npt .NDArray [np .float_ ],
159159 Y : Optional [npt .NDArray [np .float_ ]] = None ,
160- xs_interval : str = "linear " ,
160+ xs_interval : str = "quantiles " ,
161161 xs_values : Optional [Union [int , List [float ]]] = None ,
162162 var_idx : Optional [List [int ]] = None ,
163163 var_discrete : Optional [List [int ]] = None ,
@@ -303,7 +303,7 @@ def identity(x):
303303 idx = np .argsort (new_x )
304304 axes [count ].plot (new_x [idx ], p_di .mean (0 )[idx ], color = color_mean )
305305 axes [count ].plot (new_x [idx ], p_di .T [idx ], color = color , alpha = alpha )
306- axes [count ].set_xlabel (x_labels [var ])
306+ axes [count ].set_xlabel (x_labels [var ])
307307
308308 count += 1
309309
@@ -316,7 +316,7 @@ def plot_pdp(
316316 bartrv : Variable ,
317317 X : npt .NDArray [np .float_ ],
318318 Y : Optional [npt .NDArray [np .float_ ]] = None ,
319- xs_interval : str = "linear " ,
319+ xs_interval : str = "quantiles " ,
320320 xs_values : Optional [Union [int , List [float ]]] = None ,
321321 var_idx : Optional [List [int ]] = None ,
322322 var_discrete : Optional [List [int ]] = None ,
@@ -423,35 +423,39 @@ def identity(x):
423423 p_d = _sample_posterior (
424424 all_trees , X = fake_X , rng = rng , size = samples , excluded = excluded , shape = shape
425425 )
426- new_x = fake_X [:, var ]
427- for s_i in range (shape ):
428- p_di = func (p_d [:, :, s_i ])
429- if var in var_discrete :
430- y_means = p_di .mean (0 )
431- hdi = az .hdi (p_di )
432- axes [count ].errorbar (
433- new_x ,
434- y_means ,
435- (y_means - hdi [:, 0 ], hdi [:, 1 ] - y_means ),
436- fmt = "." ,
437- color = color ,
438- )
439- else :
440- az .plot_hdi (
441- new_x ,
442- p_di ,
443- smooth = smooth ,
444- fill_kwargs = {"alpha" : alpha , "color" : color },
445- ax = axes [count ],
446- )
447- if smooth :
448- x_data , y_data = _smooth_mean (new_x , p_di , "pdp" , smooth_kwargs )
449- axes [count ].plot (x_data , y_data , color = color_mean )
426+ with warnings .catch_warnings ():
427+ warnings .filterwarnings ("ignore" , message = "hdi currently interprets 2d data" )
428+ new_x = fake_X [:, var ]
429+ for s_i in range (shape ):
430+ p_di = func (p_d [:, :, s_i ])
431+ if var in var_discrete :
432+ _ , idx_uni = np .unique (new_x , return_index = True )
433+ y_means = p_di .mean (0 )[idx_uni ]
434+ hdi = az .hdi (p_di )[idx_uni ]
435+ axes [count ].errorbar (
436+ new_x [idx_uni ],
437+ y_means ,
438+ (y_means - hdi [:, 0 ], hdi [:, 1 ] - y_means ),
439+ fmt = "." ,
440+ color = color ,
441+ )
442+ axes [count ].set_xticks (new_x [idx_uni ])
450443 else :
451- axes [count ].plot (new_x , p_di .mean (0 ), color = color_mean )
444+ az .plot_hdi (
445+ new_x ,
446+ p_di ,
447+ smooth = smooth ,
448+ fill_kwargs = {"alpha" : alpha , "color" : color },
449+ ax = axes [count ],
450+ )
451+ if smooth :
452+ x_data , y_data = _smooth_mean (new_x , p_di , "pdp" , smooth_kwargs )
453+ axes [count ].plot (x_data , y_data , color = color_mean )
454+ else :
455+ axes [count ].plot (new_x , p_di .mean (0 ), color = color_mean )
452456 axes [count ].set_xlabel (x_labels [var ])
453457
454- count += 1
458+ count += 1
455459
456460 fig .text (- 0.05 , 0.5 , y_label , va = "center" , rotation = "vertical" , fontsize = 15 )
457461
@@ -527,16 +531,20 @@ def _get_axes(
527531 fig .delaxes (axes [i ])
528532 axes = axes [:n_plots ]
529533 else :
530- axes = [ax ]
531- fig = ax .get_figure ()
534+ if isinstance (ax , np .ndarray ):
535+ axes = ax
536+ fig = ax [0 ].get_figure ()
537+ else :
538+ axes = [ax ]
539+ fig = ax .get_figure () # type: ignore
532540
533541 return fig , axes , shape
534542
535543
536544def _prepare_plot_data (
537545 X : npt .NDArray [np .float_ ],
538546 Y : Optional [npt .NDArray [np .float_ ]] = None ,
539- xs_interval : str = "linear " ,
547+ xs_interval : str = "quantiles " ,
540548 xs_values : Optional [Union [int , List [float ]]] = None ,
541549 var_idx : Optional [List [int ]] = None ,
542550 var_discrete : Optional [List [int ]] = None ,
@@ -710,7 +718,7 @@ def plot_variable_importance(
710718 figsize : Optional [Tuple [float , float ]] = None ,
711719 samples : int = 100 ,
712720 random_seed : Optional [int ] = None ,
713- ) -> Tuple [npt .NDArray [np .int_ ], List [plt .axes ]]:
721+ ) -> Tuple [npt .NDArray [np .int_ ], List [plt .Axes ]]:
714722 """
715723 Estimates variable importance from the BART-posterior.
716724
0 commit comments