From 4b8cd880397f279200b8faf9c75df13801cb45b7 Mon Sep 17 00:00:00 2001 From: Mathis Batoul Date: Mon, 9 Aug 2021 21:43:56 +0200 Subject: [PATCH] MAINT Improve source generation using `Tempita` (#20481) Co-authored-by: Joan Massich Co-authored-by: Julien Jerphanion Co-authored-by: Olivier Grisel Co-authored-by: Guillaume Lemaitre --- .gitignore | 2 + setup.cfg | 4 +- sklearn/_build_utils/__init__.py | 2 +- sklearn/linear_model/_sag_fast.pyx.tp | 78 +++++++------ sklearn/linear_model/_sgd_fast.pyx | 2 +- sklearn/linear_model/setup.py | 2 +- sklearn/utils/_seq_dataset.pxd.tp | 14 +-- sklearn/utils/_seq_dataset.pyx.tp | 14 +-- sklearn/utils/_weight_vector.pxd | 20 ---- sklearn/utils/_weight_vector.pxd.tp | 50 +++++++++ ...eight_vector.pyx => _weight_vector.pyx.tp} | 104 ++++++++++++------ sklearn/utils/setup.py | 7 +- sklearn/utils/tests/test_cython_templating.py | 20 ++++ sklearn/utils/tests/test_weight_vector.py | 24 ++++ 14 files changed, 222 insertions(+), 121 deletions(-) delete mode 100644 sklearn/utils/_weight_vector.pxd create mode 100644 sklearn/utils/_weight_vector.pxd.tp rename sklearn/utils/{_weight_vector.pyx => _weight_vector.pyx.tp} (67%) create mode 100644 sklearn/utils/tests/test_cython_templating.py create mode 100644 sklearn/utils/tests/test_weight_vector.py diff --git a/.gitignore b/.gitignore index 3ebd8e2bb1699..2c3dd0c4794c1 100644 --- a/.gitignore +++ b/.gitignore @@ -78,4 +78,6 @@ _configtest.o.d # files generated from a template sklearn/utils/_seq_dataset.pyx sklearn/utils/_seq_dataset.pxd +sklearn/utils/_weight_vector.pyx +sklearn/utils/_weight_vector.pxd sklearn/linear_model/_sag_fast.pyx diff --git a/setup.cfg b/setup.cfg index 94ec0190337dc..050045072f428 100644 --- a/setup.cfg +++ b/setup.cfg @@ -67,5 +67,7 @@ allow_redefinition = True # ignore files missing in VCS ignore = sklearn/linear_model/_sag_fast.pyx - sklearn/utils/_seq_dataset.pxd sklearn/utils/_seq_dataset.pyx + sklearn/utils/_seq_dataset.pxd + sklearn/utils/_weight_vector.pyx + sklearn/utils/_weight_vector.pxd diff --git a/sklearn/_build_utils/__init__.py b/sklearn/_build_utils/__init__.py index 670297dab3d22..be83b4c4d8baf 100644 --- a/sklearn/_build_utils/__init__.py +++ b/sklearn/_build_utils/__init__.py @@ -80,7 +80,7 @@ def cythonize_extensions(top_path, config): ) -def gen_from_templates(templates, top_path): +def gen_from_templates(templates): """Generate cython files from a list of templates""" # Lazy import because cython is not a runtime dependency. from Cython import Tempita diff --git a/sklearn/linear_model/_sag_fast.pyx.tp b/sklearn/linear_model/_sag_fast.pyx.tp index 8508340e3b329..6ca141fe99305 100644 --- a/sklearn/linear_model/_sag_fast.pyx.tp +++ b/sklearn/linear_model/_sag_fast.pyx.tp @@ -19,14 +19,10 @@ Authors: Danny Sullivan License: BSD 3 clause """ -# name, c_type +# name_suffix, c_type, np_type dtypes = [('64', 'double', 'np.float64'), ('32', 'float', 'np.float32')] -def get_dispatch(dtypes): - for name, c_type, np_type in dtypes: - yield name, c_type, np_type - }} #------------------------------------------------------------------------------ @@ -61,17 +57,17 @@ from libc.stdio cimport printf np.import_array() -{{for name, c_type, np_type in get_dispatch(dtypes)}} +{{for name_suffix, c_type, np_type in dtypes}} cdef extern from "_sgd_fast_helpers.h": - bint skl_isfinite{{name}}({{c_type}}) nogil + bint skl_isfinite{{name_suffix}}({{c_type}}) nogil {{endfor}} -{{for name, c_type, np_type in get_dispatch(dtypes)}} +{{for name_suffix, c_type, np_type in dtypes}} -cdef inline {{c_type}} fmax{{name}}({{c_type}} x, {{c_type}} y) nogil: +cdef inline {{c_type}} fmax{{name_suffix}}({{c_type}} x, {{c_type}} y) nogil: if x > y: return x return y @@ -79,9 +75,9 @@ cdef inline {{c_type}} fmax{{name}}({{c_type}} x, {{c_type}} y) nogil: {{endfor}} -{{for name, c_type, np_type in get_dispatch(dtypes)}} +{{for name_suffix, c_type, np_type in dtypes}} -cdef {{c_type}} _logsumexp{{name}}({{c_type}}* arr, int n_classes) nogil: +cdef {{c_type}} _logsumexp{{name_suffix}}({{c_type}}* arr, int n_classes) nogil: """Computes the sum of arr assuming arr is in the log domain. Returns log(sum(exp(arr))) while minimizing the possibility of @@ -105,9 +101,9 @@ cdef {{c_type}} _logsumexp{{name}}({{c_type}}* arr, int n_classes) nogil: {{endfor}} -{{for name, c_type, np_type in get_dispatch(dtypes)}} +{{for name_suffix, c_type, np_type in dtypes}} -cdef class MultinomialLogLoss{{name}}: +cdef class MultinomialLogLoss{{name_suffix}}: cdef {{c_type}} _loss(self, {{c_type}}* prediction, {{c_type}} y, int n_classes, {{c_type}} sample_weight) nogil: r"""Multinomial Logistic regression loss. @@ -145,7 +141,7 @@ cdef class MultinomialLogLoss{{name}}: Bishop, C. M. (2006). Pattern recognition and machine learning. Springer. (Chapter 4.3.4) """ - cdef {{c_type}} logsumexp_prediction = _logsumexp{{name}}(prediction, n_classes) + cdef {{c_type}} logsumexp_prediction = _logsumexp{{name_suffix}}(prediction, n_classes) cdef {{c_type}} loss # y is the indice of the correct class of current sample. @@ -191,7 +187,7 @@ cdef class MultinomialLogLoss{{name}}: Bishop, C. M. (2006). Pattern recognition and machine learning. Springer. (Chapter 4.3.4) """ - cdef {{c_type}} logsumexp_prediction = _logsumexp{{name}}(prediction, n_classes) + cdef {{c_type}} logsumexp_prediction = _logsumexp{{name_suffix}}(prediction, n_classes) cdef int class_ind for class_ind in range(n_classes): @@ -205,21 +201,21 @@ cdef class MultinomialLogLoss{{name}}: gradient_ptr[class_ind] *= sample_weight def __reduce__(self): - return MultinomialLogLoss{{name}}, () + return MultinomialLogLoss{{name_suffix}}, () {{endfor}} -{{for name, c_type, np_type in get_dispatch(dtypes)}} +{{for name_suffix, c_type, np_type in dtypes}} -cdef inline {{c_type}} _soft_thresholding{{name}}({{c_type}} x, {{c_type}} shrinkage) nogil: - return fmax{{name}}(x - shrinkage, 0) - fmax{{name}}(- x - shrinkage, 0) +cdef inline {{c_type}} _soft_thresholding{{name_suffix}}({{c_type}} x, {{c_type}} shrinkage) nogil: + return fmax{{name_suffix}}(x - shrinkage, 0) - fmax{{name_suffix}}(- x - shrinkage, 0) {{endfor}} -{{for name, c_type, np_type in get_dispatch(dtypes)}} +{{for name_suffix, c_type, np_type in dtypes}} -def sag{{name}}(SequentialDataset{{name}} dataset, +def sag{{name_suffix}}(SequentialDataset{{name_suffix}} dataset, np.ndarray[{{c_type}}, ndim=2, mode='c'] weights_array, np.ndarray[{{c_type}}, ndim=1, mode='c'] intercept_array, int n_samples, @@ -360,11 +356,11 @@ def sag{{name}}(SequentialDataset{{name}} dataset, # Wether the loss function is multinomial cdef bint multinomial = False # Multinomial loss function - cdef MultinomialLogLoss{{name}} multiloss + cdef MultinomialLogLoss{{name_suffix}} multiloss if loss_function == "multinomial": multinomial = True - multiloss = MultinomialLogLoss{{name}}() + multiloss = MultinomialLogLoss{{name_suffix}}() elif loss_function == "log": loss = Log() elif loss_function == "squared": @@ -399,7 +395,7 @@ def sag{{name}}(SequentialDataset{{name}} dataset, # make the weight updates if sample_itr > 0: - status = lagged_update{{name}}(weights, wscale, xnnz, + status = lagged_update{{name_suffix}}(weights, wscale, xnnz, n_samples, n_classes, sample_itr, cumulative_sums, @@ -414,7 +410,7 @@ def sag{{name}}(SequentialDataset{{name}} dataset, break # find the current prediction - predict_sample{{name}}(x_data_ptr, x_ind_ptr, xnnz, weights, wscale, + predict_sample{{name_suffix}}(x_data_ptr, x_ind_ptr, xnnz, weights, wscale, intercept, prediction, n_classes) # compute the gradient for this sample, given the prediction @@ -459,7 +455,7 @@ def sag{{name}}(SequentialDataset{{name}} dataset, num_seen * intercept_decay) # check to see that the intercept is not inf or NaN - if not skl_isfinite{{name}}(intercept[class_ind]): + if not skl_isfinite{{name_suffix}}(intercept[class_ind]): status = -1 break # Break from the n_samples outer loop if an error happened @@ -488,7 +484,7 @@ def sag{{name}}(SequentialDataset{{name}} dataset, if verbose: with gil: print("rescaling...") - status = scale_weights{{name}}( + status = scale_weights{{name_suffix}}( weights, &wscale, n_features, n_samples, n_classes, sample_itr, cumulative_sums, cumulative_sums_prox, @@ -504,7 +500,7 @@ def sag{{name}}(SequentialDataset{{name}} dataset, # we scale the weights every n_samples iterations and reset the # just-in-time update system for numerical stability. - status = scale_weights{{name}}(weights, &wscale, n_features, + status = scale_weights{{name_suffix}}(weights, &wscale, n_features, n_samples, n_classes, n_samples - 1, cumulative_sums, @@ -518,8 +514,8 @@ def sag{{name}}(SequentialDataset{{name}} dataset, max_change = 0.0 max_weight = 0.0 for idx in range(n_features * n_classes): - max_weight = fmax{{name}}(max_weight, fabs(weights[idx])) - max_change = fmax{{name}}(max_change, + max_weight = fmax{{name_suffix}}(max_weight, fabs(weights[idx])) + max_change = fmax{{name_suffix}}(max_change, fabs(weights[idx] - previous_weights[idx])) previous_weights[idx] = weights[idx] @@ -553,9 +549,9 @@ def sag{{name}}(SequentialDataset{{name}} dataset, {{endfor}} -{{for name, c_type, np_type in get_dispatch(dtypes)}} +{{for name_suffix, c_type, np_type in dtypes}} -cdef int scale_weights{{name}}({{c_type}}* weights, {{c_type}}* wscale, +cdef int scale_weights{{name_suffix}}({{c_type}}* weights, {{c_type}}* wscale, int n_features, int n_samples, int n_classes, int sample_itr, {{c_type}}* cumulative_sums, @@ -574,7 +570,7 @@ cdef int scale_weights{{name}}({{c_type}}* weights, {{c_type}}* wscale, """ cdef int status - status = lagged_update{{name}}(weights, wscale[0], n_features, + status = lagged_update{{name_suffix}}(weights, wscale[0], n_features, n_samples, n_classes, sample_itr + 1, cumulative_sums, cumulative_sums_prox, @@ -592,9 +588,9 @@ cdef int scale_weights{{name}}({{c_type}}* weights, {{c_type}}* wscale, {{endfor}} -{{for name, c_type, np_type in get_dispatch(dtypes)}} +{{for name_suffix, c_type, np_type in dtypes}} -cdef int lagged_update{{name}}({{c_type}}* weights, {{c_type}} wscale, int xnnz, +cdef int lagged_update{{name_suffix}}({{c_type}}* weights, {{c_type}} wscale, int xnnz, int n_samples, int n_classes, int sample_itr, {{c_type}}* cumulative_sums, {{c_type}}* cumulative_sums_prox, @@ -630,7 +626,7 @@ cdef int lagged_update{{name}}({{c_type}}* weights, {{c_type}} wscale, int xnnz, weights[idx] -= cum_sum * sum_gradient[idx] if reset: weights[idx] *= wscale - if not skl_isfinite{{name}}(weights[idx]): + if not skl_isfinite{{name_suffix}}(weights[idx]): # returning here does not require the gil as the return # type is a C integer return -1 @@ -643,7 +639,7 @@ cdef int lagged_update{{name}}({{c_type}}* weights, {{c_type}} wscale, int xnnz, # efficient than unrolling all the lagged updates. # Idea taken from scikit-learn-contrib/lightning. weights[idx] -= cum_sum * sum_gradient[idx] - weights[idx] = _soft_thresholding{{name}}(weights[idx], + weights[idx] = _soft_thresholding{{name_suffix}}(weights[idx], cum_sum_prox) else: last_update_ind = feature_hist[feature_ind] @@ -660,13 +656,13 @@ cdef int lagged_update{{name}}({{c_type}}* weights, {{c_type}} wscale, int xnnz, grad_step = cumulative_sums[lagged_ind] prox_step = cumulative_sums_prox[lagged_ind] weights[idx] -= sum_gradient[idx] * grad_step - weights[idx] = _soft_thresholding{{name}}(weights[idx], + weights[idx] = _soft_thresholding{{name_suffix}}(weights[idx], prox_step) if reset: weights[idx] *= wscale # check to see that the weight is not inf or NaN - if not skl_isfinite{{name}}(weights[idx]): + if not skl_isfinite{{name_suffix}}(weights[idx]): return -1 if reset: feature_hist[feature_ind] = sample_itr % n_samples @@ -683,9 +679,9 @@ cdef int lagged_update{{name}}({{c_type}}* weights, {{c_type}} wscale, int xnnz, {{endfor}} -{{for name, c_type, np_type in get_dispatch(dtypes)}} +{{for name_suffix, c_type, np_type in dtypes}} -cdef void predict_sample{{name}}({{c_type}}* x_data_ptr, int* x_ind_ptr, int xnnz, +cdef void predict_sample{{name_suffix}}({{c_type}}* x_data_ptr, int* x_ind_ptr, int xnnz, {{c_type}}* w_data_ptr, {{c_type}} wscale, {{c_type}}* intercept, {{c_type}}* prediction, int n_classes) nogil: diff --git a/sklearn/linear_model/_sgd_fast.pyx b/sklearn/linear_model/_sgd_fast.pyx index edf421e537a2f..3b54a34adc80b 100644 --- a/sklearn/linear_model/_sgd_fast.pyx +++ b/sklearn/linear_model/_sgd_fast.pyx @@ -21,7 +21,7 @@ from numpy.math cimport INFINITY cdef extern from "_sgd_fast_helpers.h": bint skl_isfinite(double) nogil -from ..utils._weight_vector cimport WeightVector +from ..utils._weight_vector cimport WeightVector64 as WeightVector from ..utils._seq_dataset cimport SequentialDataset64 as SequentialDataset np.import_array() diff --git a/sklearn/linear_model/setup.py b/sklearn/linear_model/setup.py index cc5d277e13502..74d7d9e2b05ea 100644 --- a/sklearn/linear_model/setup.py +++ b/sklearn/linear_model/setup.py @@ -29,7 +29,7 @@ def configuration(parent_package="", top_path=None): # generate sag_fast from template templates = ["sklearn/linear_model/_sag_fast.pyx.tp"] - gen_from_templates(templates, top_path) + gen_from_templates(templates) config.add_extension( "_sag_fast", sources=["_sag_fast.pyx"], include_dirs=numpy.get_include() diff --git a/sklearn/utils/_seq_dataset.pxd.tp b/sklearn/utils/_seq_dataset.pxd.tp index be2d94a05b015..428f44a2c0358 100644 --- a/sklearn/utils/_seq_dataset.pxd.tp +++ b/sklearn/utils/_seq_dataset.pxd.tp @@ -12,16 +12,12 @@ Each class is duplicated for all dtypes (float and double). The keywords between double braces are substituted in setup.py. """ -# name, c_type +# name_suffix, c_type dtypes = [('64', 'double'), ('32', 'float')] -def get_dispatch(dtypes): - for name, c_type in dtypes: - yield name, c_type - }} -{{for name, c_type in get_dispatch(dtypes)}} +{{for name_suffix, c_type in dtypes}} #------------------------------------------------------------------------------ @@ -36,7 +32,7 @@ cimport numpy as np # iterators over the rows of a matrix X and corresponding target values y. -cdef class SequentialDataset{{name}}: +cdef class SequentialDataset{{name_suffix}}: cdef int current_index cdef np.ndarray index cdef int *index_data_ptr @@ -56,7 +52,7 @@ cdef class SequentialDataset{{name}}: int *nnz, {{c_type}} *y, {{c_type}} *sample_weight) nogil -cdef class ArrayDataset{{name}}(SequentialDataset{{name}}): +cdef class ArrayDataset{{name_suffix}}(SequentialDataset{{name_suffix}}): cdef np.ndarray X cdef np.ndarray Y cdef np.ndarray sample_weights @@ -69,7 +65,7 @@ cdef class ArrayDataset{{name}}(SequentialDataset{{name}}): cdef {{c_type}} *sample_weight_data -cdef class CSRDataset{{name}}(SequentialDataset{{name}}): +cdef class CSRDataset{{name_suffix}}(SequentialDataset{{name_suffix}}): cdef np.ndarray X_data cdef np.ndarray X_indptr cdef np.ndarray X_indices diff --git a/sklearn/utils/_seq_dataset.pyx.tp b/sklearn/utils/_seq_dataset.pyx.tp index 44edb9216dc62..8bc901194a24e 100644 --- a/sklearn/utils/_seq_dataset.pyx.tp +++ b/sklearn/utils/_seq_dataset.pyx.tp @@ -20,16 +20,12 @@ Author: Peter Prettenhofer License: BSD 3 clause """ -# name, c_type, np_type +# name_suffix, c_type, np_type dtypes = [('64', 'double', 'np.float64'), ('32', 'float', 'np.float32')] -def get_dispatch(dtypes): - for name, c_type, np_type in dtypes: - yield name, c_type, np_type - }} -{{for name, c_type, np_type in get_dispatch(dtypes)}} +{{for name_suffix, c_type, np_type in dtypes}} #------------------------------------------------------------------------------ @@ -47,7 +43,7 @@ np.import_array() from ._random cimport our_rand_r -cdef class SequentialDataset{{name}}: +cdef class SequentialDataset{{name_suffix}}: """Base class for datasets with sequential data access. SequentialDataset is used to iterate over the rows of a matrix X and @@ -219,7 +215,7 @@ cdef class SequentialDataset{{name}}: return (x_data, x_indices, x_indptr), y, sample_weight, sample_idx -cdef class ArrayDataset{{name}}(SequentialDataset{{name}}): +cdef class ArrayDataset{{name_suffix}}(SequentialDataset{{name_suffix}}): """Dataset backed by a two-dimensional numpy array. The dtype of the numpy array is expected to be ``{{np_type}}`` ({{c_type}}) @@ -288,7 +284,7 @@ cdef class ArrayDataset{{name}}(SequentialDataset{{name}}): sample_weight[0] = self.sample_weight_data[sample_idx] -cdef class CSRDataset{{name}}(SequentialDataset{{name}}): +cdef class CSRDataset{{name_suffix}}(SequentialDataset{{name_suffix}}): """A ``SequentialDataset`` backed by a scipy sparse CSR matrix. """ def __cinit__(self, np.ndarray[{{c_type}}, ndim=1, mode='c'] X_data, diff --git a/sklearn/utils/_weight_vector.pxd b/sklearn/utils/_weight_vector.pxd deleted file mode 100644 index fc1b47a50ef1f..0000000000000 --- a/sklearn/utils/_weight_vector.pxd +++ /dev/null @@ -1,20 +0,0 @@ -"""Efficient (dense) parameter vector implementation for linear models. """ - -cdef class WeightVector(object): - cdef double *w_data_ptr - cdef double *aw_data_ptr - cdef double wscale - cdef double average_a - cdef double average_b - cdef int n_features - cdef double sq_norm - - cdef void add(self, double *x_data_ptr, int *x_ind_ptr, - int xnnz, double c) nogil - cdef void add_average(self, double *x_data_ptr, int *x_ind_ptr, - int xnnz, double c, double num_iter) nogil - cdef double dot(self, double *x_data_ptr, int *x_ind_ptr, - int xnnz) nogil - cdef void scale(self, double c) nogil - cdef void reset_wscale(self) nogil - cdef double norm(self) nogil diff --git a/sklearn/utils/_weight_vector.pxd.tp b/sklearn/utils/_weight_vector.pxd.tp new file mode 100644 index 0000000000000..8b0fc234713fb --- /dev/null +++ b/sklearn/utils/_weight_vector.pxd.tp @@ -0,0 +1,50 @@ +{{py: + +""" +Efficient (dense) parameter vector implementation for linear models. + +Template file for easily generate fused types consistent code using Tempita +(https://github.com/cython/cython/blob/master/Cython/Tempita/_tempita.py). + +Generated file: weight_vector.pxd + +Each class is duplicated for all dtypes (float and double). The keywords +between double braces are substituted in setup.py. +""" + +# name_suffix, c_type +dtypes = [('64', 'double'), + ('32', 'float')] + +}} + +cimport numpy as np + +{{for name_suffix, c_type in dtypes}} + +cdef extern from "math.h": + cdef extern {{c_type}} sqrt({{c_type}} x) + + +cdef class WeightVector{{name_suffix}}(object): + cdef readonly {{c_type}}[::1] w + cdef readonly {{c_type}}[::1] aw + cdef {{c_type}} *w_data_ptr + cdef {{c_type}} *aw_data_ptr + cdef {{c_type}} wscale + cdef {{c_type}} average_a + cdef {{c_type}} average_b + cdef int n_features + cdef {{c_type}} sq_norm + + cdef void add(self, {{c_type}} *x_data_ptr, int *x_ind_ptr, + int xnnz, {{c_type}} c) nogil + cdef void add_average(self, {{c_type}} *x_data_ptr, int *x_ind_ptr, + int xnnz, {{c_type}} c, {{c_type}} num_iter) nogil + cdef {{c_type}} dot(self, {{c_type}} *x_data_ptr, int *x_ind_ptr, + int xnnz) nogil + cdef void scale(self, {{c_type}} c) nogil + cdef void reset_wscale(self) nogil + cdef {{c_type}} norm(self) nogil + +{{endfor}} diff --git a/sklearn/utils/_weight_vector.pyx b/sklearn/utils/_weight_vector.pyx.tp similarity index 67% rename from sklearn/utils/_weight_vector.pyx rename to sklearn/utils/_weight_vector.pyx.tp index 936c836a193e8..0e8ec45121438 100644 --- a/sklearn/utils/_weight_vector.pyx +++ b/sklearn/utils/_weight_vector.pyx.tp @@ -1,3 +1,23 @@ +{{py: + +""" +Efficient (dense) parameter vector implementation for linear models. + +Template file for easily generate fused types consistent code using Tempita +(https://github.com/cython/cython/blob/master/Cython/Tempita/_tempita.py). + +Generated file: weight_vector.pxd + +Each class is duplicated for all dtypes (float and double). The keywords +between double braces are substituted in setup.py. +""" + +# name_suffix, c_type, reset_wscale_threshold +dtypes = [('64', 'double', 1e-9), + ('32', 'float', 1e-6)] + +}} + # cython: cdivision=True # cython: boundscheck=False # cython: wraparound=False @@ -19,8 +39,9 @@ from ._cython_blas cimport _dot, _scal, _axpy np.import_array() +{{for name_suffix, c_type, reset_wscale_threshold in dtypes}} -cdef class WeightVector(object): +cdef class WeightVector{{name_suffix}}(object): """Dense vector represented by a scalar and a numpy array. The class provides methods to ``add`` a sparse vector @@ -30,57 +51,65 @@ cdef class WeightVector(object): Attributes ---------- - w : ndarray, dtype=double, order='C' + w : ndarray, dtype={{c_type}}, order='C' The numpy array which backs the weight vector. - aw : ndarray, dtype=double, order='C' + aw : ndarray, dtype={{c_type}}, order='C' The numpy array which backs the average_weight vector. - wscale : double + w_data_ptr : {{c_type}}* + A pointer to the data of the numpy array. + wscale : {{c_type}} The scale of the vector. n_features : int The number of features (= dimensionality of ``w``). - sq_norm : double + sq_norm : {{c_type}} The squared norm of ``w``. """ - def __cinit__(self, double [::1] w, double [::1] aw): + + def __cinit__(self, + {{c_type}}[::1] w, + {{c_type}}[::1] aw): + if w.shape[0] > INT_MAX: raise ValueError("More than %d features not supported; got %d." % (INT_MAX, w.shape[0])) + self.w = w + self.w_data_ptr = &w[0] self.wscale = 1.0 self.n_features = w.shape[0] - self.sq_norm = _dot(w.shape[0], &w[0], 1, &w[0], 1) + self.sq_norm = _dot(self.n_features, self.w_data_ptr, 1, self.w_data_ptr, 1) - self.w_data_ptr = &w[0] - if aw is not None: + self.aw = aw + if self.aw is not None: self.aw_data_ptr = &aw[0] self.average_a = 0.0 self.average_b = 1.0 - cdef void add(self, double *x_data_ptr, int *x_ind_ptr, int xnnz, - double c) nogil: + cdef void add(self, {{c_type}} *x_data_ptr, int *x_ind_ptr, int xnnz, + {{c_type}} c) nogil: """Scales sample x by constant c and adds it to the weight vector. This operation updates ``sq_norm``. Parameters ---------- - x_data_ptr : double* + x_data_ptr : {{c_type}}* The array which holds the feature values of ``x``. x_ind_ptr : np.intc* The array which holds the feature indices of ``x``. xnnz : int The number of non-zero features of ``x``. - c : double + c : {{c_type}} The scaling constant for the example. """ cdef int j cdef int idx - cdef double val - cdef double innerprod = 0.0 - cdef double xsqnorm = 0.0 + cdef {{c_type}} val + cdef {{c_type}} innerprod = 0.0 + cdef {{c_type}} xsqnorm = 0.0 # the next two lines save a factor of 2! - cdef double wscale = self.wscale - cdef double* w_data_ptr = self.w_data_ptr + cdef {{c_type}} wscale = self.wscale + cdef {{c_type}}* w_data_ptr = self.w_data_ptr for j in range(xnnz): idx = x_ind_ptr[j] @@ -94,30 +123,30 @@ cdef class WeightVector(object): # Update the average weights according to the sparse trick defined # here: https://research.microsoft.com/pubs/192769/tricks-2012.pdf # by Leon Bottou - cdef void add_average(self, double *x_data_ptr, int *x_ind_ptr, int xnnz, - double c, double num_iter) nogil: + cdef void add_average(self, {{c_type}} *x_data_ptr, int *x_ind_ptr, int xnnz, + {{c_type}} c, {{c_type}} num_iter) nogil: """Updates the average weight vector. Parameters ---------- - x_data_ptr : double* + x_data_ptr : {{c_type}}* The array which holds the feature values of ``x``. x_ind_ptr : np.intc* The array which holds the feature indices of ``x``. xnnz : int The number of non-zero features of ``x``. - c : double + c : {{c_type}} The scaling constant for the example. - num_iter : double + num_iter : {{c_type}} The total number of iterations. """ cdef int j cdef int idx - cdef double val - cdef double mu = 1.0 / num_iter - cdef double average_a = self.average_a - cdef double wscale = self.wscale - cdef double* aw_data_ptr = self.aw_data_ptr + cdef {{c_type}} val + cdef {{c_type}} mu = 1.0 / num_iter + cdef {{c_type}} average_a = self.average_a + cdef {{c_type}} wscale = self.wscale + cdef {{c_type}}* aw_data_ptr = self.aw_data_ptr for j in range(xnnz): idx = x_ind_ptr[j] @@ -130,13 +159,13 @@ cdef class WeightVector(object): self.average_b /= (1.0 - mu) self.average_a += mu * self.average_b * wscale - cdef double dot(self, double *x_data_ptr, int *x_ind_ptr, + cdef {{c_type}} dot(self, {{c_type}} *x_data_ptr, int *x_ind_ptr, int xnnz) nogil: """Computes the dot product of a sample x and the weight vector. Parameters ---------- - x_data_ptr : double* + x_data_ptr : {{c_type}}* The array which holds the feature values of ``x``. x_ind_ptr : np.intc* The array which holds the feature indices of ``x``. @@ -145,27 +174,28 @@ cdef class WeightVector(object): Returns ------- - innerprod : double + innerprod : {{c_type}} The inner product of ``x`` and ``w``. """ cdef int j cdef int idx - cdef double innerprod = 0.0 - cdef double* w_data_ptr = self.w_data_ptr + cdef {{c_type}} innerprod = 0.0 + cdef {{c_type}}* w_data_ptr = self.w_data_ptr for j in range(xnnz): idx = x_ind_ptr[j] innerprod += w_data_ptr[idx] * x_data_ptr[j] innerprod *= self.wscale return innerprod - cdef void scale(self, double c) nogil: + cdef void scale(self, {{c_type}} c) nogil: """Scales the weight vector by a constant ``c``. It updates ``wscale`` and ``sq_norm``. If ``wscale`` gets too small we call ``reset_swcale``.""" self.wscale *= c self.sq_norm *= (c * c) - if self.wscale < 1e-9: + + if self.wscale < {{reset_wscale_threshold}}: self.reset_wscale() cdef void reset_wscale(self) nogil: @@ -180,6 +210,8 @@ cdef class WeightVector(object): _scal(self.n_features, self.wscale, self.w_data_ptr, 1) self.wscale = 1.0 - cdef double norm(self) nogil: + cdef {{c_type}} norm(self) nogil: """The L2 norm of the weight vector. """ return sqrt(self.sq_norm) + +{{endfor}} diff --git a/sklearn/utils/setup.py b/sklearn/utils/setup.py index 35674f8b44287..b06da9777be09 100644 --- a/sklearn/utils/setup.py +++ b/sklearn/utils/setup.py @@ -47,12 +47,15 @@ def configuration(parent_package="", top_path=None): "_openmp_helpers", sources=["_openmp_helpers.pyx"], libraries=libraries ) - # generate _seq_dataset from template + # generate files from a template templates = [ "sklearn/utils/_seq_dataset.pyx.tp", "sklearn/utils/_seq_dataset.pxd.tp", + "sklearn/utils/_weight_vector.pyx.tp", + "sklearn/utils/_weight_vector.pxd.tp", ] - gen_from_templates(templates, top_path) + + gen_from_templates(templates) config.add_extension( "_seq_dataset", sources=["_seq_dataset.pyx"], include_dirs=[numpy.get_include()] diff --git a/sklearn/utils/tests/test_cython_templating.py b/sklearn/utils/tests/test_cython_templating.py new file mode 100644 index 0000000000000..572d1db523cf8 --- /dev/null +++ b/sklearn/utils/tests/test_cython_templating.py @@ -0,0 +1,20 @@ +import pathlib +import pytest +import sklearn + + +def test_files_generated_by_templates_are_git_ignored(): + """Check the consistence of the files generated from template files.""" + gitignore_file = pathlib.Path(sklearn.__file__).parent.parent / ".gitignore" + if not gitignore_file.exists(): + pytest.skip("Tests are not run from the source folder") + + base_dir = pathlib.Path(sklearn.__file__).parent + ignored_files = open(gitignore_file, "r").readlines() + ignored_files = list(map(lambda line: line.strip("\n"), ignored_files)) + + for filename in base_dir.glob("**/*.tp"): + filename = filename.relative_to(base_dir.parent) + # From "path/to/template.p??.tp" to "path/to/template.p??" + filename_wo_tempita_suffix = filename.with_suffix("") + assert str(filename_wo_tempita_suffix) in ignored_files diff --git a/sklearn/utils/tests/test_weight_vector.py b/sklearn/utils/tests/test_weight_vector.py new file mode 100644 index 0000000000000..627d46d1fda06 --- /dev/null +++ b/sklearn/utils/tests/test_weight_vector.py @@ -0,0 +1,24 @@ +import numpy as np +import pytest +from sklearn.utils._weight_vector import ( + WeightVector32, + WeightVector64, +) + + +@pytest.mark.parametrize( + "dtype, WeightVector", + [ + (np.float32, WeightVector32), + (np.float64, WeightVector64), + ], +) +def test_type_invariance(dtype, WeightVector): + """Check the `dtype` consistency of `WeightVector`.""" + weights = np.random.rand(100).astype(dtype) + average_weights = np.random.rand(100).astype(dtype) + + weight_vector = WeightVector(weights, average_weights) + + assert np.asarray(weight_vector.w).dtype is np.dtype(dtype) + assert np.asarray(weight_vector.aw).dtype is np.dtype(dtype)