Skip to content

Commit

Permalink
MAINT Improve source generation using Tempita (scikit-learn#20481)
Browse files Browse the repository at this point in the history
Co-authored-by: Joan Massich <[email protected]>
Co-authored-by: Julien Jerphanion <[email protected]>
Co-authored-by: Olivier Grisel <[email protected]>
Co-authored-by: Guillaume Lemaitre <[email protected]>
  • Loading branch information
5 people authored Aug 9, 2021
1 parent 3c732b9 commit 4b8cd88
Show file tree
Hide file tree
Showing 14 changed files with 222 additions and 121 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,6 @@ _configtest.o.d
# files generated from a template
sklearn/utils/_seq_dataset.pyx
sklearn/utils/_seq_dataset.pxd
sklearn/utils/_weight_vector.pyx
sklearn/utils/_weight_vector.pxd
sklearn/linear_model/_sag_fast.pyx
4 changes: 3 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,7 @@ allow_redefinition = True
# ignore files missing in VCS
ignore =
sklearn/linear_model/_sag_fast.pyx
sklearn/utils/_seq_dataset.pxd
sklearn/utils/_seq_dataset.pyx
sklearn/utils/_seq_dataset.pxd
sklearn/utils/_weight_vector.pyx
sklearn/utils/_weight_vector.pxd
2 changes: 1 addition & 1 deletion sklearn/_build_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def cythonize_extensions(top_path, config):
)


def gen_from_templates(templates, top_path):
def gen_from_templates(templates):
"""Generate cython files from a list of templates"""
# Lazy import because cython is not a runtime dependency.
from Cython import Tempita
Expand Down
78 changes: 37 additions & 41 deletions sklearn/linear_model/_sag_fast.pyx.tp
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,10 @@ Authors: Danny Sullivan <[email protected]>
License: BSD 3 clause
"""

# name, c_type
# name_suffix, c_type, np_type
dtypes = [('64', 'double', 'np.float64'),
('32', 'float', 'np.float32')]

def get_dispatch(dtypes):
for name, c_type, np_type in dtypes:
yield name, c_type, np_type

}}

#------------------------------------------------------------------------------
Expand Down Expand Up @@ -61,27 +57,27 @@ from libc.stdio cimport printf
np.import_array()


{{for name, c_type, np_type in get_dispatch(dtypes)}}
{{for name_suffix, c_type, np_type in dtypes}}

cdef extern from "_sgd_fast_helpers.h":
bint skl_isfinite{{name}}({{c_type}}) nogil
bint skl_isfinite{{name_suffix}}({{c_type}}) nogil


{{endfor}}

{{for name, c_type, np_type in get_dispatch(dtypes)}}
{{for name_suffix, c_type, np_type in dtypes}}

cdef inline {{c_type}} fmax{{name}}({{c_type}} x, {{c_type}} y) nogil:
cdef inline {{c_type}} fmax{{name_suffix}}({{c_type}} x, {{c_type}} y) nogil:
if x > y:
return x
return y

{{endfor}}


{{for name, c_type, np_type in get_dispatch(dtypes)}}
{{for name_suffix, c_type, np_type in dtypes}}

cdef {{c_type}} _logsumexp{{name}}({{c_type}}* arr, int n_classes) nogil:
cdef {{c_type}} _logsumexp{{name_suffix}}({{c_type}}* arr, int n_classes) nogil:
"""Computes the sum of arr assuming arr is in the log domain.

Returns log(sum(exp(arr))) while minimizing the possibility of
Expand All @@ -105,9 +101,9 @@ cdef {{c_type}} _logsumexp{{name}}({{c_type}}* arr, int n_classes) nogil:
{{endfor}}


{{for name, c_type, np_type in get_dispatch(dtypes)}}
{{for name_suffix, c_type, np_type in dtypes}}

cdef class MultinomialLogLoss{{name}}:
cdef class MultinomialLogLoss{{name_suffix}}:
cdef {{c_type}} _loss(self, {{c_type}}* prediction, {{c_type}} y, int n_classes,
{{c_type}} sample_weight) nogil:
r"""Multinomial Logistic regression loss.
Expand Down Expand Up @@ -145,7 +141,7 @@ cdef class MultinomialLogLoss{{name}}:
Bishop, C. M. (2006). Pattern recognition and machine learning.
Springer. (Chapter 4.3.4)
"""
cdef {{c_type}} logsumexp_prediction = _logsumexp{{name}}(prediction, n_classes)
cdef {{c_type}} logsumexp_prediction = _logsumexp{{name_suffix}}(prediction, n_classes)
cdef {{c_type}} loss

# y is the indice of the correct class of current sample.
Expand Down Expand Up @@ -191,7 +187,7 @@ cdef class MultinomialLogLoss{{name}}:
Bishop, C. M. (2006). Pattern recognition and machine learning.
Springer. (Chapter 4.3.4)
"""
cdef {{c_type}} logsumexp_prediction = _logsumexp{{name}}(prediction, n_classes)
cdef {{c_type}} logsumexp_prediction = _logsumexp{{name_suffix}}(prediction, n_classes)
cdef int class_ind

for class_ind in range(n_classes):
Expand All @@ -205,21 +201,21 @@ cdef class MultinomialLogLoss{{name}}:
gradient_ptr[class_ind] *= sample_weight

def __reduce__(self):
return MultinomialLogLoss{{name}}, ()
return MultinomialLogLoss{{name_suffix}}, ()

{{endfor}}

{{for name, c_type, np_type in get_dispatch(dtypes)}}
{{for name_suffix, c_type, np_type in dtypes}}

cdef inline {{c_type}} _soft_thresholding{{name}}({{c_type}} x, {{c_type}} shrinkage) nogil:
return fmax{{name}}(x - shrinkage, 0) - fmax{{name}}(- x - shrinkage, 0)
cdef inline {{c_type}} _soft_thresholding{{name_suffix}}({{c_type}} x, {{c_type}} shrinkage) nogil:
return fmax{{name_suffix}}(x - shrinkage, 0) - fmax{{name_suffix}}(- x - shrinkage, 0)

{{endfor}}


{{for name, c_type, np_type in get_dispatch(dtypes)}}
{{for name_suffix, c_type, np_type in dtypes}}

def sag{{name}}(SequentialDataset{{name}} dataset,
def sag{{name_suffix}}(SequentialDataset{{name_suffix}} dataset,
np.ndarray[{{c_type}}, ndim=2, mode='c'] weights_array,
np.ndarray[{{c_type}}, ndim=1, mode='c'] intercept_array,
int n_samples,
Expand Down Expand Up @@ -360,11 +356,11 @@ def sag{{name}}(SequentialDataset{{name}} dataset,
# Wether the loss function is multinomial
cdef bint multinomial = False
# Multinomial loss function
cdef MultinomialLogLoss{{name}} multiloss
cdef MultinomialLogLoss{{name_suffix}} multiloss

if loss_function == "multinomial":
multinomial = True
multiloss = MultinomialLogLoss{{name}}()
multiloss = MultinomialLogLoss{{name_suffix}}()
elif loss_function == "log":
loss = Log()
elif loss_function == "squared":
Expand Down Expand Up @@ -399,7 +395,7 @@ def sag{{name}}(SequentialDataset{{name}} dataset,

# make the weight updates
if sample_itr > 0:
status = lagged_update{{name}}(weights, wscale, xnnz,
status = lagged_update{{name_suffix}}(weights, wscale, xnnz,
n_samples, n_classes,
sample_itr,
cumulative_sums,
Expand All @@ -414,7 +410,7 @@ def sag{{name}}(SequentialDataset{{name}} dataset,
break

# find the current prediction
predict_sample{{name}}(x_data_ptr, x_ind_ptr, xnnz, weights, wscale,
predict_sample{{name_suffix}}(x_data_ptr, x_ind_ptr, xnnz, weights, wscale,
intercept, prediction, n_classes)

# compute the gradient for this sample, given the prediction
Expand Down Expand Up @@ -459,7 +455,7 @@ def sag{{name}}(SequentialDataset{{name}} dataset,
num_seen * intercept_decay)

# check to see that the intercept is not inf or NaN
if not skl_isfinite{{name}}(intercept[class_ind]):
if not skl_isfinite{{name_suffix}}(intercept[class_ind]):
status = -1
break
# Break from the n_samples outer loop if an error happened
Expand Down Expand Up @@ -488,7 +484,7 @@ def sag{{name}}(SequentialDataset{{name}} dataset,
if verbose:
with gil:
print("rescaling...")
status = scale_weights{{name}}(
status = scale_weights{{name_suffix}}(
weights, &wscale, n_features, n_samples, n_classes,
sample_itr, cumulative_sums,
cumulative_sums_prox,
Expand All @@ -504,7 +500,7 @@ def sag{{name}}(SequentialDataset{{name}} dataset,

# we scale the weights every n_samples iterations and reset the
# just-in-time update system for numerical stability.
status = scale_weights{{name}}(weights, &wscale, n_features,
status = scale_weights{{name_suffix}}(weights, &wscale, n_features,
n_samples,
n_classes, n_samples - 1,
cumulative_sums,
Expand All @@ -518,8 +514,8 @@ def sag{{name}}(SequentialDataset{{name}} dataset,
max_change = 0.0
max_weight = 0.0
for idx in range(n_features * n_classes):
max_weight = fmax{{name}}(max_weight, fabs(weights[idx]))
max_change = fmax{{name}}(max_change,
max_weight = fmax{{name_suffix}}(max_weight, fabs(weights[idx]))
max_change = fmax{{name_suffix}}(max_change,
fabs(weights[idx] -
previous_weights[idx]))
previous_weights[idx] = weights[idx]
Expand Down Expand Up @@ -553,9 +549,9 @@ def sag{{name}}(SequentialDataset{{name}} dataset,
{{endfor}}


{{for name, c_type, np_type in get_dispatch(dtypes)}}
{{for name_suffix, c_type, np_type in dtypes}}

cdef int scale_weights{{name}}({{c_type}}* weights, {{c_type}}* wscale,
cdef int scale_weights{{name_suffix}}({{c_type}}* weights, {{c_type}}* wscale,
int n_features,
int n_samples, int n_classes, int sample_itr,
{{c_type}}* cumulative_sums,
Expand All @@ -574,7 +570,7 @@ cdef int scale_weights{{name}}({{c_type}}* weights, {{c_type}}* wscale,
"""

cdef int status
status = lagged_update{{name}}(weights, wscale[0], n_features,
status = lagged_update{{name_suffix}}(weights, wscale[0], n_features,
n_samples, n_classes, sample_itr + 1,
cumulative_sums,
cumulative_sums_prox,
Expand All @@ -592,9 +588,9 @@ cdef int scale_weights{{name}}({{c_type}}* weights, {{c_type}}* wscale,
{{endfor}}


{{for name, c_type, np_type in get_dispatch(dtypes)}}
{{for name_suffix, c_type, np_type in dtypes}}

cdef int lagged_update{{name}}({{c_type}}* weights, {{c_type}} wscale, int xnnz,
cdef int lagged_update{{name_suffix}}({{c_type}}* weights, {{c_type}} wscale, int xnnz,
int n_samples, int n_classes, int sample_itr,
{{c_type}}* cumulative_sums,
{{c_type}}* cumulative_sums_prox,
Expand Down Expand Up @@ -630,7 +626,7 @@ cdef int lagged_update{{name}}({{c_type}}* weights, {{c_type}} wscale, int xnnz,
weights[idx] -= cum_sum * sum_gradient[idx]
if reset:
weights[idx] *= wscale
if not skl_isfinite{{name}}(weights[idx]):
if not skl_isfinite{{name_suffix}}(weights[idx]):
# returning here does not require the gil as the return
# type is a C integer
return -1
Expand All @@ -643,7 +639,7 @@ cdef int lagged_update{{name}}({{c_type}}* weights, {{c_type}} wscale, int xnnz,
# efficient than unrolling all the lagged updates.
# Idea taken from scikit-learn-contrib/lightning.
weights[idx] -= cum_sum * sum_gradient[idx]
weights[idx] = _soft_thresholding{{name}}(weights[idx],
weights[idx] = _soft_thresholding{{name_suffix}}(weights[idx],
cum_sum_prox)
else:
last_update_ind = feature_hist[feature_ind]
Expand All @@ -660,13 +656,13 @@ cdef int lagged_update{{name}}({{c_type}}* weights, {{c_type}} wscale, int xnnz,
grad_step = cumulative_sums[lagged_ind]
prox_step = cumulative_sums_prox[lagged_ind]
weights[idx] -= sum_gradient[idx] * grad_step
weights[idx] = _soft_thresholding{{name}}(weights[idx],
weights[idx] = _soft_thresholding{{name_suffix}}(weights[idx],
prox_step)

if reset:
weights[idx] *= wscale
# check to see that the weight is not inf or NaN
if not skl_isfinite{{name}}(weights[idx]):
if not skl_isfinite{{name_suffix}}(weights[idx]):
return -1
if reset:
feature_hist[feature_ind] = sample_itr % n_samples
Expand All @@ -683,9 +679,9 @@ cdef int lagged_update{{name}}({{c_type}}* weights, {{c_type}} wscale, int xnnz,
{{endfor}}


{{for name, c_type, np_type in get_dispatch(dtypes)}}
{{for name_suffix, c_type, np_type in dtypes}}

cdef void predict_sample{{name}}({{c_type}}* x_data_ptr, int* x_ind_ptr, int xnnz,
cdef void predict_sample{{name_suffix}}({{c_type}}* x_data_ptr, int* x_ind_ptr, int xnnz,
{{c_type}}* w_data_ptr, {{c_type}} wscale,
{{c_type}}* intercept, {{c_type}}* prediction,
int n_classes) nogil:
Expand Down
2 changes: 1 addition & 1 deletion sklearn/linear_model/_sgd_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ from numpy.math cimport INFINITY
cdef extern from "_sgd_fast_helpers.h":
bint skl_isfinite(double) nogil

from ..utils._weight_vector cimport WeightVector
from ..utils._weight_vector cimport WeightVector64 as WeightVector
from ..utils._seq_dataset cimport SequentialDataset64 as SequentialDataset

np.import_array()
Expand Down
2 changes: 1 addition & 1 deletion sklearn/linear_model/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def configuration(parent_package="", top_path=None):

# generate sag_fast from template
templates = ["sklearn/linear_model/_sag_fast.pyx.tp"]
gen_from_templates(templates, top_path)
gen_from_templates(templates)

config.add_extension(
"_sag_fast", sources=["_sag_fast.pyx"], include_dirs=numpy.get_include()
Expand Down
14 changes: 5 additions & 9 deletions sklearn/utils/_seq_dataset.pxd.tp
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,12 @@ Each class is duplicated for all dtypes (float and double). The keywords
between double braces are substituted in setup.py.
"""

# name, c_type
# name_suffix, c_type
dtypes = [('64', 'double'),
('32', 'float')]

def get_dispatch(dtypes):
for name, c_type in dtypes:
yield name, c_type

}}
{{for name, c_type in get_dispatch(dtypes)}}
{{for name_suffix, c_type in dtypes}}

#------------------------------------------------------------------------------

Expand All @@ -36,7 +32,7 @@ cimport numpy as np
# iterators over the rows of a matrix X and corresponding target values y.


cdef class SequentialDataset{{name}}:
cdef class SequentialDataset{{name_suffix}}:
cdef int current_index
cdef np.ndarray index
cdef int *index_data_ptr
Expand All @@ -56,7 +52,7 @@ cdef class SequentialDataset{{name}}:
int *nnz, {{c_type}} *y, {{c_type}} *sample_weight) nogil


cdef class ArrayDataset{{name}}(SequentialDataset{{name}}):
cdef class ArrayDataset{{name_suffix}}(SequentialDataset{{name_suffix}}):
cdef np.ndarray X
cdef np.ndarray Y
cdef np.ndarray sample_weights
Expand All @@ -69,7 +65,7 @@ cdef class ArrayDataset{{name}}(SequentialDataset{{name}}):
cdef {{c_type}} *sample_weight_data


cdef class CSRDataset{{name}}(SequentialDataset{{name}}):
cdef class CSRDataset{{name_suffix}}(SequentialDataset{{name_suffix}}):
cdef np.ndarray X_data
cdef np.ndarray X_indptr
cdef np.ndarray X_indices
Expand Down
14 changes: 5 additions & 9 deletions sklearn/utils/_seq_dataset.pyx.tp
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,12 @@ Author: Peter Prettenhofer <[email protected]>
License: BSD 3 clause
"""

# name, c_type, np_type
# name_suffix, c_type, np_type
dtypes = [('64', 'double', 'np.float64'),
('32', 'float', 'np.float32')]

def get_dispatch(dtypes):
for name, c_type, np_type in dtypes:
yield name, c_type, np_type

}}
{{for name, c_type, np_type in get_dispatch(dtypes)}}
{{for name_suffix, c_type, np_type in dtypes}}

#------------------------------------------------------------------------------

Expand All @@ -47,7 +43,7 @@ np.import_array()

from ._random cimport our_rand_r

cdef class SequentialDataset{{name}}:
cdef class SequentialDataset{{name_suffix}}:
"""Base class for datasets with sequential data access.

SequentialDataset is used to iterate over the rows of a matrix X and
Expand Down Expand Up @@ -219,7 +215,7 @@ cdef class SequentialDataset{{name}}:
return (x_data, x_indices, x_indptr), y, sample_weight, sample_idx


cdef class ArrayDataset{{name}}(SequentialDataset{{name}}):
cdef class ArrayDataset{{name_suffix}}(SequentialDataset{{name_suffix}}):
"""Dataset backed by a two-dimensional numpy array.

The dtype of the numpy array is expected to be ``{{np_type}}`` ({{c_type}})
Expand Down Expand Up @@ -288,7 +284,7 @@ cdef class ArrayDataset{{name}}(SequentialDataset{{name}}):
sample_weight[0] = self.sample_weight_data[sample_idx]


cdef class CSRDataset{{name}}(SequentialDataset{{name}}):
cdef class CSRDataset{{name_suffix}}(SequentialDataset{{name_suffix}}):
"""A ``SequentialDataset`` backed by a scipy sparse CSR matrix. """

def __cinit__(self, np.ndarray[{{c_type}}, ndim=1, mode='c'] X_data,
Expand Down
Loading

0 comments on commit 4b8cd88

Please sign in to comment.