62
62
//! let svc = SVC::fit(&x, &y, parameters, None).unwrap();
63
63
//!
64
64
//! let y_hat = svc.predict(&x).unwrap();
65
+ //!
65
66
//! ```
66
67
//!
67
68
//! ## References:
@@ -92,20 +93,43 @@ use crate::svm::Kernel;
92
93
93
94
#[ cfg_attr( feature = "serde" , derive( Serialize , Deserialize ) ) ]
94
95
#[ derive( Debug ) ]
96
+ /// Configuration for a multi-class Support Vector Machine (SVM) classifier.
97
+ ///
98
+ /// This struct holds the indices of the data points relevant to a specific binary
99
+ /// classification problem within a multi-class context, and the two classes
100
+ /// being discriminated.
95
101
pub struct MultiClassConfig < TY : Number + Ord > {
102
+ /// The indices of the data points from the original dataset that belong to the two `classes`.
96
103
indices : Vec < usize > ,
104
+ /// A tuple representing the two classes that this configuration is designed to distinguish.
97
105
classes : ( TY , TY ) ,
98
106
}
99
107
100
108
impl < ' a , TX : Number + RealNumber , TY : Number + Ord , X : Array2 < TX > , Y : Array1 < TY > >
101
109
SupervisedEstimatorBorrow < ' a , X , Y , SVCParameters < TX , TY , X , Y > >
102
110
for MultiClassSVC < ' a , TX , TY , X , Y >
103
111
{
112
+ /// Creates a new, empty `MultiClassSVC` instance.
113
+ ///
114
+ /// The `classifiers` field is initialized to `Option::None`, indicating that
115
+ /// the model has not yet been fitted.
104
116
fn new ( ) -> Self {
105
117
Self {
106
118
classifiers : Option :: None ,
107
119
}
108
120
}
121
+
122
+ /// Fits the `MultiClassSVC` model to the provided data and parameters.
123
+ ///
124
+ /// This method delegates the fitting process to the inherent `MultiClassSVC::fit` method.
125
+ ///
126
+ /// # Arguments
127
+ /// * `x` - A reference to the input features (2D array).
128
+ /// * `y` - A reference to the target labels (1D array).
129
+ /// * `parameters` - A reference to the `SVCParameters` controlling the SVM training.
130
+ ///
131
+ /// # Returns
132
+ /// A `Result` indicating success (`Self`) or failure (`Failed`).
109
133
fn fit (
110
134
x : & ' a X ,
111
135
y : & ' a Y ,
@@ -118,50 +142,95 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
118
142
impl < ' a , TX : Number + RealNumber , TY : Number + Ord , X : Array2 < TX > , Y : Array1 < TY > >
119
143
PredictorBorrow < ' a , X , TX > for MultiClassSVC < ' a , TX , TY , X , Y >
120
144
{
145
+ /// Predicts the class labels for new data points.
146
+ ///
147
+ /// This method delegates the prediction process to the inherent `MultiClassSVC::predict` method.
148
+ /// It unwraps the inner `Result` from `MultiClassSVC::predict`, assuming that
149
+ /// the prediction will always succeed after a successful fit.
150
+ ///
151
+ /// # Arguments
152
+ /// * `x` - A reference to the input features (2D array) for which to make predictions.
153
+ ///
154
+ /// # Returns
155
+ /// A `Result` containing a `Vec` of predicted class labels (`TX`) or a `Failed` error.
121
156
fn predict ( & self , x : & ' a X ) -> Result < Vec < TX > , Failed > {
122
157
Ok ( self . predict ( x) . unwrap ( ) )
123
158
}
124
159
}
125
160
161
+ /// A multi-class Support Vector Machine (SVM) classifier.
162
+ ///
163
+ /// This struct implements a multi-class SVM using the "one-vs-one" strategy,
164
+ /// where a separate binary SVC classifier is trained for every pair of classes.
165
+ ///
166
+ /// # Type Parameters
167
+ /// * `'a` - Lifetime parameter for borrowed data.
168
+ /// * `TX` - The numeric type of the input features (must implement `Number` and `RealNumber`).
169
+ /// * `TY` - The numeric type of the target labels (must implement `Number` and `Ord`).
170
+ /// * `X` - The type representing the 2D array of input features (e.g., a matrix).
171
+ /// * `Y` - The type representing the 1D array of target labels (e.g., a vector).
126
172
pub struct MultiClassSVC <
127
173
' a ,
128
174
TX : Number + RealNumber ,
129
175
TY : Number + Ord ,
130
176
X : Array2 < TX > ,
131
177
Y : Array1 < TY > ,
132
178
> {
179
+ /// An optional vector of binary `SVC` classifiers.
180
+ ///
181
+ /// This will be `Some` after the model has been fitted, containing one `SVC`
182
+ /// for each pair of unique classes.
133
183
classifiers : Option < Vec < SVC < ' a , TX , TY , X , Y > > > ,
134
184
}
135
185
136
186
impl < ' a , TX : Number + RealNumber , TY : Number + Ord , X : Array2 < TX > , Y : Array1 < TY > >
137
187
MultiClassSVC < ' a , TX , TY , X , Y >
138
188
{
189
+ /// Fits the `MultiClassSVC` model to the provided data using a one-vs-one strategy.
190
+ ///
191
+ /// This method identifies all unique classes in the target labels `y` and then
192
+ /// trains a binary `SVC` for every unique pair of classes. For each pair, it
193
+ /// extracts the relevant data points and their labels, and then trains a
194
+ /// specialized `SVC` for that binary classification task.
195
+ ///
196
+ /// # Arguments
197
+ /// * `x` - A reference to the input features (2D array).
198
+ /// * `y` - A reference to the target labels (1D array).
199
+ /// * `parameters` - A reference to the `SVCParameters` controlling the SVM training for each individual binary classifier.
200
+ ///
201
+ ///
202
+ /// # Returns
203
+ /// A `Result` indicating success (`MultiClassSVC`) or failure (`Failed`).
139
204
pub fn fit (
140
205
x : & ' a X ,
141
206
y : & ' a Y ,
142
207
parameters : & ' a SVCParameters < TX , TY , X , Y > ,
143
208
) -> Result < MultiClassSVC < ' a , TX , TY , X , Y > , Failed > {
144
209
let unique_classes = y. unique ( ) ;
145
210
let mut classifiers = Vec :: new ( ) ;
211
+ // Iterate through all unique pairs of classes (one-vs-one strategy)
146
212
for i in 0 ..unique_classes. len ( ) {
147
213
for j in i..unique_classes. len ( ) {
148
214
if i == j {
149
- continue ;
215
+ continue ; // Skip comparing a class to itself
150
216
}
151
217
let class0 = unique_classes[ j] ;
152
218
let class1 = unique_classes[ i] ;
219
+
153
220
let mut indices = Vec :: new ( ) ;
221
+ // Collect indices of data points belonging to the current pair of classes
154
222
for ( index, v) in y. iterator ( 0 ) . enumerate ( ) {
155
223
if * v == class0 || * v == class1 {
156
224
indices. push ( index)
157
225
}
158
226
}
159
227
let classes = ( class0, class1) ;
160
228
let multiclass_config = MultiClassConfig {
161
- classes : classes . clone ( ) ,
229
+ classes,
162
230
indices,
163
231
} ;
164
- let svc = SVC :: fit ( x, y, parameters, Some ( multiclass_config) ) . unwrap ( ) ;
232
+ // Fit a binary SVC for the current pair of classes
233
+ let svc = SVC :: fit ( x, y, parameters, Some ( multiclass_config) ) . unwrap ( ) ; // .unwrap() might panic if SVC::fit fails
165
234
classifiers. push ( svc) ;
166
235
}
167
236
}
@@ -170,25 +239,50 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
170
239
} )
171
240
}
172
241
242
+ /// Predicts the class labels for new data points using the trained multi-class SVM.
243
+ ///
244
+ /// This method uses a "voting" scheme (majority vote) among all the binary
245
+ /// classifiers to determine the final prediction for each data point.
246
+ ///
247
+ /// # Arguments
248
+ /// * `x` - A reference to the input features (2D array) for which to make predictions.
249
+ ///
250
+ /// # Returns
251
+ /// A `Result` containing a `Vec` of predicted class labels (`TX`) or a `Failed` error.
252
+ ///
253
+ /// # Panics
254
+ /// Panics if the model has not been fitted (`self.classifiers` is `None`).
173
255
pub fn predict ( & self , x : & X ) -> Result < Vec < TX > , Failed > {
256
+ // Initialize a HashMap for each data point to store votes for each class
174
257
let mut polls = vec ! [ HashMap :: new( ) ; x. shape( ) . 0 ] ;
258
+ // Retrieve the trained binary classifiers; panics if not fitted
175
259
let classifiers = self . classifiers . as_ref ( ) . unwrap ( ) ;
260
+
261
+ // Iterate through each binary classifier
176
262
for i in 0 ..classifiers. len ( ) {
177
- let svc = classifiers. get ( i) . unwrap ( ) ;
178
- let predictions = svc. predict ( x) . unwrap ( ) ;
263
+ let svc = classifiers. get ( i) . unwrap ( ) ; // .unwrap() might panic if index is out of bounds
264
+ let predictions = svc. predict ( x) . unwrap ( ) ; // .unwrap() might panic if SVC::predict fails
265
+
266
+ // For each prediction from the current binary classifier
179
267
for ( j, prediction) in predictions. iter ( ) . enumerate ( ) {
180
- let prediction = prediction. to_i32 ( ) . unwrap ( ) ;
181
- let poll = polls. get_mut ( j) . unwrap ( ) ;
268
+ let prediction = prediction. to_i32 ( ) . unwrap ( ) ; // Convert prediction to i32 for HashMap key
269
+ let poll = polls. get_mut ( j) . unwrap ( ) ; // Get the poll for the current data point
270
+ // Increment the vote for the predicted class
182
271
if let Some ( count) = poll. get_mut ( & prediction) {
183
272
* count += 1
184
273
} else {
185
274
poll. insert ( prediction, 1 ) ;
186
275
}
187
276
}
188
277
}
278
+
279
+ // Determine the final prediction for each data point based on majority vote
189
280
Ok ( polls
190
281
. iter ( )
191
- . map ( |v| TX :: from ( * v. iter ( ) . max_by_key ( |( _, class) | * class) . unwrap ( ) . 0 ) . unwrap ( ) )
282
+ . map ( |v| {
283
+ // Find the class with the maximum votes for each data point
284
+ TX :: from ( * v. iter ( ) . max_by_key ( |( _, class) | * class) . unwrap ( ) . 0 ) . unwrap ( )
285
+ } )
192
286
. collect ( ) )
193
287
}
194
288
}
0 commit comments