diff --git a/Contextuality-GCD-main/${DATASET_DIR}/cifar10/cifar-10-batches-py/batches.meta b/Contextuality-GCD-main/${DATASET_DIR}/cifar10/cifar-10-batches-py/batches.meta
new file mode 100644
index 0000000..4467a6e
Binary files /dev/null and b/Contextuality-GCD-main/${DATASET_DIR}/cifar10/cifar-10-batches-py/batches.meta differ
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..a9ec139
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 Xin Wen
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..3b422db
--- /dev/null
+++ b/README.md
@@ -0,0 +1,86 @@
+# Parametric Classification for Generalized Category Discovery: A Baseline Study
+
+
+
+
+
+
+
+
+
+ Parametric Classification for Generalized Category Discovery: A Baseline Study (ICCV 2023)
+ By
+ Xin Wen*,
+ Bingchen Zhao*, and
+ Xiaojuan Qi.
+
+
+
+
+Generalized Category Discovery (GCD) aims to discover novel categories in unlabelled datasets using knowledge learned from labelled samples.
+Previous studies argued that parametric classifiers are prone to overfitting to seen categories, and endorsed using a non-parametric classifier formed with semi-supervised $k$-means.
+
+However, in this study, we investigate the failure of parametric classifiers, verify the effectiveness of previous design choices when high-quality supervision is available, and identify unreliable pseudo-labels as a key problem. We demonstrate that two prediction biases exist: the classifier tends to predict seen classes more often, and produces an imbalanced distribution across seen and novel categories.
+Based on these findings, we propose a simple yet effective parametric classification method that benefits from entropy regularisation, achieves state-of-the-art performance on multiple GCD benchmarks and shows strong robustness to unknown class numbers.
+We hope the investigation and proposed simple framework can serve as a strong baseline to facilitate future studies in this field.
+
+## Running
+
+### Dependencies
+
+```
+pip install -r requirements.txt
+```
+
+### Config
+
+Set paths to datasets and desired log directories in ```config.py```
+
+
+### Datasets
+
+We use fine-grained benchmarks in this paper, including:
+
+* [The Semantic Shift Benchmark (SSB)](https://github.com/sgvaze/osr_closed_set_all_you_need#ssb) and [Herbarium19](https://www.kaggle.com/c/herbarium-2019-fgvc6)
+
+We also use generic object recognition datasets, including:
+
+* [CIFAR-10/100](https://pytorch.org/vision/stable/datasets.html) and [ImageNet-100/1K](https://image-net.org/download.php)
+
+
+### Scripts
+
+**Train the model**:
+
+```
+bash scripts/run_${DATASET_NAME}.sh
+```
+
+We found picking the model according to 'Old' class performance could lead to possible over-fitting, and since 'New' class labels on the held-out validation set should be assumed unavailable, we suggest not to perform model selection, and simply use the last-epoch model.
+
+## Results
+Our results:
+
+Source | Paper (3 runs) | Current Github (5 runs) |
---|
Dataset | All | Old | New | All | Old | New |
CIFAR10 | 97.1±0.0 | 95.1±0.1 | 98.1±0.1 | 97.0±0.1 | 93.9±0.1 | 98.5±0.1 |
CIFAR100 | 80.1±0.9 | 81.2±0.4 | 77.8±2.0 | 79.8±0.6 | 81.1±0.5 | 77.4±2.5 |
ImageNet-100 | 83.0±1.2 | 93.1±0.2 | 77.9±1.9 | 83.6±1.4 | 92.4±0.1 | 79.1±2.2 |
ImageNet-1K | 57.1±0.1 | 77.3±0.1 | 46.9±0.2 | 57.0±0.4 | 77.1±0.1 | 46.9±0.5 |
CUB | 60.3±0.1 | 65.6±0.9 | 57.7±0.4 | 61.5±0.5 | 65.7±0.5 | 59.4±0.8 |
Stanford Cars | 53.8±2.2 | 71.9±1.7 | 45.0±2.4 | 53.4±1.6 | 71.5±1.6 | 44.6±1.7 |
FGVC-Aircraft | 54.2±1.9 | 59.1±1.2 | 51.8±2.3 | 54.3±0.7 | 59.4±0.4 | 51.7±1.2 |
Herbarium 19 | 44.0±0.4 | 58.0±0.4 | 36.4±0.8 | 44.2±0.2 | 57.6±0.6 | 37.0±0.4 |
+
+## Citing this work
+
+If you find this repo useful for your research, please consider citing our paper:
+
+```
+@inproceedings{wen2023simgcd,
+ author = {Wen, Xin and Zhao, Bingchen and Qi, Xiaojuan},
+ title = {Parametric Classification for Generalized Category Discovery: A Baseline Study},
+ booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
+ year = {2023},
+ pages = {16590-16600}
+}
+```
+
+## Acknowledgements
+
+The codebase is largely built on this repo: https://github.com/sgvaze/generalized-category-discovery.
+
+## License
+
+This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
diff --git a/config.py b/config.py
new file mode 100644
index 0000000..2a9e34a
--- /dev/null
+++ b/config.py
@@ -0,0 +1,18 @@
+# -----------------
+# DATASET ROOTS
+# -----------------
+cifar_10_root = '${DATASET_DIR}/cifar10'
+cifar_100_root = '${DATASET_DIR}/cifar100'
+cub_root = '${DATASET_DIR}/cub'
+aircraft_root = '${DATASET_DIR}/fgvc-aircraft-2013b'
+car_root = '${DATASET_DIR}/cars'
+herbarium_dataroot = '${DATASET_DIR}/herbarium_19'
+imagenet_root = '${DATASET_DIR}/ImageNet'
+
+# OSR Split dir
+osr_split_dir = 'data/ssb_splits'
+
+# -----------------
+# OTHER PATHS
+# -----------------
+exp_root = 'dev_outputs' # All logs and checkpoints will be saved here
\ No newline at end of file
diff --git a/data/augmentations/__init__.py b/data/augmentations/__init__.py
new file mode 100644
index 0000000..51e934a
--- /dev/null
+++ b/data/augmentations/__init__.py
@@ -0,0 +1,38 @@
+from torchvision import transforms
+
+import torch
+
+def get_transform(transform_type='imagenet', image_size=32, args=None):
+
+ if transform_type == 'imagenet':
+
+ mean = (0.485, 0.456, 0.406)
+ std = (0.229, 0.224, 0.225)
+ interpolation = args.interpolation
+ crop_pct = args.crop_pct
+
+ train_transform = transforms.Compose([
+ transforms.Resize(int(image_size / crop_pct), interpolation),
+ transforms.RandomCrop(image_size),
+ transforms.RandomHorizontalFlip(p=0.5),
+ transforms.ColorJitter(),
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=torch.tensor(mean),
+ std=torch.tensor(std))
+ ])
+
+ test_transform = transforms.Compose([
+ transforms.Resize(int(image_size / crop_pct), interpolation),
+ transforms.CenterCrop(image_size),
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=torch.tensor(mean),
+ std=torch.tensor(std))
+ ])
+
+ else:
+
+ raise NotImplementedError
+
+ return (train_transform, test_transform)
\ No newline at end of file
diff --git a/data/cifar.py b/data/cifar.py
new file mode 100644
index 0000000..499f3ea
--- /dev/null
+++ b/data/cifar.py
@@ -0,0 +1,195 @@
+from torchvision.datasets import CIFAR10, CIFAR100
+from copy import deepcopy
+import numpy as np
+
+from data.data_utils import subsample_instances
+from config import cifar_10_root, cifar_100_root
+
+
+class CustomCIFAR10(CIFAR10):
+
+ def __init__(self, *args, **kwargs):
+
+ super(CustomCIFAR10, self).__init__(*args, **kwargs)
+
+ self.uq_idxs = np.array(range(len(self)))
+
+ def __getitem__(self, item):
+
+ img, label = super().__getitem__(item)
+ uq_idx = self.uq_idxs[item]
+
+ return img, label, uq_idx
+
+ def __len__(self):
+ return len(self.targets)
+
+
+class CustomCIFAR100(CIFAR100):
+
+ def __init__(self, *args, **kwargs):
+ super(CustomCIFAR100, self).__init__(*args, **kwargs)
+
+ self.uq_idxs = np.array(range(len(self)))
+
+ def __getitem__(self, item):
+ img, label = super().__getitem__(item)
+ uq_idx = self.uq_idxs[item]
+
+ return img, label, uq_idx
+
+ def __len__(self):
+ return len(self.targets)
+
+
+def subsample_dataset(dataset, idxs):
+
+ # Allow for setting in which all empty set of indices is passed
+
+ if len(idxs) > 0:
+
+ dataset.data = dataset.data[idxs]
+ dataset.targets = np.array(dataset.targets)[idxs].tolist()
+ dataset.uq_idxs = dataset.uq_idxs[idxs]
+
+ return dataset
+
+ else:
+
+ return None
+
+
+def subsample_classes(dataset, include_classes=(0, 1, 8, 9)):
+
+ cls_idxs = [x for x, t in enumerate(dataset.targets) if t in include_classes]
+
+ target_xform_dict = {}
+ for i, k in enumerate(include_classes):
+ target_xform_dict[k] = i
+
+ dataset = subsample_dataset(dataset, cls_idxs)
+
+ # dataset.target_transform = lambda x: target_xform_dict[x]
+
+ return dataset
+
+
+def get_train_val_indices(train_dataset, val_split=0.2):
+
+ train_classes = np.unique(train_dataset.targets)
+
+ # Get train/test indices
+ train_idxs = []
+ val_idxs = []
+ for cls in train_classes:
+
+ cls_idxs = np.where(train_dataset.targets == cls)[0]
+
+ v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
+ t_ = [x for x in cls_idxs if x not in v_]
+
+ train_idxs.extend(t_)
+ val_idxs.extend(v_)
+
+ return train_idxs, val_idxs
+
+
+def get_cifar_10_datasets(train_transform, test_transform, train_classes=(0, 1, 8, 9),
+ prop_train_labels=0.8, split_train_val=False, seed=0):
+
+ np.random.seed(seed)
+
+ # Init entire training set
+ whole_training_set = CustomCIFAR10(root=cifar_10_root, transform=train_transform, train=True)
+
+ # Get labelled training set which has subsampled classes, then subsample some indices from that
+ train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
+ subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
+ train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
+
+ # Split into training and validation sets
+ train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
+ train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
+ val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
+ val_dataset_labelled_split.transform = test_transform
+
+ # Get unlabelled data
+ unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
+ train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
+
+ # Get test set for all classes
+ test_dataset = CustomCIFAR10(root=cifar_10_root, transform=test_transform, train=False)
+
+ # Either split train into train and val or use test set as val
+ train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
+ val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
+
+ all_datasets = {
+ 'train_labelled': train_dataset_labelled,
+ 'train_unlabelled': train_dataset_unlabelled,
+ 'val': val_dataset_labelled,
+ 'test': test_dataset,
+ }
+
+ return all_datasets
+
+
+def get_cifar_100_datasets(train_transform, test_transform, train_classes=range(80),
+ prop_train_labels=0.8, split_train_val=False, seed=0):
+
+ np.random.seed(seed)
+
+ # Init entire training set
+ whole_training_set = CustomCIFAR100(root=cifar_100_root, transform=train_transform, train=True, download=True)
+
+ # Get labelled training set which has subsampled classes, then subsample some indices from that
+ train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
+ subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
+ train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
+
+ # Split into training and validation sets
+ train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
+ train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
+ val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
+ val_dataset_labelled_split.transform = test_transform
+
+ # Get unlabelled data
+ unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
+ train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
+
+ # Get test set for all classes
+ test_dataset = CustomCIFAR100(root=cifar_100_root, transform=test_transform, train=False, download=True)
+
+ # Either split train into train and val or use test set as val
+ train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
+ val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
+
+ all_datasets = {
+ 'train_labelled': train_dataset_labelled,
+ 'train_unlabelled': train_dataset_unlabelled,
+ 'val': val_dataset_labelled,
+ 'test': test_dataset,
+ }
+
+ return all_datasets
+
+
+if __name__ == '__main__':
+
+ x = get_cifar_100_datasets(None, None, split_train_val=False,
+ train_classes=range(80), prop_train_labels=0.5)
+
+ print('Printing lens...')
+ for k, v in x.items():
+ if v is not None:
+ print(f'{k}: {len(v)}')
+
+ print('Printing labelled and unlabelled overlap...')
+ print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
+ print('Printing total instances in train...')
+ print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
+
+ print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}')
+ print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}')
+ print(f'Len labelled set: {len(x["train_labelled"])}')
+ print(f'Len unlabelled set: {len(x["train_unlabelled"])}')
\ No newline at end of file
diff --git a/data/cub.py b/data/cub.py
new file mode 100644
index 0000000..1198bbf
--- /dev/null
+++ b/data/cub.py
@@ -0,0 +1,203 @@
+import os
+import pandas as pd
+import numpy as np
+from copy import deepcopy
+
+from torchvision.datasets.folder import default_loader
+from torchvision.datasets.utils import download_url
+from torch.utils.data import Dataset
+
+from data.data_utils import subsample_instances
+from config import cub_root
+
+
+class CustomCub2011(Dataset):
+ base_folder = 'CUB_200_2011/images'
+ url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'
+ filename = 'CUB_200_2011.tgz'
+ tgz_md5 = '97eceeb196236b17998738112f37df78'
+
+ def __init__(self, root, train=True, transform=None, target_transform=None, loader=default_loader, download=True):
+
+ self.root = os.path.expanduser(root)
+ self.transform = transform
+ self.target_transform = target_transform
+
+ self.loader = loader
+ self.train = train
+
+
+ if download:
+ self._download()
+
+ if not self._check_integrity():
+ raise RuntimeError('Dataset not found or corrupted.' +
+ ' You can use download=True to download it')
+
+ self.uq_idxs = np.array(range(len(self)))
+
+ def _load_metadata(self):
+ images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ',
+ names=['img_id', 'filepath'])
+ image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'),
+ sep=' ', names=['img_id', 'target'])
+ train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),
+ sep=' ', names=['img_id', 'is_training_img'])
+
+ data = images.merge(image_class_labels, on='img_id')
+ self.data = data.merge(train_test_split, on='img_id')
+
+ if self.train:
+ self.data = self.data[self.data.is_training_img == 1]
+ else:
+ self.data = self.data[self.data.is_training_img == 0]
+
+ def _check_integrity(self):
+ try:
+ self._load_metadata()
+ except Exception:
+ return False
+
+ for index, row in self.data.iterrows():
+ filepath = os.path.join(self.root, self.base_folder, row.filepath)
+ if not os.path.isfile(filepath):
+ print(filepath)
+ return False
+ return True
+
+ def _download(self):
+ import tarfile
+
+ if self._check_integrity():
+ print('Files already downloaded and verified')
+ return
+
+ download_url(self.url, self.root, self.filename, self.tgz_md5)
+
+ with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
+ tar.extractall(path=self.root)
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ sample = self.data.iloc[idx]
+ path = os.path.join(self.root, self.base_folder, sample.filepath)
+ target = sample.target - 1 # Targets start at 1 by default, so shift to 0
+ img = self.loader(path)
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target, self.uq_idxs[idx]
+
+
+def subsample_dataset(dataset, idxs):
+
+ mask = np.zeros(len(dataset)).astype('bool')
+ mask[idxs] = True
+
+ dataset.data = dataset.data[mask]
+ dataset.uq_idxs = dataset.uq_idxs[mask]
+
+ return dataset
+
+
+def subsample_classes(dataset, include_classes=range(160)):
+
+ include_classes_cub = np.array(include_classes) + 1 # CUB classes are indexed 1 --> 200 instead of 0 --> 199
+ cls_idxs = [x for x, (_, r) in enumerate(dataset.data.iterrows()) if int(r['target']) in include_classes_cub]
+
+ # TODO: For now have no target transform
+ target_xform_dict = {}
+ for i, k in enumerate(include_classes):
+ target_xform_dict[k] = i
+
+ dataset = subsample_dataset(dataset, cls_idxs)
+
+ dataset.target_transform = lambda x: target_xform_dict[x]
+
+ return dataset
+
+
+def get_train_val_indices(train_dataset, val_split=0.2):
+
+ train_classes = np.unique(train_dataset.data['target'])
+
+ # Get train/test indices
+ train_idxs = []
+ val_idxs = []
+ for cls in train_classes:
+
+ cls_idxs = np.where(train_dataset.data['target'] == cls)[0]
+
+ v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
+ t_ = [x for x in cls_idxs if x not in v_]
+
+ train_idxs.extend(t_)
+ val_idxs.extend(v_)
+
+ return train_idxs, val_idxs
+
+
+def get_cub_datasets(train_transform, test_transform, train_classes=range(160), prop_train_labels=0.8,
+ split_train_val=False, seed=0, download=False):
+
+ np.random.seed(seed)
+
+ # Init entire training set
+ whole_training_set = CustomCub2011(root=cub_root, transform=train_transform, train=True, download=download)
+
+ # Get labelled training set which has subsampled classes, then subsample some indices from that
+ train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
+ subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
+ train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
+
+ # Split into training and validation sets
+ train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
+ train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
+ val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
+ val_dataset_labelled_split.transform = test_transform
+
+ # Get unlabelled data
+ unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
+ train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
+
+ # Get test set for all classes
+ test_dataset = CustomCub2011(root=cub_root, transform=test_transform, train=False)
+
+ # Either split train into train and val or use test set as val
+ train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
+ val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
+
+ all_datasets = {
+ 'train_labelled': train_dataset_labelled,
+ 'train_unlabelled': train_dataset_unlabelled,
+ 'val': val_dataset_labelled,
+ 'test': test_dataset,
+ }
+
+ return all_datasets
+
+if __name__ == '__main__':
+
+ x = get_cub_datasets(None, None, split_train_val=False,
+ train_classes=range(100), prop_train_labels=0.5)
+
+ print('Printing lens...')
+ for k, v in x.items():
+ if v is not None:
+ print(f'{k}: {len(v)}')
+
+ print('Printing labelled and unlabelled overlap...')
+ print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
+ print('Printing total instances in train...')
+ print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
+
+ print(f'Num Labelled Classes: {len(set(x["train_labelled"].data["target"].values))}')
+ print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].data["target"].values))}')
+ print(f'Len labelled set: {len(x["train_labelled"])}')
+ print(f'Len unlabelled set: {len(x["train_unlabelled"])}')
\ No newline at end of file
diff --git a/data/data_utils.py b/data/data_utils.py
new file mode 100644
index 0000000..d05d7a1
--- /dev/null
+++ b/data/data_utils.py
@@ -0,0 +1,40 @@
+import numpy as np
+from torch.utils.data import Dataset
+
+def subsample_instances(dataset, prop_indices_to_subsample=0.8):
+
+ np.random.seed(0)
+ subsample_indices = np.random.choice(range(len(dataset)), replace=False,
+ size=(int(prop_indices_to_subsample * len(dataset)),))
+
+ return subsample_indices
+
+class MergedDataset(Dataset):
+
+ """
+ Takes two datasets (labelled_dataset, unlabelled_dataset) and merges them
+ Allows you to iterate over them in parallel
+ """
+
+ def __init__(self, labelled_dataset, unlabelled_dataset):
+
+ self.labelled_dataset = labelled_dataset
+ self.unlabelled_dataset = unlabelled_dataset
+ self.target_transform = None
+
+ def __getitem__(self, item):
+
+ if item < len(self.labelled_dataset):
+ img, label, uq_idx = self.labelled_dataset[item]
+ labeled_or_not = 1
+
+ else:
+
+ img, label, uq_idx = self.unlabelled_dataset[item - len(self.labelled_dataset)]
+ labeled_or_not = 0
+
+
+ return img, label, uq_idx, np.array([labeled_or_not])
+
+ def __len__(self):
+ return len(self.unlabelled_dataset) + len(self.labelled_dataset)
diff --git a/data/fgvc_aircraft.py b/data/fgvc_aircraft.py
new file mode 100644
index 0000000..97dc745
--- /dev/null
+++ b/data/fgvc_aircraft.py
@@ -0,0 +1,270 @@
+import os
+import numpy as np
+from copy import deepcopy
+
+from torchvision.datasets.folder import default_loader
+from torch.utils.data import Dataset
+
+from data.data_utils import subsample_instances
+from config import aircraft_root
+
+def make_dataset(dir, image_ids, targets):
+ assert(len(image_ids) == len(targets))
+ images = []
+ dir = os.path.expanduser(dir)
+ for i in range(len(image_ids)):
+ item = (os.path.join(dir, 'data', 'images',
+ '%s.jpg' % image_ids[i]), targets[i])
+ images.append(item)
+ return images
+
+
+def find_classes(classes_file):
+
+ # read classes file, separating out image IDs and class names
+ image_ids = []
+ targets = []
+ f = open(classes_file, 'r')
+ for line in f:
+ split_line = line.split(' ')
+ image_ids.append(split_line[0])
+ targets.append(' '.join(split_line[1:]))
+ f.close()
+
+ # index class names
+ classes = np.unique(targets)
+ class_to_idx = {classes[i]: i for i in range(len(classes))}
+ targets = [class_to_idx[c] for c in targets]
+
+ return (image_ids, targets, classes, class_to_idx)
+
+
+class FGVCAircraft(Dataset):
+
+ """`FGVC-Aircraft `_ Dataset.
+
+ Args:
+ root (string): Root directory path to dataset.
+ class_type (string, optional): The level of FGVC-Aircraft fine-grain classification
+ to label data with (i.e., ``variant``, ``family``, or ``manufacturer``).
+ transform (callable, optional): A function/transform that takes in a PIL image
+ and returns a transformed version. E.g. ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ loader (callable, optional): A function to load an image given its path.
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in the root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ """
+ url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'
+ class_types = ('variant', 'family', 'manufacturer')
+ splits = ('train', 'val', 'trainval', 'test')
+
+ def __init__(self, root, class_type='variant', split='train', transform=None,
+ target_transform=None, loader=default_loader, download=False):
+ if split not in self.splits:
+ raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
+ split, ', '.join(self.splits),
+ ))
+ if class_type not in self.class_types:
+ raise ValueError('Class type "{}" not found. Valid class types are: {}'.format(
+ class_type, ', '.join(self.class_types),
+ ))
+ self.root = os.path.expanduser(root)
+ self.class_type = class_type
+ self.split = split
+ self.classes_file = os.path.join(self.root, 'data',
+ 'images_%s_%s.txt' % (self.class_type, self.split))
+
+ if download:
+ self.download()
+
+ (image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file)
+ samples = make_dataset(self.root, image_ids, targets)
+
+ self.transform = transform
+ self.target_transform = target_transform
+ self.loader = loader
+
+ self.samples = samples
+ self.classes = classes
+ self.class_to_idx = class_to_idx
+ self.train = True if split == 'train' else False
+
+ self.uq_idxs = np.array(range(len(self)))
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (sample, target) where target is class_index of the target class.
+ """
+
+ path, target = self.samples[index]
+ sample = self.loader(path)
+ if self.transform is not None:
+ sample = self.transform(sample)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return sample, target, self.uq_idxs[index]
+
+ def __len__(self):
+ return len(self.samples)
+
+ def __repr__(self):
+ fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
+ fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
+ fmt_str += ' Root Location: {}\n'.format(self.root)
+ tmp = ' Transforms (if any): '
+ fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
+ tmp = ' Target Transforms (if any): '
+ fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
+ return fmt_str
+
+ def _check_exists(self):
+ return os.path.exists(os.path.join(self.root, 'data', 'images')) and \
+ os.path.exists(self.classes_file)
+
+ def download(self):
+ """Download the FGVC-Aircraft data if it doesn't exist already."""
+ from six.moves import urllib
+ import tarfile
+
+ if self._check_exists():
+ return
+
+ # prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz
+ print('Downloading %s ... (may take a few minutes)' % self.url)
+ parent_dir = os.path.abspath(os.path.join(self.root, os.pardir))
+ tar_name = self.url.rpartition('/')[-1]
+ tar_path = os.path.join(parent_dir, tar_name)
+ data = urllib.request.urlopen(self.url)
+
+ # download .tar.gz file
+ with open(tar_path, 'wb') as f:
+ f.write(data.read())
+
+ # extract .tar.gz to PARENT_DIR/fgvc-aircraft-2013b
+ data_folder = tar_path.strip('.tar.gz')
+ print('Extracting %s to %s ... (may take a few minutes)' % (tar_path, data_folder))
+ tar = tarfile.open(tar_path)
+ tar.extractall(parent_dir)
+
+ # if necessary, rename data folder to self.root
+ if not os.path.samefile(data_folder, self.root):
+ print('Renaming %s to %s ...' % (data_folder, self.root))
+ os.rename(data_folder, self.root)
+
+ # delete .tar.gz file
+ print('Deleting %s ...' % tar_path)
+ os.remove(tar_path)
+
+ print('Done!')
+
+
+def subsample_dataset(dataset, idxs):
+
+ mask = np.zeros(len(dataset)).astype('bool')
+ mask[idxs] = True
+
+ dataset.samples = [(p, t) for i, (p, t) in enumerate(dataset.samples) if i in idxs]
+ dataset.uq_idxs = dataset.uq_idxs[mask]
+
+ return dataset
+
+
+def subsample_classes(dataset, include_classes=range(60)):
+
+ cls_idxs = [i for i, (p, t) in enumerate(dataset.samples) if t in include_classes]
+
+ # TODO: Don't transform targets for now
+ target_xform_dict = {}
+ for i, k in enumerate(include_classes):
+ target_xform_dict[k] = i
+
+ dataset = subsample_dataset(dataset, cls_idxs)
+
+ dataset.target_transform = lambda x: target_xform_dict[x]
+
+ return dataset
+
+
+def get_train_val_indices(train_dataset, val_split=0.2):
+
+ all_targets = [t for i, (p, t) in enumerate(train_dataset.samples)]
+ train_classes = np.unique(all_targets)
+
+ # Get train/test indices
+ train_idxs = []
+ val_idxs = []
+ for cls in train_classes:
+ cls_idxs = np.where(all_targets == cls)[0]
+
+ v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
+ t_ = [x for x in cls_idxs if x not in v_]
+
+ train_idxs.extend(t_)
+ val_idxs.extend(v_)
+
+ return train_idxs, val_idxs
+
+
+def get_aircraft_datasets(train_transform, test_transform, train_classes=range(50), prop_train_labels=0.8,
+ split_train_val=False, seed=0):
+
+ np.random.seed(seed)
+
+ # Init entire training set
+ whole_training_set = FGVCAircraft(root=aircraft_root, transform=train_transform, split='trainval')
+
+ # Get labelled training set which has subsampled classes, then subsample some indices from that
+ train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
+ subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
+ train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
+
+ # Split into training and validation sets
+ train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
+ train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
+ val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
+ val_dataset_labelled_split.transform = test_transform
+
+ # Get unlabelled data
+ unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
+ train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
+
+ # Get test set for all classes
+ test_dataset = FGVCAircraft(root=aircraft_root, transform=test_transform, split='test')
+
+ # Either split train into train and val or use test set as val
+ train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
+ val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
+
+ all_datasets = {
+ 'train_labelled': train_dataset_labelled,
+ 'train_unlabelled': train_dataset_unlabelled,
+ 'val': val_dataset_labelled,
+ 'test': test_dataset,
+ }
+
+ return all_datasets
+
+if __name__ == '__main__':
+
+ x = get_aircraft_datasets(None, None, split_train_val=False)
+
+ print('Printing lens...')
+ for k, v in x.items():
+ if v is not None:
+ print(f'{k}: {len(v)}')
+
+ print('Printing labelled and unlabelled overlap...')
+ print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
+ print('Printing total instances in train...')
+ print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
+ print('Printing number of labelled classes...')
+ print(len(set([i[1] for i in x['train_labelled'].samples])))
+ print('Printing total number of classes...')
+ print(len(set([i[1] for i in x['train_unlabelled'].samples])))
diff --git a/data/get_datasets.py b/data/get_datasets.py
new file mode 100644
index 0000000..dfe12eb
--- /dev/null
+++ b/data/get_datasets.py
@@ -0,0 +1,176 @@
+from data.data_utils import MergedDataset
+
+from data.cifar import get_cifar_10_datasets, get_cifar_100_datasets
+from data.herbarium_19 import get_herbarium_datasets
+from data.stanford_cars import get_scars_datasets
+from data.imagenet import get_imagenet_100_datasets, get_imagenet_1k_datasets
+from data.cub import get_cub_datasets
+from data.fgvc_aircraft import get_aircraft_datasets
+
+from copy import deepcopy
+import pickle
+import os
+
+from config import osr_split_dir
+
+
+get_dataset_funcs = {
+ 'cifar10': get_cifar_10_datasets,
+ 'cifar100': get_cifar_100_datasets,
+ 'imagenet_100': get_imagenet_100_datasets,
+ 'imagenet_1k': get_imagenet_1k_datasets,
+ 'herbarium_19': get_herbarium_datasets,
+ 'cub': get_cub_datasets,
+ 'aircraft': get_aircraft_datasets,
+ 'scars': get_scars_datasets
+}
+
+
+def get_datasets(dataset_name, train_transform, test_transform, args):
+
+ """
+ :return: train_dataset: MergedDataset which concatenates labelled and unlabelled
+ test_dataset,
+ unlabelled_train_examples_test,
+ datasets
+ """
+
+ #
+ if dataset_name not in get_dataset_funcs.keys():
+ raise ValueError
+
+ # Get datasets
+ get_dataset_f = get_dataset_funcs[dataset_name]
+ datasets = get_dataset_f(train_transform=train_transform, test_transform=test_transform,
+ train_classes=args.train_classes,
+ prop_train_labels=args.prop_train_labels,
+ split_train_val=False)
+ # Set target transforms:
+ target_transform_dict = {}
+ for i, cls in enumerate(list(args.train_classes) + list(args.unlabeled_classes)):
+ target_transform_dict[cls] = i
+ target_transform = lambda x: target_transform_dict[x]
+
+ for dataset_name, dataset in datasets.items():
+ if dataset is not None:
+ dataset.target_transform = target_transform
+
+ # Train split (labelled and unlabelled classes) for training
+ train_dataset = MergedDataset(labelled_dataset=deepcopy(datasets['train_labelled']),
+ unlabelled_dataset=deepcopy(datasets['train_unlabelled']))
+
+ test_dataset = datasets['test']
+ unlabelled_train_examples_test = deepcopy(datasets['train_unlabelled'])
+ unlabelled_train_examples_test.transform = test_transform
+
+ return train_dataset, test_dataset, unlabelled_train_examples_test, datasets
+
+
+def get_class_splits(args):
+
+ # For FGVC datasets, optionally return bespoke splits
+ if args.dataset_name in ('scars', 'cub', 'aircraft'):
+ if hasattr(args, 'use_ssb_splits'):
+ use_ssb_splits = args.use_ssb_splits
+ else:
+ use_ssb_splits = False
+
+ # -------------
+ # GET CLASS SPLITS
+ # -------------
+ if args.dataset_name == 'cifar10':
+
+ args.image_size = 32
+ args.train_classes = range(5)
+ args.unlabeled_classes = range(5, 10)
+
+ elif args.dataset_name == 'cifar100':
+
+ args.image_size = 32
+ args.train_classes = range(80)
+ args.unlabeled_classes = range(80, 100)
+
+ elif args.dataset_name == 'herbarium_19':
+
+ args.image_size = 224
+ herb_path_splits = os.path.join(osr_split_dir, 'herbarium_19_class_splits.pkl')
+
+ with open(herb_path_splits, 'rb') as handle:
+ class_splits = pickle.load(handle)
+
+ args.train_classes = class_splits['Old']
+ args.unlabeled_classes = class_splits['New']
+
+ elif args.dataset_name == 'imagenet_100':
+
+ args.image_size = 224
+ args.train_classes = range(50)
+ args.unlabeled_classes = range(50, 100)
+
+ elif args.dataset_name == 'imagenet_1k':
+
+ args.image_size = 224
+ args.train_classes = range(500)
+ args.unlabeled_classes = range(500, 1000)
+
+ elif args.dataset_name == 'scars':
+
+ args.image_size = 224
+
+ if use_ssb_splits:
+
+ split_path = os.path.join(osr_split_dir, 'scars_osr_splits.pkl')
+ with open(split_path, 'rb') as handle:
+ class_info = pickle.load(handle)
+
+ args.train_classes = class_info['known_classes']
+ open_set_classes = class_info['unknown_classes']
+ args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy']
+
+ else:
+
+ args.train_classes = range(98)
+ args.unlabeled_classes = range(98, 196)
+
+ elif args.dataset_name == 'aircraft':
+
+ args.image_size = 224
+ if use_ssb_splits:
+
+ split_path = os.path.join(osr_split_dir, 'aircraft_osr_splits.pkl')
+ with open(split_path, 'rb') as handle:
+ class_info = pickle.load(handle)
+
+ args.train_classes = class_info['known_classes']
+ open_set_classes = class_info['unknown_classes']
+ args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy']
+
+ else:
+
+ args.train_classes = range(50)
+ args.unlabeled_classes = range(50, 100)
+
+ elif args.dataset_name == 'cub':
+
+ args.image_size = 224
+
+ if use_ssb_splits:
+
+ split_path = os.path.join(osr_split_dir, 'cub_osr_splits.pkl')
+ with open(split_path, 'rb') as handle:
+ class_info = pickle.load(handle)
+
+ args.train_classes = class_info['known_classes']
+ open_set_classes = class_info['unknown_classes']
+ args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy']
+
+ else:
+
+ args.train_classes = range(100)
+ args.unlabeled_classes = range(100, 200)
+
+ else:
+
+ raise NotImplementedError
+
+ return args
diff --git a/data/herbarium_19.py b/data/herbarium_19.py
new file mode 100644
index 0000000..1d87f85
--- /dev/null
+++ b/data/herbarium_19.py
@@ -0,0 +1,164 @@
+import os
+
+import torchvision
+import numpy as np
+from copy import deepcopy
+
+from data.data_utils import subsample_instances
+from config import herbarium_dataroot
+
+class HerbariumDataset19(torchvision.datasets.ImageFolder):
+
+ def __init__(self, *args, **kwargs):
+
+ # Process metadata json for training images into a DataFrame
+ super().__init__(*args, **kwargs)
+
+ self.uq_idxs = np.array(range(len(self)))
+
+ def __getitem__(self, idx):
+
+ img, label = super().__getitem__(idx)
+ uq_idx = self.uq_idxs[idx]
+
+ return img, label, uq_idx
+
+
+def subsample_dataset(dataset, idxs):
+
+ mask = np.zeros(len(dataset)).astype('bool')
+ mask[idxs] = True
+
+ dataset.samples = np.array(dataset.samples)[mask].tolist()
+ dataset.targets = np.array(dataset.targets)[mask].tolist()
+
+ dataset.uq_idxs = dataset.uq_idxs[mask]
+
+ dataset.samples = [[x[0], int(x[1])] for x in dataset.samples]
+ dataset.targets = [int(x) for x in dataset.targets]
+
+ return dataset
+
+
+def subsample_classes(dataset, include_classes=range(250)):
+
+ cls_idxs = [x for x, l in enumerate(dataset.targets) if l in include_classes]
+
+ target_xform_dict = {}
+ for i, k in enumerate(include_classes):
+ target_xform_dict[k] = i
+
+ dataset = subsample_dataset(dataset, cls_idxs)
+
+ dataset.target_transform = lambda x: target_xform_dict[x]
+
+ return dataset
+
+
+def get_train_val_indices(train_dataset, val_instances_per_class=5):
+
+ train_classes = list(set(train_dataset.targets))
+
+ # Get train/test indices
+ train_idxs = []
+ val_idxs = []
+ for cls in train_classes:
+
+ cls_idxs = np.where(np.array(train_dataset.targets) == cls)[0]
+
+ # Have a balanced test set
+ v_ = np.random.choice(cls_idxs, replace=False, size=(val_instances_per_class,))
+ t_ = [x for x in cls_idxs if x not in v_]
+
+ train_idxs.extend(t_)
+ val_idxs.extend(v_)
+
+ return train_idxs, val_idxs
+
+
+def get_herbarium_datasets(train_transform, test_transform, train_classes=range(500), prop_train_labels=0.8,
+ seed=0, split_train_val=False):
+
+ np.random.seed(seed)
+
+ # Init entire training set
+ train_dataset = HerbariumDataset19(transform=train_transform,
+ root=os.path.join(herbarium_dataroot, 'small-train'))
+
+ # Get labelled training set which has subsampled classes, then subsample some indices from that
+ # TODO: Subsampling unlabelled set in uniform random fashion from training data, will contain many instances of dominant class
+ train_dataset_labelled = subsample_classes(deepcopy(train_dataset), include_classes=train_classes)
+ subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
+ train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
+
+ # Split into training and validation sets
+ if split_train_val:
+
+ train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled,
+ val_instances_per_class=5)
+ train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
+ val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
+ val_dataset_labelled_split.transform = test_transform
+
+ else:
+
+ train_dataset_labelled_split, val_dataset_labelled_split = None, None
+
+ # Get unlabelled data
+ unlabelled_indices = set(train_dataset.uq_idxs) - set(train_dataset_labelled.uq_idxs)
+ train_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset), np.array(list(unlabelled_indices)))
+
+ # Get test dataset
+ test_dataset = HerbariumDataset19(transform=test_transform,
+ root=os.path.join(herbarium_dataroot, 'small-validation'))
+
+ # Transform dict
+ unlabelled_classes = list(set(train_dataset.targets) - set(train_classes))
+ target_xform_dict = {}
+ for i, k in enumerate(list(train_classes) + unlabelled_classes):
+ target_xform_dict[k] = i
+
+ test_dataset.target_transform = lambda x: target_xform_dict[x]
+ train_dataset_unlabelled.target_transform = lambda x: target_xform_dict[x]
+
+ # Either split train into train and val or use test set as val
+ train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
+ val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
+
+ all_datasets = {
+ 'train_labelled': train_dataset_labelled,
+ 'train_unlabelled': train_dataset_unlabelled,
+ 'val': val_dataset_labelled,
+ 'test': test_dataset,
+ }
+
+ return all_datasets
+
+if __name__ == '__main__':
+
+ np.random.seed(0)
+ train_classes = np.random.choice(range(683,), size=(int(683 / 2)), replace=False)
+
+ x = get_herbarium_datasets(None, None, train_classes=train_classes,
+ prop_train_labels=0.5)
+
+ assert set(x['train_unlabelled'].targets) == set(range(683))
+
+ print('Printing lens...')
+ for k, v in x.items():
+ if v is not None:
+ print(f'{k}: {len(v)}')
+
+ print('Printing labelled and unlabelled overlap...')
+ print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
+ print('Printing total instances in train...')
+ print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
+ print('Printing number of labelled classes...')
+ print(len(set(x['train_labelled'].targets)))
+ print('Printing total number of classes...')
+ print(len(set(x['train_unlabelled'].targets)))
+
+ print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}')
+ print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}')
+ print(f'Len labelled set: {len(x["train_labelled"])}')
+ print(f'Len unlabelled set: {len(x["train_unlabelled"])}')
\ No newline at end of file
diff --git a/data/imagenet.py b/data/imagenet.py
new file mode 100644
index 0000000..a7d0489
--- /dev/null
+++ b/data/imagenet.py
@@ -0,0 +1,201 @@
+import torchvision
+import numpy as np
+
+import os
+
+from copy import deepcopy
+from data.data_utils import subsample_instances
+from config import imagenet_root
+
+
+class ImageNetBase(torchvision.datasets.ImageFolder):
+
+ def __init__(self, root, transform):
+
+ super(ImageNetBase, self).__init__(root, transform)
+
+ self.uq_idxs = np.array(range(len(self)))
+
+ def __getitem__(self, item):
+
+ img, label = super().__getitem__(item)
+ uq_idx = self.uq_idxs[item]
+
+ return img, label, uq_idx
+
+
+def subsample_dataset(dataset, idxs):
+
+ imgs_ = []
+ for i in idxs:
+ imgs_.append(dataset.imgs[i])
+ dataset.imgs = imgs_
+
+ samples_ = []
+ for i in idxs:
+ samples_.append(dataset.samples[i])
+ dataset.samples = samples_
+
+ # dataset.imgs = [x for i, x in enumerate(dataset.imgs) if i in idxs]
+ # dataset.samples = [x for i, x in enumerate(dataset.samples) if i in idxs]
+
+ dataset.targets = np.array(dataset.targets)[idxs].tolist()
+ dataset.uq_idxs = dataset.uq_idxs[idxs]
+
+ return dataset
+
+
+def subsample_classes(dataset, include_classes=list(range(1000))):
+
+ cls_idxs = [x for x, t in enumerate(dataset.targets) if t in include_classes]
+
+ target_xform_dict = {}
+ for i, k in enumerate(include_classes):
+ target_xform_dict[k] = i
+
+ dataset = subsample_dataset(dataset, cls_idxs)
+ dataset.target_transform = lambda x: target_xform_dict[x]
+
+ return dataset
+
+
+def get_train_val_indices(train_dataset, val_split=0.2):
+
+ train_classes = list(set(train_dataset.targets))
+
+ # Get train/test indices
+ train_idxs = []
+ val_idxs = []
+ for cls in train_classes:
+
+ cls_idxs = np.where(np.array(train_dataset.targets) == cls)[0]
+
+ v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
+ t_ = [x for x in cls_idxs if x not in v_]
+
+ train_idxs.extend(t_)
+ val_idxs.extend(v_)
+
+ return train_idxs, val_idxs
+
+
+def get_imagenet_100_datasets(train_transform, test_transform, train_classes=range(80),
+ prop_train_labels=0.8, split_train_val=False, seed=0):
+
+ np.random.seed(seed)
+
+ # Subsample imagenet dataset initially to include 100 classes
+ subsampled_100_classes = np.random.choice(range(1000), size=(100,), replace=False)
+ subsampled_100_classes = np.sort(subsampled_100_classes)
+ print(f'Constructing ImageNet-100 dataset from the following classes: {subsampled_100_classes.tolist()}')
+ cls_map = {i: j for i, j in zip(subsampled_100_classes, range(100))}
+
+ # Init entire training set
+ imagenet_training_set = ImageNetBase(root=os.path.join(imagenet_root, 'train'), transform=train_transform)
+ whole_training_set = subsample_classes(imagenet_training_set, include_classes=subsampled_100_classes)
+
+ # Reset dataset
+ whole_training_set.samples = [(s[0], cls_map[s[1]]) for s in whole_training_set.samples]
+ whole_training_set.targets = [s[1] for s in whole_training_set.samples]
+ whole_training_set.uq_idxs = np.array(range(len(whole_training_set)))
+ whole_training_set.target_transform = None
+
+ # Get labelled training set which has subsampled classes, then subsample some indices from that
+ train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
+ subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
+ train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
+
+ # Split into training and validation sets
+ train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
+ train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
+ val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
+ val_dataset_labelled_split.transform = test_transform
+
+ # Get unlabelled data
+ unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
+ train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
+
+ # Get test set for all classes
+ test_dataset = ImageNetBase(root=os.path.join(imagenet_root, 'val'), transform=test_transform)
+ test_dataset = subsample_classes(test_dataset, include_classes=subsampled_100_classes)
+
+ # Reset test set
+ test_dataset.samples = [(s[0], cls_map[s[1]]) for s in test_dataset.samples]
+ test_dataset.targets = [s[1] for s in test_dataset.samples]
+ test_dataset.uq_idxs = np.array(range(len(test_dataset)))
+ test_dataset.target_transform = None
+
+ # Either split train into train and val or use test set as val
+ train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
+ val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
+
+ all_datasets = {
+ 'train_labelled': train_dataset_labelled,
+ 'train_unlabelled': train_dataset_unlabelled,
+ 'val': val_dataset_labelled,
+ 'test': test_dataset,
+ }
+
+ return all_datasets
+
+
+def get_imagenet_1k_datasets(train_transform, test_transform, train_classes=range(500),
+ prop_train_labels=0.5, split_train_val=False, seed=0):
+
+ np.random.seed(seed)
+
+ # Init entire training set
+ whole_training_set = ImageNetBase(root=os.path.join(imagenet_root, 'train'), transform=train_transform)
+
+ # Get labelled training set which has subsampled classes, then subsample some indices from that
+ train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
+ subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
+ train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
+
+ # Split into training and validation sets
+ train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
+ train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
+ val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
+ val_dataset_labelled_split.transform = test_transform
+
+ # Get unlabelled data
+ unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
+ train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
+
+ # Get test set for all classes
+ test_dataset = ImageNetBase(root=os.path.join(imagenet_root, 'val'), transform=test_transform)
+
+ # Either split train into train and val or use test set as val
+ train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
+ val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
+
+ all_datasets = {
+ 'train_labelled': train_dataset_labelled,
+ 'train_unlabelled': train_dataset_unlabelled,
+ 'val': val_dataset_labelled,
+ 'test': test_dataset,
+ }
+
+ return all_datasets
+
+
+
+if __name__ == '__main__':
+
+ x = get_imagenet_100_datasets(None, None, split_train_val=False,
+ train_classes=range(50), prop_train_labels=0.5)
+
+ print('Printing lens...')
+ for k, v in x.items():
+ if v is not None:
+ print(f'{k}: {len(v)}')
+
+ print('Printing labelled and unlabelled overlap...')
+ print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
+ print('Printing total instances in train...')
+ print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
+
+ print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}')
+ print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}')
+ print(f'Len labelled set: {len(x["train_labelled"])}')
+ print(f'Len unlabelled set: {len(x["train_unlabelled"])}')
\ No newline at end of file
diff --git a/data/ssb_splits/aircraft_osr_splits.pkl b/data/ssb_splits/aircraft_osr_splits.pkl
new file mode 100644
index 0000000..2a70d70
Binary files /dev/null and b/data/ssb_splits/aircraft_osr_splits.pkl differ
diff --git a/data/ssb_splits/cub_osr_splits.pkl b/data/ssb_splits/cub_osr_splits.pkl
new file mode 100644
index 0000000..8631178
Binary files /dev/null and b/data/ssb_splits/cub_osr_splits.pkl differ
diff --git a/data/ssb_splits/herbarium_19_class_splits.pkl b/data/ssb_splits/herbarium_19_class_splits.pkl
new file mode 100644
index 0000000..e2980cf
Binary files /dev/null and b/data/ssb_splits/herbarium_19_class_splits.pkl differ
diff --git a/data/ssb_splits/scars_osr_splits.pkl b/data/ssb_splits/scars_osr_splits.pkl
new file mode 100644
index 0000000..c0d279d
Binary files /dev/null and b/data/ssb_splits/scars_osr_splits.pkl differ
diff --git a/data/stanford_cars.py b/data/stanford_cars.py
new file mode 100644
index 0000000..4979ae4
--- /dev/null
+++ b/data/stanford_cars.py
@@ -0,0 +1,166 @@
+import os
+import pandas as pd
+import numpy as np
+from copy import deepcopy
+from scipy import io as mat_io
+
+from torchvision.datasets.folder import default_loader
+from torch.utils.data import Dataset
+
+from data.data_utils import subsample_instances
+from config import car_root
+
+class CarsDataset(Dataset):
+ """
+ Cars Dataset
+ """
+ def __init__(self, train=True, limit=0, data_dir=car_root, transform=None):
+
+ metas = os.path.join(data_dir, 'devkit/cars_train_annos.mat') if train else os.path.join(data_dir, 'devkit/cars_test_annos_withlabels.mat')
+ data_dir = os.path.join(data_dir, 'cars_train/') if train else os.path.join(data_dir, 'cars_test/')
+
+ self.loader = default_loader
+ self.data_dir = data_dir
+ self.data = []
+ self.target = []
+ self.train = train
+
+ self.transform = transform
+
+ if not isinstance(metas, str):
+ raise Exception("Train metas must be string location !")
+ labels_meta = mat_io.loadmat(metas)
+
+ for idx, img_ in enumerate(labels_meta['annotations'][0]):
+ if limit:
+ if idx > limit:
+ break
+
+ # self.data.append(img_resized)
+ self.data.append(data_dir + img_[5][0])
+ # if self.mode == 'train':
+ self.target.append(img_[4][0][0])
+
+ self.uq_idxs = np.array(range(len(self)))
+ self.target_transform = None
+
+ def __getitem__(self, idx):
+
+ image = self.loader(self.data[idx])
+ target = self.target[idx] - 1
+
+ if self.transform is not None:
+ image = self.transform(image)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ idx = self.uq_idxs[idx]
+
+ return image, target, idx
+
+ def __len__(self):
+ return len(self.data)
+
+
+def subsample_dataset(dataset, idxs):
+
+ dataset.data = np.array(dataset.data)[idxs].tolist()
+ dataset.target = np.array(dataset.target)[idxs].tolist()
+ dataset.uq_idxs = dataset.uq_idxs[idxs]
+
+ return dataset
+
+
+def subsample_classes(dataset, include_classes=range(160)):
+
+ include_classes_cars = np.array(include_classes) + 1 # SCars classes are indexed 1 --> 196 instead of 0 --> 195
+ cls_idxs = [x for x, t in enumerate(dataset.target) if t in include_classes_cars]
+
+ target_xform_dict = {}
+ for i, k in enumerate(include_classes):
+ target_xform_dict[k] = i
+
+ dataset = subsample_dataset(dataset, cls_idxs)
+
+ # dataset.target_transform = lambda x: target_xform_dict[x]
+
+ return dataset
+
+def get_train_val_indices(train_dataset, val_split=0.2):
+
+ train_classes = np.unique(train_dataset.target)
+
+ # Get train/test indices
+ train_idxs = []
+ val_idxs = []
+ for cls in train_classes:
+
+ cls_idxs = np.where(train_dataset.target == cls)[0]
+
+ v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
+ t_ = [x for x in cls_idxs if x not in v_]
+
+ train_idxs.extend(t_)
+ val_idxs.extend(v_)
+
+ return train_idxs, val_idxs
+
+
+def get_scars_datasets(train_transform, test_transform, train_classes=range(160), prop_train_labels=0.8,
+ split_train_val=False, seed=0):
+
+ np.random.seed(seed)
+
+ # Init entire training set
+ whole_training_set = CarsDataset(data_dir=car_root, transform=train_transform, train=True)
+
+ # Get labelled training set which has subsampled classes, then subsample some indices from that
+ train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
+ subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
+ train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
+
+ # Split into training and validation sets
+ train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
+ train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
+ val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
+ val_dataset_labelled_split.transform = test_transform
+
+ # Get unlabelled data
+ unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
+ train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
+
+ # Get test set for all classes
+ test_dataset = CarsDataset(data_dir=car_root, transform=test_transform, train=False)
+
+ # Either split train into train and val or use test set as val
+ train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
+ val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
+
+ all_datasets = {
+ 'train_labelled': train_dataset_labelled,
+ 'train_unlabelled': train_dataset_unlabelled,
+ 'val': val_dataset_labelled,
+ 'test': test_dataset,
+ }
+
+ return all_datasets
+
+if __name__ == '__main__':
+
+ x = get_scars_datasets(None, None, train_classes=range(98), prop_train_labels=0.5, split_train_val=False)
+
+ print('Printing lens...')
+ for k, v in x.items():
+ if v is not None:
+ print(f'{k}: {len(v)}')
+
+ print('Printing labelled and unlabelled overlap...')
+ print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
+ print('Printing total instances in train...')
+ print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
+
+ print(f'Num Labelled Classes: {len(set(x["train_labelled"].target))}')
+ print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].target))}')
+ print(f'Len labelled set: {len(x["train_labelled"])}')
+ print(f'Len unlabelled set: {len(x["train_unlabelled"])}')
\ No newline at end of file
diff --git a/kmeans_loss.py b/kmeans_loss.py
new file mode 100644
index 0000000..e0aeef2
--- /dev/null
+++ b/kmeans_loss.py
@@ -0,0 +1,87 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+import numpy as np
+from torch_clustering.kmeans.kmeans import PyTorchKMeans
+class Kmeans_Loss(nn.Module):
+ def __init__(self, temperature=0.5, n_clusters=196):
+ super(Kmeans_Loss, self).__init__()
+ self.temperature = temperature
+ self.num_cluster = n_clusters
+ self.clustering_model = PyTorchKMeans(init='k-means++', max_iter=300, tol=1e-4, n_clusters=self.num_cluster)
+ self.psedo_labels = None
+
+ def clustering(self, features, n_clusters):
+
+ # kwargs = {
+ # 'metric': 'cosine' if self.l2_normalize else 'euclidean',
+ # 'distributed': True,
+ # 'random_state': 0,
+ # 'n_clusters': n_clusters,
+ # 'verbose': True
+ # }
+ clustering_model = PyTorchKMeans(init='k-means++', max_iter=300, tol=1e-4, n_clusters=self.num_cluster)
+
+ psedo_labels = clustering_model.fit_predict(features)#首先features不带标签,训练的同时也输出features的标签
+ self.psedo_labels = psedo_labels
+ cluster_centers = clustering_model.cluster_centers_
+ return psedo_labels, cluster_centers
+
+ def compute_cluster_loss(self, q_centers, k_centers, temperature=0.5, psedo_labels=None):
+ # 计算聚类中心的相似性矩阵 d_q
+ d_q = q_centers.mm(q_centers.T) / temperature
+
+ # 计算每个样本与其对应聚类中心的相似度
+ d_k = (q_centers * k_centers).sum(dim=1) / temperature
+
+ # 将对角线上的值替换为样本与自己的相似度
+ d_q = d_q.float()
+ d_q[torch.arange(self.num_cluster), torch.arange(self.num_cluster)] = d_k
+
+ # 计算样本与聚类中心的相似度后进行一些处理
+ zero_classes = torch.arange(self.num_cluster).cuda()[torch.sum(F.one_hot(torch.unique(psedo_labels),
+ self.num_cluster), dim=0) == 0]
+
+ # 将相似度矩阵中某些位置的值替换为-10
+ mask = torch.zeros((self.num_cluster, self.num_cluster), dtype=torch.bool, device=d_q.device)
+ mask[:, zero_classes] = 1
+ d_q.masked_fill_(mask, -10)
+
+ # 提取正样本相似度和负样本相似度
+ pos = d_q.diag(0)
+ mask = torch.ones((self.num_cluster, self.num_cluster))
+ mask = mask.fill_diagonal_(0).bool()
+ neg = d_q[mask].reshape(-1, self.num_cluster - 1)
+
+ # 计算对比损失
+ loss = -pos + torch.logsumexp(torch.cat([pos.reshape(self.num_cluster, 1), neg], dim=1), dim=1)
+
+ # 将属于没有样本的聚类的损失值设为0
+ loss[zero_classes] = 0.
+
+ # 对损失值求和并除以聚类数目
+ loss = loss.sum() / (self.num_cluster - len(zero_classes))
+
+ return loss
+
+ def compute_centers(self, x, psedo_labels):
+ n_samples = x.size(0)
+ if len(psedo_labels.size()) > 1:
+ weight = psedo_labels.T
+ else:
+ weight = torch.zeros(self.num_cluster, n_samples).to(x) # L, N
+ weight[psedo_labels, torch.arange(n_samples)] = 1
+ weight = F.normalize(weight, p=1, dim=1) # l1 normalization
+ centers = torch.mm(weight, x)
+ centers = F.normalize(centers, dim=1)
+ return centers
+
+ def forward(self, im_q, im_k, psedo_labels):
+ batch_all_psedo_labels = psedo_labels
+ q_centers = self.compute_centers(im_q, batch_all_psedo_labels)
+ k_centers = self.compute_centers(im_k, batch_all_psedo_labels)
+
+ cluster_loss = self.compute_cluster_loss(q_centers, k_centers, self.temperature, batch_all_psedo_labels)
+
+ return cluster_loss
diff --git a/model.py b/model.py
new file mode 100644
index 0000000..b7fb885
--- /dev/null
+++ b/model.py
@@ -0,0 +1,231 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+class DINOHead(nn.Module):
+ def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True,
+ nlayers=3, hidden_dim=2048, bottleneck_dim=256):
+ super().__init__()
+ nlayers = max(nlayers, 1)
+ if nlayers == 1:
+ self.mlp = nn.Linear(in_dim, bottleneck_dim)
+ elif nlayers != 0:
+ layers = [nn.Linear(in_dim, hidden_dim)]
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ for _ in range(nlayers - 2):
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim))
+ self.mlp = nn.Sequential(*layers)
+ self.apply(self._init_weights)
+ self.last_layer = nn.utils.weight_norm(nn.Linear(in_dim, out_dim, bias=False))
+ self.last_layer.weight_g.data.fill_(1)
+ if norm_last_layer:
+ self.last_layer.weight_g.requires_grad = False
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ torch.nn.init.trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ x_proj = self.mlp(x)
+ x = nn.functional.normalize(x, dim=-1, p=2)
+ # x = x.detach()
+ logits = self.last_layer(x)
+ return x_proj, logits
+
+
+class ContrastiveLearningViewGenerator(object):
+ """Take two random crops of one image as the query and key."""
+
+ def __init__(self, base_transform, n_views=2):
+ self.base_transform = base_transform
+ self.n_views = n_views
+
+ def __call__(self, x):
+ if not isinstance(self.base_transform, list):
+ return [self.base_transform(x) for i in range(self.n_views)]
+ else:
+ return [self.base_transform[i](x) for i in range(self.n_views)]
+
+class SupConLoss(torch.nn.Module):
+ """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
+ It also supports the unsupervised contrastive loss in SimCLR
+ From: https://github.com/HobbitLong/SupContrast"""
+ def __init__(self, temperature=0.07, contrast_mode='all',
+ base_temperature=0.07):
+ super(SupConLoss, self).__init__()
+ self.temperature = temperature
+ self.contrast_mode = contrast_mode
+ self.base_temperature = base_temperature
+
+ def forward(self, features, labels=None, mask=None):
+ """Compute loss for model. If both `labels` and `mask` are None,
+ it degenerates to SimCLR unsupervised loss:
+ https://arxiv.org/pdf/2002.05709.pdf
+ Args:
+ features: hidden vector of shape [bsz, n_views, ...].
+ labels: ground truth of shape [bsz].
+ mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
+ has the same class as sample i. Can be asymmetric.
+ Returns:
+ A loss scalar.
+ """
+
+ device = (torch.device('cuda')
+ if features.is_cuda
+ else torch.device('cpu'))
+
+ if len(features.shape) < 3:
+ raise ValueError('`features` needs to be [bsz, n_views, ...],'
+ 'at least 3 dimensions are required')
+ if len(features.shape) > 3:
+ features = features.view(features.shape[0], features.shape[1], -1)
+
+ batch_size = features.shape[0]
+ if labels is not None and mask is not None:
+ raise ValueError('Cannot define both `labels` and `mask`')
+ elif labels is None and mask is None:
+ mask = torch.eye(batch_size, dtype=torch.float32).to(device)
+ elif labels is not None:
+ labels = labels.contiguous().view(-1, 1)
+ if labels.shape[0] != batch_size:
+ raise ValueError('Num of labels does not match num of features')
+ mask = torch.eq(labels, labels.T).float().to(device)
+ else:
+ mask = mask.float().to(device)
+
+ contrast_count = features.shape[1]
+ contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
+ if self.contrast_mode == 'one':
+ anchor_feature = features[:, 0]
+ anchor_count = 1
+ elif self.contrast_mode == 'all':
+ anchor_feature = contrast_feature
+ anchor_count = contrast_count
+ else:
+ raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
+
+ # compute logits
+ anchor_dot_contrast = torch.div(
+ torch.matmul(anchor_feature, contrast_feature.T),
+ self.temperature)
+
+ # for numerical stability
+ logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
+ logits = anchor_dot_contrast - logits_max.detach()
+
+ # tile mask
+ mask = mask.repeat(anchor_count, contrast_count)
+ # mask-out self-contrast cases
+ logits_mask = torch.scatter(
+ torch.ones_like(mask),
+ 1,
+ torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
+ 0
+ )
+ mask = mask * logits_mask
+
+ # compute log_prob
+ exp_logits = torch.exp(logits) * logits_mask
+ log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
+
+ # compute mean of log-likelihood over positive
+ mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
+
+ # loss
+ loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
+ loss = loss.view(anchor_count, batch_size).mean()
+
+ return loss
+
+
+
+def info_nce_logits(features, n_views=2, temperature=1.0, device='cuda'):
+
+ b_ = 0.5 * int(features.size(0))
+
+ labels = torch.cat([torch.arange(b_) for i in range(n_views)], dim=0)
+ labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
+ labels = labels.to(device)
+
+ features = F.normalize(features, dim=1)
+
+ similarity_matrix = torch.matmul(features, features.T)
+
+ # discard the main diagonal from both: labels and similarities matrix
+ mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device)
+ labels = labels[~mask].view(labels.shape[0], -1)
+ similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
+
+ # select and combine multiple positives
+ positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
+
+ # select only the negatives the negatives
+ negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
+
+ logits = torch.cat([positives, negatives], dim=1)
+ labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)
+
+ logits = logits / temperature
+ return logits, labels
+
+
+def get_params_groups(model):
+ regularized = []
+ not_regularized = []
+ for name, param in model.named_parameters():
+ if not param.requires_grad:
+ continue
+ # we do not regularize biases nor Norm parameters
+ if name.endswith(".bias") or len(param.shape) == 1:
+ not_regularized.append(param)
+ else:
+ regularized.append(param)
+ return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}]
+
+
+class DistillLoss(nn.Module):
+ def __init__(self, warmup_teacher_temp_epochs, nepochs,
+ ncrops=2, warmup_teacher_temp=0.07, teacher_temp=0.04,
+ student_temp=0.1):
+ super().__init__()
+ self.student_temp = student_temp
+ self.ncrops = ncrops
+ self.teacher_temp_schedule = np.concatenate((
+ np.linspace(warmup_teacher_temp,
+ teacher_temp, warmup_teacher_temp_epochs),
+ np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
+ ))
+
+ def forward(self, student_output, teacher_output, epoch):
+ """
+ Cross-entropy between softmax outputs of the teacher and student networks.
+ """
+ student_out = student_output / self.student_temp
+ student_out = student_out.chunk(self.ncrops)
+
+ # teacher centering and sharpening
+ temp = self.teacher_temp_schedule[epoch]
+ teacher_out = F.softmax(teacher_output / temp, dim=-1)
+ teacher_out = teacher_out.detach().chunk(2)
+
+ total_loss = 0
+ n_loss_terms = 0
+ for iq, q in enumerate(teacher_out):
+ for v in range(len(student_out)):
+ if v == iq:
+ # we skip cases where student and teacher operate on the same view
+ continue
+ loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
+ total_loss += loss.mean()
+ n_loss_terms += 1
+ total_loss /= n_loss_terms
+ return total_loss
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..0b5a84e
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,10 @@
+loguru
+numpy
+pandas
+scikit_learn
+scipy
+torch
+torchvision
+matplotlib
+munkres
+tqdm
\ No newline at end of file
diff --git a/scripts/run_aircraft.sh b/scripts/run_aircraft.sh
new file mode 100644
index 0000000..fc2a449
--- /dev/null
+++ b/scripts/run_aircraft.sh
@@ -0,0 +1,22 @@
+#!/bin/bash
+
+set -e
+set -x
+
+CUDA_VISIBLE_DEVICES=0 python train.py \
+ --dataset_name 'aircraft' \
+ --batch_size 128 \
+ --grad_from_block 11 \
+ --epochs 200 \
+ --num_workers 8 \
+ --use_ssb_splits \
+ --sup_weight 0.35 \
+ --weight_decay 5e-5 \
+ --transform 'imagenet' \
+ --lr 0.1 \
+ --eval_funcs 'v2' \
+ --warmup_teacher_temp 0.07 \
+ --teacher_temp 0.04 \
+ --warmup_teacher_temp_epochs 30 \
+ --memax_weight 1 \
+ --exp_name aircraft_simgcd
diff --git a/scripts/run_cars.sh b/scripts/run_cars.sh
new file mode 100644
index 0000000..db89c2e
--- /dev/null
+++ b/scripts/run_cars.sh
@@ -0,0 +1,22 @@
+#!/bin/bash
+
+set -e
+set -x
+
+CUDA_VISIBLE_DEVICES=0 python train.py \
+ --dataset_name 'scars' \
+ --batch_size 128 \
+ --grad_from_block 11 \
+ --epochs 200 \
+ --num_workers 8 \
+ --use_ssb_splits \
+ --sup_weight 0.35 \
+ --weight_decay 5e-5 \
+ --transform 'imagenet' \
+ --lr 0.1 \
+ --eval_funcs 'v2' \
+ --warmup_teacher_temp 0.07 \
+ --teacher_temp 0.04 \
+ --warmup_teacher_temp_epochs 30 \
+ --memax_weight 1 \
+ --exp_name scars_simgcd
diff --git a/scripts/run_cifar10.sh b/scripts/run_cifar10.sh
new file mode 100644
index 0000000..55d0189
--- /dev/null
+++ b/scripts/run_cifar10.sh
@@ -0,0 +1,22 @@
+#!/bin/bash
+
+set -e
+set -x
+
+CUDA_VISIBLE_DEVICES=0 python train.py \
+ --dataset_name 'cifar10' \
+ --batch_size 128 \
+ --grad_from_block 11 \
+ --epochs 200 \
+ --num_workers 8 \
+ --use_ssb_splits \
+ --sup_weight 0.35 \
+ --weight_decay 5e-5 \
+ --transform 'imagenet' \
+ --lr 0.1 \
+ --eval_funcs 'v2' \
+ --warmup_teacher_temp 0.07 \
+ --teacher_temp 0.04 \
+ --warmup_teacher_temp_epochs 30 \
+ --memax_weight 1 \
+ --exp_name cifar10_simgcd
diff --git a/scripts/run_cifar100.sh b/scripts/run_cifar100.sh
new file mode 100644
index 0000000..b7e7315
--- /dev/null
+++ b/scripts/run_cifar100.sh
@@ -0,0 +1,24 @@
+#!/bin/bash
+
+set -e
+set -x
+
+CUDA_VISIBLE_DEVICES=1 python train.py \
+ --dataset_name 'cifar100' \
+ --batch_size 128 \
+ --grad_from_block 11 \
+ --epochs 200 \
+ --num_workers 8 \
+ --use_ssb_splits \
+ --sup_weight 0.35 \
+ --weight_decay 5e-5 \
+ --transform 'imagenet' \
+ --lr 0.1 \
+ --eval_funcs 'v2' \
+ --warmup_teacher_temp 0.07 \
+ --teacher_temp 0.04 \
+ --warmup_teacher_temp_epochs 30 \
+ --memax_weight 4 \
+ --nn_per_image 8 \
+ --stml_weight 0.1 \
+ --exp_name cifar100_simgcd
diff --git a/scripts/run_cub.sh b/scripts/run_cub.sh
new file mode 100644
index 0000000..b29198d
--- /dev/null
+++ b/scripts/run_cub.sh
@@ -0,0 +1,24 @@
+#!/bin/bash
+
+set -e
+set -x
+
+CUDA_VISIBLE_DEVICES=1 python train.py \
+ --dataset_name 'cub' \
+ --batch_size 128 \
+ --grad_from_block 11 \
+ --epochs 200 \
+ --num_workers 8 \
+ --use_ssb_splits \
+ --sup_weight 0.35 \
+ --weight_decay 5e-5 \
+ --transform 'imagenet' \
+ --lr 0.1 \
+ --eval_funcs 'v2' \
+ --warmup_teacher_temp 0.07 \
+ --teacher_temp 0.04 \
+ --warmup_teacher_temp_epochs 30 \
+ --memax_weight 2 \
+ --nn_per_image 8 \
+ --stml_weight 0.1 \
+ --exp_name cub_simgcd
diff --git a/scripts/run_herb19.sh b/scripts/run_herb19.sh
new file mode 100644
index 0000000..2d11b64
--- /dev/null
+++ b/scripts/run_herb19.sh
@@ -0,0 +1,22 @@
+#!/bin/bash
+
+set -e
+set -x
+
+CUDA_VISIBLE_DEVICES=0 python train.py \
+ --dataset_name 'herbarium_19' \
+ --batch_size 128 \
+ --grad_from_block 11 \
+ --epochs 200 \
+ --num_workers 8 \
+ --use_ssb_splits \
+ --sup_weight 0.35 \
+ --weight_decay 5e-5 \
+ --transform 'imagenet' \
+ --lr 0.1 \
+ --eval_funcs 'v2' 'v2b' \
+ --warmup_teacher_temp 0.07 \
+ --teacher_temp 0.04 \
+ --warmup_teacher_temp_epochs 30 \
+ --memax_weight 1 \
+ --exp_name herb19_simgcd
diff --git a/scripts/run_imagenet100.sh b/scripts/run_imagenet100.sh
new file mode 100644
index 0000000..bedcb8a
--- /dev/null
+++ b/scripts/run_imagenet100.sh
@@ -0,0 +1,22 @@
+#!/bin/bash
+
+set -e
+set -x
+
+CUDA_VISIBLE_DEVICES=0 python train.py \
+ --dataset_name 'imagenet_100' \
+ --batch_size 128 \
+ --grad_from_block 11 \
+ --epochs 200 \
+ --num_workers 8 \
+ --use_ssb_splits \
+ --sup_weight 0.35 \
+ --weight_decay 5e-5 \
+ --transform 'imagenet' \
+ --lr 0.1 \
+ --eval_funcs 'v2' \
+ --warmup_teacher_temp 0.07 \
+ --teacher_temp 0.04 \
+ --warmup_teacher_temp_epochs 30 \
+ --memax_weight 1 \
+ --exp_name imagenet100_simgcd
diff --git a/scripts/run_imagenet1k.sh b/scripts/run_imagenet1k.sh
new file mode 100644
index 0000000..bcb5da7
--- /dev/null
+++ b/scripts/run_imagenet1k.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+
+set -e
+set -x
+
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --master_port 12348 --nproc_per_node=8 train_mp.py \
+ --dataset_name 'imagenet_1k' \
+ --batch_size 128 \
+ --grad_from_block 11 \
+ --epochs 200 \
+ --num_workers 8 \
+ --use_ssb_splits \
+ --sup_weight 0.35 \
+ --weight_decay 5e-5 \
+ --transform 'imagenet' \
+ --lr 0.1 \
+ --eval_funcs 'v2' \
+ --warmup_teacher_temp 0.07 \
+ --teacher_temp 0.04 \
+ --warmup_teacher_temp_epochs 30 \
+ --memax_weight 1 \
+ --exp_name imagenet1k_simgcd \
+ --print_freq 100
diff --git a/torch_clustering/__base__.py b/torch_clustering/__base__.py
new file mode 100644
index 0000000..e191308
--- /dev/null
+++ b/torch_clustering/__base__.py
@@ -0,0 +1,89 @@
+# -*- coding: UTF-8 -*-
+'''
+@Project : torch_clustering
+@File : __base__.py
+@Author : Zhizhong Huang from Fudan University
+@Homepage: https://hzzone.github.io/
+@Email : zzhuang19@fudan.edu.cn
+@Date : 2022/10/19 12:20 PM
+'''
+
+import torch
+import torch.nn.functional as F
+import torch.distributed as dist
+
+
+class BasicClustering:
+ def __init__(self,
+ n_clusters,
+ init='k-means++',
+ n_init=10,
+ random_state=0,
+ max_iter=300,
+ tol=1e-4,
+ distributed=False,
+ verbose=True):
+ '''
+ :param n_clusters:
+ :param init: {'k-means++', 'random'}, callable or array-like of shape \
+ (n_clusters, n_features), default='k-means++'
+ Method for initialization:
+ 'k-means++' : selects initial cluster centers for k-mean
+ clustering in a smart way to speed up convergence. See section
+ Notes in k_init for more details.
+ 'random': choose `n_clusters` observations (rows) at random from data
+ for the initial centroids.
+ If an array is passed, it should be of shape (n_clusters, n_features)
+ and gives the initial centers.
+ If a callable is passed, it should take arguments X, n_clusters and a
+ random state and return an initialization.
+ :param n_init: int, default=10
+ Number of time the k-means algorithm will be run with different
+ centroid seeds. The final results will be the best output of
+ n_init consecutive runs in terms of inertia.
+ :param random_state: int, RandomState instance or None, default=None
+ Determines random number generation for centroid initialization. Use
+ an int to make the randomness deterministic.
+ See :term:`Glossary `.
+ :param max_iter:
+ :param tol:
+ :param verbose: int, default=0 Verbosity mode.
+ '''
+ self.n_clusters = n_clusters
+ self.n_init = n_init
+ self.max_iter = max_iter
+ self.tol = tol
+ self.cluster_centers_ = None
+ self.init = init
+ self.random_state = random_state
+ self.is_root_worker = True if not dist.is_initialized() else (dist.get_rank() == 0)
+ self.verbose = verbose and self.is_root_worker
+ self.distributed = distributed and dist.is_initialized()
+ if verbose and self.distributed and self.is_root_worker:
+ print('Perform K-means in distributed mode.')
+ self.world_size = dist.get_world_size() if self.distributed else 1
+ self.rank = dist.get_rank() if self.distributed else 0
+
+ def fit_predict(self, X):
+ pass
+
+ def distributed_sync(self, tensor):
+ tensors_gather = [torch.ones_like(tensor)
+ for _ in range(torch.distributed.get_world_size())]
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
+ output = torch.stack(tensors_gather)
+ return output
+
+
+def pairwise_cosine(x1: torch.Tensor, x2: torch.Tensor, pairwise=True):
+ x1 = F.normalize(x1)
+ x2 = F.normalize(x2)
+ if not pairwise:
+ return (1 - (x1 * x2).sum(dim=1))
+ return 1 - x1.mm(x2.T)
+
+
+def pairwise_euclidean(x1: torch.Tensor, x2: torch.Tensor, pairwise=True):
+ if not pairwise:
+ return ((x1 - x2) ** 2).sum(dim=1).sqrt()
+ return torch.cdist(x1, x2, p=2.)
diff --git a/torch_clustering/__init__.py b/torch_clustering/__init__.py
new file mode 100644
index 0000000..e9a569e
--- /dev/null
+++ b/torch_clustering/__init__.py
@@ -0,0 +1,85 @@
+# -*- coding: UTF-8 -*-
+'''
+@Project : torch_clustering
+@File : __init__.py
+@Author : Zhizhong Huang from Fudan University
+@Homepage: https://hzzone.github.io/
+@Email : zzhuang19@fudan.edu.cn
+@Date : 2022/10/19 12:21 PM
+'''
+
+from .kmeans.kmeans import PyTorchKMeans
+from .faiss_kmeans import FaissKMeans
+from .gaussian_mixture import PyTorchGaussianMixture
+from .beta_mixture import BetaMixture1D
+
+import numpy as np
+from munkres import Munkres
+from sklearn import metrics
+import warnings
+
+def evaluate_clustering(label, pred, eval_metric=['nmi', 'acc', 'ari'], phase='train'):
+ mask = (label != -1)
+ label = label[mask]
+ pred = pred[mask]
+ results = {}
+ if 'nmi' in eval_metric:
+ nmi = metrics.normalized_mutual_info_score(label, pred, average_method='arithmetic')
+ results[f'{phase}_nmi'] = nmi
+ if 'ari' in eval_metric:
+ ari = metrics.adjusted_rand_score(label, pred)
+ results[f'{phase}_ari'] = ari
+ if 'f' in eval_metric:
+ f = metrics.fowlkes_mallows_score(label, pred)
+ results[f'{phase}_f'] = f
+ if 'acc' in eval_metric:
+ n_clusters = len(set(label))
+ if n_clusters == len(set(pred)):
+ pred_adjusted = get_y_preds(label, pred, n_clusters=n_clusters)
+ acc = metrics.accuracy_score(pred_adjusted, label)
+ else:
+ acc = 0.
+ warnings.warn('TODO: the number of classes is not equal...')
+ results[f'{phase}_acc'] = acc
+ return results
+
+
+def calculate_cost_matrix(C, n_clusters):
+ cost_matrix = np.zeros((n_clusters, n_clusters))
+ # cost_matrix[i,j] will be the cost of assigning cluster i to label j
+ for j in range(n_clusters):
+ s = np.sum(C[:, j]) # number of examples in cluster i
+ for i in range(n_clusters):
+ t = C[i, j]
+ cost_matrix[j, i] = s - t
+ return cost_matrix
+
+
+def get_cluster_labels_from_indices(indices):
+ n_clusters = len(indices)
+ cluster_labels = np.zeros(n_clusters)
+ for i in range(n_clusters):
+ cluster_labels[i] = indices[i][1]
+ return cluster_labels
+
+
+def get_y_preds(y_true, cluster_assignments, n_clusters):
+ """
+ Computes the predicted labels, where label assignments now
+ correspond to the actual labels in y_true (as estimated by Munkres)
+ cluster_assignments: array of labels, outputted by kmeans
+ y_true: true labels
+ n_clusters: number of clusters in the dataset
+ returns: a tuple containing the accuracy and confusion matrix,
+ in that order
+ """
+ confusion_matrix = metrics.confusion_matrix(y_true, cluster_assignments, labels=None)
+ # compute accuracy based on optimal 1:1 assignment of clusters to labels
+ cost_matrix = calculate_cost_matrix(confusion_matrix, n_clusters)
+ indices = Munkres().compute(cost_matrix)
+ kmeans_to_true_cluster_labels = get_cluster_labels_from_indices(indices)
+
+ if np.min(cluster_assignments) != 0:
+ cluster_assignments = cluster_assignments - np.min(cluster_assignments)
+ y_pred = kmeans_to_true_cluster_labels[cluster_assignments]
+ return y_pred
diff --git a/torch_clustering/beta_mixture.py b/torch_clustering/beta_mixture.py
new file mode 100644
index 0000000..87fc396
--- /dev/null
+++ b/torch_clustering/beta_mixture.py
@@ -0,0 +1,84 @@
+# -*- coding: UTF-8 -*-
+'''
+@Project : torch_clustering
+@File : beta_mixture.py
+@Author : Zhizhong Huang from Fudan University
+@Homepage: https://hzzone.github.io/
+@Email : zzhuang19@fudan.edu.cn
+@Date : 2022/10/19 12:21 PM
+'''
+
+import numpy as np
+import matplotlib.pyplot as plt
+import torch
+import torch.nn as nn
+
+
+class BetaMixture1D(object):
+ def __init__(self,
+ max_iters=10,
+ alphas_init=[1, 2],
+ betas_init=[2, 1],
+ weights_init=[0.5, 0.5]):
+ self.alphas = np.array(alphas_init, dtype=np.float64)
+ self.betas = np.array(betas_init, dtype=np.float64)
+ self.weight = np.array(weights_init, dtype=np.float64)
+ self.max_iters = max_iters
+ self.eps_nan = 1e-12
+
+ def fit_beta_weighted(self, x, w):
+ def weighted_mean(x, w):
+ return np.sum(w * x) / np.sum(w)
+
+ x_bar = weighted_mean(x, w)
+ s2 = weighted_mean((x - x_bar) ** 2, w)
+ alpha = x_bar * ((x_bar * (1 - x_bar)) / s2 - 1)
+ beta = alpha * (1 - x_bar) / x_bar
+ return alpha, beta
+
+ def likelihood(self, x, y):
+ import scipy.stats as stats
+ return stats.beta.pdf(x, self.alphas[y], self.betas[y])
+
+ def weighted_likelihood(self, x, y):
+ return self.weight[y] * self.likelihood(x, y)
+
+ def probability(self, x):
+ return sum(self.weighted_likelihood(x, y) for y in range(2))
+
+ def responsibilities(self, x):
+ r = np.array([self.weighted_likelihood(x, i) for i in range(2)])
+ # there are ~200 samples below that value
+ r[r <= self.eps_nan] = self.eps_nan
+ r /= r.sum(axis=0)
+ return r.T
+
+ def fit(self, x):
+ x = np.copy(x)
+
+ # EM on beta distributions unsable with x == 0 or 1
+ eps = 1e-4
+ x[x >= 1 - eps] = 1 - eps
+ x[x <= eps] = eps
+
+ for i in range(self.max_iters):
+ # E-step
+ r = self.responsibilities(x).T
+
+ # M-step
+ self.alphas[0], self.betas[0] = self.fit_beta_weighted(x, r[0])
+ self.alphas[1], self.betas[1] = self.fit_beta_weighted(x, r[1])
+ self.weight = r.sum(axis=1)
+ self.weight /= self.weight.sum()
+
+ return self
+
+ def plot(self):
+ x = np.linspace(0, 1, 100)
+ plt.plot(x, self.weighted_likelihood(x, 0), label='negative')
+ plt.plot(x, self.weighted_likelihood(x, 1), label='positive')
+ # plt.plot(x, self.probability(x), lw=2, label='mixture')
+ plt.legend()
+
+ def __repr__(self):
+ return 'BetaMixture1D(w={}, a={}, b={})'.format(self.weight, self.alphas, self.betas)
diff --git a/torch_clustering/faiss_kmeans.py b/torch_clustering/faiss_kmeans.py
new file mode 100644
index 0000000..85e6181
--- /dev/null
+++ b/torch_clustering/faiss_kmeans.py
@@ -0,0 +1,147 @@
+# -*- coding: UTF-8 -*-
+'''
+@Project : torch_clustering
+@File : faiss_kmeans.py
+@Author : Zhizhong Huang from Fudan University
+@Homepage: https://hzzone.github.io/
+@Email : zzhuang19@fudan.edu.cn
+@Date : 2022/10/19 12:22 PM
+'''
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.distributed as dist
+
+try:
+ import faiss
+except:
+ print('faiss not installed')
+from .__base__ import BasicClustering
+
+
+class FaissKMeans(BasicClustering):
+ def __init__(self,
+ metric='euclidean',
+ n_clusters=8,
+ n_init=10,
+ max_iter=300,
+ random_state=1234,
+ distributed=False,
+ verbose=True):
+ super().__init__(n_clusters=n_clusters,
+ n_init=n_init,
+ max_iter=max_iter,
+ distributed=distributed,
+ verbose=verbose)
+
+ if metric == 'euclidean':
+ self.spherical = False
+ elif metric == 'cosine':
+ self.spherical = True
+ else:
+ raise NotImplementedError
+ self.random_state = random_state
+
+ def apply_pca(self, X, dim):
+ n, d = X.shape
+ if self.spherical:
+ X = F.normalize(X, dim=1)
+ mat = faiss.PCAMatrix(d, dim)
+ mat.train(n, X)
+ X = mat.apply_py(X)
+
+ def fit_predict(self, input: torch.Tensor, device=-1):
+ n, d = input.shape
+
+ assert isinstance(input, (torch.Tensor, np.ndarray))
+ is_torch_tensor = isinstance(input, torch.Tensor)
+ if is_torch_tensor:
+ if self.spherical:
+ input = F.normalize(input, dim=1)
+
+ if input.is_cuda:
+ device = input.device.index
+ X = input.cpu().numpy().astype(np.float32)
+ else:
+ if self.spherical:
+ X = input / np.linalg.norm(input, 2, axis=1)[:, np.newaxis]
+ else:
+ X = input
+ X = X.astype(np.float32)
+
+ random_states = torch.arange(self.world_size) + self.random_state
+ random_state = random_states[self.rank]
+ if device >= 0:
+ # faiss implementation of k-means
+ clus = faiss.Clustering(int(d), int(self.n_clusters))
+
+ # Change faiss seed at each k-means so that the randomly picked
+ # initialization centroids do not correspond to the same feature ids
+ # from an epoch to another.
+ # clus.seed = np.random.randint(1234)
+ clus.seed = int(random_state)
+
+ clus.niter = self.max_iter
+ clus.max_points_per_centroid = 10000000
+ clus.min_points_per_centroid = 10
+ clus.spherical = self.spherical
+ clus.nredo = self.n_init
+ clus.verbose = self.verbose
+ res = faiss.StandardGpuResources()
+ flat_config = faiss.GpuIndexFlatConfig()
+ flat_config.useFloat16 = False
+ flat_config.device = device
+ flat_config.verbose = self.verbose
+ flat_config.spherical = self.spherical
+ flat_config.nredo = self.n_init
+ index = faiss.GpuIndexFlatL2(res, d, flat_config)
+
+ # perform the training
+ clus.train(X, index)
+ D, I = index.search(X, 1)
+ else:
+ clus = faiss.Kmeans(d=d,
+ k=self.n_clusters,
+ niter=self.max_iter,
+ nredo=self.n_init,
+ verbose=self.verbose,
+ spherical=self.spherical)
+ X = X.astype(np.float32)
+ clus.train(X)
+ # self.cluster_centers_ = self.kmeans.centroids
+ D, I = clus.index.search.search(X, 1) # for each sample, find cluster distance and assignments
+
+ tensor_device = 'cpu' if device < 0 else f'cuda:{device}'
+
+ best_labels = torch.from_numpy(I.flatten()).to(tensor_device)
+ min_inertia = torch.from_numpy(D.flatten()).to(tensor_device).sum()
+ best_states = faiss.vector_to_array(clus.centroids).reshape(self.n_clusters, d)
+ best_states = torch.from_numpy(best_states).to(tensor_device)
+
+ if self.distributed:
+ min_inertia = self.distributed_sync(min_inertia)
+ best_idx = torch.argmin(min_inertia).item()
+ min_inertia = min_inertia[best_idx]
+ dist.broadcast(best_labels, src=best_idx)
+ dist.broadcast(best_states, src=best_idx)
+
+ if self.verbose:
+ print(f"Final min inertia {min_inertia.item()}.")
+
+ self.cluster_centers_ = best_states
+ return best_labels
+
+
+if __name__ == '__main__':
+ dist.init_process_group(backend='nccl', init_method='env://')
+ torch.cuda.set_device(dist.get_rank())
+ X = torch.randn(1280, 256).cuda()
+ clustering_model = FaissKMeans(metric='euclidean',
+ n_clusters=10,
+ n_init=2,
+ max_iter=1,
+ random_state=1234,
+ distributed=True,
+ verbose=True)
+ clustering_model.fit_predict(X)
diff --git a/torch_clustering/gaussian_mixture.py b/torch_clustering/gaussian_mixture.py
new file mode 100644
index 0000000..4ebd1d5
--- /dev/null
+++ b/torch_clustering/gaussian_mixture.py
@@ -0,0 +1,222 @@
+# -*- coding: UTF-8 -*-
+'''
+@Project : torch_clustering
+@File : gaussian_mixture.py
+@Author : Zhizhong Huang from Fudan University
+@Homepage: https://hzzone.github.io/
+@Email : zzhuang19@fudan.edu.cn
+@Date : 2022/10/19 12:22 PM
+'''
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.distributions as D
+import torch.distributed as dist
+from .__base__ import BasicClustering
+from .kmeans.kmeans import PyTorchKMeans
+
+
+class PyTorchGaussianMixture(BasicClustering):
+ def __init__(self,
+ covariance_type='diag',
+ metric='euclidean',
+ reg_covar=1e-6,
+ init='k-means++',
+ random_state=0,
+ n_clusters=8,
+ n_init=10,
+ max_iter=300,
+ tol=1e-4,
+ distributed=False,
+ verbose=True):
+ '''
+ pytorch_gaussian_mixture = PyTorchGaussianMixture(metric='cosine',
+ covariance_type='diag',
+ reg_covar=1e-6,
+ init='k-means++',
+ random_state=0,
+ n_clusters=10,
+ n_init=10,
+ max_iter=300,
+ tol=1e-5,
+ verbose=True)
+ pseudo_labels = pytorch_gaussian_mixture.fit_predict(torch.from_numpy(features).cuda())
+ :param metric:
+ :param reg_covar:
+ :param init:
+ :param random_state:
+ :param n_clusters:
+ :param n_init:
+ :param max_iter:
+ :param tol:
+ :param verbose:
+ '''
+ super().__init__(n_clusters=n_clusters,
+ init=init,
+ distributed=distributed,
+ random_state=random_state,
+ n_init=n_init,
+ max_iter=max_iter,
+ tol=tol,
+ verbose=verbose)
+ self.reg_covar = reg_covar
+ self.metric = metric
+ self._estimate_gaussian_covariances = {'diag': self._estimate_gaussian_covariances_diag,
+ 'spherical': self._estimate_gaussian_covariances_spherical}[
+ covariance_type]
+ self.covariances, self.weights, self.lower_bound_ = None, None, None
+
+ def _estimate_gaussian_covariances_diag(self, resp: torch.Tensor, X: torch.Tensor, nk: torch.Tensor,
+ means: torch.Tensor):
+ avg_X2 = torch.matmul(resp.T, X * X) / nk[:, None]
+ avg_means2 = means ** 2
+ avg_X_means = means * torch.matmul(resp.T, X) / nk[:, None]
+ return avg_X2 - 2 * avg_X_means + avg_means2 + self.reg_covar
+ # N * K * L
+ # return (((X.unsqueeze(1) - means.unsqueeze(0)) ** 2) * resp.unsqueeze(-1)).sum(0) / nk[:, None] + self.reg_covar
+
+ def _estimate_gaussian_covariances_spherical(self, resp: torch.Tensor, X: torch.Tensor, nk: torch.Tensor,
+ means: torch.Tensor):
+ return self._estimate_gaussian_covariances_diag(resp, X, nk, means).mean(1, keepdim=True)
+
+ def initialize(self, X: torch.Tensor, resp: torch.Tensor):
+ """Initialization of the Gaussian mixture parameters.
+ Parameters
+ ----------
+ X : array-like of shape (n_samples, n_features)
+ resp : array-like of shape (n_samples, n_components)
+ """
+ n_samples, _ = X.shape
+ weights, means, covariances = self._estimate_gaussian_parameters(X, resp)
+ return means, covariances, weights
+
+ def _estimate_gaussian_parameters(self, X: torch.Tensor, resp: torch.Tensor):
+ # N * K * L
+ nk = resp.sum(dim=0) + 10 * torch.finfo(resp.dtype).eps
+ # means = torch.sum(X[:, None, :] * resp[:, :, None], dim=0) / nk[:, None]
+ means = resp.T.mm(X) / nk[:, None]
+ if self.metric == 'cosine':
+ means = F.normalize(means, dim=-1)
+ covariances = self._estimate_gaussian_covariances(resp, X, nk, means)
+ weights = nk / X.size(0)
+ return weights, means, covariances
+
+ def fit_predict(self, X: torch.Tensor):
+ if self.metric == 'cosine':
+ X = F.normalize(X, dim=1)
+ best_means, best_covariances, best_weights, best_resp = None, None, None, None
+ max_lower_bound = - float("Inf")
+
+ g = torch.Generator()
+ g.manual_seed(self.random_state)
+ random_states = torch.randperm(10000, generator=g)[:self.n_init * self.world_size]
+ random_states = random_states[self.rank:self.n_init * self.world_size:self.world_size]
+
+ for n_init in range(self.n_init):
+
+ random_state = int(random_states[n_init])
+ # KMeans init
+ pseudo_labels = PyTorchKMeans(metric=self.metric,
+ init=self.init,
+ n_clusters=self.n_clusters,
+ random_state=random_state,
+ n_init=self.n_init,
+ max_iter=self.max_iter,
+ tol=self.tol,
+ distributed=self.distributed,
+ verbose=self.verbose).fit_predict(X)
+ resp = F.one_hot(pseudo_labels, self.n_clusters).to(X)
+ means, covariances, weights = self.initialize(X, resp)
+ previous_lower_bound_ = self.log_likehood(resp.log())
+
+ for n_iter in range(self.max_iter):
+ # E step
+ log_resp = self._e_step(X, means, covariances, weights)
+
+ resp = F.softmax(log_resp, dim=1)
+
+ lower_bound_ = self.log_likehood(log_resp)
+
+ shift = torch.abs(previous_lower_bound_ - lower_bound_)
+
+ if shift < self.tol:
+ if self.verbose:
+ print('converge at Iteration {} with shift: {}'.format(n_iter, shift))
+ break
+
+ if self.verbose:
+ print(f'Iteration {n_iter}, loglikehood: {lower_bound_.item()}, shift: {shift.item()}')
+ previous_lower_bound_ = lower_bound_
+
+ if lower_bound_ > max_lower_bound:
+ max_lower_bound = lower_bound_
+ best_means, best_covariances, best_weights, best_resp = \
+ means, covariances, weights, resp
+
+ # M step
+ means, covariances, weights = self._m_step(X, resp)
+
+ if self.distributed:
+ max_lower_bound = self.distributed_sync(max_lower_bound)
+ best_idx = torch.argmax(max_lower_bound).item()
+ max_lower_bound = max_lower_bound[best_idx]
+ dist.broadcast(best_means, src=best_idx)
+ dist.broadcast(best_covariances, src=best_idx)
+ dist.broadcast(best_weights, src=best_idx)
+ dist.broadcast(best_resp, src=best_idx)
+ if self.verbose:
+ print(f"Final loglikehood {max_lower_bound.item()}.")
+
+ if self.verbose:
+ print(f'Converged with loglikehood {max_lower_bound.item()}')
+ self.cluster_centers_, self.covariances, self.weights, self.lower_bound_ = \
+ best_means, best_covariances, best_weights, max_lower_bound
+ return self.predict_score(X)
+
+ def _e_step(self, X: torch.Tensor, means: torch.Tensor, covariances: torch.Tensor, weights: torch.Tensor):
+ estimate_precision_error_message = (
+ "Fitting the mixture model failed because some components have "
+ "ill-defined empirical covariance (for instance caused by singleton "
+ "or collapsed samples). Try to decrease the number of components, "
+ "or increase reg_covar.")
+ if torch.any(torch.le(covariances, 0.0)):
+ raise ValueError(estimate_precision_error_message)
+ log_resp = self.log_prob(X, means, covariances, weights)
+ return log_resp
+
+ def log_prob(self, X: torch.Tensor, means: torch.Tensor, covariances: torch.Tensor, weights: torch.Tensor):
+ log_resp = D.Normal(loc=means.unsqueeze(0),
+ scale=covariances.unsqueeze(0).sqrt()).log_prob(X.unsqueeze(1)).sum(dim=-1)
+ log_resp = log_resp + weights.unsqueeze(0).log()
+ return log_resp
+
+ def log_prob_sklearn(self, X: torch.Tensor, means: torch.Tensor, covariances: torch.Tensor, weights: torch.Tensor):
+ n_samples, n_features = X.size()
+ n_components, _ = means.size()
+
+ precisions_chol = 1. / torch.sqrt(covariances)
+
+ log_det = torch.sum(precisions_chol.log(), dim=1)
+ precisions = precisions_chol ** 2
+ log_prob = (torch.sum((means ** 2 * precisions), dim=1) -
+ 2. * torch.matmul(X, (means * precisions).T) +
+ torch.matmul(X ** 2, precisions.T))
+ log_p = -.5 * (n_features * np.log(2 * np.pi) + log_prob) + log_det
+ weighted_log_p = log_p + weights.unsqueeze(0).log()
+
+ # seems not work
+ # weighted_log_p = weighted_log_p - weighted_log_p.logsumexp(dim=1, keepdim=True)
+ return weighted_log_p
+
+ def _m_step(self, X: torch.Tensor, resp: torch.Tensor):
+ n_samples, _ = X.shape
+ weights, means, covariances = self._estimate_gaussian_parameters(X, resp)
+ return means, covariances, weights
+
+ def log_likehood(self, log_resp: torch.Tensor):
+ # N * K
+ return log_resp.logsumexp(dim=1).mean()
+
+ def predict_score(self, X: torch.Tensor):
+ return F.softmax(self._e_step(X, self.cluster_centers_, self.covariances, self.weights), dim=1)
diff --git a/torch_clustering/kmeans/__init__.py b/torch_clustering/kmeans/__init__.py
new file mode 100644
index 0000000..53cfd8a
--- /dev/null
+++ b/torch_clustering/kmeans/__init__.py
@@ -0,0 +1 @@
+from .kmeans import PyTorchKMeans
diff --git a/torch_clustering/kmeans/kmeans.py b/torch_clustering/kmeans/kmeans.py
new file mode 100644
index 0000000..68d5a2f
--- /dev/null
+++ b/torch_clustering/kmeans/kmeans.py
@@ -0,0 +1,192 @@
+# -*- coding: UTF-8 -*-
+'''
+@Project : torch_clustering
+@File : kmeans.py
+@Author : Zhizhong Huang from Fudan University
+@Homepage: https://hzzone.github.io/
+@Email : zzhuang19@fudan.edu.cn
+@Date : 2022/10/19 12:23 PM
+'''
+
+import numpy as np
+import torch
+import tqdm
+import torch.distributed as dist
+from ..__base__ import BasicClustering, pairwise_euclidean, pairwise_cosine
+from .kmeans_plus_plus import _kmeans_plusplus
+
+
+class PyTorchKMeans(BasicClustering):
+ def __init__(self,
+ metric='euclidean',
+ init='k-means++',
+ random_state=0,
+ n_clusters=8,
+ n_init=10,
+ max_iter=300,
+ tol=1e-4,
+ distributed=False,
+ verbose=True):
+ super().__init__(n_clusters=n_clusters,
+ init=init,
+ random_state=random_state,
+ n_init=n_init,
+ max_iter=max_iter,
+ tol=tol,
+ verbose=verbose,
+ distributed=distributed)
+ self.distance_metric = {'euclidean': pairwise_euclidean, 'cosine': pairwise_cosine}[metric]
+ # self.distance_metric = lambda a, b: torch.cdist(a, b, p=2.)
+ if isinstance(self.init, (np.ndarray, torch.Tensor)): self.n_init = 1
+
+ def initialize(self, X: torch.Tensor, random_state: int):
+ num_samples = len(X)
+ if isinstance(self.init, str):
+ g = torch.Generator()
+ g.manual_seed(random_state)
+ if self.init == 'random':
+ indices = torch.randperm(num_samples, generator=g)[:self.n_clusters]
+ init_state = X[indices]
+ elif self.init == 'k-means++':
+ init_state, _ = _kmeans_plusplus(X,
+ random_state=random_state,
+ n_clusters=self.n_clusters,
+ pairwise_distance=self.distance_metric)
+ # init_state = X[torch.randperm(num_samples, generator=g)[0]].unsqueeze(0)
+ # for k in range(1, self.n_clusters):
+ # d = torch.min(self.distance_metric(X, init_state), dim=1)[0]
+ # init_state = torch.cat([init_state, X[torch.argmax(d)].unsqueeze(0)], dim=0)
+ else:
+ raise NotImplementedError
+ elif isinstance(self.init, (np.ndarray, torch.Tensor)):
+ init_state = self.init.to(X)
+ else:
+ raise NotImplementedError
+
+ return init_state
+
+ def fit_predict(self, X: torch.Tensor):
+
+ tol = torch.mean(torch.var(X, dim=0)) * self.tol
+
+ min_inertia, best_states, best_labels = float('Inf'), None, None
+
+ random_states = torch.arange(self.n_init * self.world_size) + self.random_state
+ random_states = random_states[self.rank:len(random_states):self.world_size]
+ # g = torch.Generator()
+ # g.manual_seed(self.random_state)
+ # random_states = torch.randperm(10000, generator=g)[:self.n_init * self.world_size]
+ # random_states = random_states[self.rank:self.n_init * self.world_size:self.world_size]
+
+ self.stats = {'state': [], 'inertia': [], 'label': []}
+ for n_init in range(self.n_init):
+ random_state = int(random_states[n_init])
+ old_state = self.initialize(X, random_state=random_state)
+ old_labels, inertia = self.predict(X, old_state)
+
+ labels = old_labels
+
+ progress_bar = tqdm.tqdm(total=self.max_iter, disable=not self.verbose)
+
+ for n_iter in range(self.max_iter):
+
+ # https://discuss.pytorch.org/t/groupby-aggregate-mean-in-pytorch/45335/7
+ # n_samples = X.size(0)
+ # weight = torch.zeros(self.n_clusters, n_samples, dtype=X.dtype, device=X.device) # L, N
+ # weight[labels, torch.arange(n_samples)] = 1
+ # weight = F.normalize(weight, p=1, dim=1) # l1 normalization
+ # state = torch.mm(weight, X) # L, F
+ state = torch.zeros(self.n_clusters, X.size(1), dtype=X.dtype, device=X.device)
+ counts = torch.zeros(self.n_clusters, dtype=X.dtype, device=X.device) + 1e-6
+ classes, classes_counts = torch.unique(labels, return_counts=True)
+ counts[classes] = classes_counts.to(X)
+ state.index_add_(0, labels, X)
+ state = state / counts.view(-1, 1)
+
+ # d = self.distance_metric(X, state)
+ # inertia, labels = d.min(dim=1)
+ # inertia = inertia.sum()
+ labels, inertia = self.predict(X, state)
+
+ if inertia < min_inertia:
+ min_inertia = inertia
+ best_states, best_labels = state, labels
+
+ if self.verbose:
+ progress_bar.set_description(
+ f'nredo {n_init + 1}/{self.n_init:02d}, iteration {n_iter:03d} with inertia {inertia:.2f}')
+ progress_bar.update(n=1)
+
+ center_shift = self.distance_metric(old_state, state, pairwise=False)
+
+ if torch.equal(labels, old_labels):
+ # First check the labels for strict convergence.
+ if self.verbose:
+ print(f"Converged at iteration {n_iter}: strict convergence.")
+ break
+ else:
+ # center_shift = self.distance_metric(old_state, state).diag().sum()
+ # No strict convergence, check for tol based convergence.
+ # center_shift_tot = (center_shift ** 2).sum()
+ center_shift_tot = center_shift.sum()
+ if center_shift_tot <= tol:
+ if self.verbose:
+ print(
+ f"Converged at iteration {n_iter}: center shift "
+ f"{center_shift_tot} within tolerance {tol} "
+ f"and min inertia {min_inertia.item()}."
+ )
+ break
+
+ old_labels[:] = labels
+ old_state = state
+ progress_bar.close()
+ self.stats['state'].append(old_state)
+ self.stats['inertia'].append(inertia)
+ self.stats['label'].append(old_labels)
+
+ self.stats['state'] = torch.stack(self.stats['state'])
+ self.stats['inertia'] = torch.stack(self.stats['inertia'])
+ self.stats['label'] = torch.stack(self.stats['label'])
+ if self.distributed:
+ min_inertia = self.distributed_sync(min_inertia)
+ best_idx = torch.argmin(min_inertia).item()
+ min_inertia = min_inertia[best_idx]
+ dist.broadcast(best_labels, src=best_idx)
+ dist.broadcast(best_states, src=best_idx)
+ self.stats['state'] = self.distributed_sync(self.stats['state'])
+ self.stats['inertia'] = self.distributed_sync(self.stats['inertia'])
+ self.stats['label'] = self.distributed_sync(self.stats['label'])
+
+ if self.verbose:
+ print(f"Final min inertia {min_inertia.item()}.")
+
+ self.cluster_centers_ = best_states
+ return best_labels
+
+ def predict(self, X: torch.Tensor, cluster_centers_=None):
+ if cluster_centers_ is None:
+ cluster_centers_ = self.cluster_centers_
+ split_size = min(4096, X.size(0))
+ inertia, pred_labels = 0., []
+ for f in X.split(split_size, dim=0):
+ d = self.distance_metric(f, cluster_centers_)
+ inertia_, labels_ = d.min(dim=1)
+ inertia += inertia_.sum()
+ pred_labels.append(labels_)
+ return torch.cat(pred_labels, dim=0), inertia
+
+
+if __name__ == '__main__':
+ torch.cuda.set_device(1)
+ clustering_model = PyTorchKMeans(metric='cosine',
+ init='k-means++',
+ random_state=0,
+ n_clusters=1000,
+ n_init=10,
+ max_iter=300,
+ tol=1e-4,
+ distributed=False,
+ verbose=True)
+ X = torch.randn(1280000, 256).cuda()
+ clustering_model.fit_predict(X)
diff --git a/torch_clustering/kmeans/kmeans_plus_plus.py b/torch_clustering/kmeans/kmeans_plus_plus.py
new file mode 100644
index 0000000..fa081df
--- /dev/null
+++ b/torch_clustering/kmeans/kmeans_plus_plus.py
@@ -0,0 +1,132 @@
+# -*- coding: UTF-8 -*-
+'''
+@Project : torch_clustering
+@File : kmeans_plus_plus.py
+@Author : Zhizhong Huang from Fudan University
+@Homepage: https://hzzone.github.io/
+@Email : zzhuang19@fudan.edu.cn
+@Date : 2022/10/19 12:23 PM
+'''
+
+import torch
+import numpy as np
+import warnings
+
+
+def stable_cumsum(arr, dim=None, rtol=1e-05, atol=1e-08):
+ """Use high precision for cumsum and check that final value matches sum.
+ Parameters
+ ----------
+ arr : array-like
+ To be cumulatively summed as flat.
+ axis : int, default=None
+ Axis along which the cumulative sum is computed.
+ The default (None) is to compute the cumsum over the flattened array.
+ rtol : float, default=1e-05
+ Relative tolerance, see ``np.allclose``.
+ atol : float, default=1e-08
+ Absolute tolerance, see ``np.allclose``.
+ """
+ if dim is None:
+ arr = arr.flatten()
+ dim = 0
+ out = torch.cumsum(arr, dim=dim, dtype=torch.float64)
+ expected = torch.sum(arr, dim=dim, dtype=torch.float64)
+ if not torch.all(torch.isclose(out.take(torch.Tensor([-1]).long().to(arr.device)),
+ expected, rtol=rtol,
+ atol=atol, equal_nan=True)):
+ warnings.warn('cumsum was found to be unstable: '
+ 'its last element does not correspond to sum',
+ RuntimeWarning)
+ return out
+
+
+def _kmeans_plusplus(X,
+ n_clusters,
+ random_state,
+ pairwise_distance,
+ n_local_trials=None):
+ """Computational component for initialization of n_clusters by
+ k-means++. Prior validation of data is assumed.
+ Parameters
+ ----------
+ X : {ndarray, sparse matrix} of shape (n_samples, n_features)
+ The data to pick seeds for.
+ n_clusters : int
+ The number of seeds to choose.
+ random_state : RandomState instance
+ The generator used to initialize the centers.
+ See :term:`Glossary `.
+ n_local_trials : int, default=None
+ The number of seeding trials for each center (except the first),
+ of which the one reducing inertia the most is greedily chosen.
+ Set to None to make the number of trials depend logarithmically
+ on the number of seeds (2+log(k)); this is the default.
+ Returns
+ -------
+ centers : ndarray of shape (n_clusters, n_features)
+ The inital centers for k-means.
+ indices : ndarray of shape (n_clusters,)
+ The index location of the chosen centers in the data array X. For a
+ given index and center, X[index] = center.
+ """
+ n_samples, n_features = X.shape
+
+ generator = torch.Generator(device=str(X.device))
+ generator.manual_seed(random_state)
+
+ centers = torch.empty((n_clusters, n_features), dtype=X.dtype, device=X.device)
+
+ # Set the number of local seeding trials if none is given
+ if n_local_trials is None:
+ # This is what Arthur/Vassilvitskii tried, but did not report
+ # specific results for other than mentioning in the conclusion
+ # that it helped.
+ n_local_trials = 2 + int(np.log(n_clusters))
+
+ # Pick first center randomly and track index of point
+ # center_id = random_state.randint(n_samples)
+ center_id = torch.randint(n_samples, (1,), generator=generator, device=X.device)
+
+ indices = torch.full((n_clusters,), -1, dtype=torch.int, device=X.device)
+ centers[0] = X[center_id]
+ indices[0] = center_id
+
+ # Initialize list of closest distances and calculate current potential
+ closest_dist_sq = pairwise_distance(
+ centers[0, None], X)
+ current_pot = closest_dist_sq.sum()
+
+ # Pick the remaining n_clusters-1 points
+ for c in range(1, n_clusters):
+ # Choose center candidates by sampling with probability proportional
+ # to the squared distance to the closest existing center
+ # rand_vals = random_state.random_sample(n_local_trials) * current_pot
+ rand_vals = torch.rand(n_local_trials, generator=generator, device=X.device) * current_pot
+
+ candidate_ids = torch.searchsorted(stable_cumsum(closest_dist_sq),
+ rand_vals)
+ # XXX: numerical imprecision can result in a candidate_id out of range
+ torch.clip(candidate_ids, None, closest_dist_sq.numel() - 1,
+ out=candidate_ids)
+
+ # Compute distances to center candidates
+ distance_to_candidates = pairwise_distance(
+ X[candidate_ids], X)
+
+ # update closest distances squared and potential for each candidate
+ torch.minimum(closest_dist_sq, distance_to_candidates,
+ out=distance_to_candidates)
+ candidates_pot = distance_to_candidates.sum(dim=1)
+
+ # Decide which candidate is the best
+ best_candidate = torch.argmin(candidates_pot)
+ current_pot = candidates_pot[best_candidate]
+ closest_dist_sq = distance_to_candidates[best_candidate]
+ best_candidate = candidate_ids[best_candidate]
+
+ # Permanently add best center candidate found in local tries
+ centers[c] = X[best_candidate]
+ indices[c] = best_candidate
+
+ return centers, indices
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..7d7d382
--- /dev/null
+++ b/train.py
@@ -0,0 +1,365 @@
+import argparse
+
+import math
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.optim import SGD, lr_scheduler
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+from data.augmentations import get_transform
+from data.get_datasets import get_datasets, get_class_splits
+
+from util.general_utils import AverageMeter, init_experiment
+from util.cluster_and_log_utils import log_accs_from_preds
+from config import exp_root
+from model import DINOHead, info_nce_logits, SupConLoss, DistillLoss, ContrastiveLearningViewGenerator, get_params_groups
+
+from torch.utils.data.sampler import Sampler
+from util.general_utils import NNBatchSampler, STML_loss_simgcd
+
+def train(student, simgcd_train_loader, train_loader, train_dataset, test_loader, unlabelled_train_loader, args):
+ params_groups = get_params_groups(student)
+ optimizer = SGD(params_groups, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
+ fp16_scaler = None
+ if args.fp16:
+ fp16_scaler = torch.cuda.amp.GradScaler()
+
+ exp_lr_scheduler = lr_scheduler.CosineAnnealingLR(
+ optimizer,
+ T_max=args.epochs,
+ eta_min=args.lr * 1e-3,
+ )
+
+
+ cluster_criterion = DistillLoss(
+ args.warmup_teacher_temp_epochs,
+ args.epochs,
+ args.n_views,
+ args.warmup_teacher_temp,
+ args.teacher_temp,
+ )
+
+ # Kmeans loss
+ from kmeans_loss import Kmeans_Loss
+ kmeans_loss_f = Kmeans_Loss(n_clusters=args.mlp_out_dim)
+
+ # # inductive
+ # best_test_acc_lab = 0
+ # # transductive
+ # best_train_acc_lab = 0
+ # best_train_acc_ubl = 0
+ # best_train_acc_all = 0
+
+ ori_train_loader = train_loader
+ for epoch in range(args.epochs):
+ loss_record = AverageMeter()
+
+ loader = simgcd_train_loader
+ if epoch > args.stml_warmup_ep:
+ assert isinstance(student, nn.Sequential)
+ balanced_sampler = NNBatchSampler(train_dataset, student[0], ori_train_loader, args.batch_size, nn_per_image=args.nn_per_image, using_feat=True, is_norm=False)
+ train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_sampler=balanced_sampler, pin_memory=True)
+ stml_crit = STML_loss_simgcd(topk=args.nn_per_image, view=args.n_views)
+ # loader = train_loader
+
+ student.train()
+ # TODO: try loading two batches from two loaders, and compute loss on both
+ for batch_idx, batch in enumerate(loader):
+ if epoch > args.stml_warmup_ep:
+ try:
+ stml_batch = next(iter(train_loader))
+ except StopIteration:
+ train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_sampler=balanced_sampler, pin_memory=True)
+ stml_batch = next(iter(train_loader))
+
+ images, class_labels, uq_idxs, mask_lab = batch
+ mask_lab = mask_lab[:, 0]
+
+ class_labels, mask_lab = class_labels.cuda(non_blocking=True), mask_lab.cuda(non_blocking=True).bool()
+ images = torch.cat(images, dim=0).cuda(non_blocking=True)
+
+ with torch.cuda.amp.autocast(fp16_scaler is not None):
+ # __import__("ipdb").set_trace()
+ # student_proj, student_out = student(images)
+ # teacher_out = student_out.detach()
+
+ if epoch > args.stml_warmup_ep:
+ # __import__("ipdb").set_trace()
+ stml_images, stml_labels, stml_uq_idxs, stml_mask_lab = stml_batch
+ stml_images = torch.cat(stml_images, dim=0).cuda(non_blocking=True)
+
+ all_images = torch.cat([images, stml_images], dim=0)
+ all_proj, all_out = student(all_images)
+ student_proj, stml_proj = all_proj.chunk(2)
+ student_out, stml_out = all_out.chunk(2)
+ teacher_out = student_out.detach()
+
+ # representation learning, contextuality
+ # PCA as the teacher
+ # with torch.no_grad():
+ # U, S, V = torch.pca_lowrank(stml_proj, q=stml_proj.shape[-1] // 2, )
+ # t_emb = torch.matmul(stml_proj, V)
+
+ # cls head as the teacher
+ t_emb = stml_out
+ stml_loss = stml_crit(stml_proj, t_emb, torch.cat([stml_uq_idxs] * args.n_views))
+ else:
+ stml_loss = torch.zeros(1).cuda()[0]
+
+ student_proj, student_out = student(images)
+ teacher_out = student_out.detach()
+
+ # clustering, sup
+ sup_logits = torch.cat([f[mask_lab] for f in (student_out / 0.1).chunk(2)], dim=0)
+ sup_labels = torch.cat([class_labels[mask_lab] for _ in range(2)], dim=0)
+ cls_loss = nn.CrossEntropyLoss()(sup_logits, sup_labels)
+
+ # clustering, unsup
+ cluster_loss = cluster_criterion(student_out, teacher_out, epoch)
+ avg_probs = (student_out / 0.1).softmax(dim=1).mean(dim=0)
+ me_max_loss = - torch.sum(torch.log(avg_probs**(-avg_probs))) + math.log(float(len(avg_probs)))
+ cluster_loss += args.memax_weight * me_max_loss
+
+ # represent learning, unsup
+ contrastive_logits, contrastive_labels = info_nce_logits(features=student_proj)
+ contrastive_loss = torch.nn.CrossEntropyLoss()(contrastive_logits, contrastive_labels)
+
+ if epoch > args.stml_warmup_ep:
+ # rep learning, proto
+ kmeans_loss = kmeans_loss_f(student_proj.chunk(2)[0], student_proj.chunk(2)[1], student_out.chunk(2)[0].argmax(1))
+ else:
+ kmeans_loss = torch.zeros(1).cuda()[0]
+
+ # representation learning, sup
+ student_proj = torch.cat([f[mask_lab].unsqueeze(1) for f in student_proj.chunk(2)], dim=1)
+ student_proj = torch.nn.functional.normalize(student_proj, dim=-1)
+ sup_con_labels = class_labels[mask_lab]
+ sup_con_loss = SupConLoss()(student_proj, labels=sup_con_labels)
+
+ pstr = ''
+ pstr += f'cls_loss: {cls_loss.item():.4f} '
+ pstr += f'cluster_loss: {cluster_loss.item():.4f} '
+ pstr += f'sup_con_loss: {sup_con_loss.item():.4f} '
+ pstr += f'contrastive_loss: {contrastive_loss.item():.4f} '
+ pstr += f'stml_loss: {stml_loss.item():.4f} '
+ pstr += f'kmeans_loss: {kmeans_loss.item():.4f}'
+
+ loss = 0
+ loss += (1 - args.sup_weight) * cluster_loss + args.sup_weight * cls_loss
+ loss += (1 - args.sup_weight) * contrastive_loss + args.sup_weight * sup_con_loss
+ loss += args.stml_weight * stml_loss
+ loss += args.stml_weight * kmeans_loss
+
+ # Train acc
+ loss_record.update(loss.item(), class_labels.size(0))
+ optimizer.zero_grad()
+ if fp16_scaler is None:
+ loss.backward()
+ optimizer.step()
+ else:
+ fp16_scaler.scale(loss).backward()
+ fp16_scaler.step(optimizer)
+ fp16_scaler.update()
+
+ if batch_idx % args.print_freq == 0:
+ args.logger.info('Epoch: [{}][{}/{}]\t loss {:.5f}\t {}'
+ .format(epoch, batch_idx, len(train_loader), loss.item(), pstr))
+
+ args.logger.info('Train Epoch: {} Avg Loss: {:.4f} '.format(epoch, loss_record.avg))
+
+ args.logger.info('Testing on unlabelled examples in the training data...')
+ all_acc, old_acc, new_acc = test(student, unlabelled_train_loader, epoch=epoch, save_name='Train ACC Unlabelled', args=args)
+ # args.logger.info('Testing on disjoint test set...')
+ # all_acc_test, old_acc_test, new_acc_test = test(student, test_loader, epoch=epoch, save_name='Test ACC', args=args)
+
+
+ args.logger.info('Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc, new_acc))
+ # args.logger.info('Test Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc_test, old_acc_test, new_acc_test))
+
+ # Step schedule
+ exp_lr_scheduler.step()
+
+ save_dict = {
+ 'model': student.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'epoch': epoch + 1,
+ }
+
+ torch.save(save_dict, args.model_path)
+ args.logger.info("model saved to {}.".format(args.model_path))
+
+ # if old_acc_test > best_test_acc_lab:
+ #
+ # args.logger.info(f'Best ACC on old Classes on disjoint test set: {old_acc_test:.4f}...')
+ # args.logger.info('Best Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc, new_acc))
+ #
+ # torch.save(save_dict, args.model_path[:-3] + f'_best.pt')
+ # args.logger.info("model saved to {}.".format(args.model_path[:-3] + f'_best.pt'))
+ #
+ # # inductive
+ # best_test_acc_lab = old_acc_test
+ # # transductive
+ # best_train_acc_lab = old_acc
+ # best_train_acc_ubl = new_acc
+ # best_train_acc_all = all_acc
+ #
+ # args.logger.info(f'Exp Name: {args.exp_name}')
+ # args.logger.info(f'Metrics with best model on test set: All: {best_train_acc_all:.4f} Old: {best_train_acc_lab:.4f} New: {best_train_acc_ubl:.4f}')
+
+
+def test(model, test_loader, epoch, save_name, args):
+
+ model.eval()
+
+ preds, targets = [], []
+ mask = np.array([])
+ for batch_idx, (images, label, _) in enumerate(tqdm(test_loader)):
+ images = images.cuda(non_blocking=True)
+ with torch.no_grad():
+ _, logits = model(images)
+ preds.append(logits.argmax(1).cpu().numpy())
+ targets.append(label.cpu().numpy())
+ mask = np.append(mask, np.array([True if x.item() in range(len(args.train_classes)) else False for x in label]))
+
+ preds = np.concatenate(preds)
+ targets = np.concatenate(targets)
+ all_acc, old_acc, new_acc = log_accs_from_preds(y_true=targets, y_pred=preds, mask=mask,
+ T=epoch, eval_funcs=args.eval_funcs, save_name=save_name,
+ args=args)
+
+ return all_acc, old_acc, new_acc
+
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser(description='cluster', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument('--batch_size', default=128, type=int)
+ parser.add_argument('--num_workers', default=8, type=int)
+ parser.add_argument('--eval_funcs', nargs='+', help='Which eval functions to use', default=['v2', 'v2p'])
+
+ parser.add_argument('--warmup_model_dir', type=str, default=None)
+ parser.add_argument('--dataset_name', type=str, default='scars', help='options: cifar10, cifar100, imagenet_100, cub, scars, fgvc_aricraft, herbarium_19')
+ parser.add_argument('--prop_train_labels', type=float, default=0.5)
+ parser.add_argument('--use_ssb_splits', action='store_true', default=True)
+
+ parser.add_argument('--grad_from_block', type=int, default=11)
+ parser.add_argument('--lr', type=float, default=0.1)
+ parser.add_argument('--gamma', type=float, default=0.1)
+ parser.add_argument('--momentum', type=float, default=0.9)
+ parser.add_argument('--weight_decay', type=float, default=1e-4)
+ parser.add_argument('--epochs', default=200, type=int)
+ parser.add_argument('--exp_root', type=str, default=exp_root)
+ parser.add_argument('--transform', type=str, default='imagenet')
+ parser.add_argument('--sup_weight', type=float, default=0.35)
+ parser.add_argument('--n_views', default=2, type=int)
+
+ parser.add_argument('--memax_weight', type=float, default=2)
+ parser.add_argument('--warmup_teacher_temp', default=0.07, type=float, help='Initial value for the teacher temperature.')
+ parser.add_argument('--teacher_temp', default=0.04, type=float, help='Final value (after linear warmup)of the teacher temperature.')
+ parser.add_argument('--warmup_teacher_temp_epochs', default=30, type=int, help='Number of warmup epochs for the teacher temperature.')
+
+ parser.add_argument('--nn_per_image', type=int, default=4)
+ parser.add_argument('--stml_weight', type=float, default=0.1)
+ parser.add_argument('--stml_warmup_ep', type=int, default=50)
+
+ parser.add_argument('--fp16', action='store_true', default=False)
+ parser.add_argument('--print_freq', default=10, type=int)
+ parser.add_argument('--exp_name', default=None, type=str)
+
+ # ----------------------
+ # INIT
+ # ----------------------
+ args = parser.parse_args()
+ device = torch.device('cuda:0')
+ args = get_class_splits(args)
+
+ args.num_labeled_classes = len(args.train_classes)
+ args.num_unlabeled_classes = len(args.unlabeled_classes)
+
+ init_experiment(args, runner_name=['simgcd'])
+ args.logger.info(f'Using evaluation function {args.eval_funcs[0]} to print results')
+
+ torch.backends.cudnn.benchmark = True
+
+ # ----------------------
+ # BASE MODEL
+ # ----------------------
+ args.interpolation = 3
+ args.crop_pct = 0.875
+
+
+ backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
+
+ if args.warmup_model_dir is not None:
+ args.logger.info(f'Loading weights from {args.warmup_model_dir}')
+ backbone.load_state_dict(torch.load(args.warmup_model_dir, map_location='cpu'))
+
+ # NOTE: Hardcoded image size as we do not finetune the entire ViT model
+ args.image_size = 224
+ args.feat_dim = 768
+ args.num_mlp_layers = 3
+ args.mlp_out_dim = args.num_labeled_classes + args.num_unlabeled_classes
+
+ # ----------------------
+ # HOW MUCH OF BASE MODEL TO FINETUNE
+ # ----------------------
+ for m in backbone.parameters():
+ m.requires_grad = False
+
+ # Only finetune layers from block 'args.grad_from_block' onwards
+ for name, m in backbone.named_parameters():
+ if 'block' in name:
+ block_num = int(name.split('.')[1])
+ if block_num >= args.grad_from_block:
+ m.requires_grad = True
+
+
+ args.logger.info('model build')
+
+ # --------------------
+ # CONTRASTIVE TRANSFORM
+ # --------------------
+ train_transform, test_transform = get_transform(args.transform, image_size=args.image_size, args=args)
+ train_transform = ContrastiveLearningViewGenerator(base_transform=train_transform, n_views=args.n_views)
+ # --------------------
+ # DATASETS
+ # --------------------
+ train_dataset, test_dataset, unlabelled_train_examples_test, datasets = get_datasets(args.dataset_name,
+ train_transform,
+ test_transform,
+ args)
+
+ # --------------------
+ # SAMPLER
+ # Sampler which balances labelled and unlabelled examples in each batch
+ # --------------------
+ label_len = len(train_dataset.labelled_dataset)
+ unlabelled_len = len(train_dataset.unlabelled_dataset)
+ sample_weights = [1 if i < label_len else label_len / unlabelled_len for i in range(len(train_dataset))]
+ sample_weights = torch.DoubleTensor(sample_weights)
+ sampler = torch.utils.data.WeightedRandomSampler(sample_weights, num_samples=len(train_dataset))
+
+ # --------------------
+ # DATALOADERS
+ # --------------------
+ simgcd_train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False,
+ sampler=sampler, drop_last=True, pin_memory=True)
+ train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False, pin_memory=True)
+ test_loader_unlabelled = DataLoader(unlabelled_train_examples_test, num_workers=args.num_workers,
+ batch_size=256, shuffle=False, pin_memory=False)
+ # test_loader_labelled = DataLoader(test_dataset, num_workers=args.num_workers,
+ # batch_size=256, shuffle=False, pin_memory=False)
+
+ # ----------------------
+ # PROJECTION HEAD
+ # ----------------------
+ projector = DINOHead(in_dim=args.feat_dim, out_dim=args.mlp_out_dim, nlayers=args.num_mlp_layers)
+ model = nn.Sequential(backbone, projector).to(device)
+
+ # ----------------------
+ # TRAIN
+ # ----------------------
+ # train(model, train_loader, test_loader_labelled, test_loader_unlabelled, args)
+ train(model, simgcd_train_loader, train_loader, train_dataset, None, test_loader_unlabelled, args)
diff --git a/train_mp.py b/train_mp.py
new file mode 100644
index 0000000..c405ef1
--- /dev/null
+++ b/train_mp.py
@@ -0,0 +1,325 @@
+import argparse
+import os
+
+import math
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+import torch.backends.cudnn as cudnn
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+from data.augmentations import get_transform
+from data.get_datasets import get_datasets, get_class_splits
+
+from util.general_utils import AverageMeter, init_experiment, DistributedWeightedSampler
+from util.cluster_and_log_utils import log_accs_from_preds
+from config import exp_root
+from model import DINOHead, info_nce_logits, SupConLoss, DistillLoss, ContrastiveLearningViewGenerator, get_params_groups
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(description='cluster', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+
+ parser.add_argument('--batch_size', default=128, type=int)
+ parser.add_argument('--num_workers', default=8, type=int)
+ parser.add_argument('--eval_funcs', nargs='+', help='Which eval functions to use', default=['v2', 'v2b'])
+
+ parser.add_argument('--warmup_model_dir', type=str, default=None)
+ parser.add_argument('--dataset_name', type=str, default='scars', help='options: cifar10, cifar100, imagenet_100, cub, scars, fgvc_aricraft, herbarium_19')
+ parser.add_argument('--prop_train_labels', type=float, default=0.5)
+ parser.add_argument('--use_ssb_splits', action='store_true', default=True)
+
+ parser.add_argument('--grad_from_block', type=int, default=11)
+ parser.add_argument('--lr', type=float, default=0.1)
+ parser.add_argument('--gamma', type=float, default=0.1)
+ parser.add_argument('--momentum', type=float, default=0.9)
+ parser.add_argument('--weight_decay', type=float, default=1e-4)
+ parser.add_argument('--epochs', default=200, type=int)
+ parser.add_argument('--exp_root', type=str, default=exp_root)
+ parser.add_argument('--transform', type=str, default='imagenet')
+ parser.add_argument('--sup_weight', type=float, default=0.35)
+ parser.add_argument('--n_views', default=2, type=int)
+
+ parser.add_argument('--memax_weight', type=float, default=2)
+ parser.add_argument('--warmup_teacher_temp', default=0.07, type=float, help='Initial value for the teacher temperature.')
+ parser.add_argument('--teacher_temp', default=0.04, type=float, help='Final value (after linear warmup)of the teacher temperature.')
+ parser.add_argument('--warmup_teacher_temp_epochs', default=30, type=int, help='Number of warmup epochs for the teacher temperature.')
+
+ parser.add_argument('--fp16', action='store_true', default=False)
+ parser.add_argument('--print_freq', default=10, type=int)
+ parser.add_argument('--exp_name', default=None, type=str)
+
+ # ----------------------
+ # INIT
+ # ----------------------
+ args = parser.parse_args()
+ args = get_class_splits(args)
+
+ args.num_labeled_classes = len(args.train_classes)
+ args.num_unlabeled_classes = len(args.unlabeled_classes)
+
+ if os.environ["LOCAL_RANK"] is not None:
+ args.local_rank = int(os.environ["LOCAL_RANK"])
+
+ return args
+
+
+def main(args):
+ # ----------------------
+ # BASE MODEL
+ # ----------------------
+ args.interpolation = 3
+ args.crop_pct = 0.875
+
+ backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
+
+ if args.warmup_model_dir is not None:
+ if dist.get_rank() == 0:
+ args.logger.info(f'Loading weights from {args.warmup_model_dir}')
+ backbone.load_state_dict(torch.load(args.warmup_model_dir, map_location='cpu'))
+
+ # NOTE: Hardcoded image size as we do not finetune the entire ViT model
+ args.image_size = 224
+ args.feat_dim = 768
+ args.num_mlp_layers = 3
+ args.mlp_out_dim = args.num_labeled_classes + args.num_unlabeled_classes
+
+ # ----------------------
+ # HOW MUCH OF BASE MODEL TO FINETUNE
+ # ----------------------
+ for m in backbone.parameters():
+ m.requires_grad = False
+
+ # Only finetune layers from block 'args.grad_from_block' onwards
+ for name, m in backbone.named_parameters():
+ if 'block' in name:
+ block_num = int(name.split('.')[1])
+ if block_num >= args.grad_from_block:
+ m.requires_grad = True
+
+ if dist.get_rank() == 0:
+ args.logger.info('model build')
+
+ # --------------------
+ # CONTRASTIVE TRANSFORM
+ # --------------------
+ train_transform, test_transform = get_transform(args.transform, image_size=args.image_size, args=args)
+ train_transform = ContrastiveLearningViewGenerator(base_transform=train_transform, n_views=args.n_views)
+ # --------------------
+ # DATASETS
+ # --------------------
+ train_dataset, test_dataset, unlabelled_train_examples_test, datasets = get_datasets(args.dataset_name,
+ train_transform,
+ test_transform,
+ args)
+
+ # --------------------
+ # SAMPLER
+ # Sampler which balances labelled and unlabelled examples in each batch
+ # --------------------
+ label_len = len(train_dataset.labelled_dataset)
+ unlabelled_len = len(train_dataset.unlabelled_dataset)
+ sample_weights = [1 if i < label_len else label_len / unlabelled_len for i in range(len(train_dataset))]
+ sample_weights = torch.DoubleTensor(sample_weights)
+ train_sampler = DistributedWeightedSampler(train_dataset, sample_weights, num_samples=len(train_dataset))
+ unlabelled_train_sampler = torch.utils.data.distributed.DistributedSampler(unlabelled_train_examples_test)
+ # test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)
+ # --------------------
+ # DATALOADERS
+ # --------------------
+ train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False,
+ sampler=train_sampler, drop_last=True, pin_memory=True)
+ unlabelled_train_loader = DataLoader(unlabelled_train_examples_test, num_workers=args.num_workers, batch_size=256,
+ shuffle=False, sampler=unlabelled_train_sampler, pin_memory=False)
+ # test_loader = DataLoader(test_dataset, num_workers=args.num_workers, batch_size=256,
+ # shuffle=False, sampler=test_sampler, pin_memory=False)
+
+ # ----------------------
+ # PROJECTION HEAD
+ # ----------------------
+ projector = DINOHead(in_dim=args.feat_dim, out_dim=args.mlp_out_dim, nlayers=args.num_mlp_layers)
+ model = nn.Sequential(backbone, projector).cuda()
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
+
+ params_groups = get_params_groups(model)
+ optimizer = torch.optim.SGD(
+ params_groups,
+ lr=args.lr * (args.batch_size * dist.get_world_size() / 128), # linear scaling rule
+ momentum=args.momentum,
+ weight_decay=args.weight_decay
+ )
+
+ fp16_scaler = None
+ if args.fp16:
+ fp16_scaler = torch.cuda.amp.GradScaler()
+
+ exp_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
+ optimizer,
+ T_max=args.epochs,
+ eta_min=args.lr * (args.batch_size * dist.get_world_size() / 128) * 1e-3,
+ )
+
+ cluster_criterion = DistillLoss(
+ args.warmup_teacher_temp_epochs,
+ args.epochs,
+ args.n_views,
+ args.warmup_teacher_temp,
+ args.teacher_temp,
+ )
+
+ # # inductive
+ # best_test_acc_lab = 0
+ # # transductive
+ # best_train_acc_lab = 0
+ # best_train_acc_ubl = 0
+ # best_train_acc_all = 0
+
+ for epoch in range(args.epochs):
+ train_sampler.set_epoch(epoch)
+ train(model, train_loader, optimizer, fp16_scaler, exp_lr_scheduler, cluster_criterion, epoch, args)
+
+ unlabelled_train_sampler.set_epoch(epoch)
+ # test_sampler.set_epoch(epoch)
+ if dist.get_rank() == 0:
+ args.logger.info('Testing on unlabelled examples in the training data...')
+ all_acc, old_acc, new_acc = test(model, unlabelled_train_loader, epoch=epoch, save_name='Train ACC Unlabelled', args=args)
+ # if dist.get_rank() == 0:
+ # args.logger.info('Testing on disjoint test set...')
+ # all_acc_test, old_acc_test, new_acc_test = test(model, test_loader, epoch=epoch, save_name='Test ACC', args=args)
+
+ if dist.get_rank() == 0:
+ args.logger.info('Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc, new_acc))
+ # args.logger.info('Test Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc_test, old_acc_test, new_acc_test))
+
+ save_dict = {
+ 'model': model.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'epoch': epoch + 1,
+ }
+
+ torch.save(save_dict, args.model_path)
+ args.logger.info("model saved to {}.".format(args.model_path))
+
+ # if old_acc_test > best_test_acc_lab and dist.get_rank() == 0:
+ # args.logger.info(f'Best ACC on old Classes on disjoint test set: {old_acc_test:.4f}...')
+ # args.logger.info('Best Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc, new_acc))
+ #
+ # torch.save(save_dict, args.model_path[:-3] + f'_best.pt')
+ # args.logger.info("model saved to {}.".format(args.model_path[:-3] + f'_best.pt'))
+ #
+ # # inductive
+ # best_test_acc_lab = old_acc_test
+ # # transductive
+ # best_train_acc_lab = old_acc
+ # best_train_acc_ubl = new_acc
+ # best_train_acc_all = all_acc
+ #
+ # if dist.get_rank() == 0:
+ # args.logger.info(f'Exp Name: {args.exp_name}')
+ # args.logger.info(f'Metrics with best model on test set: All: {best_train_acc_all:.4f} Old: {best_train_acc_lab:.4f} New: {best_train_acc_ubl:.4f}')
+
+
+def train(student, train_loader, optimizer, scaler, scheduler, cluster_criterion, epoch, args):
+ loss_record = AverageMeter()
+
+ student.train()
+ for batch_idx, batch in enumerate(train_loader):
+ images, class_labels, uq_idxs, mask_lab = batch
+ mask_lab = mask_lab[:, 0]
+
+ class_labels, mask_lab = class_labels.cuda(non_blocking=True), mask_lab.cuda(non_blocking=True).bool()
+ images = torch.cat(images, dim=0).cuda(non_blocking=True)
+
+ with torch.cuda.amp.autocast(scaler is not None):
+ student_proj, student_out = student(images)
+ teacher_out = student_out.detach()
+
+ # clustering, sup
+ sup_logits = torch.cat([f[mask_lab] for f in (student_out / 0.1).chunk(2)], dim=0)
+ sup_labels = torch.cat([class_labels[mask_lab] for _ in range(2)], dim=0)
+ cls_loss = nn.CrossEntropyLoss()(sup_logits, sup_labels)
+
+ # clustering, unsup
+ cluster_loss = cluster_criterion(student_out, teacher_out, epoch)
+ avg_probs = (student_out / 0.1).softmax(dim=1).mean(dim=0)
+ me_max_loss = - torch.sum(torch.log(avg_probs**(-avg_probs))) + math.log(float(len(avg_probs)))
+ cluster_loss += args.memax_weight * me_max_loss
+
+ # represent learning, unsup
+ contrastive_logits, contrastive_labels = info_nce_logits(features=student_proj)
+ contrastive_loss = torch.nn.CrossEntropyLoss()(contrastive_logits, contrastive_labels)
+
+ # representation learning, sup
+ student_proj = torch.cat([f[mask_lab].unsqueeze(1) for f in student_proj.chunk(2)], dim=1)
+ student_proj = torch.nn.functional.normalize(student_proj, dim=-1)
+ sup_con_labels = class_labels[mask_lab]
+ sup_con_loss = SupConLoss()(student_proj, labels=sup_con_labels)
+
+ pstr = ''
+ pstr += f'cls_loss: {cls_loss.item():.4f} '
+ pstr += f'cluster_loss: {cluster_loss.item():.4f} '
+ pstr += f'sup_con_loss: {sup_con_loss.item():.4f} '
+ pstr += f'contrastive_loss: {contrastive_loss.item():.4f} '
+
+ loss = 0
+ loss += (1 - args.sup_weight) * cluster_loss + args.sup_weight * cls_loss
+ loss += (1 - args.sup_weight) * contrastive_loss + args.sup_weight * sup_con_loss
+
+ # Train acc
+ loss_record.update(loss.item(), class_labels.size(0))
+ optimizer.zero_grad()
+ if scaler is None:
+ loss.backward()
+ optimizer.step()
+ else:
+ scaler.scale(loss).backward()
+ scaler.step(optimizer)
+ scaler.update()
+
+ if batch_idx % args.print_freq == 0 and dist.get_rank() == 0:
+ args.logger.info('Epoch: [{}][{}/{}]\t loss {:.5f}\t {}'
+ .format(epoch, batch_idx, len(train_loader), loss.item(), pstr))
+ # Step schedule
+ scheduler.step()
+
+ if dist.get_rank() == 0:
+ args.logger.info('Train Epoch: {} Avg Loss: {:.4f} '.format(epoch, loss_record.avg))
+
+
+def test(model, test_loader, epoch, save_name, args):
+
+ model.eval()
+
+ preds, targets = [], []
+ mask = np.array([])
+ for batch_idx, (images, label, _) in enumerate(tqdm(test_loader)):
+ images = images.cuda(non_blocking=True)
+ with torch.no_grad():
+ _, logits = model(images)
+ preds.append(logits.argmax(1).cpu().numpy())
+ targets.append(label.cpu().numpy())
+ mask = np.append(mask, np.array([True if x.item() in range(len(args.train_classes)) else False for x in label]))
+
+ preds = np.concatenate(preds)
+ targets = np.concatenate(targets)
+ all_acc, old_acc, new_acc = log_accs_from_preds(y_true=targets, y_pred=preds, mask=mask,
+ T=epoch, eval_funcs=args.eval_funcs, save_name=save_name,
+ args=args)
+
+ return all_acc, old_acc, new_acc
+
+
+if __name__ == '__main__':
+ args = get_parser()
+
+ torch.cuda.set_device(args.local_rank)
+ torch.distributed.init_process_group(backend='nccl', init_method='env://')
+ cudnn.benchmark = True
+
+ if dist.get_rank() == 0:
+ init_experiment(args, runner_name=['simgcd'])
+ args.logger.info(f'Using evaluation function {args.eval_funcs[0]} to print results')
+
+ main(args)
\ No newline at end of file
diff --git a/util/New_Kmeans.py b/util/New_Kmeans.py
new file mode 100644
index 0000000..402b779
--- /dev/null
+++ b/util/New_Kmeans.py
@@ -0,0 +1,42 @@
+# import torch
+# import torch.nn as nn
+# import torch.nn.functional as F
+# from torch.autograd import Variable
+# import numpy as np
+# import torch_clustering
+#
+# class Keans_Loss(nn.Module):
+# def __init__(self, temperature = 0.5, n_cluster = 196):
+# super(Keans_Loss,self).__init__()
+# self.temperature = temperature
+# self.n_cluster = n_cluster
+# self.clustering_model = torch_clustering.PyTorchKMeans(init='k-means++', max_iter=300, tol=1e-4,
+# n_clusters=self.num_cluster)
+# self.psedo_labels = None
+#
+# def clustering(self, n_cluster, features):
+#
+# clustering_model = torch_clustering.PyTorchKMeans(init='k-means++', max_iter=300, tol=1e-4,
+# n_clusters=self.num_cluster)
+# psedo_labels = clustering_model.fit_predict(features)
+# self.psedo_labels = psedo_labels
+# cluster_centers = clustering_model.cluster_centers_
+# return psedo_labels, cluster_centers
+#
+# def compute_cluster_loss(self,
+#
+#
+#
+import torch
+
+x = torch.rand((6, 10)) # 假设有一个形状为 (6, 10) 的张量
+print(x)
+print("==============")
+proj,out = x.chunk(2)
+print(proj)
+
+print("============")
+
+print(out)
+print("==============")
+print(out[0].argmax(1))
\ No newline at end of file
diff --git a/util/cluster_and_log_utils.py b/util/cluster_and_log_utils.py
new file mode 100644
index 0000000..25874e3
--- /dev/null
+++ b/util/cluster_and_log_utils.py
@@ -0,0 +1,184 @@
+import torch
+import torch.distributed as dist
+import numpy as np
+from scipy.optimize import linear_sum_assignment as linear_assignment
+
+
+def all_sum_item(item):
+ item = torch.tensor(item).cuda()
+ dist.all_reduce(item)
+ return item.item()
+
+def split_cluster_acc_v2(y_true, y_pred, mask):
+ """
+ Calculate clustering accuracy. Require scikit-learn installed
+ First compute linear assignment on all data, then look at how good the accuracy is on subsets
+
+ # Arguments
+ mask: Which instances come from old classes (True) and which ones come from new classes (False)
+ y: true labels, numpy.array with shape `(n_samples,)`
+ y_pred: predicted labels, numpy.array with shape `(n_samples,)`
+
+ # Return
+ accuracy, in [0,1]
+ """
+ y_true = y_true.astype(int)
+
+ old_classes_gt = set(y_true[mask])
+ new_classes_gt = set(y_true[~mask])
+
+ assert y_pred.size == y_true.size
+ D = max(y_pred.max(), y_true.max()) + 1
+ w = np.zeros((D, D), dtype=int)
+ for i in range(y_pred.size):
+ w[y_pred[i], y_true[i]] += 1
+
+ ind = linear_assignment(w.max() - w)
+ ind = np.vstack(ind).T
+
+ ind_map = {j: i for i, j in ind}
+ total_acc = sum([w[i, j] for i, j in ind])
+ total_instances = y_pred.size
+ try:
+ if dist.get_world_size() > 0:
+ total_acc = all_sum_item(total_acc)
+ total_instances = all_sum_item(total_instances)
+ except:
+ pass
+ total_acc /= total_instances
+
+ old_acc = 0
+ total_old_instances = 0
+ for i in old_classes_gt:
+ old_acc += w[ind_map[i], i]
+ total_old_instances += sum(w[:, i])
+
+ try:
+ if dist.get_world_size() > 0:
+ old_acc = all_sum_item(old_acc)
+ total_old_instances = all_sum_item(total_old_instances)
+ except:
+ pass
+ old_acc /= total_old_instances
+
+ new_acc = 0
+ total_new_instances = 0
+ for i in new_classes_gt:
+ new_acc += w[ind_map[i], i]
+ total_new_instances += sum(w[:, i])
+
+ try:
+ if dist.get_world_size() > 0:
+ new_acc = all_sum_item(new_acc)
+ total_new_instances = all_sum_item(total_new_instances)
+ except:
+ pass
+ new_acc /= total_new_instances
+
+ return total_acc, old_acc, new_acc
+
+
+def split_cluster_acc_v2_balanced(y_true, y_pred, mask):
+ """
+ Calculate clustering accuracy. Require scikit-learn installed
+ First compute linear assignment on all data, then look at how good the accuracy is on subsets
+
+ # Arguments
+ mask: Which instances come from old classes (True) and which ones come from new classes (False)
+ y: true labels, numpy.array with shape `(n_samples,)`
+ y_pred: predicted labels, numpy.array with shape `(n_samples,)`
+
+ # Return
+ accuracy, in [0,1]
+ """
+ y_true = y_true.astype(int)
+
+ old_classes_gt = set(y_true[mask])
+ new_classes_gt = set(y_true[~mask])
+
+ assert y_pred.size == y_true.size
+ D = max(y_pred.max(), y_true.max()) + 1
+ w = np.zeros((D, D), dtype=int)
+ for i in range(y_pred.size):
+ w[y_pred[i], y_true[i]] += 1
+
+ ind = linear_assignment(w.max() - w)
+ ind = np.vstack(ind).T
+
+ ind_map = {j: i for i, j in ind}
+
+ old_acc = np.zeros(len(old_classes_gt))
+ total_old_instances = np.zeros(len(old_classes_gt))
+ for idx, i in enumerate(old_classes_gt):
+ old_acc[idx] += w[ind_map[i], i]
+ total_old_instances[idx] += sum(w[:, i])
+
+ new_acc = np.zeros(len(new_classes_gt))
+ total_new_instances = np.zeros(len(new_classes_gt))
+ for idx, i in enumerate(new_classes_gt):
+ new_acc[idx] += w[ind_map[i], i]
+ total_new_instances[idx] += sum(w[:, i])
+
+ try:
+ if dist.get_world_size() > 0:
+ old_acc, new_acc = torch.from_numpy(old_acc).cuda(), torch.from_numpy(new_acc).cuda()
+ dist.all_reduce(old_acc), dist.all_reduce(new_acc)
+ dist.all_reduce(total_old_instances), dist.all_reduce(total_new_instances)
+ old_acc, new_acc = old_acc.cpu().numpy(), new_acc.cpu().numpy()
+ total_old_instances, total_new_instances = total_old_instances.cpu().numpy(), total_new_instances.cpu().numpy()
+ except:
+ pass
+
+ total_acc = np.concatenate([old_acc, new_acc]) / np.concatenate([total_old_instances, total_new_instances])
+ old_acc /= total_old_instances
+ new_acc /= total_new_instances
+ total_acc, old_acc, new_acc = total_acc.mean(), old_acc.mean(), new_acc.mean()
+ return total_acc, old_acc, new_acc
+
+
+EVAL_FUNCS = {
+ 'v2': split_cluster_acc_v2,
+ 'v2b': split_cluster_acc_v2_balanced
+}
+
+def log_accs_from_preds(y_true, y_pred, mask, eval_funcs, save_name, T=None,
+ print_output=True, args=None):
+
+ """
+ Given a list of evaluation functions to use (e.g ['v1', 'v2']) evaluate and log ACC results
+
+ :param y_true: GT labels
+ :param y_pred: Predicted indices
+ :param mask: Which instances belong to Old and New classes
+ :param T: Epoch
+ :param eval_funcs: Which evaluation functions to use
+ :param save_name: What are we evaluating ACC on
+ :param writer: Tensorboard logger
+ :return:
+ """
+
+ mask = mask.astype(bool)
+ y_true = y_true.astype(int)
+ y_pred = y_pred.astype(int)
+
+ for i, f_name in enumerate(eval_funcs):
+
+ acc_f = EVAL_FUNCS[f_name]
+ all_acc, old_acc, new_acc = acc_f(y_true, y_pred, mask)
+ log_name = f'{save_name}_{f_name}'
+
+ if i == 0:
+ to_return = (all_acc, old_acc, new_acc)
+
+ if print_output:
+ print_str = f'Epoch {T}, {log_name}: All {all_acc:.4f} | Old {old_acc:.4f} | New {new_acc:.4f}'
+ try:
+ if dist.get_rank() == 0:
+ try:
+ args.logger.info(print_str)
+ except:
+ print(print_str)
+ except:
+ pass
+
+ return to_return
\ No newline at end of file
diff --git a/util/general_utils.py b/util/general_utils.py
new file mode 100644
index 0000000..26ccced
--- /dev/null
+++ b/util/general_utils.py
@@ -0,0 +1,384 @@
+import os
+import torch
+import inspect
+
+from datetime import datetime
+from torch.utils.data.sampler import Sampler
+from loguru import logger
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+def init_experiment(args, runner_name=None, exp_id=None):
+ # Get filepath of calling script
+ if runner_name is None:
+ runner_name = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))).split(".")[-2:]
+
+ root_dir = os.path.join(args.exp_root, *runner_name)
+
+ if not os.path.exists(root_dir):
+ os.makedirs(root_dir)
+
+ # Either generate a unique experiment ID, or use one which is passed
+ if exp_id is None:
+
+ if args.exp_name is None:
+ raise ValueError("Need to specify the experiment name")
+ # Unique identifier for experiment
+ now = '{}_({:02d}.{:02d}.{}_|_'.format(args.exp_name, datetime.now().day, datetime.now().month, datetime.now().year) + \
+ datetime.now().strftime("%S.%f")[:-3] + ')'
+
+ log_dir = os.path.join(root_dir, 'log', now)
+ while os.path.exists(log_dir):
+ now = '({:02d}.{:02d}.{}_|_'.format(datetime.now().day, datetime.now().month, datetime.now().year) + \
+ datetime.now().strftime("%S.%f")[:-3] + ')'
+
+ log_dir = os.path.join(root_dir, 'log', now)
+
+ else:
+
+ log_dir = os.path.join(root_dir, 'log', f'{exp_id}')
+
+ if not os.path.exists(log_dir):
+ os.makedirs(log_dir)
+
+
+ logger.add(os.path.join(log_dir, 'log.txt'))
+ args.logger = logger
+ args.log_dir = log_dir
+
+ # Instantiate directory to save models to
+ model_root_dir = os.path.join(args.log_dir, 'checkpoints')
+ if not os.path.exists(model_root_dir):
+ os.mkdir(model_root_dir)
+
+ args.model_dir = model_root_dir
+ args.model_path = os.path.join(args.model_dir, 'model.pt')
+
+ print(f'Experiment saved to: {args.log_dir}')
+
+ hparam_dict = {}
+
+ for k, v in vars(args).items():
+ if isinstance(v, (int, float, str, bool, torch.Tensor)):
+ hparam_dict[k] = v
+
+ print(runner_name)
+ print(args)
+
+ return args
+
+
+class DistributedWeightedSampler(torch.utils.data.distributed.DistributedSampler):
+
+ def __init__(self, dataset, weights, num_samples, num_replicas=None, rank=None,
+ replacement=True, generator=None):
+ super(DistributedWeightedSampler, self).__init__(dataset, num_replicas, rank)
+ if not isinstance(num_samples, int) or isinstance(num_samples, bool) or \
+ num_samples <= 0:
+ raise ValueError("num_samples should be a positive integer "
+ "value, but got num_samples={}".format(num_samples))
+ if not isinstance(replacement, bool):
+ raise ValueError("replacement should be a boolean value, but got "
+ "replacement={}".format(replacement))
+ self.weights = torch.as_tensor(weights, dtype=torch.double)
+ self.num_samples = num_samples
+ self.replacement = replacement
+ self.generator = generator
+ self.weights = self.weights[self.rank::self.num_replicas]
+ self.num_samples = self.num_samples // self.num_replicas
+
+ def __iter__(self):
+ rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)
+ rand_tensor = self.rank + rand_tensor * self.num_replicas
+ yield from iter(rand_tensor.tolist())
+
+ def __len__(self):
+ return self.num_samples
+
+
+import numpy as np
+from tqdm import tqdm
+import torch.nn.functional as F
+class NNBatchSampler(Sampler):
+ """
+ BatchSampler that ensures a fixed amount of images per class are sampled in the minibatch
+ """
+ def __init__(self, data_source, model, seen_dataloader, batch_size, nn_per_image = 5, using_feat = True, is_norm = False):
+ self.batch_size = batch_size
+ self.nn_per_image = nn_per_image
+ self.using_feat = using_feat
+ self.is_norm = is_norm
+ self.num_samples = data_source.__len__()
+ self.nn_matrix, self.dist_matrix = self._build_nn_matrix(model, seen_dataloader)
+
+ def __iter__(self):
+ for _ in range(len(self)):
+ yield self.sample_batch()
+
+ def _predict_batchwise(self, model, seen_dataloader):
+ device = "cuda"
+ model_is_training = model.training
+ model.eval()
+
+ ds = seen_dataloader.dataset
+ A = [[] for i in range(len(ds[0]))]
+ with torch.no_grad():
+ # extract batches (A becomes list of samples)
+ for batch in tqdm(seen_dataloader):
+ for i, J in enumerate(batch):
+ # i = 0: sz_batch * images
+ # i = 1: sz_batch * labels
+ # i = 2: sz_batch * indices
+ # i = 3: sz_batch * mask_lab
+ if i == 0:
+ J = J[0]
+ # move images to device of model (approximate device)
+ # if self.using_feat:
+ # J, _ = model(J.cuda())
+ # else:
+ J = model(J.cuda())
+
+ if self.is_norm:
+ J = F.normalize(J, p=2, dim=1)
+
+ for j in J:
+ A[i].append(j)
+
+ model.train()
+ model.train(model_is_training) # revert to previous training state
+
+ return [torch.stack(A[i]) for i in range(len(A))]
+
+ def _build_nn_matrix(self, model, seen_dataloader):
+ # calculate embeddings with model and get targets
+ X, T, _, _ = self._predict_batchwise(model, seen_dataloader)
+
+ # get predictions by assigning nearest 8 neighbors with cosine
+ K = self.nn_per_image * 1
+ nn_matrix = []
+ dist_matrix = []
+ xs = []
+
+ for x in X:
+ if len(xs)<5000:
+ xs.append(x)
+ else:
+ xs.append(x)
+ xs = torch.stack(xs,dim=0)
+
+ dist_emb = xs.pow(2).sum(1) + (-2) * X.mm(xs.t())
+ dist_emb = X.pow(2).sum(1) + dist_emb.t()
+
+ ind = dist_emb.topk(K, largest = False)[1].long().cpu()
+ dist = dist_emb.topk(K, largest = False)[0]
+ nn_matrix.append(ind)
+ dist_matrix.append(dist.cpu())
+ xs = []
+ del ind
+
+ # Last Loop
+ xs = torch.stack(xs,dim=0)
+ dist_emb = xs.pow(2).sum(1) + (-2) * X.mm(xs.t())
+ dist_emb = X.pow(2).sum(1) + dist_emb.t()
+ ind = dist_emb.topk(K, largest = False)[1]
+ dist = dist_emb.topk(K, largest = False)[0]
+ nn_matrix.append(ind.long().cpu())
+ dist_matrix.append(dist.cpu())
+ nn_matrix = torch.cat(nn_matrix, dim=0)
+ dist_matrix = torch.cat(dist_matrix, dim=0)
+
+ return nn_matrix, dist_matrix
+
+
+ def sample_batch(self):
+ num_image = self.batch_size // self.nn_per_image
+ sampled_queries = np.random.choice(self.num_samples, num_image, replace=False)
+ sampled_indices = self.nn_matrix[sampled_queries].view(-1).tolist()
+
+ return sampled_indices
+
+ def __len__(self):
+ return self.num_samples // self.batch_size
+
+import torch.nn as nn
+class RC_STML(nn.Module):
+ def __init__(self, sigma, delta, view, disable_mu, topk):
+ super(RC_STML, self).__init__()
+ self.sigma = sigma
+ self.delta = delta
+ self.view = view
+ self.disable_mu = disable_mu
+ self.topk = topk
+
+ def k_reciprocal_neigh(self, initial_rank, i, topk):
+ forward_k_neigh_index = initial_rank[i,:topk]
+ backward_k_neigh_index = initial_rank[forward_k_neigh_index,:topk]
+ fi = np.where(backward_k_neigh_index==i)[0]
+ return forward_k_neigh_index[fi]
+
+ def forward(self, s_emb, t_emb, idx, v2=False):
+ if v2:
+ return self.forward_v2(t_emb, s_emb)
+ if self.disable_mu:
+ s_emb = F.normalize(s_emb)
+ t_emb = F.normalize(t_emb)
+
+ N = len(s_emb)
+ S_dist = torch.cdist(s_emb, s_emb)
+ S_dist = S_dist / S_dist.mean(1, keepdim=True)
+
+ with torch.no_grad():
+ T_dist = torch.cdist(t_emb, t_emb)
+ W_P = torch.exp(-T_dist.pow(2) / self.sigma)
+
+ batch_size = len(s_emb) // self.view
+ W_P_copy = W_P.clone()
+ W_P_copy[idx.unsqueeze(1) == idx.unsqueeze(1).t()] = 1
+
+ topk_index = torch.topk(W_P_copy, self.topk)[1]
+ topk_half_index = topk_index[:, :int(np.around(self.topk/2))]
+
+ W_NN = torch.zeros_like(W_P).scatter_(1, topk_index, torch.ones_like(W_P))
+ V = ((W_NN + W_NN.t())/2 == 1).float()
+
+ W_C_tilda = torch.zeros_like(W_P)
+ for i in range(N):
+ indNonzero = torch.where(V[i, :]!=0)[0]
+ W_C_tilda[i, indNonzero] = (V[:,indNonzero].sum(1) / len(indNonzero))[indNonzero]
+
+ W_C_hat = W_C_tilda[topk_half_index].mean(1)
+ W_C = (W_C_hat + W_C_hat.t())/2
+ W = (W_P + W_C)/2
+
+ identity_matrix = torch.eye(N).cuda(non_blocking=True)
+ pos_weight = (W) * (1 - identity_matrix)
+ neg_weight = (1 - W) * (1 - identity_matrix)
+
+ pull_losses = torch.relu(S_dist).pow(2) * pos_weight
+ push_losses = torch.relu(self.delta - S_dist).pow(2) * neg_weight
+
+ loss = (pull_losses.sum() + push_losses.sum()) / (len(s_emb) * (len(s_emb)-1))
+
+ return loss
+
+
+ def forward_v2(self, probs, feats):
+ with torch.no_grad():
+ pseudo_labels = probs.argmax(1).cuda()
+ one_hot = torch.zeros(probs.shape).cuda().scatter(1, pseudo_labels.unsqueeze(1), 1.0)
+ W_P = torch.mm(one_hot, one_hot.t())
+ feats_dist = torch.cdist(feats, feats)
+ topk_index = torch.topk(feats_dist, self.topk)[1]
+ W_NN = torch.zeros_like(feats_dist).scatter_(1, topk_index, W_P)
+
+ W = ((W_NN + W_NN.t())/2 == 0.5).float()
+
+ N = len(probs)
+ identity_matrix = torch.eye(N).cuda(non_blocking=True)
+ pos_weight = (W) * (1 - identity_matrix)
+ neg_weight = (1 - W) * (1 - identity_matrix)
+
+ pull_losses = torch.relu(feats_dist).pow(2) * pos_weight
+ push_losses = torch.relu(self.delta - feats_dist).pow(2) * neg_weight
+
+ loss = (pull_losses.sum() + push_losses.sum()) / (len(probs) * (len(probs)-1))
+
+ return loss
+
+class KL_STML(nn.Module):
+ def __init__(self, disable_mu, temp=1):
+ super(KL_STML, self).__init__()
+ self.disable_mu = disable_mu
+ self.temp = temp
+
+ def kl_div(self, A, B, T = 1):
+ log_q = F.log_softmax(A/T, dim=-1)
+ p = F.softmax(B/T, dim=-1)
+ kl_d = F.kl_div(log_q, p, reduction='sum') * T**2 / A.size(0)
+ return kl_d
+
+ def forward(self, s_f, s_g):
+ if self.disable_mu:
+ s_f, s_g = F.normalize(s_f), F.normalize(s_g)
+
+ N = len(s_f)
+ S_dist = torch.cdist(s_f, s_f)
+ S_dist = S_dist / S_dist.mean(1, keepdim=True)
+
+ S_bg_dist = torch.cdist(s_g, s_g)
+ S_bg_dist = S_bg_dist / S_bg_dist.mean(1, keepdim=True)
+
+ loss = self.kl_div(-S_dist, -S_bg_dist.detach(), T=1)
+
+ return loss
+
+class STML_loss(nn.Module):
+ def __init__(self, sigma, delta, view, disable_mu, topk):
+ super(STML_loss, self).__init__()
+ self.sigma = sigma
+ self.delta = delta
+ self.view = view
+ self.disable_mu = disable_mu
+ self.topk = topk
+ self.RC_criterion = RC_STML(sigma, delta, view, disable_mu, topk)
+ self.KL_criterion = KL_STML(disable_mu, temp=1)
+
+ def forward(self, s_f, s_g, t_g, idx):
+ # Relaxed contrastive loss for STML
+ loss_RC_f = self.RC_criterion(s_f, t_g, idx)
+ loss_RC_g = self.RC_criterion(s_g, t_g, idx)
+ loss_RC = (loss_RC_f + loss_RC_g)/2
+
+ # Self-Distillation for STML
+ loss_KL = self.KL_criterion(s_f, s_g)
+
+ loss = loss_RC + loss_KL
+
+ total_loss = dict(RC=loss_RC, KL=loss_KL, loss=loss)
+
+ return total_loss
+
+class STML_loss_simgcd(nn.Module):
+ def __init__(self, disable_mu=0, topk=4, sigma=1, delta=1, view=2, v2=True):
+ super().__init__()
+ self.sigma = sigma
+ self.delta = delta
+ self.view = view
+ self.disable_mu = disable_mu
+ self.topk = topk
+ self.v2 = v2
+ self.RC_criterion = RC_STML(sigma, delta, view, disable_mu, topk)
+ self.KL_criterion = KL_STML(disable_mu, temp=1)
+
+ def forward(self, s_g, t_g, idx):
+ # Relaxed contrastive loss for STML
+ loss_RC = self.RC_criterion(s_g, t_g, idx, v2=self.v2)
+
+ if not self.v2:
+ # Self-Distillation for STML
+ loss_KL = self.KL_criterion(s_g, t_g)
+ else:
+ loss_KL = 0.0
+ loss = loss_RC + loss_KL
+
+ # total_loss = dict(RC=loss_RC, KL=loss_KL, loss=loss)
+
+ return loss