@@ -41,17 +41,14 @@ eval(:(const RFE{M} =
4141
4242# Common keyword constructor for both model types
4343"""
44- RecursiveFeatureElimination(model, n_features, step)
44+ RecursiveFeatureElimination(model; n_features=0 , step=1 )
4545
4646This model implements a recursive feature elimination algorithm for feature selection.
4747It recursively removes features, training a base model on the remaining features and
4848evaluating their importance until the desired number of features is selected.
49-
50- Construct an instance with default hyper-parameters using the syntax
51- `rfe_model = RecursiveFeatureElimination(model=...)`. Provide keyword arguments to override
52- hyper-parameter defaults.
5349
5450# Training data
51+
5552In MLJ or MLJBase, bind an instance `rfe_model` to data with
5653
5754 mach = machine(rfe_model, X, y)
@@ -92,53 +89,62 @@ Train the machine using `fit!(mach, rows=...)`.
9289# Operations
9390
9491- `transform(mach, X)`: transform the input table `X` into a new table containing only
95- columns corresponding to features gotten from the RFE algorithm.
92+ columns corresponding to features accepted by the RFE algorithm.
9693
9794- `predict(mach, X)`: transform the input table `X` into a new table same as in
98-
99- - `transform(mach, X)` above and predict using the fitted base model on the
100- transformed table.
95+ `transform(mach, X)` above and predict using the fitted base model on the transformed
96+ table.
10197
10298# Fitted parameters
99+
103100The fields of `fitted_params(mach)` are:
101+
104102- `features_left`: names of features remaining after recursive feature elimination.
105103
106104- `model_fitresult`: fitted parameters of the base model.
107105
108106# Report
107+
109108The fields of `report(mach)` are:
109+
110110- `scores`: dictionary of scores for each feature in the training dataset.
111- The model deems highly scored variables more significant.
111+ The model deems highly scored variables more significant.
112112
113113- `model_report`: report for the fitted base model.
114114
115115
116116# Examples
117+
118+ The following example assumes you have MLJDecisionTreeInterface in the active package
119+ ennvironment.
120+
117121```
118- using FeatureSelection, MLJ, StableRNGs
122+ using MLJ
119123
120124RandomForestRegressor = @load RandomForestRegressor pkg=DecisionTree
121125
122126# Creates a dataset where the target only depends on the first 5 columns of the input table.
123- A = rand(rng, 50, 10);
127+ A = rand(50, 10);
124128y = 10 .* sin.(
125129 pi .* A[:, 1] .* A[:, 2]
126- ) + 20 .* (A[:, 3] .- 0.5).^ 2 .+ 10 .* A[:, 4] .+ 5 * A[:, 5]) ;
130+ ) + 20 .* (A[:, 3] .- 0.5).^ 2 .+ 10 .* A[:, 4] .+ 5 * A[:, 5];
127131X = MLJ.table(A);
128132
129- # fit a rfe model
133+ # fit a rfe model:
130134rf = RandomForestRegressor()
131- selector = RecursiveFeatureElimination(model = rf )
135+ selector = RecursiveFeatureElimination(rf, n_features=2 )
132136mach = machine(selector, X, y)
133137fit!(mach)
134138
135139# view the feature importances
136140feature_importances(mach)
137141
138- # predict using the base model
139- Xnew = MLJ.table(rand(rng, 50, 10));
142+ # predict using the base model trained on the reduced feature set:
143+ Xnew = MLJ.table(rand(50, 10));
140144predict(mach, Xnew)
141145
146+ # transform data with all features to the reduced feature set:
147+ transform(mach, Xnew)
142148```
143149"""
144150function RecursiveFeatureElimination (
@@ -173,7 +179,7 @@ function RecursiveFeatureElimination(
173179 # This branch is hit just incase there are any models that supports_class_weights
174180 # feature importance that aren't `<:Probabilistic` or `<:Deterministic`
175181 # which is rare.
176- throw (ERR_MODEL_TYPE)
182+ throw (ERR_MODEL_TYPE)
177183 end
178184 message = MMI. clean! (selector)
179185 isempty (message) || @warn (message)
@@ -214,22 +220,30 @@ abs_last(x::Pair{<:Any, <:Real}) = abs(last(x))
214220"""
215221 score_features!(scores_dict, features, importances, n_features_to_score)
216222
217- Internal method that updates the `scores_dict` by increasing the score for each feature based on their
223+ **Private method.**
224+
225+ Update the `scores_dict` by increasing the score for each feature based on their
218226importance and store the features in the `features` array.
219227
220228# Arguments
221- - `scores_dict::Dict{Symbol, Int}`: A dictionary where the keys are features and
229+
230+ - `scores_dict::Dict{Symbol, Int}`: A dictionary where the keys are features and
222231 the values are their corresponding scores.
232+
223233- `features::Vector{Symbol}`: An array to store the top features based on importance.
224- - `importances::Vector{Pair(Symbol, <:Real)}}`: An array of tuples where each tuple
225- contains a feature and its importance score.
234+
235+ - `importances::Vector{Pair(Symbol, <:Real)}}`: An array of tuples where each tuple
236+ contains a feature and its importance score.
237+
226238- `n_features_to_score::Int`: The number of top features to score and store.
227239
228240# Notes
229- Ensure that `n_features_to_score` is less than or equal to the minimum of the
241+
242+ Ensure that `n_features_to_score` is less than or equal to the minimum of the
230243lengths of `features` and `importances`.
231244
232245# Example
246+
233247```julia
234248scores_dict = Dict(:feature1 => 0, :feature2 => 0, :feature3 => 0)
235249features = [:x1, :x1, :x1]
@@ -244,7 +258,7 @@ features == [:feature1, :feature2, :x1]
244258function score_features! (scores_dict, features, importances, n_features_to_score)
245259 for i in Base. OneTo (n_features_to_score)
246260 ftr = first (importances[i])
247- features[i] = ftr
261+ features[i] = ftr
248262 scores_dict[ftr] += 1
249263 end
250264end
@@ -273,7 +287,7 @@ function MMI.fit(selector::RFE, verbosity::Int, X, y, args...)
273287 " n_features > number of features in training data, " *
274288 " hence no feature will be eliminated."
275289 )
276- end
290+ end
277291 end
278292
279293 _step = selector. step
@@ -296,17 +310,17 @@ function MMI.fit(selector::RFE, verbosity::Int, X, y, args...)
296310 verbosity > 0 && @info (" Fitting estimator with $(n_features_to_keep) features." )
297311 data = MMI. reformat (model, MMI. selectcols (X, features_left), args... )
298312 fitresult, _, report = MMI. fit (model, verbosity - 1 , data... )
299- # Note that the MLJ feature importance API does not impose any restrictions on the
300- # ordering of `feature => score` pairs in the `importances` vector.
313+ # Note that the MLJ feature importance API does not impose any restrictions on the
314+ # ordering of `feature => score` pairs in the `importances` vector.
301315 # Therefore, the order of `feature => score` pairs in the `importances` vector
302- # might differ from the order of features in the `features` vector, which is
316+ # might differ from the order of features in the `features` vector, which is
303317 # extracted from the feature matrix `X` above. Hence the need for a dictionary
304318 # implementation.
305319 importances = MMI. feature_importances (
306320 selector. model,
307321 fitresult,
308322 report
309- )
323+ )
310324
311325 # Eliminate the worse features and increase score of remaining features
312326 sort! (importances, by= abs_last, rev = true )
396410MMI. load_path (:: Type{<:RFE} ) = " FeatureSelection.RecursiveFeatureElimination"
397411MMI. constructor (:: Type{<:RFE} ) = RecursiveFeatureElimination
398412MMI. package_name (:: Type{<:RFE} ) = " FeatureSelection"
413+ MMI. is_wrapper (:: Type{<:RFE} ) = true
399414
400415for trait in [
401416 :supports_weights ,
0 commit comments