11function warn_double_spec (arg, model)
22 return " Using `model=$arg `. Ignoring keyword specification `model=$model `. "
33end
4-
4+
55const ERR_SPECIFY_MODEL = ArgumentError (
66 " You need to specify model as positional argument or specify `model=...`."
77)
@@ -36,66 +36,67 @@ for (ModelType, ModelSuperType) in MODELTYPE_GIVEN_SUPERTYPES
3636 eval (ex)
3737end
3838
39- eval (:(const RFE{M} = Union{$ ((Expr (:curly , modeltype, :M ) for modeltype in MODEL_TYPES). .. )}))
39+ eval (:(const RFE{M} =
40+ Union{$ ((Expr (:curly , modeltype, :M ) for modeltype in MODEL_TYPES). .. )}))
4041
4142# Common keyword constructor for both model types
4243"""
4344 RecursiveFeatureElimination(model, n_features, step)
4445
45- This model implements a recursive feature elimination algorithm for feature selection.
46- It recursively removes features, training a base model on the remaining features and
46+ This model implements a recursive feature elimination algorithm for feature selection.
47+ It recursively removes features, training a base model on the remaining features and
4748evaluating their importance until the desired number of features is selected.
4849
49- Construct an instance with default hyper-parameters using the syntax
50- `model = RecursiveFeatureElimination(model=...)`. Provide keyword arguments to override
51- hyper-parameter defaults.
52-
50+ Construct an instance with default hyper-parameters using the syntax
51+ `rfe_model = RecursiveFeatureElimination(model=...)`. Provide keyword arguments to override
52+ hyper-parameter defaults.
53+
5354# Training data
54- In MLJ or MLJBase, bind an instance `model ` to data with
55+ In MLJ or MLJBase, bind an instance `rfe_model ` to data with
5556
56- mach = machine(model , X, y)
57+ mach = machine(rfe_model , X, y)
5758
5859OR, if the base model supports weights, as
5960
60- mach = machine(model , X, y, w)
61+ mach = machine(rfe_model , X, y, w)
6162
6263Here:
6364
6465- `X` is any table of input features (eg, a `DataFrame`) whose columns are of the scitype
65- as that required by the base model; check column scitypes with `schema(X)` and column
66+ as that required by the base model; check column scitypes with `schema(X)` and column
6667 scitypes required by base model with `input_scitype(basemodel)`.
6768
68- - `y` is the target, which can be any table of responses whose element scitype is
69- `Continuous` or `Finite` depending on the `target_scitype` required by the base model;
69+ - `y` is the target, which can be any table of responses whose element scitype is
70+ `Continuous` or `Finite` depending on the `target_scitype` required by the base model;
7071 check the scitype with `scitype(y)`.
7172
72- - `w` is the observation weights which can either be `nothing`(default) or an
73- `AbstractVector` whoose element scitype is `Count` or `Continuous`. This is different
73+ - `w` is the observation weights which can either be `nothing`(default) or an
74+ `AbstractVector` whoose element scitype is `Count` or `Continuous`. This is different
7475 from `weights` kernel which is an hyperparameter to the model, see below.
7576
7677Train the machine using `fit!(mach, rows=...)`.
7778
7879# Hyper-parameters
79- - model: A base model with a `fit` method that provides information on feature
80+ - model: A base model with a `fit` method that provides information on feature
8081 feature importance (i.e `reports_feature_importances(model) == true`)
8182
82- - n_features::Real = 0: The number of features to select. If `0`, half of the
83- features are selected. If a positive integer, the parameter is the absolute number
84- of features to select. If a real number between 0 and 1, it is the fraction of features
83+ - n_features::Real = 0: The number of features to select. If `0`, half of the
84+ features are selected. If a positive integer, the parameter is the absolute number
85+ of features to select. If a real number between 0 and 1, it is the fraction of features
8586 to select.
8687
87- - step::Real=1: If the value of step is at least 1, it signifies the quantity of features to
88- eliminate in each iteration. Conversely, if step falls strictly within the range of
88+ - step::Real=1: If the value of step is at least 1, it signifies the quantity of features to
89+ eliminate in each iteration. Conversely, if step falls strictly within the range of
8990 0.0 to 1.0, it denotes the proportion (rounded down) of features to remove during each iteration.
9091
9192# Operations
9293
93- - `transform(mach, X)`: transform the input table `X` into a new table containing only
94+ - `transform(mach, X)`: transform the input table `X` into a new table containing only
9495columns corresponding to features gotten from the RFE algorithm.
9596
96- - `predict(mach, X)`: transform the input table `X` into a new table same as in
97+ - `predict(mach, X)`: transform the input table `X` into a new table same as in
9798
98- - `transform(mach, X)` above and predict using the fitted base model on the
99+ - `transform(mach, X)` above and predict using the fitted base model on the
99100 transformed table.
100101
101102# Fitted parameters
@@ -106,11 +107,11 @@ The fields of `fitted_params(mach)` are:
106107
107108# Report
108109The fields of `report(mach)` are:
109- - `ranking`: The feature ranking of each features in the training dataset.
110+ - `ranking`: The feature ranking of each features in the training dataset.
110111
111112- `model_report`: report for the fitted base model.
112113
113- - `features`: names of features seen during the training process.
114+ - `features`: names of features seen during the training process.
114115
115116# Examples
116117```
@@ -131,10 +132,10 @@ selector = RecursiveFeatureElimination(model = rf)
131132mach = machine(selector, X, y)
132133fit!(mach)
133134
134- # view the feature importances
135+ # view the feature importances
135136feature_importances(mach)
136137
137- # predict using the base model
138+ # predict using the base model
138139Xnew = MLJ.table(rand(rng, 50, 10));
139140predict(mach, Xnew)
140141
@@ -160,7 +161,7 @@ function RecursiveFeatureElimination(
160161 # TODO : Check that the specifed model implements the predict method.
161162 # probably add a trait to check this
162163 MMI. reports_feature_importances (model) || throw (ERR_FEATURE_IMPORTANCE_SUPPORT)
163- if model isa Deterministic
164+ if model isa Deterministic
164165 selector = DeterministicRecursiveFeatureElimination {typeof(model)} (
165166 model, Float64 (n_features), Float64 (step)
166167 )
@@ -170,7 +171,7 @@ function RecursiveFeatureElimination(
170171 )
171172 else
172173 throw (ERR_MODEL_TYPE)
173- end
174+ end
174175 message = MMI. clean! (selector)
175176 isempty (message) || @warn (message)
176177 return selector
@@ -204,21 +205,21 @@ function MMI.fit(selector::RFE, verbosity::Int, X, y, args...)
204205 n_features_select = selector. n_features
205206 # # zero indicates that half of the features be selected.
206207 if n_features_select == 0
207- n_features_select = div (nfeatures, 2 )
208+ n_features_select = div (nfeatures, 2 )
208209 elseif 0 < n_features_select < 1
209210 n_features_select = round (Int, n_features_select * nfeatures)
210211 else
211212 n_features_select = round (Int, n_features_select)
212213 end
213214
214215 step = selector. step
215-
216+
216217 if 0 < step < 1
217218 step = round (Int, max (1 , step * n_features_select))
218219 else
219- step = round (Int, step)
220+ step = round (Int, step)
220221 end
221-
222+
222223 support = trues (nfeatures)
223224 ranking = ones (Int, nfeatures) # every feature has equal rank initially
224225 mask = trues (nfeatures) # for boolean indexing of ranking vector in while loop below
@@ -230,7 +231,7 @@ function MMI.fit(selector::RFE, verbosity::Int, X, y, args...)
230231 # Rank the remaining features
231232 model = selector. model
232233 verbosity > 0 && @info (" Fitting estimator with $(n_features_left) features." )
233-
234+
234235 data = MMI. reformat (model, MMI. selectcols (X, features_left), args... )
235236
236237 fitresult, _, report = MMI. fit (model, verbosity - 1 , data... )
@@ -263,14 +264,14 @@ function MMI.fit(selector::RFE, verbosity::Int, X, y, args...)
263264 data = MMI. reformat (selector. model, MMI. selectcols (X, features_left), args... )
264265 verbosity > 0 && @info (" Fitting estimator with $(n_features_left) features." )
265266 model_fitresult, _, model_report = MMI. fit (selector. model, verbosity - 1 , data... )
266-
267+
267268 fitresult = (
268269 support = support,
269270 model_fitresult = model_fitresult,
270271 features_left = features_left,
271272 features = features
272273 )
273- report = (
274+ report = (
274275 ranking = ranking,
275276 model_report = model_report
276277 )
294295
295296function MMI. transform (:: RFE , fitresult, X)
296297 sch = Tables. schema (Tables. columns (X))
297- if (length (fitresult. features) == length (sch. names) &&
298+ if (length (fitresult. features) == length (sch. names) &&
298299 ! all (e -> e in sch. names, fitresult. features))
299300 throw (
300301 ERR_FEATURES_SEEN
@@ -312,7 +313,7 @@ function MMI.save(model::RFE, fitresult)
312313 atomic_fitresult = fitresult. model_fitresult
313314 features_left = fitresult. features_left
314315 features = fitresult. features
315-
316+
316317 atom = model. model
317318 return (
318319 support = copy (support),
@@ -337,14 +338,12 @@ function MMI.restore(model::RFE, serializable_fitresult)
337338 )
338339end
339340
340- # # Traits definitions
341- function MMI. load_path (:: Type{<:DeterministicRecursiveFeatureElimination} )
342- return " FeatureSelection.DeterministicRecursiveFeatureElimination"
343- end
341+ # # Trait definitions
344342
345- function MMI. load_path (:: Type{<:ProbabilisticRecursiveFeatureElimination} )
346- return " FeatureSelection.ProbabilisticRecursiveFeatureElimination"
347- end
343+ # load path points to constructor not type:
344+ MMI. load_path (:: Type{<:RFE} ) = " FeatureSelection.RecursiveFeatureElimination"
345+ MMI. constructor (:: Type{<:RFE} ) = RecursiveFeatureElimination
346+ MMI. package_name (:: Type{<:RFE} ) = " FeatureSelection"
348347
349348for trait in [
350349 :supports_weights ,
387386# # TRAINING LOSSES SUPPORT
388387function MMI. training_losses (model:: RFE , rfe_report)
389388 return MMI. training_losses (model. model, rfe_report. model_report)
390- end
389+ end
390+
391+ # # Pkg Traits
392+ MMI. metadata_pkg .(
393+ (
394+ DeterministicRecursiveFeatureElimination,
395+ ProbabilisticRecursiveFeatureElimination,
396+ ),
397+ package_name = " FeatureSelection" ,
398+ package_uuid = " 33837fe5-dbff-4c9e-8c2f-c5612fe2b8b6" ,
399+ package_url = " https://github.com/JuliaAI/FeatureSelection.jl" ,
400+ is_pure_julia = true ,
401+ package_license = " MIT"
402+ )
0 commit comments