Skip to content

Commit 05c4fd7

Browse files
committed
fixed lints
1 parent daba88d commit 05c4fd7

File tree

1 file changed

+102
-8
lines changed

1 file changed

+102
-8
lines changed

src/svm/svc.rs

Lines changed: 102 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
//! let svc = SVC::fit(&x, &y, parameters, None).unwrap();
6363
//!
6464
//! let y_hat = svc.predict(&x).unwrap();
65+
//!
6566
//! ```
6667
//!
6768
//! ## References:
@@ -92,20 +93,43 @@ use crate::svm::Kernel;
9293

9394
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9495
#[derive(Debug)]
96+
/// Configuration for a multi-class Support Vector Machine (SVM) classifier.
97+
///
98+
/// This struct holds the indices of the data points relevant to a specific binary
99+
/// classification problem within a multi-class context, and the two classes
100+
/// being discriminated.
95101
pub struct MultiClassConfig<TY: Number + Ord> {
102+
/// The indices of the data points from the original dataset that belong to the two `classes`.
96103
indices: Vec<usize>,
104+
/// A tuple representing the two classes that this configuration is designed to distinguish.
97105
classes: (TY, TY),
98106
}
99107

100108
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
101109
SupervisedEstimatorBorrow<'a, X, Y, SVCParameters<TX, TY, X, Y>>
102110
for MultiClassSVC<'a, TX, TY, X, Y>
103111
{
112+
/// Creates a new, empty `MultiClassSVC` instance.
113+
///
114+
/// The `classifiers` field is initialized to `Option::None`, indicating that
115+
/// the model has not yet been fitted.
104116
fn new() -> Self {
105117
Self {
106118
classifiers: Option::None,
107119
}
108120
}
121+
122+
/// Fits the `MultiClassSVC` model to the provided data and parameters.
123+
///
124+
/// This method delegates the fitting process to the inherent `MultiClassSVC::fit` method.
125+
///
126+
/// # Arguments
127+
/// * `x` - A reference to the input features (2D array).
128+
/// * `y` - A reference to the target labels (1D array).
129+
/// * `parameters` - A reference to the `SVCParameters` controlling the SVM training.
130+
///
131+
/// # Returns
132+
/// A `Result` indicating success (`Self`) or failure (`Failed`).
109133
fn fit(
110134
x: &'a X,
111135
y: &'a Y,
@@ -118,50 +142,95 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
118142
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
119143
PredictorBorrow<'a, X, TX> for MultiClassSVC<'a, TX, TY, X, Y>
120144
{
145+
/// Predicts the class labels for new data points.
146+
///
147+
/// This method delegates the prediction process to the inherent `MultiClassSVC::predict` method.
148+
/// It unwraps the inner `Result` from `MultiClassSVC::predict`, assuming that
149+
/// the prediction will always succeed after a successful fit.
150+
///
151+
/// # Arguments
152+
/// * `x` - A reference to the input features (2D array) for which to make predictions.
153+
///
154+
/// # Returns
155+
/// A `Result` containing a `Vec` of predicted class labels (`TX`) or a `Failed` error.
121156
fn predict(&self, x: &'a X) -> Result<Vec<TX>, Failed> {
122157
Ok(self.predict(x).unwrap())
123158
}
124159
}
125160

161+
/// A multi-class Support Vector Machine (SVM) classifier.
162+
///
163+
/// This struct implements a multi-class SVM using the "one-vs-one" strategy,
164+
/// where a separate binary SVC classifier is trained for every pair of classes.
165+
///
166+
/// # Type Parameters
167+
/// * `'a` - Lifetime parameter for borrowed data.
168+
/// * `TX` - The numeric type of the input features (must implement `Number` and `RealNumber`).
169+
/// * `TY` - The numeric type of the target labels (must implement `Number` and `Ord`).
170+
/// * `X` - The type representing the 2D array of input features (e.g., a matrix).
171+
/// * `Y` - The type representing the 1D array of target labels (e.g., a vector).
126172
pub struct MultiClassSVC<
127173
'a,
128174
TX: Number + RealNumber,
129175
TY: Number + Ord,
130176
X: Array2<TX>,
131177
Y: Array1<TY>,
132178
> {
179+
/// An optional vector of binary `SVC` classifiers.
180+
///
181+
/// This will be `Some` after the model has been fitted, containing one `SVC`
182+
/// for each pair of unique classes.
133183
classifiers: Option<Vec<SVC<'a, TX, TY, X, Y>>>,
134184
}
135185

136186
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
137187
MultiClassSVC<'a, TX, TY, X, Y>
138188
{
189+
/// Fits the `MultiClassSVC` model to the provided data using a one-vs-one strategy.
190+
///
191+
/// This method identifies all unique classes in the target labels `y` and then
192+
/// trains a binary `SVC` for every unique pair of classes. For each pair, it
193+
/// extracts the relevant data points and their labels, and then trains a
194+
/// specialized `SVC` for that binary classification task.
195+
///
196+
/// # Arguments
197+
/// * `x` - A reference to the input features (2D array).
198+
/// * `y` - A reference to the target labels (1D array).
199+
/// * `parameters` - A reference to the `SVCParameters` controlling the SVM training for each individual binary classifier.
200+
///
201+
///
202+
/// # Returns
203+
/// A `Result` indicating success (`MultiClassSVC`) or failure (`Failed`).
139204
pub fn fit(
140205
x: &'a X,
141206
y: &'a Y,
142207
parameters: &'a SVCParameters<TX, TY, X, Y>,
143208
) -> Result<MultiClassSVC<'a, TX, TY, X, Y>, Failed> {
144209
let unique_classes = y.unique();
145210
let mut classifiers = Vec::new();
211+
// Iterate through all unique pairs of classes (one-vs-one strategy)
146212
for i in 0..unique_classes.len() {
147213
for j in i..unique_classes.len() {
148214
if i == j {
149-
continue;
215+
continue; // Skip comparing a class to itself
150216
}
151217
let class0 = unique_classes[j];
152218
let class1 = unique_classes[i];
219+
153220
let mut indices = Vec::new();
221+
// Collect indices of data points belonging to the current pair of classes
154222
for (index, v) in y.iterator(0).enumerate() {
155223
if *v == class0 || *v == class1 {
156224
indices.push(index)
157225
}
158226
}
159227
let classes = (class0, class1);
160228
let multiclass_config = MultiClassConfig {
161-
classes: classes.clone(),
229+
classes,
162230
indices,
163231
};
164-
let svc = SVC::fit(x, y, parameters, Some(multiclass_config)).unwrap();
232+
// Fit a binary SVC for the current pair of classes
233+
let svc = SVC::fit(x, y, parameters, Some(multiclass_config)).unwrap(); // .unwrap() might panic if SVC::fit fails
165234
classifiers.push(svc);
166235
}
167236
}
@@ -170,25 +239,50 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
170239
})
171240
}
172241

242+
/// Predicts the class labels for new data points using the trained multi-class SVM.
243+
///
244+
/// This method uses a "voting" scheme (majority vote) among all the binary
245+
/// classifiers to determine the final prediction for each data point.
246+
///
247+
/// # Arguments
248+
/// * `x` - A reference to the input features (2D array) for which to make predictions.
249+
///
250+
/// # Returns
251+
/// A `Result` containing a `Vec` of predicted class labels (`TX`) or a `Failed` error.
252+
///
253+
/// # Panics
254+
/// Panics if the model has not been fitted (`self.classifiers` is `None`).
173255
pub fn predict(&self, x: &X) -> Result<Vec<TX>, Failed> {
256+
// Initialize a HashMap for each data point to store votes for each class
174257
let mut polls = vec![HashMap::new(); x.shape().0];
258+
// Retrieve the trained binary classifiers; panics if not fitted
175259
let classifiers = self.classifiers.as_ref().unwrap();
260+
261+
// Iterate through each binary classifier
176262
for i in 0..classifiers.len() {
177-
let svc = classifiers.get(i).unwrap();
178-
let predictions = svc.predict(x).unwrap();
263+
let svc = classifiers.get(i).unwrap(); // .unwrap() might panic if index is out of bounds
264+
let predictions = svc.predict(x).unwrap(); // .unwrap() might panic if SVC::predict fails
265+
266+
// For each prediction from the current binary classifier
179267
for (j, prediction) in predictions.iter().enumerate() {
180-
let prediction = prediction.to_i32().unwrap();
181-
let poll = polls.get_mut(j).unwrap();
268+
let prediction = prediction.to_i32().unwrap(); // Convert prediction to i32 for HashMap key
269+
let poll = polls.get_mut(j).unwrap(); // Get the poll for the current data point
270+
// Increment the vote for the predicted class
182271
if let Some(count) = poll.get_mut(&prediction) {
183272
*count += 1
184273
} else {
185274
poll.insert(prediction, 1);
186275
}
187276
}
188277
}
278+
279+
// Determine the final prediction for each data point based on majority vote
189280
Ok(polls
190281
.iter()
191-
.map(|v| TX::from(*v.iter().max_by_key(|(_, class)| *class).unwrap().0).unwrap())
282+
.map(|v| {
283+
// Find the class with the maximum votes for each data point
284+
TX::from(*v.iter().max_by_key(|(_, class)| *class).unwrap().0).unwrap()
285+
})
192286
.collect())
193287
}
194288
}

0 commit comments

Comments
 (0)