Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions src/jabs/classifier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,8 @@
Gradient Boosting, and XGBoost), utilities for feature management, data splitting, model evaluation, and serialization.`
"""

import pathlib

from .classifier import Classifier

HYPERPARAMETER_PATH = pathlib.Path(__file__).parent / "hyperparameters.json"

__all__ = [
"Classifier",
]
143 changes: 77 additions & 66 deletions src/jabs/classifier/classifier.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import random
import re
import typing
Expand All @@ -8,7 +9,7 @@
import joblib
import numpy as np
import pandas as pd
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.exceptions import InconsistentVersionWarning
from sklearn.metrics import (
accuracy_score,
Expand All @@ -21,28 +22,53 @@
from jabs.types import ClassifierType
from jabs.utils import hash_file

_VERSION = 9

_classifier_choices = [ClassifierType.RANDOM_FOREST, ClassifierType.GRADIENT_BOOSTING]
_VERSION = 10

try:
_xgboost = import_module("xgboost")
# we were able to import xgboost, make it available as an option:
_classifier_choices.append(ClassifierType.XGBOOST)
except Exception:
except ImportError:
# we were unable to import the xgboost module. It's either not
# installed (it should be if the user used our requirements-old.txt)
# installed (it should be if the user installed JABS as a package)
# or it may have been unable to be imported due to a missing
# libomp. Either way, we won't add it to the available choices and
# we can otherwise ignore this exception
_xgboost = None
logging.warning(
"Unable to import xgboost. XGBoost support will be unavailable. "
"You may need to install xgboost and/or libomp."
)


# Classifier factory helpers and mapping
def _make_random_forest(n_jobs: int, random_seed: int | None):
"""Factory function to construct a RandomForest classifier."""
return RandomForestClassifier(n_jobs=n_jobs, random_state=random_seed)


def _make_xgboost(n_jobs: int, random_seed: int | None):
"""Factory function to construct an XGBoost classifier."""
if _xgboost is None:
raise RuntimeError(
"XGBoost classifier requested but 'xgboost' is not available in this environment."
)
return _xgboost.XGBClassifier(n_jobs=n_jobs, random_state=random_seed)


# _CLASSIFIER_FACTORIES serves as both the single source of truth for classifiers
# supported by the current JABS environment, in addition to the mapping of ClassifierTypes
# to factory functions that produce instantiated classifiers for that type
_CLASSIFIER_FACTORIES: dict[ClassifierType, typing.Callable[[int, int | None], typing.Any]] = {
ClassifierType.RANDOM_FOREST: _make_random_forest,
}
if _xgboost is not None:
_CLASSIFIER_FACTORIES[ClassifierType.XGBOOST] = _make_xgboost


class Classifier:
"""A machine learning classifier for behavior classification tasks.

This class supports training, evaluating, saving, and loading classifiers
for behavioral data using Random Forest, Gradient Boosting, or XGBoost algorithms.
for behavioral data using Random Forest or XGBoost algorithms.
It provides utilities for data splitting, balancing, augmentation, and feature management.

Attributes:
Expand All @@ -51,13 +77,7 @@ class Classifier:

LABEL_THRESHOLD = 20

_CLASSIFIER_NAMES: typing.ClassVar[dict] = {
ClassifierType.RANDOM_FOREST: "Random Forest",
ClassifierType.GRADIENT_BOOSTING: "Gradient Boosting",
ClassifierType.XGBOOST: "XGBoost",
}

def __init__(self, classifier=ClassifierType.RANDOM_FOREST, n_jobs=1):
def __init__(self, classifier: ClassifierType = ClassifierType.RANDOM_FOREST, n_jobs: int = 1):
self._classifier_type = classifier
self._classifier = None
self._project_settings = None
Expand All @@ -69,9 +89,10 @@ def __init__(self, classifier=ClassifierType.RANDOM_FOREST, n_jobs=1):
self._classifier_file = None
self._classifier_hash = None
self._classifier_source = None
self._supported_classifiers = self._supported_classifier_choices()

# make sure the value passed for the classifier parameter is valid
if classifier not in _classifier_choices:
if classifier not in self._supported_classifiers:
raise ValueError("Invalid classifier type")

@classmethod
Expand All @@ -93,10 +114,10 @@ def from_training_file(cls, path: Path):
classifier.behavior_name = behavior
classifier.set_dict_settings(loaded_training_data["settings"])
classifier_type = ClassifierType(loaded_training_data["classifier_type"])
if classifier_type in classifier.classifier_choices():
if classifier_type in classifier._supported_classifiers:
classifier.set_classifier(classifier_type)
else:
print(
logging.warning(
f"Specified classifier type {classifier_type.name} is unavailable, using default: {classifier.classifier_type.name}"
)
training_features = classifier.combine_data(
Expand All @@ -119,7 +140,7 @@ def from_training_file(cls, path: Path):
@property
def classifier_name(self) -> str:
"""return the name of the classifier used as a string"""
return self._CLASSIFIER_NAMES[self._classifier_type]
return self._classifier_type.value

@property
def classifier_type(self) -> ClassifierType:
Expand Down Expand Up @@ -148,7 +169,7 @@ def project_settings(self) -> dict:
return {}

@property
def behavior_name(self) -> str:
def behavior_name(self) -> str | None:
"""return the behavior name property"""
return self._behavior

Expand All @@ -163,7 +184,7 @@ def version(self) -> int:
return self._version

@property
def feature_names(self) -> list:
def feature_names(self) -> list[str] | None:
"""returns the list of feature names used when training this classifier"""
return self._feature_names

Expand Down Expand Up @@ -299,9 +320,8 @@ def downsample_balance(features, labels, random_seed=None):
selected_samples = []
for cur_label in label_states:
idxs = np.where(labels == cur_label)[0]
if random_seed is not None:
np.random.seed(random_seed)
sampled_idxs = np.random.choice(idxs, max_examples_per_class, replace=False)
rng = np.random.default_rng(random_seed)
sampled_idxs = rng.choice(idxs, max_examples_per_class, replace=False)
selected_samples.append(sampled_idxs)
selected_samples = np.sort(np.concatenate(selected_samples))
features = features.iloc[selected_samples]
Expand Down Expand Up @@ -346,9 +366,9 @@ def augment_symmetric(features, labels, random_str="ASygRQDZJD"):
# print(str(lowercase_features[idx]) + ' -> ' + str(reflected_feature_names[idx]))
return features, labels

def set_classifier(self, classifier):
def set_classifier(self, classifier: ClassifierType):
"""change the type of the classifier being used"""
if classifier not in _classifier_choices:
if classifier not in self._supported_classifiers:
raise ValueError("Invalid Classifier Type")
self._classifier_type = classifier

Expand Down Expand Up @@ -380,15 +400,17 @@ def classifier_choices(self):

Returns:
dict where keys are ClassifierType enum values, and the
values are string names for the classifiers. example:

{
<ClassifierType.RANDOM_FOREST: 1>: 'Random Forest',
<ClassifierType.GRADIENT_BOOSTING: 2>: 'Gradient Boosting',
<ClassifierType.XGBOOST: 3>: 'XGBoost'
}
values are string names for the classifiers.
"""
return {d: self._CLASSIFIER_NAMES[d] for d in _classifier_choices}
return {t: t.value for t in sorted(self._supported_classifiers, key=lambda t: t.value)}

def _create_classifier(self, random_seed: int | None = None):
"""Instantiate the underlying classifier for the current classifier type."""
try:
factory = _CLASSIFIER_FACTORIES[self._classifier_type]
except KeyError:
raise ValueError(f"Unsupported classifier type: {self._classifier_type!r}") from None
return factory(self._n_jobs, random_seed)

def train(self, data, random_seed: int | None = None):
"""train the classifier
Expand Down Expand Up @@ -421,16 +443,23 @@ def train(self, data, random_seed: int | None = None):
if self._project_settings.get("balance_labels", False):
features, labels = self.downsample_balance(features, labels, random_seed)

if self._classifier_type == ClassifierType.RANDOM_FOREST:
self._classifier = self._fit_random_forest(features, labels, random_seed=random_seed)
elif self._classifier_type == ClassifierType.GRADIENT_BOOSTING:
self._classifier = self._fit_gradient_boost(features, labels, random_seed=random_seed)
elif _xgboost is not None and self._classifier_type == ClassifierType.XGBOOST:
classifier = self._create_classifier(random_seed=random_seed)

if self._classifier_type == ClassifierType.XGBOOST:
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=FutureWarning)
self._classifier = self._fit_xgboost(features, labels, random_seed=random_seed)
# XGBoost natively supports NaN as a marker for missing values and handles them
# during tree construction. For XGBoost we therefore convert infinite values to NaN
# and leave them as missing, instead of imputing them with 0. This differs from the
# Random Forest (and other sklearn) path below, where both infinities and NaN are
# replaced with 0.
cleaned_features = features.replace([np.inf, -np.inf], np.nan)
self._classifier = classifier.fit(cleaned_features, labels)
else:
raise ValueError("Unsupported classifier")
# RandomForestClassifier (and most other sklearn estimators) do not natively support NaN
# values, so here we replace infinite values and NaNs with 0 before fitting.
cleaned_features = features.replace([np.inf, -np.inf], 0).fillna(0)
self._classifier = classifier.fit(cleaned_features, labels)

# Classifier may have been re-used from a prior training, blank the logging attributes
self._classifier_file = None
Expand Down Expand Up @@ -513,7 +542,7 @@ def load(self, path: Path):
)

# make sure the value passed for the classifier parameter is valid
if c._classifier_type not in _classifier_choices:
if c._classifier_type not in self._supported_classifiers:
raise ValueError("Invalid classifier type")

self._classifier = c._classifier
Expand All @@ -529,16 +558,6 @@ def load(self, path: Path):
self._classifier_hash = hash_file(Path(path))
self._classifier_source = "pickle"

def _update_classifier_type(self):
# we may need to update the classifier type based
# on the type of the loaded object
if isinstance(self._classifier, RandomForestClassifier):
self._classifier_type = ClassifierType.RANDOM_FOREST
elif isinstance(self._classifier, GradientBoostingClassifier):
self._classifier_type = ClassifierType.GRADIENT_BOOSTING
else:
self._classifier_type = ClassifierType.XGBOOST

@staticmethod
def accuracy_score(truth, predictions):
"""return accuracy score"""
Expand Down Expand Up @@ -567,19 +586,6 @@ def combine_data(per_frame, window):
"""
return pd.concat([per_frame, window], axis=1)

def _fit_random_forest(self, features, labels, random_seed: int | None = None):
classifier = RandomForestClassifier(n_jobs=self._n_jobs, random_state=random_seed)
return classifier.fit(features.replace([np.inf, -np.inf], 0).fillna(0), labels)

def _fit_gradient_boost(self, features, labels, random_seed: int | None = None):
classifier = GradientBoostingClassifier(random_state=random_seed)
return classifier.fit(features.replace([np.inf, -np.inf], 0).fillna(0), labels)

def _fit_xgboost(self, features, labels, random_seed: int | None = None):
classifier = _xgboost.XGBClassifier(n_jobs=self._n_jobs, random_state=random_seed)
classifier.fit(features.replace([np.inf, -np.inf]), labels)
return classifier

def print_feature_importance(self, feature_list, limit=20):
"""print the most important features and their importance

Expand Down Expand Up @@ -665,3 +671,8 @@ def label_threshold_met(all_counts: dict, min_groups: int):
"""
group_count = Classifier.count_label_threshold(all_counts)
return 1 < group_count >= min_groups

@staticmethod
def _supported_classifier_choices() -> set[ClassifierType]:
"""Determine the list of supported classifier types in the current JABS environment."""
return set(_CLASSIFIER_FACTORIES.keys())
2 changes: 1 addition & 1 deletion src/jabs/project/export_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def export_training_data(
out_h5.attrs["min_pose_version"] = pose_version
out_h5.attrs["behavior"] = behavior
write_project_settings(out_h5, project.settings_manager.get_behavior(behavior), "settings")
out_h5.attrs["classifier_type"] = classifier_type.value
out_h5.attrs["classifier_type"] = str(classifier_type)
out_h5.attrs["training_seed"] = training_seed
feature_group = out_h5.create_group("features")
for feature, data in features["per_frame"].items():
Expand Down
6 changes: 2 additions & 4 deletions src/jabs/resources/docs/user_guide/user_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ See `jabs-classify COMMAND --help` for information on a specific command.
```

```text
usage: jabs-classify classify [-h] [--random-forest | --gradient-boosting | --xgboost]
usage: jabs-classify classify [-h] [--random-forest | --xgboost]
(--training TRAINING | --classifier CLASSIFIER) --input-pose
INPUT_POSE --out-dir OUT_DIR [--fps FPS]
[--feature-dir FEATURE_DIR]
Expand All @@ -664,7 +664,6 @@ optionally override the classifier specified in the training file:
Ignored if trained classifier passed with --classifier option.
(the following options are mutually exclusive):
--random-forest Use Random Forest
--gradient-boosting Use Gradient Boosting
--xgboost Use XGBoost

Classifier Input (one of the following is required):
Expand All @@ -674,7 +673,7 @@ Classifier Input (one of the following is required):
```

```text
usage: jabs-classify train [-h] [--random-forest | --gradient-boosting | --xgboost]
usage: jabs-classify train [-h] [--random-forest | --xgboost]
training_file out_file

positional arguments:
Expand All @@ -687,7 +686,6 @@ optional arguments:
optionally override the classifier specified in the training file:
(the following options are mutually exclusive):
--random-forest Use Random Forest
--gradient-boosting Use Gradient Boosting
--xgboost Use XGBoost
```

Expand Down
2 changes: 1 addition & 1 deletion src/jabs/scripts/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def classify_main():

required_args.add_argument(
"--input-pose",
help="input HDF5 pose file (v2, v3, v4, or v5).",
help="input HDF5 pose file.",
required=True,
)
required_args.add_argument(
Expand Down
9 changes: 4 additions & 5 deletions src/jabs/types/classifier_types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from enum import IntEnum
from enum import Enum


class ClassifierType(IntEnum):
class ClassifierType(str, Enum):
"""Classifier type for the project."""

RANDOM_FOREST = 1
GRADIENT_BOOSTING = 2
XGBOOST = 3
RANDOM_FOREST = "Random Forest"
XGBOOST = "XGBoost"
8 changes: 7 additions & 1 deletion src/jabs/ui/main_control_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from PySide6.QtGui import QIcon, QPainter, QPixmap

from jabs.classifier import Classifier
from jabs.types import ClassifierType
from jabs.ui.ear_tag_icons import EarTagIconManager

from .colors import (
Expand Down Expand Up @@ -335,7 +336,12 @@ def classify_button_enabled(self, enabled: bool):
@property
def classifier_type(self):
"""return the selected classifier type"""
return self._classifier_selection.currentData()
data = self._classifier_selection.currentData()
# QComboBox may return a string instead of the enum due to serialization
# so we need to convert it back to ClassifierType if it's a string
if isinstance(data, str):
return ClassifierType(data)
return data

@property
def use_balance_labels(self):
Expand Down