Skip to content

SVC multiclass #306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: development
Choose a base branch
from

Conversation

DanielLacina
Copy link

Implemented the multiclass feature for SVC using a one for all approach.

@DanielLacina DanielLacina requested a review from Mec-iS as a code owner June 3, 2025 21:19
@DanielLacina
Copy link
Author

Had to close the previous pr because I sent it on the wrong branch.

@Mec-iS
Copy link
Collaborator

Mec-iS commented Jun 4, 2025

#305 (comment)

@DanielLacina
Copy link
Author

DanielLacina commented Jun 4, 2025

The issue with Array2 is that the type within the Array is inferred and static. If a Vec<u 32> is passed in, then the application won't be able to transform the labels to: {1, -1} in order for them to be used for binary classification. Originally, you had an assert statement to validate the data passed in had labels: -1 and 1, but now the SVC, with it being multiclass, can accept a wide variety of labels.

@Mec-iS
Copy link
Collaborator

Mec-iS commented Jun 9, 2025

The issue with Array2 is that the type within the Array is inferred and static. If a Vec<u 32> is passed in, then the application won't be able to transform the labels to: {1, -1} in order for them to be used for binary classification.

Array2 is an abstraction for a 2D vector, it can be used for any instance supported by Vec, it just need to be implemented.

Test Driven Development (TDD) should be followed. Every time you implement something new you have to be sure to support existing behaviour. If there is no test for the operations you are changing, you should add it. Also please add one or more tests when you implement something new.

When I try to run the tests in the module I get:

running 3 tests
test svm::svc::tests::svc_fit_predict ... ok
test svm::svc::tests::svc_fit_predict_rbf ... ok
test svm::svc::tests::svc_fit_decision_function ... FAILED

successes:

successes:
    svm::svc::tests::svc_fit_predict
    svm::svc::tests::svc_fit_predict_rbf

failures:

---- svm::svc::tests::svc_fit_decision_function stdout ----
thread 'svm::svc::tests::svc_fit_decision_function' panicked at src/svm/svc.rs:1033:9:
assertion failed: y_hat[1] < y_hat[2]

This is a reference implementation as generated by my LLM, you can start from this as it is not fully implemented (ie. it is an example and needs to be implemented using generic types like TX and TY). Please note that you should not modify the existing struct but instead create new structs to handle the multiclass possibility. This implementation suggests to use a 1D Vec but you should check if it correct (if it is right there is no need to use a 2D Vec):


To implement a multiclass Support Vector Classification (SVC) in Rust using smartcore, we can adopt the one-vs-one (OvO) strategy, which trains binary classifiers for each pair of classes. Here's a complete implementation:

use smartcore::svm::svc::{SVC, SVCParameters};
use smartcore::linalg::{BaseVector, Matrix, MatrixTrait};
use smartcore::metrics::accuracy;
use smartcore::dataset::iris::load_dataset;

// Multiclass SVC using One-vs-One strategy
struct MulticlassSVC {
    classifiers: Vec<SVC<f64, DenseMatrix<f64>, Vec<f64>>>,
    classes: Vec<u32>,
}

impl MulticlassSVC {
    pub fn fit(
        x: &DenseMatrix<f64>,
        y: &Vec<u32>,
        params: &SVCParameters<f64>,
    ) -> Result<Self, Failed> {
        let classes = y.iter().unique().sorted().collect::<Vec<_>>();
        let mut classifiers = Vec::new();

        // Generate all class pairs
        for (i, &class1) in classes.iter().enumerate() {
            for &class2 in classes.iter().skip(i + 1) {
                // Filter samples for current class pair
                let indices: Vec<usize> = y.iter()
                    .enumerate()
                    .filter(|(_, &c)| c == class1 || c == class2)
                    .map(|(i, _)| i)
                    .collect();

                let x_filtered = x.select_rows(&indices);
                let y_filtered: Vec<f64> = indices.iter()
                    .map(|&i| if y[i] == class1 { 1.0 } else { -1.0 })
                    .collect();

                // Train binary classifier
                let mut clf = SVC::fit(&x_filtered, &y_filtered, params.clone())?;
                classifiers.push((class1, class2, clf));
            }
        }

        Ok(Self { classifiers, classes })
    }

    pub fn predict(&self, x: &DenseMatrix<f64>) -> Vec<u32> {
        let mut votes = vec![HashMap::new(); x.shape().0];
        
        for (class1, class2, clf) in &self.classifiers {
            let preds = clf.predict(x).unwrap();
            
            for (i, &p) in preds.iter().enumerate() {
                let vote = if p > 0.0 { *class1 } else { *class2 };
                *votes[i].entry(vote).or_insert(0) += 1;
            }
        }

        votes.iter()
            .map(|v| *v.iter().max_by_key(|(_, &count)| count).unwrap().0)
            .collect()
    }
}

// Example usage with Iris dataset
fn main() -> Result<(), Failed> {
    let iris = load_dataset();
    let (x_train, x_test, y_train, y_test) = train_test_split(
        &iris.data,
        &iris.target,
        0.8,
        true,
        Some(42),
    );

    let params = SVCParameters::default()
        .with_c(200.0)
        .with_kernel(smartcore::svm::Kernel::linear());

    let clf = MulticlassSVC::fit(&x_train, &y_train, &params)?;
    let preds = clf.predict(&x_test);
    
    println!("Accuracy: {}", accuracy(&y_test, &preds));
    Ok(())
}

Key implementation details:

  1. OvO Strategy:
    • Creates n_classes * (n_classes - 1) / 2 binary classifiers12
    • Uses filtered subsets of data for each class pair3
  2. smartcore Integration:
    • Uses SVC with configurable parameters (C, kernel)4
    • Handles DenseMatrix input as per smartcore's data requirements[^1]
  3. Prediction Aggregation:
    • Implements voting system across all binary classifiers23
    • Handles ties by selecting first majority class

Advantages over naive implementation:

  • Maintains smartcore's API conventions[^1]
  • Avoids unsafe code and complex lifetimes
  • Compatible with smartcore's linear algebra traits
  • Supports all kernel types available in smartcore4

This implementation follows smartcore's design principles by:

  • Using pure Rust without external dependencies[^1]
  • Maintaining a Pythonic/scikit-learn-like API[^1]
  • Supporting TDD through clear input/output contracts
  • Working efficiently with smartcore's matrix types[^1]

For production use, you'd want to add:

  • Model persistence (serialization/deserialization)
  • Class weighting support
  • Parallel training of binary classifiers (using rayon)
  • More sophisticated tie-breaking strategies

Footnotes

  1. https://scikit-learn.org/stable/modules/svm.html

  2. https://www.nb-data.com/p/one-vs-all-vs-one-vs-one-which-multi 2

  3. https://www.baeldung.com/cs/svm-multiclass-classification 2

  4. http://smartcorelib.org/user_guide/supervised.html 2

@DanielLacina
Copy link
Author

Appreciate the guide.

@DanielLacina
Copy link
Author

It doesn't look like the DenseMatrix has the select_rows method.

@DanielLacina
Copy link
Author

I assume you want the multiclass struct to take in any generic Array2 object as x.

@Mec-iS
Copy link
Collaborator

Mec-iS commented Jun 9, 2025

It doesn't look like the DenseMatrix has the select_rows method.

it has the get_row method, see src/linalg/basic/matrix.rs. Please read the API for DenseMatrix.

I assume you want the multiclass struct to take in any generic Array2 object as x.

it depends what you are going to do with the y parameter, in this case using &Vec it is OK I guess. Just try and see.

@DanielLacina
Copy link
Author

Alright. That helps a ton.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants