Skip to content

Commit 7cb5daf

Browse files
stsievertglemaitre
authored andcommitted
FEA Implements log-uniform random variable (scikit-learn#11232)
1 parent 4ca6ee4 commit 7cb5daf

File tree

5 files changed

+118
-20
lines changed

5 files changed

+118
-20
lines changed

doc/modules/grid_search.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ discrete choices (which will be sampled uniformly) can be specified::
121121
This example uses the ``scipy.stats`` module, which contains many useful
122122
distributions for sampling parameters, such as ``expon``, ``gamma``,
123123
``uniform`` or ``randint``.
124+
124125
In principle, any function can be passed that provides a ``rvs`` (random
125126
variate sample) method to sample a value. A call to the ``rvs`` function should
126127
provide independent random samples from possible parameter values on
@@ -139,6 +140,22 @@ For continuous parameters, such as ``C`` above, it is important to specify
139140
a continuous distribution to take full advantage of the randomization. This way,
140141
increasing ``n_iter`` will always lead to a finer search.
141142

143+
A continuous log-uniform random variable is available through
144+
:class:`~sklearn.utils.fixes.loguniform`. This is a continuous version of
145+
log-spaced parameters. For example to specify ``C`` above, ``loguniform(1,
146+
100)`` can be used instead of ``[1, 10, 100]`` or ``np.logspace(0, 2,
147+
num=1000)``. This is an alias to SciPy's `stats.reciprocal
148+
<https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.reciprocal.html>`_.
149+
150+
Mirroring the example above in grid search, we can specify a continuous random
151+
variable that is log-uniformly distributed between ``1e0`` and ``1e3``::
152+
153+
from sklearn.utils.fixes import loguniform
154+
{'C': loguniform(1e0, 1e3),
155+
'gamma': loguniform(1e-4, 1e-3),
156+
'kernel': ['rbf'],
157+
'class_weight':['balanced', None]}
158+
142159
.. topic:: Examples:
143160

144161
* :ref:`sphx_glr_auto_examples_model_selection_plot_randomized_search.py` compares the usage and efficiency

doc/whats_new/v0.22.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,13 @@ Changelog
575575
:func:`~utils.estimator_checks.parametrize_with_checks`, to parametrize
576576
estimator checks for a list of estimators. :pr:`14381` by `Thomas Fan`_.
577577

578+
- A new random variable, :class:`utils.fixes.loguniform` implements a
579+
log-uniform random variable (e.g., for use in RandomizedSearchCV).
580+
For example, the outcomes ``1``, ``10`` and ``100`` are all equally likely
581+
for ``loguniform(1, 100)``. See :issue:`11232` by
582+
:user:`Scott Sievert <stsievert>` and :user:`Nathaniel Saul <sauln>`,
583+
and `SciPy PR 10815 <https://github.com/scipy/scipy/pull/10815>`.
584+
578585
- |API| The following utils have been deprecated and are now private:
579586
- ``choose_check_classifiers_labels``
580587
- ``enforce_estimator_tags_y``

examples/model_selection/plot_randomized_search.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
parameters. The result in parameter settings is quite similar, while the run
1313
time for randomized search is drastically lower.
1414
15-
The performance is slightly worse for the randomized search, though this
16-
is most likely a noise effect and would not carry over to a held-out test set.
15+
The performance is may slightly worse for the randomized search, and is likely
16+
due to a noise effect and would not carry over to a held-out test set.
1717
1818
Note that in practice, one would not search over this many different parameters
1919
simultaneously using grid search, but pick only the ones deemed most important.
@@ -23,18 +23,19 @@
2323
import numpy as np
2424

2525
from time import time
26-
from scipy.stats import randint as sp_randint
26+
import scipy.stats as stats
27+
from sklearn.utils.fixes import loguniform
2728

28-
from sklearn.model_selection import GridSearchCV
29-
from sklearn.model_selection import RandomizedSearchCV
29+
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
3030
from sklearn.datasets import load_digits
31-
from sklearn.ensemble import RandomForestClassifier
31+
from sklearn.linear_model import SGDClassifier
3232

3333
# get some data
3434
X, y = load_digits(return_X_y=True)
3535

3636
# build a classifier
37-
clf = RandomForestClassifier(n_estimators=20)
37+
clf = SGDClassifier(loss='hinge', penalty='elasticnet',
38+
fit_intercept=True)
3839

3940

4041
# Utility function to report best scores
@@ -43,19 +44,17 @@ def report(results, n_top=3):
4344
candidates = np.flatnonzero(results['rank_test_score'] == i)
4445
for candidate in candidates:
4546
print("Model with rank: {0}".format(i))
46-
print("Mean validation score: {0:.3f} (std: {1:.3f})".format(
47-
results['mean_test_score'][candidate],
48-
results['std_test_score'][candidate]))
47+
print("Mean validation score: {0:.3f} (std: {1:.3f})"
48+
.format(results['mean_test_score'][candidate],
49+
results['std_test_score'][candidate]))
4950
print("Parameters: {0}".format(results['params'][candidate]))
5051
print("")
5152

5253

5354
# specify parameters and distributions to sample from
54-
param_dist = {"max_depth": [3, None],
55-
"max_features": sp_randint(1, 11),
56-
"min_samples_split": sp_randint(2, 11),
57-
"bootstrap": [True, False],
58-
"criterion": ["gini", "entropy"]}
55+
param_dist = {'average': [True, False],
56+
'l1_ratio': stats.uniform(0, 1),
57+
'alpha': loguniform(1e-4, 1e0)}
5958

6059
# run randomized search
6160
n_iter_search = 20
@@ -69,11 +68,9 @@ def report(results, n_top=3):
6968
report(random_search.cv_results_)
7069

7170
# use a full grid over all parameters
72-
param_grid = {"max_depth": [3, None],
73-
"max_features": [1, 3, 10],
74-
"min_samples_split": [2, 3, 10],
75-
"bootstrap": [True, False],
76-
"criterion": ["gini", "entropy"]}
71+
param_grid = {'average': [True, False],
72+
'l1_ratio': np.linspace(0, 1, num=10),
73+
'alpha': np.power(10, np.arange(-4, 1, dtype=float))}
7774

7875
# run grid search
7976
grid_search = GridSearchCV(clf, param_grid=param_grid)

sklearn/utils/fixes.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import numpy as np
1616
import scipy.sparse as sp
1717
import scipy
18+
import scipy.stats
1819
from scipy.sparse.linalg import lsqr as sparse_lsqr # noqa
1920

2021

@@ -256,3 +257,52 @@ def _joblib_parallel_args(**kwargs):
256257
if require == 'sharedmem':
257258
args['backend'] = 'threading'
258259
return args
260+
261+
262+
class loguniform(scipy.stats.reciprocal):
263+
"""A class supporting log-uniform random variables.
264+
265+
Parameters
266+
----------
267+
low : float
268+
The minimum value
269+
high : float
270+
The maximum value
271+
272+
Methods
273+
-------
274+
rvs(self, size=None, random_state=None)
275+
Generate log-uniform random variables
276+
277+
The most useful method for Scikit-learn usage is highlighted here.
278+
For a full list, see
279+
`scipy.stats.reciprocal
280+
<https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.reciprocal.html>`_.
281+
This list includes all functions of ``scipy.stats`` continuous
282+
distributions such as ``pdf``.
283+
284+
Notes
285+
-----
286+
This class generates values between ``low`` and ``high`` or
287+
288+
low <= loguniform(low, high).rvs() <= high
289+
290+
The logarithmic probability density function (PDF) is uniform. When
291+
``x`` is a uniformly distributed random variable between 0 and 1, ``10**x``
292+
are random variales that are equally likely to be returned.
293+
294+
This class is an alias to ``scipy.stats.reciprocal``, which uses the
295+
reciprocal distribution:
296+
https://en.wikipedia.org/wiki/Reciprocal_distribution
297+
298+
Examples
299+
--------
300+
301+
>>> from sklearn.utils.fixes import loguniform
302+
>>> rv = loguniform(1e-3, 1e1)
303+
>>> rvs = rv.rvs(random_state=42, size=1000)
304+
>>> rvs.min() # doctest: +SKIP
305+
0.0010435856341129003
306+
>>> rvs.max() # doctest: +SKIP
307+
9.97403052786026
308+
"""

sklearn/utils/tests/test_fixes.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,19 @@
33
# Lars Buitinck
44
# License: BSD 3 clause
55

6+
import math
67
import pickle
78

89
import numpy as np
910
import pytest
11+
import scipy.stats
1012

1113
from sklearn.utils.testing import assert_array_equal
1214

1315
from sklearn.utils.fixes import MaskedArray
1416
from sklearn.utils.fixes import _joblib_parallel_args
1517
from sklearn.utils.fixes import _object_dtype_isnan
18+
from sklearn.utils.fixes import loguniform
1619

1720

1821
def test_masked_array_obj_dtype_pickleable():
@@ -68,3 +71,27 @@ def test_object_dtype_isnan(dtype, val):
6871
mask = _object_dtype_isnan(X)
6972

7073
assert_array_equal(mask, expected_mask)
74+
75+
76+
@pytest.mark.parametrize("low,high,base",
77+
[(-1, 0, 10), (0, 2, np.exp(1)), (-1, 1, 2)])
78+
def test_loguniform(low, high, base):
79+
rv = loguniform(base ** low, base ** high)
80+
assert isinstance(rv, scipy.stats._distn_infrastructure.rv_frozen)
81+
rvs = rv.rvs(size=2000, random_state=0)
82+
83+
# Test the basics; right bounds, right size
84+
assert (base ** low <= rvs).all() and (rvs <= base ** high).all()
85+
assert len(rvs) == 2000
86+
87+
# Test that it's actually (fairly) uniform
88+
log_rvs = np.array([math.log(x, base) for x in rvs])
89+
counts, _ = np.histogram(log_rvs)
90+
assert counts.mean() == 200
91+
assert np.abs(counts - counts.mean()).max() <= 40
92+
93+
# Test that random_state works
94+
assert (
95+
loguniform(base ** low, base ** high).rvs(random_state=0)
96+
== loguniform(base ** low, base ** high).rvs(random_state=0)
97+
)

0 commit comments

Comments
 (0)