Skip to content

Commit db1cbb8

Browse files
author
morgsmss7
committed
Generate data once
1 parent 9fcc1a8 commit db1cbb8

File tree

1 file changed

+47
-21
lines changed

1 file changed

+47
-21
lines changed

examples/ensemble/plot_random_forest_regression_criteria_comparison.py

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,43 @@ def _test_forest(X, y, regr):
5858
y_pred = regr.predict(X)
5959
return mean_squared_error(y, y_pred)
6060

61+
def _prep_data(sim_dict, simulation_name, max_n_samples, n_dimensions):
62+
"""Generate train and test data for all trials."""
63+
# Get simulation parameters and validation dataset
64+
sim, noise, (X_test, y_test) = simulations[simulation_name]
65+
n_samples = int(max_n_samples)
66+
n_dimensions = int(n_dimensions)
67+
68+
# Sample training data
69+
if noise is not None:
70+
X_train, y_train = sim(n_samples=n_samples,
71+
n_dimensions=n_dimensions,
72+
noise=noise,
73+
random_state=random_state)
74+
else:
75+
X_train, y_train = sim(n_samples=n_samples,
76+
n_dimensions=n_dimensions,
77+
random_state=random_state)
78+
sim_dict[simulation_name] = (X_train, y_train, X_test, y_test)
79+
return sim_dict
6180

6281
###############################################################################
63-
def main(simulation_name, n_samples, criterion, n_dimensions, n_iter):
82+
def main(simulation_name, sim_data, n_samples, criterion, n_dimensions, n_iter):
6483
"""Measure the performance of RandomForest under simulation conditions.
6584
Parameters
6685
----------
6786
simulation_name : str
6887
Key from `simulations` dictionary.
88+
sim_data: dict
89+
Contains X_train, y_train, X_test, and y_test for each simulation_name
90+
X_train : np.array #TODO check this
91+
All X training data for given simulation
92+
y_train : np.array # TODO
93+
All y training data for given simulation
94+
X_test : np.array #TODO check this
95+
All X testing data for given simulation
96+
y_test : np.array # TODO
97+
All y testing data for given simulation
6998
n_samples : int
7099
Number of training samples.
71100
criterion : string
@@ -93,25 +122,16 @@ def main(simulation_name, n_samples, criterion, n_dimensions, n_iter):
93122
"""
94123
print(simulation_name, n_samples, criterion, n_dimensions, n_iter)
95124

96-
# Get simulation parameters and validation dataset
97-
sim, noise, (X_test, y_test) = simulations[simulation_name]
98-
n_samples = int(n_samples)
99-
n_dimensions = int(n_dimensions)
100-
101-
# Sample training data
102-
if noise is not None:
103-
X_train, y_train = sim(n_samples=n_samples,
104-
n_dimensions=n_dimensions,
105-
noise=noise,
106-
random_state=random_state)
107-
else:
108-
X_train, y_train = sim(n_samples=n_samples,
109-
n_dimensions=n_dimensions,
110-
random_state=random_state)
125+
# Unpack training and testing data
126+
X_train, y_train, X_test, y_test = sim_data
111127

128+
# Get subset of training data
129+
curr_X_train = X_train[0:n_samples]
130+
curr_y_train = y_train[0:n_samples]
131+
112132
# Train forest
113133
start = time.time()
114-
regr = _train_forest(X_train, y_train, criterion)
134+
regr = _train_forest(curr_X_train, curr_y_train, criterion)
115135
stop = time.time()
116136

117137
# Evaluate on testing data and record runtime
@@ -131,11 +151,10 @@ def main(simulation_name, n_samples, criterion, n_dimensions, n_iter):
131151
criteria = ["mae", "mse", "friedman_mse"]
132152

133153
# Number of times to repeat each simulation setting
134-
n_repeats = 10
154+
n_repeats = 30
135155

136156
# Create the parameter space
137-
params = product(simulation_names, sample_sizes, criteria,
138-
[n_dimensions], range(n_repeats))
157+
params = product(simulation_names, sample_sizes, criteria, range(n_repeats))
139158

140159

141160
###############################################################################
@@ -156,8 +175,15 @@ def main(simulation_name, n_samples, criterion, n_dimensions, n_iter):
156175
###############################################################################
157176
print("Running simulations...")
158177

178+
# Generate training and test data for simulations
179+
sim_data = {}
180+
for sim in simulation_names:
181+
sim_data = _prep_data(sim_data, sim, sample_sizes[-1], n_dimensions)
182+
159183
# Run the simulations in parallel
160-
data = Parallel(n_jobs=4)(delayed(main)(sim, n, crit, n_dim, n_iter) for sim, n, crit, n_dim, n_iter in params)
184+
data = Parallel(n_jobs=-2)(delayed(main)
185+
(sim, sim_data[simulation_name], n, crit, n_dimensions, n_iter)
186+
for sim, n, crit, n_iter in params)
161187

162188
# Save results as a DataFrame
163189
columns = ["simulation", "n_samples", "criterion",

0 commit comments

Comments
 (0)