-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 6d968a7
Showing
41 changed files
with
4,339 additions
and
0 deletions.
There are no files selected for viewing
Binary file added
BIN
+158 Bytes
Contextuality-GCD-main/${DATASET_DIR}/cifar10/cifar-10-batches-py/batches.meta
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2022 Xin Wen | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Parametric Classification for Generalized Category Discovery: A Baseline Study | ||
|
||
|
||
<p align="center"> | ||
<a href="https://openaccess.thecvf.com/content/ICCV2023/html/Wen_Parametric_Classification_for_Generalized_Category_Discovery_A_Baseline_Study_ICCV_2023_paper.html"><img src="https://img.shields.io/badge/-ICCV%202023-68488b"></a> | ||
<a href="https://arxiv.org/abs/2211.11727"><img src="https://img.shields.io/badge/arXiv-2211.11727-b31b1b"></a> | ||
<a href="https://wen-xin.info/simgcd"><img src="https://img.shields.io/badge/Project-Website-blue"></a> | ||
<a href="https://github.com/CVMI-Lab/SlotCon/blob/master/LICENSE"><img src="https://img.shields.io/badge/License-MIT-blue.svg"></a> | ||
</p> | ||
<p align="center"> | ||
Parametric Classification for Generalized Category Discovery: A Baseline Study (ICCV 2023)<br> | ||
By | ||
<a href="https://wen-xin.info">Xin Wen</a>*, | ||
<a href="https://bzhao.me/">Bingchen Zhao</a>*, and | ||
<a href="https://xjqi.github.io/">Xiaojuan Qi</a>. | ||
</p> | ||
|
||
data:image/s3,"s3://crabby-images/360b5/360b560c7090966655c04937579b33da247ec121" alt="teaser" | ||
|
||
Generalized Category Discovery (GCD) aims to discover novel categories in unlabelled datasets using knowledge learned from labelled samples. | ||
Previous studies argued that parametric classifiers are prone to overfitting to seen categories, and endorsed using a non-parametric classifier formed with semi-supervised $k$-means. | ||
|
||
However, in this study, we investigate the failure of parametric classifiers, verify the effectiveness of previous design choices when high-quality supervision is available, and identify unreliable pseudo-labels as a key problem. We demonstrate that two prediction biases exist: the classifier tends to predict seen classes more often, and produces an imbalanced distribution across seen and novel categories. | ||
Based on these findings, we propose a simple yet effective parametric classification method that benefits from entropy regularisation, achieves state-of-the-art performance on multiple GCD benchmarks and shows strong robustness to unknown class numbers. | ||
We hope the investigation and proposed simple framework can serve as a strong baseline to facilitate future studies in this field. | ||
|
||
## Running | ||
|
||
### Dependencies | ||
|
||
``` | ||
pip install -r requirements.txt | ||
``` | ||
|
||
### Config | ||
|
||
Set paths to datasets and desired log directories in ```config.py``` | ||
|
||
|
||
### Datasets | ||
|
||
We use fine-grained benchmarks in this paper, including: | ||
|
||
* [The Semantic Shift Benchmark (SSB)](https://github.com/sgvaze/osr_closed_set_all_you_need#ssb) and [Herbarium19](https://www.kaggle.com/c/herbarium-2019-fgvc6) | ||
|
||
We also use generic object recognition datasets, including: | ||
|
||
* [CIFAR-10/100](https://pytorch.org/vision/stable/datasets.html) and [ImageNet-100/1K](https://image-net.org/download.php) | ||
|
||
|
||
### Scripts | ||
|
||
**Train the model**: | ||
|
||
``` | ||
bash scripts/run_${DATASET_NAME}.sh | ||
``` | ||
|
||
We found picking the model according to 'Old' class performance could lead to possible over-fitting, and since 'New' class labels on the held-out validation set should be assumed unavailable, we suggest not to perform model selection, and simply use the last-epoch model. | ||
|
||
## Results | ||
Our results: | ||
|
||
<table><thead><tr><th>Source</th><th colspan="3">Paper (3 runs) </th><th colspan="3">Current Github (5 runs) </th></tr></thead><tbody><tr><td>Dataset</td><td>All</td><td>Old</td><td>New</td><td>All</td><td>Old</td><td>New</td></tr><tr><td>CIFAR10</td><td>97.1±0.0</td><td>95.1±0.1</td><td>98.1±0.1</td><td>97.0±0.1</td><td>93.9±0.1</td><td>98.5±0.1</td></tr><tr><td>CIFAR100</td><td>80.1±0.9</td><td>81.2±0.4</td><td>77.8±2.0</td><td>79.8±0.6</td><td>81.1±0.5</td><td>77.4±2.5</td></tr><tr><td>ImageNet-100</td><td>83.0±1.2</td><td>93.1±0.2</td><td>77.9±1.9</td><td>83.6±1.4</td><td>92.4±0.1</td><td>79.1±2.2</td></tr><tr><td>ImageNet-1K</td><td>57.1±0.1</td><td>77.3±0.1</td><td>46.9±0.2</td><td>57.0±0.4</td><td>77.1±0.1</td><td>46.9±0.5</td></tr><tr><td>CUB</td><td>60.3±0.1</td><td>65.6±0.9</td><td>57.7±0.4</td><td>61.5±0.5</td><td>65.7±0.5</td><td>59.4±0.8</td></tr><tr><td>Stanford Cars</td><td>53.8±2.2</td><td>71.9±1.7</td><td>45.0±2.4</td><td>53.4±1.6</td><td>71.5±1.6</td><td>44.6±1.7</td></tr><tr><td>FGVC-Aircraft</td><td>54.2±1.9</td><td>59.1±1.2</td><td>51.8±2.3</td><td>54.3±0.7</td><td>59.4±0.4</td><td>51.7±1.2</td></tr><tr><td>Herbarium 19</td><td>44.0±0.4</td><td>58.0±0.4</td><td>36.4±0.8</td><td>44.2±0.2</td><td>57.6±0.6</td><td>37.0±0.4</td></tr></tbody></table> | ||
|
||
## Citing this work | ||
|
||
If you find this repo useful for your research, please consider citing our paper: | ||
|
||
``` | ||
@inproceedings{wen2023simgcd, | ||
author = {Wen, Xin and Zhao, Bingchen and Qi, Xiaojuan}, | ||
title = {Parametric Classification for Generalized Category Discovery: A Baseline Study}, | ||
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, | ||
year = {2023}, | ||
pages = {16590-16600} | ||
} | ||
``` | ||
|
||
## Acknowledgements | ||
|
||
The codebase is largely built on this repo: https://github.com/sgvaze/generalized-category-discovery. | ||
|
||
## License | ||
|
||
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# ----------------- | ||
# DATASET ROOTS | ||
# ----------------- | ||
cifar_10_root = '${DATASET_DIR}/cifar10' | ||
cifar_100_root = '${DATASET_DIR}/cifar100' | ||
cub_root = '${DATASET_DIR}/cub' | ||
aircraft_root = '${DATASET_DIR}/fgvc-aircraft-2013b' | ||
car_root = '${DATASET_DIR}/cars' | ||
herbarium_dataroot = '${DATASET_DIR}/herbarium_19' | ||
imagenet_root = '${DATASET_DIR}/ImageNet' | ||
|
||
# OSR Split dir | ||
osr_split_dir = 'data/ssb_splits' | ||
|
||
# ----------------- | ||
# OTHER PATHS | ||
# ----------------- | ||
exp_root = 'dev_outputs' # All logs and checkpoints will be saved here |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from torchvision import transforms | ||
|
||
import torch | ||
|
||
def get_transform(transform_type='imagenet', image_size=32, args=None): | ||
|
||
if transform_type == 'imagenet': | ||
|
||
mean = (0.485, 0.456, 0.406) | ||
std = (0.229, 0.224, 0.225) | ||
interpolation = args.interpolation | ||
crop_pct = args.crop_pct | ||
|
||
train_transform = transforms.Compose([ | ||
transforms.Resize(int(image_size / crop_pct), interpolation), | ||
transforms.RandomCrop(image_size), | ||
transforms.RandomHorizontalFlip(p=0.5), | ||
transforms.ColorJitter(), | ||
transforms.ToTensor(), | ||
transforms.Normalize( | ||
mean=torch.tensor(mean), | ||
std=torch.tensor(std)) | ||
]) | ||
|
||
test_transform = transforms.Compose([ | ||
transforms.Resize(int(image_size / crop_pct), interpolation), | ||
transforms.CenterCrop(image_size), | ||
transforms.ToTensor(), | ||
transforms.Normalize( | ||
mean=torch.tensor(mean), | ||
std=torch.tensor(std)) | ||
]) | ||
|
||
else: | ||
|
||
raise NotImplementedError | ||
|
||
return (train_transform, test_transform) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
from torchvision.datasets import CIFAR10, CIFAR100 | ||
from copy import deepcopy | ||
import numpy as np | ||
|
||
from data.data_utils import subsample_instances | ||
from config import cifar_10_root, cifar_100_root | ||
|
||
|
||
class CustomCIFAR10(CIFAR10): | ||
|
||
def __init__(self, *args, **kwargs): | ||
|
||
super(CustomCIFAR10, self).__init__(*args, **kwargs) | ||
|
||
self.uq_idxs = np.array(range(len(self))) | ||
|
||
def __getitem__(self, item): | ||
|
||
img, label = super().__getitem__(item) | ||
uq_idx = self.uq_idxs[item] | ||
|
||
return img, label, uq_idx | ||
|
||
def __len__(self): | ||
return len(self.targets) | ||
|
||
|
||
class CustomCIFAR100(CIFAR100): | ||
|
||
def __init__(self, *args, **kwargs): | ||
super(CustomCIFAR100, self).__init__(*args, **kwargs) | ||
|
||
self.uq_idxs = np.array(range(len(self))) | ||
|
||
def __getitem__(self, item): | ||
img, label = super().__getitem__(item) | ||
uq_idx = self.uq_idxs[item] | ||
|
||
return img, label, uq_idx | ||
|
||
def __len__(self): | ||
return len(self.targets) | ||
|
||
|
||
def subsample_dataset(dataset, idxs): | ||
|
||
# Allow for setting in which all empty set of indices is passed | ||
|
||
if len(idxs) > 0: | ||
|
||
dataset.data = dataset.data[idxs] | ||
dataset.targets = np.array(dataset.targets)[idxs].tolist() | ||
dataset.uq_idxs = dataset.uq_idxs[idxs] | ||
|
||
return dataset | ||
|
||
else: | ||
|
||
return None | ||
|
||
|
||
def subsample_classes(dataset, include_classes=(0, 1, 8, 9)): | ||
|
||
cls_idxs = [x for x, t in enumerate(dataset.targets) if t in include_classes] | ||
|
||
target_xform_dict = {} | ||
for i, k in enumerate(include_classes): | ||
target_xform_dict[k] = i | ||
|
||
dataset = subsample_dataset(dataset, cls_idxs) | ||
|
||
# dataset.target_transform = lambda x: target_xform_dict[x] | ||
|
||
return dataset | ||
|
||
|
||
def get_train_val_indices(train_dataset, val_split=0.2): | ||
|
||
train_classes = np.unique(train_dataset.targets) | ||
|
||
# Get train/test indices | ||
train_idxs = [] | ||
val_idxs = [] | ||
for cls in train_classes: | ||
|
||
cls_idxs = np.where(train_dataset.targets == cls)[0] | ||
|
||
v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) | ||
t_ = [x for x in cls_idxs if x not in v_] | ||
|
||
train_idxs.extend(t_) | ||
val_idxs.extend(v_) | ||
|
||
return train_idxs, val_idxs | ||
|
||
|
||
def get_cifar_10_datasets(train_transform, test_transform, train_classes=(0, 1, 8, 9), | ||
prop_train_labels=0.8, split_train_val=False, seed=0): | ||
|
||
np.random.seed(seed) | ||
|
||
# Init entire training set | ||
whole_training_set = CustomCIFAR10(root=cifar_10_root, transform=train_transform, train=True) | ||
|
||
# Get labelled training set which has subsampled classes, then subsample some indices from that | ||
train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) | ||
subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) | ||
train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) | ||
|
||
# Split into training and validation sets | ||
train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) | ||
train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) | ||
val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) | ||
val_dataset_labelled_split.transform = test_transform | ||
|
||
# Get unlabelled data | ||
unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) | ||
train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) | ||
|
||
# Get test set for all classes | ||
test_dataset = CustomCIFAR10(root=cifar_10_root, transform=test_transform, train=False) | ||
|
||
# Either split train into train and val or use test set as val | ||
train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled | ||
val_dataset_labelled = val_dataset_labelled_split if split_train_val else None | ||
|
||
all_datasets = { | ||
'train_labelled': train_dataset_labelled, | ||
'train_unlabelled': train_dataset_unlabelled, | ||
'val': val_dataset_labelled, | ||
'test': test_dataset, | ||
} | ||
|
||
return all_datasets | ||
|
||
|
||
def get_cifar_100_datasets(train_transform, test_transform, train_classes=range(80), | ||
prop_train_labels=0.8, split_train_val=False, seed=0): | ||
|
||
np.random.seed(seed) | ||
|
||
# Init entire training set | ||
whole_training_set = CustomCIFAR100(root=cifar_100_root, transform=train_transform, train=True, download=True) | ||
|
||
# Get labelled training set which has subsampled classes, then subsample some indices from that | ||
train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) | ||
subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) | ||
train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) | ||
|
||
# Split into training and validation sets | ||
train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) | ||
train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) | ||
val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) | ||
val_dataset_labelled_split.transform = test_transform | ||
|
||
# Get unlabelled data | ||
unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) | ||
train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) | ||
|
||
# Get test set for all classes | ||
test_dataset = CustomCIFAR100(root=cifar_100_root, transform=test_transform, train=False, download=True) | ||
|
||
# Either split train into train and val or use test set as val | ||
train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled | ||
val_dataset_labelled = val_dataset_labelled_split if split_train_val else None | ||
|
||
all_datasets = { | ||
'train_labelled': train_dataset_labelled, | ||
'train_unlabelled': train_dataset_unlabelled, | ||
'val': val_dataset_labelled, | ||
'test': test_dataset, | ||
} | ||
|
||
return all_datasets | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
x = get_cifar_100_datasets(None, None, split_train_val=False, | ||
train_classes=range(80), prop_train_labels=0.5) | ||
|
||
print('Printing lens...') | ||
for k, v in x.items(): | ||
if v is not None: | ||
print(f'{k}: {len(v)}') | ||
|
||
print('Printing labelled and unlabelled overlap...') | ||
print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) | ||
print('Printing total instances in train...') | ||
print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) | ||
|
||
print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}') | ||
print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}') | ||
print(f'Len labelled set: {len(x["train_labelled"])}') | ||
print(f'Len unlabelled set: {len(x["train_unlabelled"])}') |
Oops, something went wrong.