Skip to content

Commit 9fcc1a8

Browse files
authored
remove new split criteria from plots & change to joblib
1 parent 253d03b commit 9fcc1a8

File tree

1 file changed

+21
-28
lines changed

1 file changed

+21
-28
lines changed

examples/ensemble/plot_random_forest_regression_criteria_comparison.py

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,11 @@
22
===============================================================================
33
Comparing different split criteria for random forest regression on toy datasets
44
===============================================================================
5-
65
An example to compare the different split criteria available for
76
:class:`sklearn.ensemble.RandomForestRegressor`.
8-
97
Metrics used to evaluate these splitters include Mean Squared Error (MSE), a
108
measure of distance between the true target (`y_true`) and the predicted output
119
(`y_pred`), and runtime.
12-
1310
For visual examples of these datasets, see
1411
:ref:`sphx_glr_auto_examples_datasets_plot_nonlinear_regression_datasets.py`.
1512
"""
@@ -19,7 +16,7 @@
1916

2017
import time
2118
from itertools import product
22-
from multiprocessing import Pool
19+
from joblib import Parallel, delayed
2320

2421
import matplotlib.pyplot as plt
2522
import numpy as np
@@ -65,7 +62,6 @@ def _test_forest(X, y, regr):
6562
###############################################################################
6663
def main(simulation_name, n_samples, criterion, n_dimensions, n_iter):
6764
"""Measure the performance of RandomForest under simulation conditions.
68-
6965
Parameters
7066
----------
7167
simulation_name : str
@@ -74,12 +70,11 @@ def main(simulation_name, n_samples, criterion, n_dimensions, n_iter):
7470
Number of training samples.
7571
criterion : string
7672
Split criterion used to train forest. Choose from
77-
("mse", "mae", "friedman_mse", "axis", "oblique").
73+
("mse", "mae", "friedman_mse").
7874
n_dimensions : int
7975
Number of features and targets to sample.
8076
n_iter : int
8177
Which repeat of the same simulation parameter we're on. Ignored.
82-
8378
Returns
8479
-------
8580
simulation_name : str
@@ -96,7 +91,7 @@ def main(simulation_name, n_samples, criterion, n_dimensions, n_iter):
9691
runtime : float
9792
Runtime (in seconds).
9893
"""
99-
print(simulation_name, n_samples)
94+
print(simulation_name, n_samples, criterion, n_dimensions, n_iter)
10095

10196
# Get simulation parameters and validation dataset
10297
sim, noise, (X_test, y_test) = simulations[simulation_name]
@@ -133,7 +128,7 @@ def main(simulation_name, n_samples, criterion, n_dimensions, n_iter):
133128
n_dimensions = 10
134129
simulation_names = simulations.keys()
135130
sample_sizes = np.arange(5, 51, 3)
136-
criteria = ["mae", "mse", "friedman_mse", "axis", "oblique"]
131+
criteria = ["mae", "mse", "friedman_mse"]
137132

138133
# Number of times to repeat each simulation setting
139134
n_repeats = 10
@@ -161,22 +156,20 @@ def main(simulation_name, n_samples, criterion, n_dimensions, n_iter):
161156
###############################################################################
162157
print("Running simulations...")
163158

164-
with Pool() as pool:
165-
166-
# Run the simulations in parallel
167-
data = pool.starmap(main, params)
168-
169-
# Save results as a DataFrame
170-
columns = ["simulation", "n_samples", "criterion",
171-
"n_dimensions", "mse", "runtime"]
172-
df = pd.DataFrame(data, columns=columns)
173-
174-
# Plot the results
175-
sns.relplot(x="n_samples",
176-
y="mse",
177-
hue="criterion",
178-
col="simulation",
179-
kind="line",
180-
data=df,
181-
facet_kws={'sharey': False, 'sharex': True})
182-
plt.show()
159+
# Run the simulations in parallel
160+
data = Parallel(n_jobs=4)(delayed(main)(sim, n, crit, n_dim, n_iter) for sim, n, crit, n_dim, n_iter in params)
161+
162+
# Save results as a DataFrame
163+
columns = ["simulation", "n_samples", "criterion",
164+
"n_dimensions", "mse", "runtime"]
165+
df = pd.DataFrame(data, columns=columns)
166+
167+
# Plot the results
168+
sns.relplot(x="n_samples",
169+
y="mse",
170+
hue="criterion",
171+
col="simulation",
172+
kind="line",
173+
data=df,
174+
facet_kws={'sharey': False, 'sharex': True})
175+
plt.show()

0 commit comments

Comments
 (0)