Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
Clarence-CV authored Jun 12, 2024
0 parents commit 6d968a7
Show file tree
Hide file tree
Showing 41 changed files with 4,339 additions and 0 deletions.
Binary file not shown.
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2022 Xin Wen

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
86 changes: 86 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Parametric Classification for Generalized Category Discovery: A Baseline Study


<p align="center">
<a href="https://openaccess.thecvf.com/content/ICCV2023/html/Wen_Parametric_Classification_for_Generalized_Category_Discovery_A_Baseline_Study_ICCV_2023_paper.html"><img src="https://img.shields.io/badge/-ICCV%202023-68488b"></a>
<a href="https://arxiv.org/abs/2211.11727"><img src="https://img.shields.io/badge/arXiv-2211.11727-b31b1b"></a>
<a href="https://wen-xin.info/simgcd"><img src="https://img.shields.io/badge/Project-Website-blue"></a>
<a href="https://github.com/CVMI-Lab/SlotCon/blob/master/LICENSE"><img src="https://img.shields.io/badge/License-MIT-blue.svg"></a>
</p>
<p align="center">
Parametric Classification for Generalized Category Discovery: A Baseline Study (ICCV 2023)<br>
By
<a href="https://wen-xin.info">Xin Wen</a>*,
<a href="https://bzhao.me/">Bingchen Zhao</a>*, and
<a href="https://xjqi.github.io/">Xiaojuan Qi</a>.
</p>

![teaser](assets/teaser.jpg)

Generalized Category Discovery (GCD) aims to discover novel categories in unlabelled datasets using knowledge learned from labelled samples.
Previous studies argued that parametric classifiers are prone to overfitting to seen categories, and endorsed using a non-parametric classifier formed with semi-supervised $k$-means.

However, in this study, we investigate the failure of parametric classifiers, verify the effectiveness of previous design choices when high-quality supervision is available, and identify unreliable pseudo-labels as a key problem. We demonstrate that two prediction biases exist: the classifier tends to predict seen classes more often, and produces an imbalanced distribution across seen and novel categories.
Based on these findings, we propose a simple yet effective parametric classification method that benefits from entropy regularisation, achieves state-of-the-art performance on multiple GCD benchmarks and shows strong robustness to unknown class numbers.
We hope the investigation and proposed simple framework can serve as a strong baseline to facilitate future studies in this field.

## Running

### Dependencies

```
pip install -r requirements.txt
```

### Config

Set paths to datasets and desired log directories in ```config.py```


### Datasets

We use fine-grained benchmarks in this paper, including:

* [The Semantic Shift Benchmark (SSB)](https://github.com/sgvaze/osr_closed_set_all_you_need#ssb) and [Herbarium19](https://www.kaggle.com/c/herbarium-2019-fgvc6)

We also use generic object recognition datasets, including:

* [CIFAR-10/100](https://pytorch.org/vision/stable/datasets.html) and [ImageNet-100/1K](https://image-net.org/download.php)


### Scripts

**Train the model**:

```
bash scripts/run_${DATASET_NAME}.sh
```

We found picking the model according to 'Old' class performance could lead to possible over-fitting, and since 'New' class labels on the held-out validation set should be assumed unavailable, we suggest not to perform model selection, and simply use the last-epoch model.

## Results
Our results:

<table><thead><tr><th>Source</th><th colspan="3">Paper (3 runs) </th><th colspan="3">Current Github (5 runs) </th></tr></thead><tbody><tr><td>Dataset</td><td>All</td><td>Old</td><td>New</td><td>All</td><td>Old</td><td>New</td></tr><tr><td>CIFAR10</td><td>97.1±0.0</td><td>95.1±0.1</td><td>98.1±0.1</td><td>97.0±0.1</td><td>93.9±0.1</td><td>98.5±0.1</td></tr><tr><td>CIFAR100</td><td>80.1±0.9</td><td>81.2±0.4</td><td>77.8±2.0</td><td>79.8±0.6</td><td>81.1±0.5</td><td>77.4±2.5</td></tr><tr><td>ImageNet-100</td><td>83.0±1.2</td><td>93.1±0.2</td><td>77.9±1.9</td><td>83.6±1.4</td><td>92.4±0.1</td><td>79.1±2.2</td></tr><tr><td>ImageNet-1K</td><td>57.1±0.1</td><td>77.3±0.1</td><td>46.9±0.2</td><td>57.0±0.4</td><td>77.1±0.1</td><td>46.9±0.5</td></tr><tr><td>CUB</td><td>60.3±0.1</td><td>65.6±0.9</td><td>57.7±0.4</td><td>61.5±0.5</td><td>65.7±0.5</td><td>59.4±0.8</td></tr><tr><td>Stanford Cars</td><td>53.8±2.2</td><td>71.9±1.7</td><td>45.0±2.4</td><td>53.4±1.6</td><td>71.5±1.6</td><td>44.6±1.7</td></tr><tr><td>FGVC-Aircraft</td><td>54.2±1.9</td><td>59.1±1.2</td><td>51.8±2.3</td><td>54.3±0.7</td><td>59.4±0.4</td><td>51.7±1.2</td></tr><tr><td>Herbarium 19</td><td>44.0±0.4</td><td>58.0±0.4</td><td>36.4±0.8</td><td>44.2±0.2</td><td>57.6±0.6</td><td>37.0±0.4</td></tr></tbody></table>

## Citing this work

If you find this repo useful for your research, please consider citing our paper:

```
@inproceedings{wen2023simgcd,
author = {Wen, Xin and Zhao, Bingchen and Qi, Xiaojuan},
title = {Parametric Classification for Generalized Category Discovery: A Baseline Study},
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
year = {2023},
pages = {16590-16600}
}
```

## Acknowledgements

The codebase is largely built on this repo: https://github.com/sgvaze/generalized-category-discovery.

## License

This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
18 changes: 18 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# -----------------
# DATASET ROOTS
# -----------------
cifar_10_root = '${DATASET_DIR}/cifar10'
cifar_100_root = '${DATASET_DIR}/cifar100'
cub_root = '${DATASET_DIR}/cub'
aircraft_root = '${DATASET_DIR}/fgvc-aircraft-2013b'
car_root = '${DATASET_DIR}/cars'
herbarium_dataroot = '${DATASET_DIR}/herbarium_19'
imagenet_root = '${DATASET_DIR}/ImageNet'

# OSR Split dir
osr_split_dir = 'data/ssb_splits'

# -----------------
# OTHER PATHS
# -----------------
exp_root = 'dev_outputs' # All logs and checkpoints will be saved here
38 changes: 38 additions & 0 deletions data/augmentations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from torchvision import transforms

import torch

def get_transform(transform_type='imagenet', image_size=32, args=None):

if transform_type == 'imagenet':

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
interpolation = args.interpolation
crop_pct = args.crop_pct

train_transform = transforms.Compose([
transforms.Resize(int(image_size / crop_pct), interpolation),
transforms.RandomCrop(image_size),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(),
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
])

test_transform = transforms.Compose([
transforms.Resize(int(image_size / crop_pct), interpolation),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
])

else:

raise NotImplementedError

return (train_transform, test_transform)
195 changes: 195 additions & 0 deletions data/cifar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
from torchvision.datasets import CIFAR10, CIFAR100
from copy import deepcopy
import numpy as np

from data.data_utils import subsample_instances
from config import cifar_10_root, cifar_100_root


class CustomCIFAR10(CIFAR10):

def __init__(self, *args, **kwargs):

super(CustomCIFAR10, self).__init__(*args, **kwargs)

self.uq_idxs = np.array(range(len(self)))

def __getitem__(self, item):

img, label = super().__getitem__(item)
uq_idx = self.uq_idxs[item]

return img, label, uq_idx

def __len__(self):
return len(self.targets)


class CustomCIFAR100(CIFAR100):

def __init__(self, *args, **kwargs):
super(CustomCIFAR100, self).__init__(*args, **kwargs)

self.uq_idxs = np.array(range(len(self)))

def __getitem__(self, item):
img, label = super().__getitem__(item)
uq_idx = self.uq_idxs[item]

return img, label, uq_idx

def __len__(self):
return len(self.targets)


def subsample_dataset(dataset, idxs):

# Allow for setting in which all empty set of indices is passed

if len(idxs) > 0:

dataset.data = dataset.data[idxs]
dataset.targets = np.array(dataset.targets)[idxs].tolist()
dataset.uq_idxs = dataset.uq_idxs[idxs]

return dataset

else:

return None


def subsample_classes(dataset, include_classes=(0, 1, 8, 9)):

cls_idxs = [x for x, t in enumerate(dataset.targets) if t in include_classes]

target_xform_dict = {}
for i, k in enumerate(include_classes):
target_xform_dict[k] = i

dataset = subsample_dataset(dataset, cls_idxs)

# dataset.target_transform = lambda x: target_xform_dict[x]

return dataset


def get_train_val_indices(train_dataset, val_split=0.2):

train_classes = np.unique(train_dataset.targets)

# Get train/test indices
train_idxs = []
val_idxs = []
for cls in train_classes:

cls_idxs = np.where(train_dataset.targets == cls)[0]

v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
t_ = [x for x in cls_idxs if x not in v_]

train_idxs.extend(t_)
val_idxs.extend(v_)

return train_idxs, val_idxs


def get_cifar_10_datasets(train_transform, test_transform, train_classes=(0, 1, 8, 9),
prop_train_labels=0.8, split_train_val=False, seed=0):

np.random.seed(seed)

# Init entire training set
whole_training_set = CustomCIFAR10(root=cifar_10_root, transform=train_transform, train=True)

# Get labelled training set which has subsampled classes, then subsample some indices from that
train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)

# Split into training and validation sets
train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
val_dataset_labelled_split.transform = test_transform

# Get unlabelled data
unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))

# Get test set for all classes
test_dataset = CustomCIFAR10(root=cifar_10_root, transform=test_transform, train=False)

# Either split train into train and val or use test set as val
train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
val_dataset_labelled = val_dataset_labelled_split if split_train_val else None

all_datasets = {
'train_labelled': train_dataset_labelled,
'train_unlabelled': train_dataset_unlabelled,
'val': val_dataset_labelled,
'test': test_dataset,
}

return all_datasets


def get_cifar_100_datasets(train_transform, test_transform, train_classes=range(80),
prop_train_labels=0.8, split_train_val=False, seed=0):

np.random.seed(seed)

# Init entire training set
whole_training_set = CustomCIFAR100(root=cifar_100_root, transform=train_transform, train=True, download=True)

# Get labelled training set which has subsampled classes, then subsample some indices from that
train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)

# Split into training and validation sets
train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
val_dataset_labelled_split.transform = test_transform

# Get unlabelled data
unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))

# Get test set for all classes
test_dataset = CustomCIFAR100(root=cifar_100_root, transform=test_transform, train=False, download=True)

# Either split train into train and val or use test set as val
train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
val_dataset_labelled = val_dataset_labelled_split if split_train_val else None

all_datasets = {
'train_labelled': train_dataset_labelled,
'train_unlabelled': train_dataset_unlabelled,
'val': val_dataset_labelled,
'test': test_dataset,
}

return all_datasets


if __name__ == '__main__':

x = get_cifar_100_datasets(None, None, split_train_val=False,
train_classes=range(80), prop_train_labels=0.5)

print('Printing lens...')
for k, v in x.items():
if v is not None:
print(f'{k}: {len(v)}')

print('Printing labelled and unlabelled overlap...')
print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
print('Printing total instances in train...')
print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))

print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}')
print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}')
print(f'Len labelled set: {len(x["train_labelled"])}')
print(f'Len unlabelled set: {len(x["train_unlabelled"])}')
Loading

0 comments on commit 6d968a7

Please sign in to comment.