From 6d968a7c2fc9adca910262663718eb6d820a9867 Mon Sep 17 00:00:00 2001 From: Clarence <104248624+Clarence-CV@users.noreply.github.com> Date: Wed, 12 Jun 2024 19:08:44 +0800 Subject: [PATCH] Add files via upload --- .../cifar10/cifar-10-batches-py/batches.meta | Bin 0 -> 158 bytes LICENSE | 21 + README.md | 86 ++++ config.py | 18 + data/augmentations/__init__.py | 38 ++ data/cifar.py | 195 +++++++++ data/cub.py | 203 +++++++++ data/data_utils.py | 40 ++ data/fgvc_aircraft.py | 270 ++++++++++++ data/get_datasets.py | 176 ++++++++ data/herbarium_19.py | 164 ++++++++ data/imagenet.py | 201 +++++++++ data/ssb_splits/aircraft_osr_splits.pkl | Bin 0 -> 683 bytes data/ssb_splits/cub_osr_splits.pkl | Bin 0 -> 1268 bytes data/ssb_splits/herbarium_19_class_splits.pkl | Bin 0 -> 7498 bytes data/ssb_splits/scars_osr_splits.pkl | Bin 0 -> 502 bytes data/stanford_cars.py | 166 ++++++++ kmeans_loss.py | 87 ++++ model.py | 231 +++++++++++ requirements.txt | 10 + scripts/run_aircraft.sh | 22 + scripts/run_cars.sh | 22 + scripts/run_cifar10.sh | 22 + scripts/run_cifar100.sh | 24 ++ scripts/run_cub.sh | 24 ++ scripts/run_herb19.sh | 22 + scripts/run_imagenet100.sh | 22 + scripts/run_imagenet1k.sh | 23 ++ torch_clustering/__base__.py | 89 ++++ torch_clustering/__init__.py | 85 ++++ torch_clustering/beta_mixture.py | 84 ++++ torch_clustering/faiss_kmeans.py | 147 +++++++ torch_clustering/gaussian_mixture.py | 222 ++++++++++ torch_clustering/kmeans/__init__.py | 1 + torch_clustering/kmeans/kmeans.py | 192 +++++++++ torch_clustering/kmeans/kmeans_plus_plus.py | 132 ++++++ train.py | 365 +++++++++++++++++ train_mp.py | 325 +++++++++++++++ util/New_Kmeans.py | 42 ++ util/cluster_and_log_utils.py | 184 +++++++++ util/general_utils.py | 384 ++++++++++++++++++ 41 files changed, 4339 insertions(+) create mode 100644 Contextuality-GCD-main/${DATASET_DIR}/cifar10/cifar-10-batches-py/batches.meta create mode 100644 LICENSE create mode 100644 README.md create mode 100644 config.py create mode 100644 data/augmentations/__init__.py create mode 100644 data/cifar.py create mode 100644 data/cub.py create mode 100644 data/data_utils.py create mode 100644 data/fgvc_aircraft.py create mode 100644 data/get_datasets.py create mode 100644 data/herbarium_19.py create mode 100644 data/imagenet.py create mode 100644 data/ssb_splits/aircraft_osr_splits.pkl create mode 100644 data/ssb_splits/cub_osr_splits.pkl create mode 100644 data/ssb_splits/herbarium_19_class_splits.pkl create mode 100644 data/ssb_splits/scars_osr_splits.pkl create mode 100644 data/stanford_cars.py create mode 100644 kmeans_loss.py create mode 100644 model.py create mode 100644 requirements.txt create mode 100644 scripts/run_aircraft.sh create mode 100644 scripts/run_cars.sh create mode 100644 scripts/run_cifar10.sh create mode 100644 scripts/run_cifar100.sh create mode 100644 scripts/run_cub.sh create mode 100644 scripts/run_herb19.sh create mode 100644 scripts/run_imagenet100.sh create mode 100644 scripts/run_imagenet1k.sh create mode 100644 torch_clustering/__base__.py create mode 100644 torch_clustering/__init__.py create mode 100644 torch_clustering/beta_mixture.py create mode 100644 torch_clustering/faiss_kmeans.py create mode 100644 torch_clustering/gaussian_mixture.py create mode 100644 torch_clustering/kmeans/__init__.py create mode 100644 torch_clustering/kmeans/kmeans.py create mode 100644 torch_clustering/kmeans/kmeans_plus_plus.py create mode 100644 train.py create mode 100644 train_mp.py create mode 100644 util/New_Kmeans.py create mode 100644 util/cluster_and_log_utils.py create mode 100644 util/general_utils.py diff --git a/Contextuality-GCD-main/${DATASET_DIR}/cifar10/cifar-10-batches-py/batches.meta b/Contextuality-GCD-main/${DATASET_DIR}/cifar10/cifar-10-batches-py/batches.meta new file mode 100644 index 0000000000000000000000000000000000000000..4467a6ec2e886a9f14f25e31776fb0152d8ac64a GIT binary patch literal 158 zcmWm8OAdlC5CBkxA_(|NJcO*g3CmfUW?DvQER^ZToryN1rS-Id$f%7|y4k|Q$wYU%$P-BX2cFI`d9SCLoz$N4wBUc~>BF}rs o2RCvJ;^BWbP)yDT;ub`h%*qESqEGtCM}qSIc$vVbe$%Gg7eB5uW&i*H literal 0 HcmV?d00001 diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..a9ec139 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Xin Wen + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..3b422db --- /dev/null +++ b/README.md @@ -0,0 +1,86 @@ +# Parametric Classification for Generalized Category Discovery: A Baseline Study + + +

+ + + + +

+

+ Parametric Classification for Generalized Category Discovery: A Baseline Study (ICCV 2023)
+ By + Xin Wen*, + Bingchen Zhao*, and + Xiaojuan Qi. +

+ +![teaser](assets/teaser.jpg) + +Generalized Category Discovery (GCD) aims to discover novel categories in unlabelled datasets using knowledge learned from labelled samples. +Previous studies argued that parametric classifiers are prone to overfitting to seen categories, and endorsed using a non-parametric classifier formed with semi-supervised $k$-means. + +However, in this study, we investigate the failure of parametric classifiers, verify the effectiveness of previous design choices when high-quality supervision is available, and identify unreliable pseudo-labels as a key problem. We demonstrate that two prediction biases exist: the classifier tends to predict seen classes more often, and produces an imbalanced distribution across seen and novel categories. +Based on these findings, we propose a simple yet effective parametric classification method that benefits from entropy regularisation, achieves state-of-the-art performance on multiple GCD benchmarks and shows strong robustness to unknown class numbers. +We hope the investigation and proposed simple framework can serve as a strong baseline to facilitate future studies in this field. + +## Running + +### Dependencies + +``` +pip install -r requirements.txt +``` + +### Config + +Set paths to datasets and desired log directories in ```config.py``` + + +### Datasets + +We use fine-grained benchmarks in this paper, including: + +* [The Semantic Shift Benchmark (SSB)](https://github.com/sgvaze/osr_closed_set_all_you_need#ssb) and [Herbarium19](https://www.kaggle.com/c/herbarium-2019-fgvc6) + +We also use generic object recognition datasets, including: + +* [CIFAR-10/100](https://pytorch.org/vision/stable/datasets.html) and [ImageNet-100/1K](https://image-net.org/download.php) + + +### Scripts + +**Train the model**: + +``` +bash scripts/run_${DATASET_NAME}.sh +``` + +We found picking the model according to 'Old' class performance could lead to possible over-fitting, and since 'New' class labels on the held-out validation set should be assumed unavailable, we suggest not to perform model selection, and simply use the last-epoch model. + +## Results +Our results: + +
SourcePaper (3 runs) Current Github (5 runs)
DatasetAllOldNewAllOldNew
CIFAR1097.1±0.095.1±0.198.1±0.197.0±0.193.9±0.198.5±0.1
CIFAR10080.1±0.981.2±0.477.8±2.079.8±0.681.1±0.577.4±2.5
ImageNet-10083.0±1.293.1±0.277.9±1.983.6±1.492.4±0.179.1±2.2
ImageNet-1K57.1±0.177.3±0.146.9±0.257.0±0.477.1±0.146.9±0.5
CUB60.3±0.165.6±0.957.7±0.461.5±0.565.7±0.559.4±0.8
Stanford Cars53.8±2.271.9±1.745.0±2.453.4±1.671.5±1.644.6±1.7
FGVC-Aircraft54.2±1.959.1±1.251.8±2.354.3±0.759.4±0.451.7±1.2
Herbarium 1944.0±0.458.0±0.436.4±0.844.2±0.257.6±0.637.0±0.4
+ +## Citing this work + +If you find this repo useful for your research, please consider citing our paper: + +``` +@inproceedings{wen2023simgcd, + author = {Wen, Xin and Zhao, Bingchen and Qi, Xiaojuan}, + title = {Parametric Classification for Generalized Category Discovery: A Baseline Study}, + booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, + year = {2023}, + pages = {16590-16600} +} +``` + +## Acknowledgements + +The codebase is largely built on this repo: https://github.com/sgvaze/generalized-category-discovery. + +## License + +This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. diff --git a/config.py b/config.py new file mode 100644 index 0000000..2a9e34a --- /dev/null +++ b/config.py @@ -0,0 +1,18 @@ +# ----------------- +# DATASET ROOTS +# ----------------- +cifar_10_root = '${DATASET_DIR}/cifar10' +cifar_100_root = '${DATASET_DIR}/cifar100' +cub_root = '${DATASET_DIR}/cub' +aircraft_root = '${DATASET_DIR}/fgvc-aircraft-2013b' +car_root = '${DATASET_DIR}/cars' +herbarium_dataroot = '${DATASET_DIR}/herbarium_19' +imagenet_root = '${DATASET_DIR}/ImageNet' + +# OSR Split dir +osr_split_dir = 'data/ssb_splits' + +# ----------------- +# OTHER PATHS +# ----------------- +exp_root = 'dev_outputs' # All logs and checkpoints will be saved here \ No newline at end of file diff --git a/data/augmentations/__init__.py b/data/augmentations/__init__.py new file mode 100644 index 0000000..51e934a --- /dev/null +++ b/data/augmentations/__init__.py @@ -0,0 +1,38 @@ +from torchvision import transforms + +import torch + +def get_transform(transform_type='imagenet', image_size=32, args=None): + + if transform_type == 'imagenet': + + mean = (0.485, 0.456, 0.406) + std = (0.229, 0.224, 0.225) + interpolation = args.interpolation + crop_pct = args.crop_pct + + train_transform = transforms.Compose([ + transforms.Resize(int(image_size / crop_pct), interpolation), + transforms.RandomCrop(image_size), + transforms.RandomHorizontalFlip(p=0.5), + transforms.ColorJitter(), + transforms.ToTensor(), + transforms.Normalize( + mean=torch.tensor(mean), + std=torch.tensor(std)) + ]) + + test_transform = transforms.Compose([ + transforms.Resize(int(image_size / crop_pct), interpolation), + transforms.CenterCrop(image_size), + transforms.ToTensor(), + transforms.Normalize( + mean=torch.tensor(mean), + std=torch.tensor(std)) + ]) + + else: + + raise NotImplementedError + + return (train_transform, test_transform) \ No newline at end of file diff --git a/data/cifar.py b/data/cifar.py new file mode 100644 index 0000000..499f3ea --- /dev/null +++ b/data/cifar.py @@ -0,0 +1,195 @@ +from torchvision.datasets import CIFAR10, CIFAR100 +from copy import deepcopy +import numpy as np + +from data.data_utils import subsample_instances +from config import cifar_10_root, cifar_100_root + + +class CustomCIFAR10(CIFAR10): + + def __init__(self, *args, **kwargs): + + super(CustomCIFAR10, self).__init__(*args, **kwargs) + + self.uq_idxs = np.array(range(len(self))) + + def __getitem__(self, item): + + img, label = super().__getitem__(item) + uq_idx = self.uq_idxs[item] + + return img, label, uq_idx + + def __len__(self): + return len(self.targets) + + +class CustomCIFAR100(CIFAR100): + + def __init__(self, *args, **kwargs): + super(CustomCIFAR100, self).__init__(*args, **kwargs) + + self.uq_idxs = np.array(range(len(self))) + + def __getitem__(self, item): + img, label = super().__getitem__(item) + uq_idx = self.uq_idxs[item] + + return img, label, uq_idx + + def __len__(self): + return len(self.targets) + + +def subsample_dataset(dataset, idxs): + + # Allow for setting in which all empty set of indices is passed + + if len(idxs) > 0: + + dataset.data = dataset.data[idxs] + dataset.targets = np.array(dataset.targets)[idxs].tolist() + dataset.uq_idxs = dataset.uq_idxs[idxs] + + return dataset + + else: + + return None + + +def subsample_classes(dataset, include_classes=(0, 1, 8, 9)): + + cls_idxs = [x for x, t in enumerate(dataset.targets) if t in include_classes] + + target_xform_dict = {} + for i, k in enumerate(include_classes): + target_xform_dict[k] = i + + dataset = subsample_dataset(dataset, cls_idxs) + + # dataset.target_transform = lambda x: target_xform_dict[x] + + return dataset + + +def get_train_val_indices(train_dataset, val_split=0.2): + + train_classes = np.unique(train_dataset.targets) + + # Get train/test indices + train_idxs = [] + val_idxs = [] + for cls in train_classes: + + cls_idxs = np.where(train_dataset.targets == cls)[0] + + v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) + t_ = [x for x in cls_idxs if x not in v_] + + train_idxs.extend(t_) + val_idxs.extend(v_) + + return train_idxs, val_idxs + + +def get_cifar_10_datasets(train_transform, test_transform, train_classes=(0, 1, 8, 9), + prop_train_labels=0.8, split_train_val=False, seed=0): + + np.random.seed(seed) + + # Init entire training set + whole_training_set = CustomCIFAR10(root=cifar_10_root, transform=train_transform, train=True) + + # Get labelled training set which has subsampled classes, then subsample some indices from that + train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) + subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) + train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) + + # Split into training and validation sets + train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) + train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) + val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) + val_dataset_labelled_split.transform = test_transform + + # Get unlabelled data + unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) + train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) + + # Get test set for all classes + test_dataset = CustomCIFAR10(root=cifar_10_root, transform=test_transform, train=False) + + # Either split train into train and val or use test set as val + train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled + val_dataset_labelled = val_dataset_labelled_split if split_train_val else None + + all_datasets = { + 'train_labelled': train_dataset_labelled, + 'train_unlabelled': train_dataset_unlabelled, + 'val': val_dataset_labelled, + 'test': test_dataset, + } + + return all_datasets + + +def get_cifar_100_datasets(train_transform, test_transform, train_classes=range(80), + prop_train_labels=0.8, split_train_val=False, seed=0): + + np.random.seed(seed) + + # Init entire training set + whole_training_set = CustomCIFAR100(root=cifar_100_root, transform=train_transform, train=True, download=True) + + # Get labelled training set which has subsampled classes, then subsample some indices from that + train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) + subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) + train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) + + # Split into training and validation sets + train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) + train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) + val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) + val_dataset_labelled_split.transform = test_transform + + # Get unlabelled data + unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) + train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) + + # Get test set for all classes + test_dataset = CustomCIFAR100(root=cifar_100_root, transform=test_transform, train=False, download=True) + + # Either split train into train and val or use test set as val + train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled + val_dataset_labelled = val_dataset_labelled_split if split_train_val else None + + all_datasets = { + 'train_labelled': train_dataset_labelled, + 'train_unlabelled': train_dataset_unlabelled, + 'val': val_dataset_labelled, + 'test': test_dataset, + } + + return all_datasets + + +if __name__ == '__main__': + + x = get_cifar_100_datasets(None, None, split_train_val=False, + train_classes=range(80), prop_train_labels=0.5) + + print('Printing lens...') + for k, v in x.items(): + if v is not None: + print(f'{k}: {len(v)}') + + print('Printing labelled and unlabelled overlap...') + print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) + print('Printing total instances in train...') + print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) + + print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}') + print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}') + print(f'Len labelled set: {len(x["train_labelled"])}') + print(f'Len unlabelled set: {len(x["train_unlabelled"])}') \ No newline at end of file diff --git a/data/cub.py b/data/cub.py new file mode 100644 index 0000000..1198bbf --- /dev/null +++ b/data/cub.py @@ -0,0 +1,203 @@ +import os +import pandas as pd +import numpy as np +from copy import deepcopy + +from torchvision.datasets.folder import default_loader +from torchvision.datasets.utils import download_url +from torch.utils.data import Dataset + +from data.data_utils import subsample_instances +from config import cub_root + + +class CustomCub2011(Dataset): + base_folder = 'CUB_200_2011/images' + url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz' + filename = 'CUB_200_2011.tgz' + tgz_md5 = '97eceeb196236b17998738112f37df78' + + def __init__(self, root, train=True, transform=None, target_transform=None, loader=default_loader, download=True): + + self.root = os.path.expanduser(root) + self.transform = transform + self.target_transform = target_transform + + self.loader = loader + self.train = train + + + if download: + self._download() + + if not self._check_integrity(): + raise RuntimeError('Dataset not found or corrupted.' + + ' You can use download=True to download it') + + self.uq_idxs = np.array(range(len(self))) + + def _load_metadata(self): + images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ', + names=['img_id', 'filepath']) + image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'), + sep=' ', names=['img_id', 'target']) + train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'), + sep=' ', names=['img_id', 'is_training_img']) + + data = images.merge(image_class_labels, on='img_id') + self.data = data.merge(train_test_split, on='img_id') + + if self.train: + self.data = self.data[self.data.is_training_img == 1] + else: + self.data = self.data[self.data.is_training_img == 0] + + def _check_integrity(self): + try: + self._load_metadata() + except Exception: + return False + + for index, row in self.data.iterrows(): + filepath = os.path.join(self.root, self.base_folder, row.filepath) + if not os.path.isfile(filepath): + print(filepath) + return False + return True + + def _download(self): + import tarfile + + if self._check_integrity(): + print('Files already downloaded and verified') + return + + download_url(self.url, self.root, self.filename, self.tgz_md5) + + with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: + tar.extractall(path=self.root) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + sample = self.data.iloc[idx] + path = os.path.join(self.root, self.base_folder, sample.filepath) + target = sample.target - 1 # Targets start at 1 by default, so shift to 0 + img = self.loader(path) + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target, self.uq_idxs[idx] + + +def subsample_dataset(dataset, idxs): + + mask = np.zeros(len(dataset)).astype('bool') + mask[idxs] = True + + dataset.data = dataset.data[mask] + dataset.uq_idxs = dataset.uq_idxs[mask] + + return dataset + + +def subsample_classes(dataset, include_classes=range(160)): + + include_classes_cub = np.array(include_classes) + 1 # CUB classes are indexed 1 --> 200 instead of 0 --> 199 + cls_idxs = [x for x, (_, r) in enumerate(dataset.data.iterrows()) if int(r['target']) in include_classes_cub] + + # TODO: For now have no target transform + target_xform_dict = {} + for i, k in enumerate(include_classes): + target_xform_dict[k] = i + + dataset = subsample_dataset(dataset, cls_idxs) + + dataset.target_transform = lambda x: target_xform_dict[x] + + return dataset + + +def get_train_val_indices(train_dataset, val_split=0.2): + + train_classes = np.unique(train_dataset.data['target']) + + # Get train/test indices + train_idxs = [] + val_idxs = [] + for cls in train_classes: + + cls_idxs = np.where(train_dataset.data['target'] == cls)[0] + + v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) + t_ = [x for x in cls_idxs if x not in v_] + + train_idxs.extend(t_) + val_idxs.extend(v_) + + return train_idxs, val_idxs + + +def get_cub_datasets(train_transform, test_transform, train_classes=range(160), prop_train_labels=0.8, + split_train_val=False, seed=0, download=False): + + np.random.seed(seed) + + # Init entire training set + whole_training_set = CustomCub2011(root=cub_root, transform=train_transform, train=True, download=download) + + # Get labelled training set which has subsampled classes, then subsample some indices from that + train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) + subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) + train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) + + # Split into training and validation sets + train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) + train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) + val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) + val_dataset_labelled_split.transform = test_transform + + # Get unlabelled data + unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) + train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) + + # Get test set for all classes + test_dataset = CustomCub2011(root=cub_root, transform=test_transform, train=False) + + # Either split train into train and val or use test set as val + train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled + val_dataset_labelled = val_dataset_labelled_split if split_train_val else None + + all_datasets = { + 'train_labelled': train_dataset_labelled, + 'train_unlabelled': train_dataset_unlabelled, + 'val': val_dataset_labelled, + 'test': test_dataset, + } + + return all_datasets + +if __name__ == '__main__': + + x = get_cub_datasets(None, None, split_train_val=False, + train_classes=range(100), prop_train_labels=0.5) + + print('Printing lens...') + for k, v in x.items(): + if v is not None: + print(f'{k}: {len(v)}') + + print('Printing labelled and unlabelled overlap...') + print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) + print('Printing total instances in train...') + print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) + + print(f'Num Labelled Classes: {len(set(x["train_labelled"].data["target"].values))}') + print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].data["target"].values))}') + print(f'Len labelled set: {len(x["train_labelled"])}') + print(f'Len unlabelled set: {len(x["train_unlabelled"])}') \ No newline at end of file diff --git a/data/data_utils.py b/data/data_utils.py new file mode 100644 index 0000000..d05d7a1 --- /dev/null +++ b/data/data_utils.py @@ -0,0 +1,40 @@ +import numpy as np +from torch.utils.data import Dataset + +def subsample_instances(dataset, prop_indices_to_subsample=0.8): + + np.random.seed(0) + subsample_indices = np.random.choice(range(len(dataset)), replace=False, + size=(int(prop_indices_to_subsample * len(dataset)),)) + + return subsample_indices + +class MergedDataset(Dataset): + + """ + Takes two datasets (labelled_dataset, unlabelled_dataset) and merges them + Allows you to iterate over them in parallel + """ + + def __init__(self, labelled_dataset, unlabelled_dataset): + + self.labelled_dataset = labelled_dataset + self.unlabelled_dataset = unlabelled_dataset + self.target_transform = None + + def __getitem__(self, item): + + if item < len(self.labelled_dataset): + img, label, uq_idx = self.labelled_dataset[item] + labeled_or_not = 1 + + else: + + img, label, uq_idx = self.unlabelled_dataset[item - len(self.labelled_dataset)] + labeled_or_not = 0 + + + return img, label, uq_idx, np.array([labeled_or_not]) + + def __len__(self): + return len(self.unlabelled_dataset) + len(self.labelled_dataset) diff --git a/data/fgvc_aircraft.py b/data/fgvc_aircraft.py new file mode 100644 index 0000000..97dc745 --- /dev/null +++ b/data/fgvc_aircraft.py @@ -0,0 +1,270 @@ +import os +import numpy as np +from copy import deepcopy + +from torchvision.datasets.folder import default_loader +from torch.utils.data import Dataset + +from data.data_utils import subsample_instances +from config import aircraft_root + +def make_dataset(dir, image_ids, targets): + assert(len(image_ids) == len(targets)) + images = [] + dir = os.path.expanduser(dir) + for i in range(len(image_ids)): + item = (os.path.join(dir, 'data', 'images', + '%s.jpg' % image_ids[i]), targets[i]) + images.append(item) + return images + + +def find_classes(classes_file): + + # read classes file, separating out image IDs and class names + image_ids = [] + targets = [] + f = open(classes_file, 'r') + for line in f: + split_line = line.split(' ') + image_ids.append(split_line[0]) + targets.append(' '.join(split_line[1:])) + f.close() + + # index class names + classes = np.unique(targets) + class_to_idx = {classes[i]: i for i in range(len(classes))} + targets = [class_to_idx[c] for c in targets] + + return (image_ids, targets, classes, class_to_idx) + + +class FGVCAircraft(Dataset): + + """`FGVC-Aircraft `_ Dataset. + + Args: + root (string): Root directory path to dataset. + class_type (string, optional): The level of FGVC-Aircraft fine-grain classification + to label data with (i.e., ``variant``, ``family``, or ``manufacturer``). + transform (callable, optional): A function/transform that takes in a PIL image + and returns a transformed version. E.g. ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + loader (callable, optional): A function to load an image given its path. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in the root directory. If dataset is already downloaded, it is not + downloaded again. + """ + url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz' + class_types = ('variant', 'family', 'manufacturer') + splits = ('train', 'val', 'trainval', 'test') + + def __init__(self, root, class_type='variant', split='train', transform=None, + target_transform=None, loader=default_loader, download=False): + if split not in self.splits: + raise ValueError('Split "{}" not found. Valid splits are: {}'.format( + split, ', '.join(self.splits), + )) + if class_type not in self.class_types: + raise ValueError('Class type "{}" not found. Valid class types are: {}'.format( + class_type, ', '.join(self.class_types), + )) + self.root = os.path.expanduser(root) + self.class_type = class_type + self.split = split + self.classes_file = os.path.join(self.root, 'data', + 'images_%s_%s.txt' % (self.class_type, self.split)) + + if download: + self.download() + + (image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file) + samples = make_dataset(self.root, image_ids, targets) + + self.transform = transform + self.target_transform = target_transform + self.loader = loader + + self.samples = samples + self.classes = classes + self.class_to_idx = class_to_idx + self.train = True if split == 'train' else False + + self.uq_idxs = np.array(range(len(self))) + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: (sample, target) where target is class_index of the target class. + """ + + path, target = self.samples[index] + sample = self.loader(path) + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + + return sample, target, self.uq_idxs[index] + + def __len__(self): + return len(self.samples) + + def __repr__(self): + fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' + fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) + fmt_str += ' Root Location: {}\n'.format(self.root) + tmp = ' Transforms (if any): ' + fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + tmp = ' Target Transforms (if any): ' + fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + return fmt_str + + def _check_exists(self): + return os.path.exists(os.path.join(self.root, 'data', 'images')) and \ + os.path.exists(self.classes_file) + + def download(self): + """Download the FGVC-Aircraft data if it doesn't exist already.""" + from six.moves import urllib + import tarfile + + if self._check_exists(): + return + + # prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz + print('Downloading %s ... (may take a few minutes)' % self.url) + parent_dir = os.path.abspath(os.path.join(self.root, os.pardir)) + tar_name = self.url.rpartition('/')[-1] + tar_path = os.path.join(parent_dir, tar_name) + data = urllib.request.urlopen(self.url) + + # download .tar.gz file + with open(tar_path, 'wb') as f: + f.write(data.read()) + + # extract .tar.gz to PARENT_DIR/fgvc-aircraft-2013b + data_folder = tar_path.strip('.tar.gz') + print('Extracting %s to %s ... (may take a few minutes)' % (tar_path, data_folder)) + tar = tarfile.open(tar_path) + tar.extractall(parent_dir) + + # if necessary, rename data folder to self.root + if not os.path.samefile(data_folder, self.root): + print('Renaming %s to %s ...' % (data_folder, self.root)) + os.rename(data_folder, self.root) + + # delete .tar.gz file + print('Deleting %s ...' % tar_path) + os.remove(tar_path) + + print('Done!') + + +def subsample_dataset(dataset, idxs): + + mask = np.zeros(len(dataset)).astype('bool') + mask[idxs] = True + + dataset.samples = [(p, t) for i, (p, t) in enumerate(dataset.samples) if i in idxs] + dataset.uq_idxs = dataset.uq_idxs[mask] + + return dataset + + +def subsample_classes(dataset, include_classes=range(60)): + + cls_idxs = [i for i, (p, t) in enumerate(dataset.samples) if t in include_classes] + + # TODO: Don't transform targets for now + target_xform_dict = {} + for i, k in enumerate(include_classes): + target_xform_dict[k] = i + + dataset = subsample_dataset(dataset, cls_idxs) + + dataset.target_transform = lambda x: target_xform_dict[x] + + return dataset + + +def get_train_val_indices(train_dataset, val_split=0.2): + + all_targets = [t for i, (p, t) in enumerate(train_dataset.samples)] + train_classes = np.unique(all_targets) + + # Get train/test indices + train_idxs = [] + val_idxs = [] + for cls in train_classes: + cls_idxs = np.where(all_targets == cls)[0] + + v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) + t_ = [x for x in cls_idxs if x not in v_] + + train_idxs.extend(t_) + val_idxs.extend(v_) + + return train_idxs, val_idxs + + +def get_aircraft_datasets(train_transform, test_transform, train_classes=range(50), prop_train_labels=0.8, + split_train_val=False, seed=0): + + np.random.seed(seed) + + # Init entire training set + whole_training_set = FGVCAircraft(root=aircraft_root, transform=train_transform, split='trainval') + + # Get labelled training set which has subsampled classes, then subsample some indices from that + train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) + subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) + train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) + + # Split into training and validation sets + train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) + train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) + val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) + val_dataset_labelled_split.transform = test_transform + + # Get unlabelled data + unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) + train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) + + # Get test set for all classes + test_dataset = FGVCAircraft(root=aircraft_root, transform=test_transform, split='test') + + # Either split train into train and val or use test set as val + train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled + val_dataset_labelled = val_dataset_labelled_split if split_train_val else None + + all_datasets = { + 'train_labelled': train_dataset_labelled, + 'train_unlabelled': train_dataset_unlabelled, + 'val': val_dataset_labelled, + 'test': test_dataset, + } + + return all_datasets + +if __name__ == '__main__': + + x = get_aircraft_datasets(None, None, split_train_val=False) + + print('Printing lens...') + for k, v in x.items(): + if v is not None: + print(f'{k}: {len(v)}') + + print('Printing labelled and unlabelled overlap...') + print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) + print('Printing total instances in train...') + print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) + print('Printing number of labelled classes...') + print(len(set([i[1] for i in x['train_labelled'].samples]))) + print('Printing total number of classes...') + print(len(set([i[1] for i in x['train_unlabelled'].samples]))) diff --git a/data/get_datasets.py b/data/get_datasets.py new file mode 100644 index 0000000..dfe12eb --- /dev/null +++ b/data/get_datasets.py @@ -0,0 +1,176 @@ +from data.data_utils import MergedDataset + +from data.cifar import get_cifar_10_datasets, get_cifar_100_datasets +from data.herbarium_19 import get_herbarium_datasets +from data.stanford_cars import get_scars_datasets +from data.imagenet import get_imagenet_100_datasets, get_imagenet_1k_datasets +from data.cub import get_cub_datasets +from data.fgvc_aircraft import get_aircraft_datasets + +from copy import deepcopy +import pickle +import os + +from config import osr_split_dir + + +get_dataset_funcs = { + 'cifar10': get_cifar_10_datasets, + 'cifar100': get_cifar_100_datasets, + 'imagenet_100': get_imagenet_100_datasets, + 'imagenet_1k': get_imagenet_1k_datasets, + 'herbarium_19': get_herbarium_datasets, + 'cub': get_cub_datasets, + 'aircraft': get_aircraft_datasets, + 'scars': get_scars_datasets +} + + +def get_datasets(dataset_name, train_transform, test_transform, args): + + """ + :return: train_dataset: MergedDataset which concatenates labelled and unlabelled + test_dataset, + unlabelled_train_examples_test, + datasets + """ + + # + if dataset_name not in get_dataset_funcs.keys(): + raise ValueError + + # Get datasets + get_dataset_f = get_dataset_funcs[dataset_name] + datasets = get_dataset_f(train_transform=train_transform, test_transform=test_transform, + train_classes=args.train_classes, + prop_train_labels=args.prop_train_labels, + split_train_val=False) + # Set target transforms: + target_transform_dict = {} + for i, cls in enumerate(list(args.train_classes) + list(args.unlabeled_classes)): + target_transform_dict[cls] = i + target_transform = lambda x: target_transform_dict[x] + + for dataset_name, dataset in datasets.items(): + if dataset is not None: + dataset.target_transform = target_transform + + # Train split (labelled and unlabelled classes) for training + train_dataset = MergedDataset(labelled_dataset=deepcopy(datasets['train_labelled']), + unlabelled_dataset=deepcopy(datasets['train_unlabelled'])) + + test_dataset = datasets['test'] + unlabelled_train_examples_test = deepcopy(datasets['train_unlabelled']) + unlabelled_train_examples_test.transform = test_transform + + return train_dataset, test_dataset, unlabelled_train_examples_test, datasets + + +def get_class_splits(args): + + # For FGVC datasets, optionally return bespoke splits + if args.dataset_name in ('scars', 'cub', 'aircraft'): + if hasattr(args, 'use_ssb_splits'): + use_ssb_splits = args.use_ssb_splits + else: + use_ssb_splits = False + + # ------------- + # GET CLASS SPLITS + # ------------- + if args.dataset_name == 'cifar10': + + args.image_size = 32 + args.train_classes = range(5) + args.unlabeled_classes = range(5, 10) + + elif args.dataset_name == 'cifar100': + + args.image_size = 32 + args.train_classes = range(80) + args.unlabeled_classes = range(80, 100) + + elif args.dataset_name == 'herbarium_19': + + args.image_size = 224 + herb_path_splits = os.path.join(osr_split_dir, 'herbarium_19_class_splits.pkl') + + with open(herb_path_splits, 'rb') as handle: + class_splits = pickle.load(handle) + + args.train_classes = class_splits['Old'] + args.unlabeled_classes = class_splits['New'] + + elif args.dataset_name == 'imagenet_100': + + args.image_size = 224 + args.train_classes = range(50) + args.unlabeled_classes = range(50, 100) + + elif args.dataset_name == 'imagenet_1k': + + args.image_size = 224 + args.train_classes = range(500) + args.unlabeled_classes = range(500, 1000) + + elif args.dataset_name == 'scars': + + args.image_size = 224 + + if use_ssb_splits: + + split_path = os.path.join(osr_split_dir, 'scars_osr_splits.pkl') + with open(split_path, 'rb') as handle: + class_info = pickle.load(handle) + + args.train_classes = class_info['known_classes'] + open_set_classes = class_info['unknown_classes'] + args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy'] + + else: + + args.train_classes = range(98) + args.unlabeled_classes = range(98, 196) + + elif args.dataset_name == 'aircraft': + + args.image_size = 224 + if use_ssb_splits: + + split_path = os.path.join(osr_split_dir, 'aircraft_osr_splits.pkl') + with open(split_path, 'rb') as handle: + class_info = pickle.load(handle) + + args.train_classes = class_info['known_classes'] + open_set_classes = class_info['unknown_classes'] + args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy'] + + else: + + args.train_classes = range(50) + args.unlabeled_classes = range(50, 100) + + elif args.dataset_name == 'cub': + + args.image_size = 224 + + if use_ssb_splits: + + split_path = os.path.join(osr_split_dir, 'cub_osr_splits.pkl') + with open(split_path, 'rb') as handle: + class_info = pickle.load(handle) + + args.train_classes = class_info['known_classes'] + open_set_classes = class_info['unknown_classes'] + args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy'] + + else: + + args.train_classes = range(100) + args.unlabeled_classes = range(100, 200) + + else: + + raise NotImplementedError + + return args diff --git a/data/herbarium_19.py b/data/herbarium_19.py new file mode 100644 index 0000000..1d87f85 --- /dev/null +++ b/data/herbarium_19.py @@ -0,0 +1,164 @@ +import os + +import torchvision +import numpy as np +from copy import deepcopy + +from data.data_utils import subsample_instances +from config import herbarium_dataroot + +class HerbariumDataset19(torchvision.datasets.ImageFolder): + + def __init__(self, *args, **kwargs): + + # Process metadata json for training images into a DataFrame + super().__init__(*args, **kwargs) + + self.uq_idxs = np.array(range(len(self))) + + def __getitem__(self, idx): + + img, label = super().__getitem__(idx) + uq_idx = self.uq_idxs[idx] + + return img, label, uq_idx + + +def subsample_dataset(dataset, idxs): + + mask = np.zeros(len(dataset)).astype('bool') + mask[idxs] = True + + dataset.samples = np.array(dataset.samples)[mask].tolist() + dataset.targets = np.array(dataset.targets)[mask].tolist() + + dataset.uq_idxs = dataset.uq_idxs[mask] + + dataset.samples = [[x[0], int(x[1])] for x in dataset.samples] + dataset.targets = [int(x) for x in dataset.targets] + + return dataset + + +def subsample_classes(dataset, include_classes=range(250)): + + cls_idxs = [x for x, l in enumerate(dataset.targets) if l in include_classes] + + target_xform_dict = {} + for i, k in enumerate(include_classes): + target_xform_dict[k] = i + + dataset = subsample_dataset(dataset, cls_idxs) + + dataset.target_transform = lambda x: target_xform_dict[x] + + return dataset + + +def get_train_val_indices(train_dataset, val_instances_per_class=5): + + train_classes = list(set(train_dataset.targets)) + + # Get train/test indices + train_idxs = [] + val_idxs = [] + for cls in train_classes: + + cls_idxs = np.where(np.array(train_dataset.targets) == cls)[0] + + # Have a balanced test set + v_ = np.random.choice(cls_idxs, replace=False, size=(val_instances_per_class,)) + t_ = [x for x in cls_idxs if x not in v_] + + train_idxs.extend(t_) + val_idxs.extend(v_) + + return train_idxs, val_idxs + + +def get_herbarium_datasets(train_transform, test_transform, train_classes=range(500), prop_train_labels=0.8, + seed=0, split_train_val=False): + + np.random.seed(seed) + + # Init entire training set + train_dataset = HerbariumDataset19(transform=train_transform, + root=os.path.join(herbarium_dataroot, 'small-train')) + + # Get labelled training set which has subsampled classes, then subsample some indices from that + # TODO: Subsampling unlabelled set in uniform random fashion from training data, will contain many instances of dominant class + train_dataset_labelled = subsample_classes(deepcopy(train_dataset), include_classes=train_classes) + subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) + train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) + + # Split into training and validation sets + if split_train_val: + + train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled, + val_instances_per_class=5) + train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) + val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) + val_dataset_labelled_split.transform = test_transform + + else: + + train_dataset_labelled_split, val_dataset_labelled_split = None, None + + # Get unlabelled data + unlabelled_indices = set(train_dataset.uq_idxs) - set(train_dataset_labelled.uq_idxs) + train_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset), np.array(list(unlabelled_indices))) + + # Get test dataset + test_dataset = HerbariumDataset19(transform=test_transform, + root=os.path.join(herbarium_dataroot, 'small-validation')) + + # Transform dict + unlabelled_classes = list(set(train_dataset.targets) - set(train_classes)) + target_xform_dict = {} + for i, k in enumerate(list(train_classes) + unlabelled_classes): + target_xform_dict[k] = i + + test_dataset.target_transform = lambda x: target_xform_dict[x] + train_dataset_unlabelled.target_transform = lambda x: target_xform_dict[x] + + # Either split train into train and val or use test set as val + train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled + val_dataset_labelled = val_dataset_labelled_split if split_train_val else None + + all_datasets = { + 'train_labelled': train_dataset_labelled, + 'train_unlabelled': train_dataset_unlabelled, + 'val': val_dataset_labelled, + 'test': test_dataset, + } + + return all_datasets + +if __name__ == '__main__': + + np.random.seed(0) + train_classes = np.random.choice(range(683,), size=(int(683 / 2)), replace=False) + + x = get_herbarium_datasets(None, None, train_classes=train_classes, + prop_train_labels=0.5) + + assert set(x['train_unlabelled'].targets) == set(range(683)) + + print('Printing lens...') + for k, v in x.items(): + if v is not None: + print(f'{k}: {len(v)}') + + print('Printing labelled and unlabelled overlap...') + print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) + print('Printing total instances in train...') + print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) + print('Printing number of labelled classes...') + print(len(set(x['train_labelled'].targets))) + print('Printing total number of classes...') + print(len(set(x['train_unlabelled'].targets))) + + print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}') + print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}') + print(f'Len labelled set: {len(x["train_labelled"])}') + print(f'Len unlabelled set: {len(x["train_unlabelled"])}') \ No newline at end of file diff --git a/data/imagenet.py b/data/imagenet.py new file mode 100644 index 0000000..a7d0489 --- /dev/null +++ b/data/imagenet.py @@ -0,0 +1,201 @@ +import torchvision +import numpy as np + +import os + +from copy import deepcopy +from data.data_utils import subsample_instances +from config import imagenet_root + + +class ImageNetBase(torchvision.datasets.ImageFolder): + + def __init__(self, root, transform): + + super(ImageNetBase, self).__init__(root, transform) + + self.uq_idxs = np.array(range(len(self))) + + def __getitem__(self, item): + + img, label = super().__getitem__(item) + uq_idx = self.uq_idxs[item] + + return img, label, uq_idx + + +def subsample_dataset(dataset, idxs): + + imgs_ = [] + for i in idxs: + imgs_.append(dataset.imgs[i]) + dataset.imgs = imgs_ + + samples_ = [] + for i in idxs: + samples_.append(dataset.samples[i]) + dataset.samples = samples_ + + # dataset.imgs = [x for i, x in enumerate(dataset.imgs) if i in idxs] + # dataset.samples = [x for i, x in enumerate(dataset.samples) if i in idxs] + + dataset.targets = np.array(dataset.targets)[idxs].tolist() + dataset.uq_idxs = dataset.uq_idxs[idxs] + + return dataset + + +def subsample_classes(dataset, include_classes=list(range(1000))): + + cls_idxs = [x for x, t in enumerate(dataset.targets) if t in include_classes] + + target_xform_dict = {} + for i, k in enumerate(include_classes): + target_xform_dict[k] = i + + dataset = subsample_dataset(dataset, cls_idxs) + dataset.target_transform = lambda x: target_xform_dict[x] + + return dataset + + +def get_train_val_indices(train_dataset, val_split=0.2): + + train_classes = list(set(train_dataset.targets)) + + # Get train/test indices + train_idxs = [] + val_idxs = [] + for cls in train_classes: + + cls_idxs = np.where(np.array(train_dataset.targets) == cls)[0] + + v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) + t_ = [x for x in cls_idxs if x not in v_] + + train_idxs.extend(t_) + val_idxs.extend(v_) + + return train_idxs, val_idxs + + +def get_imagenet_100_datasets(train_transform, test_transform, train_classes=range(80), + prop_train_labels=0.8, split_train_val=False, seed=0): + + np.random.seed(seed) + + # Subsample imagenet dataset initially to include 100 classes + subsampled_100_classes = np.random.choice(range(1000), size=(100,), replace=False) + subsampled_100_classes = np.sort(subsampled_100_classes) + print(f'Constructing ImageNet-100 dataset from the following classes: {subsampled_100_classes.tolist()}') + cls_map = {i: j for i, j in zip(subsampled_100_classes, range(100))} + + # Init entire training set + imagenet_training_set = ImageNetBase(root=os.path.join(imagenet_root, 'train'), transform=train_transform) + whole_training_set = subsample_classes(imagenet_training_set, include_classes=subsampled_100_classes) + + # Reset dataset + whole_training_set.samples = [(s[0], cls_map[s[1]]) for s in whole_training_set.samples] + whole_training_set.targets = [s[1] for s in whole_training_set.samples] + whole_training_set.uq_idxs = np.array(range(len(whole_training_set))) + whole_training_set.target_transform = None + + # Get labelled training set which has subsampled classes, then subsample some indices from that + train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) + subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) + train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) + + # Split into training and validation sets + train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) + train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) + val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) + val_dataset_labelled_split.transform = test_transform + + # Get unlabelled data + unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) + train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) + + # Get test set for all classes + test_dataset = ImageNetBase(root=os.path.join(imagenet_root, 'val'), transform=test_transform) + test_dataset = subsample_classes(test_dataset, include_classes=subsampled_100_classes) + + # Reset test set + test_dataset.samples = [(s[0], cls_map[s[1]]) for s in test_dataset.samples] + test_dataset.targets = [s[1] for s in test_dataset.samples] + test_dataset.uq_idxs = np.array(range(len(test_dataset))) + test_dataset.target_transform = None + + # Either split train into train and val or use test set as val + train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled + val_dataset_labelled = val_dataset_labelled_split if split_train_val else None + + all_datasets = { + 'train_labelled': train_dataset_labelled, + 'train_unlabelled': train_dataset_unlabelled, + 'val': val_dataset_labelled, + 'test': test_dataset, + } + + return all_datasets + + +def get_imagenet_1k_datasets(train_transform, test_transform, train_classes=range(500), + prop_train_labels=0.5, split_train_val=False, seed=0): + + np.random.seed(seed) + + # Init entire training set + whole_training_set = ImageNetBase(root=os.path.join(imagenet_root, 'train'), transform=train_transform) + + # Get labelled training set which has subsampled classes, then subsample some indices from that + train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) + subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) + train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) + + # Split into training and validation sets + train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) + train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) + val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) + val_dataset_labelled_split.transform = test_transform + + # Get unlabelled data + unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) + train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) + + # Get test set for all classes + test_dataset = ImageNetBase(root=os.path.join(imagenet_root, 'val'), transform=test_transform) + + # Either split train into train and val or use test set as val + train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled + val_dataset_labelled = val_dataset_labelled_split if split_train_val else None + + all_datasets = { + 'train_labelled': train_dataset_labelled, + 'train_unlabelled': train_dataset_unlabelled, + 'val': val_dataset_labelled, + 'test': test_dataset, + } + + return all_datasets + + + +if __name__ == '__main__': + + x = get_imagenet_100_datasets(None, None, split_train_val=False, + train_classes=range(50), prop_train_labels=0.5) + + print('Printing lens...') + for k, v in x.items(): + if v is not None: + print(f'{k}: {len(v)}') + + print('Printing labelled and unlabelled overlap...') + print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) + print('Printing total instances in train...') + print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) + + print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}') + print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}') + print(f'Len labelled set: {len(x["train_labelled"])}') + print(f'Len unlabelled set: {len(x["train_unlabelled"])}') \ No newline at end of file diff --git a/data/ssb_splits/aircraft_osr_splits.pkl b/data/ssb_splits/aircraft_osr_splits.pkl new file mode 100644 index 0000000000000000000000000000000000000000..2a70d7065207588cdbc1263a90d554931646a494 GIT binary patch literal 683 zcmY+CSyK~H5QW3OhCmYbeQ)+1!s0Fpeu9eIxLo51w+c!OG-Jy0fhwQ!yX_$*RleQ+ z&dlxZbMBw2YfVUyOe7ND3m3l%`Q!62ihQIMPAi236;x71H8nKSL<_C7k)fRqI_aX9 zKKdD8m@&qgV3H|jm}QQ67FcAJHF9jS#XauxfQLNdDf>LDb{?Iz7^}-=Xtgyr~>uj*i z4tIFWE>C#NJKhuco7;E8=(jdFZ4_SeifP^;eB>Dr@{JKjaeUz`dwk+GclqFpLANzK zKEH_kNj~yF^B0%CaC5i}&#oeEy4GenwTYABv_<~;=bff*SMVimb=oGesV1Z)-pOFb zX;xxK4Yo_Xq`?lSof6w^uuI~t4R$;2kvI*5y`;mE_Brh*6FMCrYvrKRAzG{)CS~P_ zKQ&`UZKN@Y?YG~?IW%Fy>772x&8kv{aS({?P>7rQT zYX+Ajw%Xvb(-nzBH@GTg?+vdxT^Gy$<{g-{A>}|!$~oN>dn`HjP1us#j7qxgPc5FE HVsZX2RdKs> literal 0 HcmV?d00001 diff --git a/data/ssb_splits/cub_osr_splits.pkl b/data/ssb_splits/cub_osr_splits.pkl new file mode 100644 index 0000000000000000000000000000000000000000..8631178dac02a1a0fc47adc9052c8a83a933ef44 GIT binary patch literal 1268 zcmY+EXM2=Y6h#AsP(yE_cZL>vAhgiiqSv8>5Rw-Xl1#EQ3{4|NJSrfHqJW^Ff(nXQ zC?Y5bmcQkG_~xfO_c`~ybI#st-xFDyr%$v)Lr8R^@7);uDrJ zoS*rZH#x~i1fA$lITak`HLCcSZ|K1gmhuk=`J6ZCPFsGVk+1lb4gA13X0V+#%%m@q z8P5b}F_1rbpFDnKGhG?Yr@Y4>e9!CrPEXeJ8J+3CItH_w#r#br-|{YpSkR=SGo?=D-PkBN$lc;4ZTj@p#XE~Fk3w+<2 zq`EfOpmybIpW-qvIn5^e@SH2$;~7_(!V6Zih&x>3Hhb95C?0T}we+SVHBl*aFqCHMxWOq7k&xvCIhJ#c zCPs3M8uqb+mL!)ha;sJKjoGBSBAXnnXlzO{uYNRD);4F=sYRVrHCClCWzTwry4cj~ zp|1AyZm65hzZmLn#ao4X*p$*xPphs_FC%j+)Z1Q65B0GL4WYg!p-ZTrZORW7nVkEf z{>G&sG{EjY3=Op1OG1N8!m!X_Lr@~tItEpe3AhL&2$-l1iV zV8_sMr}JiLg&Uk7TIoEs4y|%T9))ldTSBWHkBOl*PTkzlTE}u%Xq}V!IJ7>gF-sd# zR&Y>gqidPaCSTnqwAoAh&=xyZ7~1N8m>t^Yq@53?T`LN0cY-g3cDPm++UaFZXqT6T zp>m7WFSOf7&V?%cG8_r*ah!*SD!pqC?KLlTp(?)tH$v5ZI8KEU*A9jDdFUjx-!D!! zRO8x-P_1{l&;dV4%R_Zm{#vNsA2fv;OytN=#`YfzHCm;bP?JH~7m_dA5o&g=C6rBa Hx#Ir-4a&Ly literal 0 HcmV?d00001 diff --git a/data/ssb_splits/herbarium_19_class_splits.pkl b/data/ssb_splits/herbarium_19_class_splits.pkl new file mode 100644 index 0000000000000000000000000000000000000000..e2980cf0b171781a0d974802a5c3d638fd46e841 GIT binary patch literal 7498 zcmZvhRdAFE6GiC`?!M^a?(Xh1H@Itn;2H?-F2M(Pmvwh{cXxMp>OaIMHchNww>HryQ8kh>tYsoOIl24!3$Kem zBGc@|=nmOu{(CJ`V$B6_LY4`+I%Hp9|FTO^9im@}Y7u=-6ukHat(S>vJ9gz|*8!rN zh_bS)8PQe-KPLK#s5{qpq;(%rPom94B{{VPFFr_=kHPG$|Bk^jv?j2tH_?*}{!Wz2 z;4xbLi8^uWLqrQ0>_PM;(IR#g;ndG)ouc(U>u+Q5at8YlJ<6_i?5a;|6ub5?xPaCV z?D{Wv=|{&7cAX$P$*wDDeL<@!cO9g42T>Q+=Oe1bU|pgFN7N#Ep26pMu?JCG1{)Gh zknF=5FKRjvj2nlRF4x4V6Z4rd5(CG*84;&i1xGVH(F;nbu&>(245t~&5K8gcCh{q zb{%H06|H1i!-#6~VtZOGXw~4wF+{T%{Dt*1X%!$EO_YyapAo%6l$9tq*LR`SpCc+W zIEiQ@gOiEwoI zr`gqtR#l==46fk_F?f<_B7=ijFIqRynn3g)x%4|z`q6rq^(BZN zXV-m1Ner$bx}7M&u4GOvO*DzmJ$3e##w zRGq;kw7#LWp1Yo;)rQu3)_0?ojn-LOjd(GI)@|(SMQbmu#zdDBjb+zZTHkW&bOtvN zeZ{V(L@kL9@M2b?TZpm~%_Vx47gy7oPIQ{V%0!bHtjC8_lhzTUpNP7!>-eQ3QZtg0 zQ|JEk-taB8|z|2Y>Z8@CAP)(*b#}?8M|V4?1_DGD2~L5I2C8&Y+N1J z#&vOH+#I*Wt#Mb}6Zggg@lZS*kHll~L_8DE#f$N3ycVy=8}Vkm74OFf@nL)#pU2nn zUHllo#$WMwoL453nUYMIO*u??OnFTOOa)DaOodHFOvOyaO(jgFOqZL=n97>UnaZ0g zm@1m8n5vqpnQEKrn(CPvm>Qazn3|i~n%bG#n>v^}nmU`hn!1^KntGY~nfjXsng*MO znueK%n^H_;O{u2wrYWWwrZm$m(`-|^X^v^0X})QpX|ZXEX{l+MX{BkEX}xKKDZ{kc zw8ga5w9~ZPw9mBPbjWnnbj)1xw8rt3{NnQk`SZo1QSm+5ZPy{7w3 zkC~n@J!N{v^sMQ5(@UmTO|O|=H@#u{!1ST%W78+5PfcH%em4DK`qT86>2K3{QznbXx>mVXyVkhYxiVZ^Tw7h+T-#kc fU3*;nT>D)ITnAl;T!&ppTt{6eT&F!Xz4reAaZu0T literal 0 HcmV?d00001 diff --git a/data/ssb_splits/scars_osr_splits.pkl b/data/ssb_splits/scars_osr_splits.pkl new file mode 100644 index 0000000000000000000000000000000000000000..c0d279db772e6114133a631955b0bc4c9a151055 GIT binary patch literal 502 zcmY+<6Y8>xa)z39(nGCw?6soi=Tc8nImS6t&TY6xD(F0=&~!Wy5_nY?z!)= zC!TudrB~kh=&K)o`=hojIT~o9l{R`CqR2|C?DX1u&Gax=)OuCE#Uu4=aH+)qyc%|T zIFepd3SyNdK}Iyl3~R2Zd6o<8@I4;Mtf__ZlG4<=AS)VVho#AutD&Y^Ypai;Rw%Yw zwKqP*|J^3Xl3RkDXpkG0E<>ie^3+pbBaOAtPJ116)JbPubk$v0FMajX-#~*5Gu#Lx zjWXI8V~sQ31QSg%*%VV1m}a^e3e7ajYzr*3NW@}GEVWGBT4lcq*0u7t0Ew-2eap literal 0 HcmV?d00001 diff --git a/data/stanford_cars.py b/data/stanford_cars.py new file mode 100644 index 0000000..4979ae4 --- /dev/null +++ b/data/stanford_cars.py @@ -0,0 +1,166 @@ +import os +import pandas as pd +import numpy as np +from copy import deepcopy +from scipy import io as mat_io + +from torchvision.datasets.folder import default_loader +from torch.utils.data import Dataset + +from data.data_utils import subsample_instances +from config import car_root + +class CarsDataset(Dataset): + """ + Cars Dataset + """ + def __init__(self, train=True, limit=0, data_dir=car_root, transform=None): + + metas = os.path.join(data_dir, 'devkit/cars_train_annos.mat') if train else os.path.join(data_dir, 'devkit/cars_test_annos_withlabels.mat') + data_dir = os.path.join(data_dir, 'cars_train/') if train else os.path.join(data_dir, 'cars_test/') + + self.loader = default_loader + self.data_dir = data_dir + self.data = [] + self.target = [] + self.train = train + + self.transform = transform + + if not isinstance(metas, str): + raise Exception("Train metas must be string location !") + labels_meta = mat_io.loadmat(metas) + + for idx, img_ in enumerate(labels_meta['annotations'][0]): + if limit: + if idx > limit: + break + + # self.data.append(img_resized) + self.data.append(data_dir + img_[5][0]) + # if self.mode == 'train': + self.target.append(img_[4][0][0]) + + self.uq_idxs = np.array(range(len(self))) + self.target_transform = None + + def __getitem__(self, idx): + + image = self.loader(self.data[idx]) + target = self.target[idx] - 1 + + if self.transform is not None: + image = self.transform(image) + + if self.target_transform is not None: + target = self.target_transform(target) + + idx = self.uq_idxs[idx] + + return image, target, idx + + def __len__(self): + return len(self.data) + + +def subsample_dataset(dataset, idxs): + + dataset.data = np.array(dataset.data)[idxs].tolist() + dataset.target = np.array(dataset.target)[idxs].tolist() + dataset.uq_idxs = dataset.uq_idxs[idxs] + + return dataset + + +def subsample_classes(dataset, include_classes=range(160)): + + include_classes_cars = np.array(include_classes) + 1 # SCars classes are indexed 1 --> 196 instead of 0 --> 195 + cls_idxs = [x for x, t in enumerate(dataset.target) if t in include_classes_cars] + + target_xform_dict = {} + for i, k in enumerate(include_classes): + target_xform_dict[k] = i + + dataset = subsample_dataset(dataset, cls_idxs) + + # dataset.target_transform = lambda x: target_xform_dict[x] + + return dataset + +def get_train_val_indices(train_dataset, val_split=0.2): + + train_classes = np.unique(train_dataset.target) + + # Get train/test indices + train_idxs = [] + val_idxs = [] + for cls in train_classes: + + cls_idxs = np.where(train_dataset.target == cls)[0] + + v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) + t_ = [x for x in cls_idxs if x not in v_] + + train_idxs.extend(t_) + val_idxs.extend(v_) + + return train_idxs, val_idxs + + +def get_scars_datasets(train_transform, test_transform, train_classes=range(160), prop_train_labels=0.8, + split_train_val=False, seed=0): + + np.random.seed(seed) + + # Init entire training set + whole_training_set = CarsDataset(data_dir=car_root, transform=train_transform, train=True) + + # Get labelled training set which has subsampled classes, then subsample some indices from that + train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) + subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) + train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) + + # Split into training and validation sets + train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) + train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) + val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) + val_dataset_labelled_split.transform = test_transform + + # Get unlabelled data + unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) + train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) + + # Get test set for all classes + test_dataset = CarsDataset(data_dir=car_root, transform=test_transform, train=False) + + # Either split train into train and val or use test set as val + train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled + val_dataset_labelled = val_dataset_labelled_split if split_train_val else None + + all_datasets = { + 'train_labelled': train_dataset_labelled, + 'train_unlabelled': train_dataset_unlabelled, + 'val': val_dataset_labelled, + 'test': test_dataset, + } + + return all_datasets + +if __name__ == '__main__': + + x = get_scars_datasets(None, None, train_classes=range(98), prop_train_labels=0.5, split_train_val=False) + + print('Printing lens...') + for k, v in x.items(): + if v is not None: + print(f'{k}: {len(v)}') + + print('Printing labelled and unlabelled overlap...') + print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) + print('Printing total instances in train...') + print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) + + print(f'Num Labelled Classes: {len(set(x["train_labelled"].target))}') + print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].target))}') + print(f'Len labelled set: {len(x["train_labelled"])}') + print(f'Len unlabelled set: {len(x["train_unlabelled"])}') \ No newline at end of file diff --git a/kmeans_loss.py b/kmeans_loss.py new file mode 100644 index 0000000..e0aeef2 --- /dev/null +++ b/kmeans_loss.py @@ -0,0 +1,87 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np +from torch_clustering.kmeans.kmeans import PyTorchKMeans +class Kmeans_Loss(nn.Module): + def __init__(self, temperature=0.5, n_clusters=196): + super(Kmeans_Loss, self).__init__() + self.temperature = temperature + self.num_cluster = n_clusters + self.clustering_model = PyTorchKMeans(init='k-means++', max_iter=300, tol=1e-4, n_clusters=self.num_cluster) + self.psedo_labels = None + + def clustering(self, features, n_clusters): + + # kwargs = { + # 'metric': 'cosine' if self.l2_normalize else 'euclidean', + # 'distributed': True, + # 'random_state': 0, + # 'n_clusters': n_clusters, + # 'verbose': True + # } + clustering_model = PyTorchKMeans(init='k-means++', max_iter=300, tol=1e-4, n_clusters=self.num_cluster) + + psedo_labels = clustering_model.fit_predict(features)#首先features不带标签,训练的同时也输出features的标签 + self.psedo_labels = psedo_labels + cluster_centers = clustering_model.cluster_centers_ + return psedo_labels, cluster_centers + + def compute_cluster_loss(self, q_centers, k_centers, temperature=0.5, psedo_labels=None): + # 计算聚类中心的相似性矩阵 d_q + d_q = q_centers.mm(q_centers.T) / temperature + + # 计算每个样本与其对应聚类中心的相似度 + d_k = (q_centers * k_centers).sum(dim=1) / temperature + + # 将对角线上的值替换为样本与自己的相似度 + d_q = d_q.float() + d_q[torch.arange(self.num_cluster), torch.arange(self.num_cluster)] = d_k + + # 计算样本与聚类中心的相似度后进行一些处理 + zero_classes = torch.arange(self.num_cluster).cuda()[torch.sum(F.one_hot(torch.unique(psedo_labels), + self.num_cluster), dim=0) == 0] + + # 将相似度矩阵中某些位置的值替换为-10 + mask = torch.zeros((self.num_cluster, self.num_cluster), dtype=torch.bool, device=d_q.device) + mask[:, zero_classes] = 1 + d_q.masked_fill_(mask, -10) + + # 提取正样本相似度和负样本相似度 + pos = d_q.diag(0) + mask = torch.ones((self.num_cluster, self.num_cluster)) + mask = mask.fill_diagonal_(0).bool() + neg = d_q[mask].reshape(-1, self.num_cluster - 1) + + # 计算对比损失 + loss = -pos + torch.logsumexp(torch.cat([pos.reshape(self.num_cluster, 1), neg], dim=1), dim=1) + + # 将属于没有样本的聚类的损失值设为0 + loss[zero_classes] = 0. + + # 对损失值求和并除以聚类数目 + loss = loss.sum() / (self.num_cluster - len(zero_classes)) + + return loss + + def compute_centers(self, x, psedo_labels): + n_samples = x.size(0) + if len(psedo_labels.size()) > 1: + weight = psedo_labels.T + else: + weight = torch.zeros(self.num_cluster, n_samples).to(x) # L, N + weight[psedo_labels, torch.arange(n_samples)] = 1 + weight = F.normalize(weight, p=1, dim=1) # l1 normalization + centers = torch.mm(weight, x) + centers = F.normalize(centers, dim=1) + return centers + + def forward(self, im_q, im_k, psedo_labels): + batch_all_psedo_labels = psedo_labels + q_centers = self.compute_centers(im_q, batch_all_psedo_labels) + k_centers = self.compute_centers(im_k, batch_all_psedo_labels) + + cluster_loss = self.compute_cluster_loss(q_centers, k_centers, self.temperature, batch_all_psedo_labels) + + return cluster_loss diff --git a/model.py b/model.py new file mode 100644 index 0000000..b7fb885 --- /dev/null +++ b/model.py @@ -0,0 +1,231 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +class DINOHead(nn.Module): + def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, + nlayers=3, hidden_dim=2048, bottleneck_dim=256): + super().__init__() + nlayers = max(nlayers, 1) + if nlayers == 1: + self.mlp = nn.Linear(in_dim, bottleneck_dim) + elif nlayers != 0: + layers = [nn.Linear(in_dim, hidden_dim)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim)) + self.mlp = nn.Sequential(*layers) + self.apply(self._init_weights) + self.last_layer = nn.utils.weight_norm(nn.Linear(in_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + if norm_last_layer: + self.last_layer.weight_g.requires_grad = False + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + torch.nn.init.trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x_proj = self.mlp(x) + x = nn.functional.normalize(x, dim=-1, p=2) + # x = x.detach() + logits = self.last_layer(x) + return x_proj, logits + + +class ContrastiveLearningViewGenerator(object): + """Take two random crops of one image as the query and key.""" + + def __init__(self, base_transform, n_views=2): + self.base_transform = base_transform + self.n_views = n_views + + def __call__(self, x): + if not isinstance(self.base_transform, list): + return [self.base_transform(x) for i in range(self.n_views)] + else: + return [self.base_transform[i](x) for i in range(self.n_views)] + +class SupConLoss(torch.nn.Module): + """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. + It also supports the unsupervised contrastive loss in SimCLR + From: https://github.com/HobbitLong/SupContrast""" + def __init__(self, temperature=0.07, contrast_mode='all', + base_temperature=0.07): + super(SupConLoss, self).__init__() + self.temperature = temperature + self.contrast_mode = contrast_mode + self.base_temperature = base_temperature + + def forward(self, features, labels=None, mask=None): + """Compute loss for model. If both `labels` and `mask` are None, + it degenerates to SimCLR unsupervised loss: + https://arxiv.org/pdf/2002.05709.pdf + Args: + features: hidden vector of shape [bsz, n_views, ...]. + labels: ground truth of shape [bsz]. + mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j + has the same class as sample i. Can be asymmetric. + Returns: + A loss scalar. + """ + + device = (torch.device('cuda') + if features.is_cuda + else torch.device('cpu')) + + if len(features.shape) < 3: + raise ValueError('`features` needs to be [bsz, n_views, ...],' + 'at least 3 dimensions are required') + if len(features.shape) > 3: + features = features.view(features.shape[0], features.shape[1], -1) + + batch_size = features.shape[0] + if labels is not None and mask is not None: + raise ValueError('Cannot define both `labels` and `mask`') + elif labels is None and mask is None: + mask = torch.eye(batch_size, dtype=torch.float32).to(device) + elif labels is not None: + labels = labels.contiguous().view(-1, 1) + if labels.shape[0] != batch_size: + raise ValueError('Num of labels does not match num of features') + mask = torch.eq(labels, labels.T).float().to(device) + else: + mask = mask.float().to(device) + + contrast_count = features.shape[1] + contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) + if self.contrast_mode == 'one': + anchor_feature = features[:, 0] + anchor_count = 1 + elif self.contrast_mode == 'all': + anchor_feature = contrast_feature + anchor_count = contrast_count + else: + raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) + + # compute logits + anchor_dot_contrast = torch.div( + torch.matmul(anchor_feature, contrast_feature.T), + self.temperature) + + # for numerical stability + logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) + logits = anchor_dot_contrast - logits_max.detach() + + # tile mask + mask = mask.repeat(anchor_count, contrast_count) + # mask-out self-contrast cases + logits_mask = torch.scatter( + torch.ones_like(mask), + 1, + torch.arange(batch_size * anchor_count).view(-1, 1).to(device), + 0 + ) + mask = mask * logits_mask + + # compute log_prob + exp_logits = torch.exp(logits) * logits_mask + log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) + + # compute mean of log-likelihood over positive + mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) + + # loss + loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos + loss = loss.view(anchor_count, batch_size).mean() + + return loss + + + +def info_nce_logits(features, n_views=2, temperature=1.0, device='cuda'): + + b_ = 0.5 * int(features.size(0)) + + labels = torch.cat([torch.arange(b_) for i in range(n_views)], dim=0) + labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float() + labels = labels.to(device) + + features = F.normalize(features, dim=1) + + similarity_matrix = torch.matmul(features, features.T) + + # discard the main diagonal from both: labels and similarities matrix + mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device) + labels = labels[~mask].view(labels.shape[0], -1) + similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1) + + # select and combine multiple positives + positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1) + + # select only the negatives the negatives + negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1) + + logits = torch.cat([positives, negatives], dim=1) + labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device) + + logits = logits / temperature + return logits, labels + + +def get_params_groups(model): + regularized = [] + not_regularized = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + # we do not regularize biases nor Norm parameters + if name.endswith(".bias") or len(param.shape) == 1: + not_regularized.append(param) + else: + regularized.append(param) + return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}] + + +class DistillLoss(nn.Module): + def __init__(self, warmup_teacher_temp_epochs, nepochs, + ncrops=2, warmup_teacher_temp=0.07, teacher_temp=0.04, + student_temp=0.1): + super().__init__() + self.student_temp = student_temp + self.ncrops = ncrops + self.teacher_temp_schedule = np.concatenate(( + np.linspace(warmup_teacher_temp, + teacher_temp, warmup_teacher_temp_epochs), + np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp + )) + + def forward(self, student_output, teacher_output, epoch): + """ + Cross-entropy between softmax outputs of the teacher and student networks. + """ + student_out = student_output / self.student_temp + student_out = student_out.chunk(self.ncrops) + + # teacher centering and sharpening + temp = self.teacher_temp_schedule[epoch] + teacher_out = F.softmax(teacher_output / temp, dim=-1) + teacher_out = teacher_out.detach().chunk(2) + + total_loss = 0 + n_loss_terms = 0 + for iq, q in enumerate(teacher_out): + for v in range(len(student_out)): + if v == iq: + # we skip cases where student and teacher operate on the same view + continue + loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1) + total_loss += loss.mean() + n_loss_terms += 1 + total_loss /= n_loss_terms + return total_loss diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0b5a84e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +loguru +numpy +pandas +scikit_learn +scipy +torch +torchvision +matplotlib +munkres +tqdm \ No newline at end of file diff --git a/scripts/run_aircraft.sh b/scripts/run_aircraft.sh new file mode 100644 index 0000000..fc2a449 --- /dev/null +++ b/scripts/run_aircraft.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +set -e +set -x + +CUDA_VISIBLE_DEVICES=0 python train.py \ + --dataset_name 'aircraft' \ + --batch_size 128 \ + --grad_from_block 11 \ + --epochs 200 \ + --num_workers 8 \ + --use_ssb_splits \ + --sup_weight 0.35 \ + --weight_decay 5e-5 \ + --transform 'imagenet' \ + --lr 0.1 \ + --eval_funcs 'v2' \ + --warmup_teacher_temp 0.07 \ + --teacher_temp 0.04 \ + --warmup_teacher_temp_epochs 30 \ + --memax_weight 1 \ + --exp_name aircraft_simgcd diff --git a/scripts/run_cars.sh b/scripts/run_cars.sh new file mode 100644 index 0000000..db89c2e --- /dev/null +++ b/scripts/run_cars.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +set -e +set -x + +CUDA_VISIBLE_DEVICES=0 python train.py \ + --dataset_name 'scars' \ + --batch_size 128 \ + --grad_from_block 11 \ + --epochs 200 \ + --num_workers 8 \ + --use_ssb_splits \ + --sup_weight 0.35 \ + --weight_decay 5e-5 \ + --transform 'imagenet' \ + --lr 0.1 \ + --eval_funcs 'v2' \ + --warmup_teacher_temp 0.07 \ + --teacher_temp 0.04 \ + --warmup_teacher_temp_epochs 30 \ + --memax_weight 1 \ + --exp_name scars_simgcd diff --git a/scripts/run_cifar10.sh b/scripts/run_cifar10.sh new file mode 100644 index 0000000..55d0189 --- /dev/null +++ b/scripts/run_cifar10.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +set -e +set -x + +CUDA_VISIBLE_DEVICES=0 python train.py \ + --dataset_name 'cifar10' \ + --batch_size 128 \ + --grad_from_block 11 \ + --epochs 200 \ + --num_workers 8 \ + --use_ssb_splits \ + --sup_weight 0.35 \ + --weight_decay 5e-5 \ + --transform 'imagenet' \ + --lr 0.1 \ + --eval_funcs 'v2' \ + --warmup_teacher_temp 0.07 \ + --teacher_temp 0.04 \ + --warmup_teacher_temp_epochs 30 \ + --memax_weight 1 \ + --exp_name cifar10_simgcd diff --git a/scripts/run_cifar100.sh b/scripts/run_cifar100.sh new file mode 100644 index 0000000..b7e7315 --- /dev/null +++ b/scripts/run_cifar100.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +set -e +set -x + +CUDA_VISIBLE_DEVICES=1 python train.py \ + --dataset_name 'cifar100' \ + --batch_size 128 \ + --grad_from_block 11 \ + --epochs 200 \ + --num_workers 8 \ + --use_ssb_splits \ + --sup_weight 0.35 \ + --weight_decay 5e-5 \ + --transform 'imagenet' \ + --lr 0.1 \ + --eval_funcs 'v2' \ + --warmup_teacher_temp 0.07 \ + --teacher_temp 0.04 \ + --warmup_teacher_temp_epochs 30 \ + --memax_weight 4 \ + --nn_per_image 8 \ + --stml_weight 0.1 \ + --exp_name cifar100_simgcd diff --git a/scripts/run_cub.sh b/scripts/run_cub.sh new file mode 100644 index 0000000..b29198d --- /dev/null +++ b/scripts/run_cub.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +set -e +set -x + +CUDA_VISIBLE_DEVICES=1 python train.py \ + --dataset_name 'cub' \ + --batch_size 128 \ + --grad_from_block 11 \ + --epochs 200 \ + --num_workers 8 \ + --use_ssb_splits \ + --sup_weight 0.35 \ + --weight_decay 5e-5 \ + --transform 'imagenet' \ + --lr 0.1 \ + --eval_funcs 'v2' \ + --warmup_teacher_temp 0.07 \ + --teacher_temp 0.04 \ + --warmup_teacher_temp_epochs 30 \ + --memax_weight 2 \ + --nn_per_image 8 \ + --stml_weight 0.1 \ + --exp_name cub_simgcd diff --git a/scripts/run_herb19.sh b/scripts/run_herb19.sh new file mode 100644 index 0000000..2d11b64 --- /dev/null +++ b/scripts/run_herb19.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +set -e +set -x + +CUDA_VISIBLE_DEVICES=0 python train.py \ + --dataset_name 'herbarium_19' \ + --batch_size 128 \ + --grad_from_block 11 \ + --epochs 200 \ + --num_workers 8 \ + --use_ssb_splits \ + --sup_weight 0.35 \ + --weight_decay 5e-5 \ + --transform 'imagenet' \ + --lr 0.1 \ + --eval_funcs 'v2' 'v2b' \ + --warmup_teacher_temp 0.07 \ + --teacher_temp 0.04 \ + --warmup_teacher_temp_epochs 30 \ + --memax_weight 1 \ + --exp_name herb19_simgcd diff --git a/scripts/run_imagenet100.sh b/scripts/run_imagenet100.sh new file mode 100644 index 0000000..bedcb8a --- /dev/null +++ b/scripts/run_imagenet100.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +set -e +set -x + +CUDA_VISIBLE_DEVICES=0 python train.py \ + --dataset_name 'imagenet_100' \ + --batch_size 128 \ + --grad_from_block 11 \ + --epochs 200 \ + --num_workers 8 \ + --use_ssb_splits \ + --sup_weight 0.35 \ + --weight_decay 5e-5 \ + --transform 'imagenet' \ + --lr 0.1 \ + --eval_funcs 'v2' \ + --warmup_teacher_temp 0.07 \ + --teacher_temp 0.04 \ + --warmup_teacher_temp_epochs 30 \ + --memax_weight 1 \ + --exp_name imagenet100_simgcd diff --git a/scripts/run_imagenet1k.sh b/scripts/run_imagenet1k.sh new file mode 100644 index 0000000..bcb5da7 --- /dev/null +++ b/scripts/run_imagenet1k.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +set -e +set -x + +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --master_port 12348 --nproc_per_node=8 train_mp.py \ + --dataset_name 'imagenet_1k' \ + --batch_size 128 \ + --grad_from_block 11 \ + --epochs 200 \ + --num_workers 8 \ + --use_ssb_splits \ + --sup_weight 0.35 \ + --weight_decay 5e-5 \ + --transform 'imagenet' \ + --lr 0.1 \ + --eval_funcs 'v2' \ + --warmup_teacher_temp 0.07 \ + --teacher_temp 0.04 \ + --warmup_teacher_temp_epochs 30 \ + --memax_weight 1 \ + --exp_name imagenet1k_simgcd \ + --print_freq 100 diff --git a/torch_clustering/__base__.py b/torch_clustering/__base__.py new file mode 100644 index 0000000..e191308 --- /dev/null +++ b/torch_clustering/__base__.py @@ -0,0 +1,89 @@ +# -*- coding: UTF-8 -*- +''' +@Project : torch_clustering +@File : __base__.py +@Author : Zhizhong Huang from Fudan University +@Homepage: https://hzzone.github.io/ +@Email : zzhuang19@fudan.edu.cn +@Date : 2022/10/19 12:20 PM +''' + +import torch +import torch.nn.functional as F +import torch.distributed as dist + + +class BasicClustering: + def __init__(self, + n_clusters, + init='k-means++', + n_init=10, + random_state=0, + max_iter=300, + tol=1e-4, + distributed=False, + verbose=True): + ''' + :param n_clusters: + :param init: {'k-means++', 'random'}, callable or array-like of shape \ + (n_clusters, n_features), default='k-means++' + Method for initialization: + 'k-means++' : selects initial cluster centers for k-mean + clustering in a smart way to speed up convergence. See section + Notes in k_init for more details. + 'random': choose `n_clusters` observations (rows) at random from data + for the initial centroids. + If an array is passed, it should be of shape (n_clusters, n_features) + and gives the initial centers. + If a callable is passed, it should take arguments X, n_clusters and a + random state and return an initialization. + :param n_init: int, default=10 + Number of time the k-means algorithm will be run with different + centroid seeds. The final results will be the best output of + n_init consecutive runs in terms of inertia. + :param random_state: int, RandomState instance or None, default=None + Determines random number generation for centroid initialization. Use + an int to make the randomness deterministic. + See :term:`Glossary `. + :param max_iter: + :param tol: + :param verbose: int, default=0 Verbosity mode. + ''' + self.n_clusters = n_clusters + self.n_init = n_init + self.max_iter = max_iter + self.tol = tol + self.cluster_centers_ = None + self.init = init + self.random_state = random_state + self.is_root_worker = True if not dist.is_initialized() else (dist.get_rank() == 0) + self.verbose = verbose and self.is_root_worker + self.distributed = distributed and dist.is_initialized() + if verbose and self.distributed and self.is_root_worker: + print('Perform K-means in distributed mode.') + self.world_size = dist.get_world_size() if self.distributed else 1 + self.rank = dist.get_rank() if self.distributed else 0 + + def fit_predict(self, X): + pass + + def distributed_sync(self, tensor): + tensors_gather = [torch.ones_like(tensor) + for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + output = torch.stack(tensors_gather) + return output + + +def pairwise_cosine(x1: torch.Tensor, x2: torch.Tensor, pairwise=True): + x1 = F.normalize(x1) + x2 = F.normalize(x2) + if not pairwise: + return (1 - (x1 * x2).sum(dim=1)) + return 1 - x1.mm(x2.T) + + +def pairwise_euclidean(x1: torch.Tensor, x2: torch.Tensor, pairwise=True): + if not pairwise: + return ((x1 - x2) ** 2).sum(dim=1).sqrt() + return torch.cdist(x1, x2, p=2.) diff --git a/torch_clustering/__init__.py b/torch_clustering/__init__.py new file mode 100644 index 0000000..e9a569e --- /dev/null +++ b/torch_clustering/__init__.py @@ -0,0 +1,85 @@ +# -*- coding: UTF-8 -*- +''' +@Project : torch_clustering +@File : __init__.py +@Author : Zhizhong Huang from Fudan University +@Homepage: https://hzzone.github.io/ +@Email : zzhuang19@fudan.edu.cn +@Date : 2022/10/19 12:21 PM +''' + +from .kmeans.kmeans import PyTorchKMeans +from .faiss_kmeans import FaissKMeans +from .gaussian_mixture import PyTorchGaussianMixture +from .beta_mixture import BetaMixture1D + +import numpy as np +from munkres import Munkres +from sklearn import metrics +import warnings + +def evaluate_clustering(label, pred, eval_metric=['nmi', 'acc', 'ari'], phase='train'): + mask = (label != -1) + label = label[mask] + pred = pred[mask] + results = {} + if 'nmi' in eval_metric: + nmi = metrics.normalized_mutual_info_score(label, pred, average_method='arithmetic') + results[f'{phase}_nmi'] = nmi + if 'ari' in eval_metric: + ari = metrics.adjusted_rand_score(label, pred) + results[f'{phase}_ari'] = ari + if 'f' in eval_metric: + f = metrics.fowlkes_mallows_score(label, pred) + results[f'{phase}_f'] = f + if 'acc' in eval_metric: + n_clusters = len(set(label)) + if n_clusters == len(set(pred)): + pred_adjusted = get_y_preds(label, pred, n_clusters=n_clusters) + acc = metrics.accuracy_score(pred_adjusted, label) + else: + acc = 0. + warnings.warn('TODO: the number of classes is not equal...') + results[f'{phase}_acc'] = acc + return results + + +def calculate_cost_matrix(C, n_clusters): + cost_matrix = np.zeros((n_clusters, n_clusters)) + # cost_matrix[i,j] will be the cost of assigning cluster i to label j + for j in range(n_clusters): + s = np.sum(C[:, j]) # number of examples in cluster i + for i in range(n_clusters): + t = C[i, j] + cost_matrix[j, i] = s - t + return cost_matrix + + +def get_cluster_labels_from_indices(indices): + n_clusters = len(indices) + cluster_labels = np.zeros(n_clusters) + for i in range(n_clusters): + cluster_labels[i] = indices[i][1] + return cluster_labels + + +def get_y_preds(y_true, cluster_assignments, n_clusters): + """ + Computes the predicted labels, where label assignments now + correspond to the actual labels in y_true (as estimated by Munkres) + cluster_assignments: array of labels, outputted by kmeans + y_true: true labels + n_clusters: number of clusters in the dataset + returns: a tuple containing the accuracy and confusion matrix, + in that order + """ + confusion_matrix = metrics.confusion_matrix(y_true, cluster_assignments, labels=None) + # compute accuracy based on optimal 1:1 assignment of clusters to labels + cost_matrix = calculate_cost_matrix(confusion_matrix, n_clusters) + indices = Munkres().compute(cost_matrix) + kmeans_to_true_cluster_labels = get_cluster_labels_from_indices(indices) + + if np.min(cluster_assignments) != 0: + cluster_assignments = cluster_assignments - np.min(cluster_assignments) + y_pred = kmeans_to_true_cluster_labels[cluster_assignments] + return y_pred diff --git a/torch_clustering/beta_mixture.py b/torch_clustering/beta_mixture.py new file mode 100644 index 0000000..87fc396 --- /dev/null +++ b/torch_clustering/beta_mixture.py @@ -0,0 +1,84 @@ +# -*- coding: UTF-8 -*- +''' +@Project : torch_clustering +@File : beta_mixture.py +@Author : Zhizhong Huang from Fudan University +@Homepage: https://hzzone.github.io/ +@Email : zzhuang19@fudan.edu.cn +@Date : 2022/10/19 12:21 PM +''' + +import numpy as np +import matplotlib.pyplot as plt +import torch +import torch.nn as nn + + +class BetaMixture1D(object): + def __init__(self, + max_iters=10, + alphas_init=[1, 2], + betas_init=[2, 1], + weights_init=[0.5, 0.5]): + self.alphas = np.array(alphas_init, dtype=np.float64) + self.betas = np.array(betas_init, dtype=np.float64) + self.weight = np.array(weights_init, dtype=np.float64) + self.max_iters = max_iters + self.eps_nan = 1e-12 + + def fit_beta_weighted(self, x, w): + def weighted_mean(x, w): + return np.sum(w * x) / np.sum(w) + + x_bar = weighted_mean(x, w) + s2 = weighted_mean((x - x_bar) ** 2, w) + alpha = x_bar * ((x_bar * (1 - x_bar)) / s2 - 1) + beta = alpha * (1 - x_bar) / x_bar + return alpha, beta + + def likelihood(self, x, y): + import scipy.stats as stats + return stats.beta.pdf(x, self.alphas[y], self.betas[y]) + + def weighted_likelihood(self, x, y): + return self.weight[y] * self.likelihood(x, y) + + def probability(self, x): + return sum(self.weighted_likelihood(x, y) for y in range(2)) + + def responsibilities(self, x): + r = np.array([self.weighted_likelihood(x, i) for i in range(2)]) + # there are ~200 samples below that value + r[r <= self.eps_nan] = self.eps_nan + r /= r.sum(axis=0) + return r.T + + def fit(self, x): + x = np.copy(x) + + # EM on beta distributions unsable with x == 0 or 1 + eps = 1e-4 + x[x >= 1 - eps] = 1 - eps + x[x <= eps] = eps + + for i in range(self.max_iters): + # E-step + r = self.responsibilities(x).T + + # M-step + self.alphas[0], self.betas[0] = self.fit_beta_weighted(x, r[0]) + self.alphas[1], self.betas[1] = self.fit_beta_weighted(x, r[1]) + self.weight = r.sum(axis=1) + self.weight /= self.weight.sum() + + return self + + def plot(self): + x = np.linspace(0, 1, 100) + plt.plot(x, self.weighted_likelihood(x, 0), label='negative') + plt.plot(x, self.weighted_likelihood(x, 1), label='positive') + # plt.plot(x, self.probability(x), lw=2, label='mixture') + plt.legend() + + def __repr__(self): + return 'BetaMixture1D(w={}, a={}, b={})'.format(self.weight, self.alphas, self.betas) diff --git a/torch_clustering/faiss_kmeans.py b/torch_clustering/faiss_kmeans.py new file mode 100644 index 0000000..85e6181 --- /dev/null +++ b/torch_clustering/faiss_kmeans.py @@ -0,0 +1,147 @@ +# -*- coding: UTF-8 -*- +''' +@Project : torch_clustering +@File : faiss_kmeans.py +@Author : Zhizhong Huang from Fudan University +@Homepage: https://hzzone.github.io/ +@Email : zzhuang19@fudan.edu.cn +@Date : 2022/10/19 12:22 PM +''' + +import numpy as np +import torch +import torch.nn.functional as F +import torch.distributed as dist + +try: + import faiss +except: + print('faiss not installed') +from .__base__ import BasicClustering + + +class FaissKMeans(BasicClustering): + def __init__(self, + metric='euclidean', + n_clusters=8, + n_init=10, + max_iter=300, + random_state=1234, + distributed=False, + verbose=True): + super().__init__(n_clusters=n_clusters, + n_init=n_init, + max_iter=max_iter, + distributed=distributed, + verbose=verbose) + + if metric == 'euclidean': + self.spherical = False + elif metric == 'cosine': + self.spherical = True + else: + raise NotImplementedError + self.random_state = random_state + + def apply_pca(self, X, dim): + n, d = X.shape + if self.spherical: + X = F.normalize(X, dim=1) + mat = faiss.PCAMatrix(d, dim) + mat.train(n, X) + X = mat.apply_py(X) + + def fit_predict(self, input: torch.Tensor, device=-1): + n, d = input.shape + + assert isinstance(input, (torch.Tensor, np.ndarray)) + is_torch_tensor = isinstance(input, torch.Tensor) + if is_torch_tensor: + if self.spherical: + input = F.normalize(input, dim=1) + + if input.is_cuda: + device = input.device.index + X = input.cpu().numpy().astype(np.float32) + else: + if self.spherical: + X = input / np.linalg.norm(input, 2, axis=1)[:, np.newaxis] + else: + X = input + X = X.astype(np.float32) + + random_states = torch.arange(self.world_size) + self.random_state + random_state = random_states[self.rank] + if device >= 0: + # faiss implementation of k-means + clus = faiss.Clustering(int(d), int(self.n_clusters)) + + # Change faiss seed at each k-means so that the randomly picked + # initialization centroids do not correspond to the same feature ids + # from an epoch to another. + # clus.seed = np.random.randint(1234) + clus.seed = int(random_state) + + clus.niter = self.max_iter + clus.max_points_per_centroid = 10000000 + clus.min_points_per_centroid = 10 + clus.spherical = self.spherical + clus.nredo = self.n_init + clus.verbose = self.verbose + res = faiss.StandardGpuResources() + flat_config = faiss.GpuIndexFlatConfig() + flat_config.useFloat16 = False + flat_config.device = device + flat_config.verbose = self.verbose + flat_config.spherical = self.spherical + flat_config.nredo = self.n_init + index = faiss.GpuIndexFlatL2(res, d, flat_config) + + # perform the training + clus.train(X, index) + D, I = index.search(X, 1) + else: + clus = faiss.Kmeans(d=d, + k=self.n_clusters, + niter=self.max_iter, + nredo=self.n_init, + verbose=self.verbose, + spherical=self.spherical) + X = X.astype(np.float32) + clus.train(X) + # self.cluster_centers_ = self.kmeans.centroids + D, I = clus.index.search.search(X, 1) # for each sample, find cluster distance and assignments + + tensor_device = 'cpu' if device < 0 else f'cuda:{device}' + + best_labels = torch.from_numpy(I.flatten()).to(tensor_device) + min_inertia = torch.from_numpy(D.flatten()).to(tensor_device).sum() + best_states = faiss.vector_to_array(clus.centroids).reshape(self.n_clusters, d) + best_states = torch.from_numpy(best_states).to(tensor_device) + + if self.distributed: + min_inertia = self.distributed_sync(min_inertia) + best_idx = torch.argmin(min_inertia).item() + min_inertia = min_inertia[best_idx] + dist.broadcast(best_labels, src=best_idx) + dist.broadcast(best_states, src=best_idx) + + if self.verbose: + print(f"Final min inertia {min_inertia.item()}.") + + self.cluster_centers_ = best_states + return best_labels + + +if __name__ == '__main__': + dist.init_process_group(backend='nccl', init_method='env://') + torch.cuda.set_device(dist.get_rank()) + X = torch.randn(1280, 256).cuda() + clustering_model = FaissKMeans(metric='euclidean', + n_clusters=10, + n_init=2, + max_iter=1, + random_state=1234, + distributed=True, + verbose=True) + clustering_model.fit_predict(X) diff --git a/torch_clustering/gaussian_mixture.py b/torch_clustering/gaussian_mixture.py new file mode 100644 index 0000000..4ebd1d5 --- /dev/null +++ b/torch_clustering/gaussian_mixture.py @@ -0,0 +1,222 @@ +# -*- coding: UTF-8 -*- +''' +@Project : torch_clustering +@File : gaussian_mixture.py +@Author : Zhizhong Huang from Fudan University +@Homepage: https://hzzone.github.io/ +@Email : zzhuang19@fudan.edu.cn +@Date : 2022/10/19 12:22 PM +''' + +import numpy as np +import torch +import torch.nn.functional as F +import torch.distributions as D +import torch.distributed as dist +from .__base__ import BasicClustering +from .kmeans.kmeans import PyTorchKMeans + + +class PyTorchGaussianMixture(BasicClustering): + def __init__(self, + covariance_type='diag', + metric='euclidean', + reg_covar=1e-6, + init='k-means++', + random_state=0, + n_clusters=8, + n_init=10, + max_iter=300, + tol=1e-4, + distributed=False, + verbose=True): + ''' + pytorch_gaussian_mixture = PyTorchGaussianMixture(metric='cosine', + covariance_type='diag', + reg_covar=1e-6, + init='k-means++', + random_state=0, + n_clusters=10, + n_init=10, + max_iter=300, + tol=1e-5, + verbose=True) + pseudo_labels = pytorch_gaussian_mixture.fit_predict(torch.from_numpy(features).cuda()) + :param metric: + :param reg_covar: + :param init: + :param random_state: + :param n_clusters: + :param n_init: + :param max_iter: + :param tol: + :param verbose: + ''' + super().__init__(n_clusters=n_clusters, + init=init, + distributed=distributed, + random_state=random_state, + n_init=n_init, + max_iter=max_iter, + tol=tol, + verbose=verbose) + self.reg_covar = reg_covar + self.metric = metric + self._estimate_gaussian_covariances = {'diag': self._estimate_gaussian_covariances_diag, + 'spherical': self._estimate_gaussian_covariances_spherical}[ + covariance_type] + self.covariances, self.weights, self.lower_bound_ = None, None, None + + def _estimate_gaussian_covariances_diag(self, resp: torch.Tensor, X: torch.Tensor, nk: torch.Tensor, + means: torch.Tensor): + avg_X2 = torch.matmul(resp.T, X * X) / nk[:, None] + avg_means2 = means ** 2 + avg_X_means = means * torch.matmul(resp.T, X) / nk[:, None] + return avg_X2 - 2 * avg_X_means + avg_means2 + self.reg_covar + # N * K * L + # return (((X.unsqueeze(1) - means.unsqueeze(0)) ** 2) * resp.unsqueeze(-1)).sum(0) / nk[:, None] + self.reg_covar + + def _estimate_gaussian_covariances_spherical(self, resp: torch.Tensor, X: torch.Tensor, nk: torch.Tensor, + means: torch.Tensor): + return self._estimate_gaussian_covariances_diag(resp, X, nk, means).mean(1, keepdim=True) + + def initialize(self, X: torch.Tensor, resp: torch.Tensor): + """Initialization of the Gaussian mixture parameters. + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + resp : array-like of shape (n_samples, n_components) + """ + n_samples, _ = X.shape + weights, means, covariances = self._estimate_gaussian_parameters(X, resp) + return means, covariances, weights + + def _estimate_gaussian_parameters(self, X: torch.Tensor, resp: torch.Tensor): + # N * K * L + nk = resp.sum(dim=0) + 10 * torch.finfo(resp.dtype).eps + # means = torch.sum(X[:, None, :] * resp[:, :, None], dim=0) / nk[:, None] + means = resp.T.mm(X) / nk[:, None] + if self.metric == 'cosine': + means = F.normalize(means, dim=-1) + covariances = self._estimate_gaussian_covariances(resp, X, nk, means) + weights = nk / X.size(0) + return weights, means, covariances + + def fit_predict(self, X: torch.Tensor): + if self.metric == 'cosine': + X = F.normalize(X, dim=1) + best_means, best_covariances, best_weights, best_resp = None, None, None, None + max_lower_bound = - float("Inf") + + g = torch.Generator() + g.manual_seed(self.random_state) + random_states = torch.randperm(10000, generator=g)[:self.n_init * self.world_size] + random_states = random_states[self.rank:self.n_init * self.world_size:self.world_size] + + for n_init in range(self.n_init): + + random_state = int(random_states[n_init]) + # KMeans init + pseudo_labels = PyTorchKMeans(metric=self.metric, + init=self.init, + n_clusters=self.n_clusters, + random_state=random_state, + n_init=self.n_init, + max_iter=self.max_iter, + tol=self.tol, + distributed=self.distributed, + verbose=self.verbose).fit_predict(X) + resp = F.one_hot(pseudo_labels, self.n_clusters).to(X) + means, covariances, weights = self.initialize(X, resp) + previous_lower_bound_ = self.log_likehood(resp.log()) + + for n_iter in range(self.max_iter): + # E step + log_resp = self._e_step(X, means, covariances, weights) + + resp = F.softmax(log_resp, dim=1) + + lower_bound_ = self.log_likehood(log_resp) + + shift = torch.abs(previous_lower_bound_ - lower_bound_) + + if shift < self.tol: + if self.verbose: + print('converge at Iteration {} with shift: {}'.format(n_iter, shift)) + break + + if self.verbose: + print(f'Iteration {n_iter}, loglikehood: {lower_bound_.item()}, shift: {shift.item()}') + previous_lower_bound_ = lower_bound_ + + if lower_bound_ > max_lower_bound: + max_lower_bound = lower_bound_ + best_means, best_covariances, best_weights, best_resp = \ + means, covariances, weights, resp + + # M step + means, covariances, weights = self._m_step(X, resp) + + if self.distributed: + max_lower_bound = self.distributed_sync(max_lower_bound) + best_idx = torch.argmax(max_lower_bound).item() + max_lower_bound = max_lower_bound[best_idx] + dist.broadcast(best_means, src=best_idx) + dist.broadcast(best_covariances, src=best_idx) + dist.broadcast(best_weights, src=best_idx) + dist.broadcast(best_resp, src=best_idx) + if self.verbose: + print(f"Final loglikehood {max_lower_bound.item()}.") + + if self.verbose: + print(f'Converged with loglikehood {max_lower_bound.item()}') + self.cluster_centers_, self.covariances, self.weights, self.lower_bound_ = \ + best_means, best_covariances, best_weights, max_lower_bound + return self.predict_score(X) + + def _e_step(self, X: torch.Tensor, means: torch.Tensor, covariances: torch.Tensor, weights: torch.Tensor): + estimate_precision_error_message = ( + "Fitting the mixture model failed because some components have " + "ill-defined empirical covariance (for instance caused by singleton " + "or collapsed samples). Try to decrease the number of components, " + "or increase reg_covar.") + if torch.any(torch.le(covariances, 0.0)): + raise ValueError(estimate_precision_error_message) + log_resp = self.log_prob(X, means, covariances, weights) + return log_resp + + def log_prob(self, X: torch.Tensor, means: torch.Tensor, covariances: torch.Tensor, weights: torch.Tensor): + log_resp = D.Normal(loc=means.unsqueeze(0), + scale=covariances.unsqueeze(0).sqrt()).log_prob(X.unsqueeze(1)).sum(dim=-1) + log_resp = log_resp + weights.unsqueeze(0).log() + return log_resp + + def log_prob_sklearn(self, X: torch.Tensor, means: torch.Tensor, covariances: torch.Tensor, weights: torch.Tensor): + n_samples, n_features = X.size() + n_components, _ = means.size() + + precisions_chol = 1. / torch.sqrt(covariances) + + log_det = torch.sum(precisions_chol.log(), dim=1) + precisions = precisions_chol ** 2 + log_prob = (torch.sum((means ** 2 * precisions), dim=1) - + 2. * torch.matmul(X, (means * precisions).T) + + torch.matmul(X ** 2, precisions.T)) + log_p = -.5 * (n_features * np.log(2 * np.pi) + log_prob) + log_det + weighted_log_p = log_p + weights.unsqueeze(0).log() + + # seems not work + # weighted_log_p = weighted_log_p - weighted_log_p.logsumexp(dim=1, keepdim=True) + return weighted_log_p + + def _m_step(self, X: torch.Tensor, resp: torch.Tensor): + n_samples, _ = X.shape + weights, means, covariances = self._estimate_gaussian_parameters(X, resp) + return means, covariances, weights + + def log_likehood(self, log_resp: torch.Tensor): + # N * K + return log_resp.logsumexp(dim=1).mean() + + def predict_score(self, X: torch.Tensor): + return F.softmax(self._e_step(X, self.cluster_centers_, self.covariances, self.weights), dim=1) diff --git a/torch_clustering/kmeans/__init__.py b/torch_clustering/kmeans/__init__.py new file mode 100644 index 0000000..53cfd8a --- /dev/null +++ b/torch_clustering/kmeans/__init__.py @@ -0,0 +1 @@ +from .kmeans import PyTorchKMeans diff --git a/torch_clustering/kmeans/kmeans.py b/torch_clustering/kmeans/kmeans.py new file mode 100644 index 0000000..68d5a2f --- /dev/null +++ b/torch_clustering/kmeans/kmeans.py @@ -0,0 +1,192 @@ +# -*- coding: UTF-8 -*- +''' +@Project : torch_clustering +@File : kmeans.py +@Author : Zhizhong Huang from Fudan University +@Homepage: https://hzzone.github.io/ +@Email : zzhuang19@fudan.edu.cn +@Date : 2022/10/19 12:23 PM +''' + +import numpy as np +import torch +import tqdm +import torch.distributed as dist +from ..__base__ import BasicClustering, pairwise_euclidean, pairwise_cosine +from .kmeans_plus_plus import _kmeans_plusplus + + +class PyTorchKMeans(BasicClustering): + def __init__(self, + metric='euclidean', + init='k-means++', + random_state=0, + n_clusters=8, + n_init=10, + max_iter=300, + tol=1e-4, + distributed=False, + verbose=True): + super().__init__(n_clusters=n_clusters, + init=init, + random_state=random_state, + n_init=n_init, + max_iter=max_iter, + tol=tol, + verbose=verbose, + distributed=distributed) + self.distance_metric = {'euclidean': pairwise_euclidean, 'cosine': pairwise_cosine}[metric] + # self.distance_metric = lambda a, b: torch.cdist(a, b, p=2.) + if isinstance(self.init, (np.ndarray, torch.Tensor)): self.n_init = 1 + + def initialize(self, X: torch.Tensor, random_state: int): + num_samples = len(X) + if isinstance(self.init, str): + g = torch.Generator() + g.manual_seed(random_state) + if self.init == 'random': + indices = torch.randperm(num_samples, generator=g)[:self.n_clusters] + init_state = X[indices] + elif self.init == 'k-means++': + init_state, _ = _kmeans_plusplus(X, + random_state=random_state, + n_clusters=self.n_clusters, + pairwise_distance=self.distance_metric) + # init_state = X[torch.randperm(num_samples, generator=g)[0]].unsqueeze(0) + # for k in range(1, self.n_clusters): + # d = torch.min(self.distance_metric(X, init_state), dim=1)[0] + # init_state = torch.cat([init_state, X[torch.argmax(d)].unsqueeze(0)], dim=0) + else: + raise NotImplementedError + elif isinstance(self.init, (np.ndarray, torch.Tensor)): + init_state = self.init.to(X) + else: + raise NotImplementedError + + return init_state + + def fit_predict(self, X: torch.Tensor): + + tol = torch.mean(torch.var(X, dim=0)) * self.tol + + min_inertia, best_states, best_labels = float('Inf'), None, None + + random_states = torch.arange(self.n_init * self.world_size) + self.random_state + random_states = random_states[self.rank:len(random_states):self.world_size] + # g = torch.Generator() + # g.manual_seed(self.random_state) + # random_states = torch.randperm(10000, generator=g)[:self.n_init * self.world_size] + # random_states = random_states[self.rank:self.n_init * self.world_size:self.world_size] + + self.stats = {'state': [], 'inertia': [], 'label': []} + for n_init in range(self.n_init): + random_state = int(random_states[n_init]) + old_state = self.initialize(X, random_state=random_state) + old_labels, inertia = self.predict(X, old_state) + + labels = old_labels + + progress_bar = tqdm.tqdm(total=self.max_iter, disable=not self.verbose) + + for n_iter in range(self.max_iter): + + # https://discuss.pytorch.org/t/groupby-aggregate-mean-in-pytorch/45335/7 + # n_samples = X.size(0) + # weight = torch.zeros(self.n_clusters, n_samples, dtype=X.dtype, device=X.device) # L, N + # weight[labels, torch.arange(n_samples)] = 1 + # weight = F.normalize(weight, p=1, dim=1) # l1 normalization + # state = torch.mm(weight, X) # L, F + state = torch.zeros(self.n_clusters, X.size(1), dtype=X.dtype, device=X.device) + counts = torch.zeros(self.n_clusters, dtype=X.dtype, device=X.device) + 1e-6 + classes, classes_counts = torch.unique(labels, return_counts=True) + counts[classes] = classes_counts.to(X) + state.index_add_(0, labels, X) + state = state / counts.view(-1, 1) + + # d = self.distance_metric(X, state) + # inertia, labels = d.min(dim=1) + # inertia = inertia.sum() + labels, inertia = self.predict(X, state) + + if inertia < min_inertia: + min_inertia = inertia + best_states, best_labels = state, labels + + if self.verbose: + progress_bar.set_description( + f'nredo {n_init + 1}/{self.n_init:02d}, iteration {n_iter:03d} with inertia {inertia:.2f}') + progress_bar.update(n=1) + + center_shift = self.distance_metric(old_state, state, pairwise=False) + + if torch.equal(labels, old_labels): + # First check the labels for strict convergence. + if self.verbose: + print(f"Converged at iteration {n_iter}: strict convergence.") + break + else: + # center_shift = self.distance_metric(old_state, state).diag().sum() + # No strict convergence, check for tol based convergence. + # center_shift_tot = (center_shift ** 2).sum() + center_shift_tot = center_shift.sum() + if center_shift_tot <= tol: + if self.verbose: + print( + f"Converged at iteration {n_iter}: center shift " + f"{center_shift_tot} within tolerance {tol} " + f"and min inertia {min_inertia.item()}." + ) + break + + old_labels[:] = labels + old_state = state + progress_bar.close() + self.stats['state'].append(old_state) + self.stats['inertia'].append(inertia) + self.stats['label'].append(old_labels) + + self.stats['state'] = torch.stack(self.stats['state']) + self.stats['inertia'] = torch.stack(self.stats['inertia']) + self.stats['label'] = torch.stack(self.stats['label']) + if self.distributed: + min_inertia = self.distributed_sync(min_inertia) + best_idx = torch.argmin(min_inertia).item() + min_inertia = min_inertia[best_idx] + dist.broadcast(best_labels, src=best_idx) + dist.broadcast(best_states, src=best_idx) + self.stats['state'] = self.distributed_sync(self.stats['state']) + self.stats['inertia'] = self.distributed_sync(self.stats['inertia']) + self.stats['label'] = self.distributed_sync(self.stats['label']) + + if self.verbose: + print(f"Final min inertia {min_inertia.item()}.") + + self.cluster_centers_ = best_states + return best_labels + + def predict(self, X: torch.Tensor, cluster_centers_=None): + if cluster_centers_ is None: + cluster_centers_ = self.cluster_centers_ + split_size = min(4096, X.size(0)) + inertia, pred_labels = 0., [] + for f in X.split(split_size, dim=0): + d = self.distance_metric(f, cluster_centers_) + inertia_, labels_ = d.min(dim=1) + inertia += inertia_.sum() + pred_labels.append(labels_) + return torch.cat(pred_labels, dim=0), inertia + + +if __name__ == '__main__': + torch.cuda.set_device(1) + clustering_model = PyTorchKMeans(metric='cosine', + init='k-means++', + random_state=0, + n_clusters=1000, + n_init=10, + max_iter=300, + tol=1e-4, + distributed=False, + verbose=True) + X = torch.randn(1280000, 256).cuda() + clustering_model.fit_predict(X) diff --git a/torch_clustering/kmeans/kmeans_plus_plus.py b/torch_clustering/kmeans/kmeans_plus_plus.py new file mode 100644 index 0000000..fa081df --- /dev/null +++ b/torch_clustering/kmeans/kmeans_plus_plus.py @@ -0,0 +1,132 @@ +# -*- coding: UTF-8 -*- +''' +@Project : torch_clustering +@File : kmeans_plus_plus.py +@Author : Zhizhong Huang from Fudan University +@Homepage: https://hzzone.github.io/ +@Email : zzhuang19@fudan.edu.cn +@Date : 2022/10/19 12:23 PM +''' + +import torch +import numpy as np +import warnings + + +def stable_cumsum(arr, dim=None, rtol=1e-05, atol=1e-08): + """Use high precision for cumsum and check that final value matches sum. + Parameters + ---------- + arr : array-like + To be cumulatively summed as flat. + axis : int, default=None + Axis along which the cumulative sum is computed. + The default (None) is to compute the cumsum over the flattened array. + rtol : float, default=1e-05 + Relative tolerance, see ``np.allclose``. + atol : float, default=1e-08 + Absolute tolerance, see ``np.allclose``. + """ + if dim is None: + arr = arr.flatten() + dim = 0 + out = torch.cumsum(arr, dim=dim, dtype=torch.float64) + expected = torch.sum(arr, dim=dim, dtype=torch.float64) + if not torch.all(torch.isclose(out.take(torch.Tensor([-1]).long().to(arr.device)), + expected, rtol=rtol, + atol=atol, equal_nan=True)): + warnings.warn('cumsum was found to be unstable: ' + 'its last element does not correspond to sum', + RuntimeWarning) + return out + + +def _kmeans_plusplus(X, + n_clusters, + random_state, + pairwise_distance, + n_local_trials=None): + """Computational component for initialization of n_clusters by + k-means++. Prior validation of data is assumed. + Parameters + ---------- + X : {ndarray, sparse matrix} of shape (n_samples, n_features) + The data to pick seeds for. + n_clusters : int + The number of seeds to choose. + random_state : RandomState instance + The generator used to initialize the centers. + See :term:`Glossary `. + n_local_trials : int, default=None + The number of seeding trials for each center (except the first), + of which the one reducing inertia the most is greedily chosen. + Set to None to make the number of trials depend logarithmically + on the number of seeds (2+log(k)); this is the default. + Returns + ------- + centers : ndarray of shape (n_clusters, n_features) + The inital centers for k-means. + indices : ndarray of shape (n_clusters,) + The index location of the chosen centers in the data array X. For a + given index and center, X[index] = center. + """ + n_samples, n_features = X.shape + + generator = torch.Generator(device=str(X.device)) + generator.manual_seed(random_state) + + centers = torch.empty((n_clusters, n_features), dtype=X.dtype, device=X.device) + + # Set the number of local seeding trials if none is given + if n_local_trials is None: + # This is what Arthur/Vassilvitskii tried, but did not report + # specific results for other than mentioning in the conclusion + # that it helped. + n_local_trials = 2 + int(np.log(n_clusters)) + + # Pick first center randomly and track index of point + # center_id = random_state.randint(n_samples) + center_id = torch.randint(n_samples, (1,), generator=generator, device=X.device) + + indices = torch.full((n_clusters,), -1, dtype=torch.int, device=X.device) + centers[0] = X[center_id] + indices[0] = center_id + + # Initialize list of closest distances and calculate current potential + closest_dist_sq = pairwise_distance( + centers[0, None], X) + current_pot = closest_dist_sq.sum() + + # Pick the remaining n_clusters-1 points + for c in range(1, n_clusters): + # Choose center candidates by sampling with probability proportional + # to the squared distance to the closest existing center + # rand_vals = random_state.random_sample(n_local_trials) * current_pot + rand_vals = torch.rand(n_local_trials, generator=generator, device=X.device) * current_pot + + candidate_ids = torch.searchsorted(stable_cumsum(closest_dist_sq), + rand_vals) + # XXX: numerical imprecision can result in a candidate_id out of range + torch.clip(candidate_ids, None, closest_dist_sq.numel() - 1, + out=candidate_ids) + + # Compute distances to center candidates + distance_to_candidates = pairwise_distance( + X[candidate_ids], X) + + # update closest distances squared and potential for each candidate + torch.minimum(closest_dist_sq, distance_to_candidates, + out=distance_to_candidates) + candidates_pot = distance_to_candidates.sum(dim=1) + + # Decide which candidate is the best + best_candidate = torch.argmin(candidates_pot) + current_pot = candidates_pot[best_candidate] + closest_dist_sq = distance_to_candidates[best_candidate] + best_candidate = candidate_ids[best_candidate] + + # Permanently add best center candidate found in local tries + centers[c] = X[best_candidate] + indices[c] = best_candidate + + return centers, indices diff --git a/train.py b/train.py new file mode 100644 index 0000000..7d7d382 --- /dev/null +++ b/train.py @@ -0,0 +1,365 @@ +import argparse + +import math +import numpy as np +import torch +import torch.nn as nn +from torch.optim import SGD, lr_scheduler +from torch.utils.data import DataLoader +from tqdm import tqdm + +from data.augmentations import get_transform +from data.get_datasets import get_datasets, get_class_splits + +from util.general_utils import AverageMeter, init_experiment +from util.cluster_and_log_utils import log_accs_from_preds +from config import exp_root +from model import DINOHead, info_nce_logits, SupConLoss, DistillLoss, ContrastiveLearningViewGenerator, get_params_groups + +from torch.utils.data.sampler import Sampler +from util.general_utils import NNBatchSampler, STML_loss_simgcd + +def train(student, simgcd_train_loader, train_loader, train_dataset, test_loader, unlabelled_train_loader, args): + params_groups = get_params_groups(student) + optimizer = SGD(params_groups, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + fp16_scaler = None + if args.fp16: + fp16_scaler = torch.cuda.amp.GradScaler() + + exp_lr_scheduler = lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=args.epochs, + eta_min=args.lr * 1e-3, + ) + + + cluster_criterion = DistillLoss( + args.warmup_teacher_temp_epochs, + args.epochs, + args.n_views, + args.warmup_teacher_temp, + args.teacher_temp, + ) + + # Kmeans loss + from kmeans_loss import Kmeans_Loss + kmeans_loss_f = Kmeans_Loss(n_clusters=args.mlp_out_dim) + + # # inductive + # best_test_acc_lab = 0 + # # transductive + # best_train_acc_lab = 0 + # best_train_acc_ubl = 0 + # best_train_acc_all = 0 + + ori_train_loader = train_loader + for epoch in range(args.epochs): + loss_record = AverageMeter() + + loader = simgcd_train_loader + if epoch > args.stml_warmup_ep: + assert isinstance(student, nn.Sequential) + balanced_sampler = NNBatchSampler(train_dataset, student[0], ori_train_loader, args.batch_size, nn_per_image=args.nn_per_image, using_feat=True, is_norm=False) + train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_sampler=balanced_sampler, pin_memory=True) + stml_crit = STML_loss_simgcd(topk=args.nn_per_image, view=args.n_views) + # loader = train_loader + + student.train() + # TODO: try loading two batches from two loaders, and compute loss on both + for batch_idx, batch in enumerate(loader): + if epoch > args.stml_warmup_ep: + try: + stml_batch = next(iter(train_loader)) + except StopIteration: + train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_sampler=balanced_sampler, pin_memory=True) + stml_batch = next(iter(train_loader)) + + images, class_labels, uq_idxs, mask_lab = batch + mask_lab = mask_lab[:, 0] + + class_labels, mask_lab = class_labels.cuda(non_blocking=True), mask_lab.cuda(non_blocking=True).bool() + images = torch.cat(images, dim=0).cuda(non_blocking=True) + + with torch.cuda.amp.autocast(fp16_scaler is not None): + # __import__("ipdb").set_trace() + # student_proj, student_out = student(images) + # teacher_out = student_out.detach() + + if epoch > args.stml_warmup_ep: + # __import__("ipdb").set_trace() + stml_images, stml_labels, stml_uq_idxs, stml_mask_lab = stml_batch + stml_images = torch.cat(stml_images, dim=0).cuda(non_blocking=True) + + all_images = torch.cat([images, stml_images], dim=0) + all_proj, all_out = student(all_images) + student_proj, stml_proj = all_proj.chunk(2) + student_out, stml_out = all_out.chunk(2) + teacher_out = student_out.detach() + + # representation learning, contextuality + # PCA as the teacher + # with torch.no_grad(): + # U, S, V = torch.pca_lowrank(stml_proj, q=stml_proj.shape[-1] // 2, ) + # t_emb = torch.matmul(stml_proj, V) + + # cls head as the teacher + t_emb = stml_out + stml_loss = stml_crit(stml_proj, t_emb, torch.cat([stml_uq_idxs] * args.n_views)) + else: + stml_loss = torch.zeros(1).cuda()[0] + + student_proj, student_out = student(images) + teacher_out = student_out.detach() + + # clustering, sup + sup_logits = torch.cat([f[mask_lab] for f in (student_out / 0.1).chunk(2)], dim=0) + sup_labels = torch.cat([class_labels[mask_lab] for _ in range(2)], dim=0) + cls_loss = nn.CrossEntropyLoss()(sup_logits, sup_labels) + + # clustering, unsup + cluster_loss = cluster_criterion(student_out, teacher_out, epoch) + avg_probs = (student_out / 0.1).softmax(dim=1).mean(dim=0) + me_max_loss = - torch.sum(torch.log(avg_probs**(-avg_probs))) + math.log(float(len(avg_probs))) + cluster_loss += args.memax_weight * me_max_loss + + # represent learning, unsup + contrastive_logits, contrastive_labels = info_nce_logits(features=student_proj) + contrastive_loss = torch.nn.CrossEntropyLoss()(contrastive_logits, contrastive_labels) + + if epoch > args.stml_warmup_ep: + # rep learning, proto + kmeans_loss = kmeans_loss_f(student_proj.chunk(2)[0], student_proj.chunk(2)[1], student_out.chunk(2)[0].argmax(1)) + else: + kmeans_loss = torch.zeros(1).cuda()[0] + + # representation learning, sup + student_proj = torch.cat([f[mask_lab].unsqueeze(1) for f in student_proj.chunk(2)], dim=1) + student_proj = torch.nn.functional.normalize(student_proj, dim=-1) + sup_con_labels = class_labels[mask_lab] + sup_con_loss = SupConLoss()(student_proj, labels=sup_con_labels) + + pstr = '' + pstr += f'cls_loss: {cls_loss.item():.4f} ' + pstr += f'cluster_loss: {cluster_loss.item():.4f} ' + pstr += f'sup_con_loss: {sup_con_loss.item():.4f} ' + pstr += f'contrastive_loss: {contrastive_loss.item():.4f} ' + pstr += f'stml_loss: {stml_loss.item():.4f} ' + pstr += f'kmeans_loss: {kmeans_loss.item():.4f}' + + loss = 0 + loss += (1 - args.sup_weight) * cluster_loss + args.sup_weight * cls_loss + loss += (1 - args.sup_weight) * contrastive_loss + args.sup_weight * sup_con_loss + loss += args.stml_weight * stml_loss + loss += args.stml_weight * kmeans_loss + + # Train acc + loss_record.update(loss.item(), class_labels.size(0)) + optimizer.zero_grad() + if fp16_scaler is None: + loss.backward() + optimizer.step() + else: + fp16_scaler.scale(loss).backward() + fp16_scaler.step(optimizer) + fp16_scaler.update() + + if batch_idx % args.print_freq == 0: + args.logger.info('Epoch: [{}][{}/{}]\t loss {:.5f}\t {}' + .format(epoch, batch_idx, len(train_loader), loss.item(), pstr)) + + args.logger.info('Train Epoch: {} Avg Loss: {:.4f} '.format(epoch, loss_record.avg)) + + args.logger.info('Testing on unlabelled examples in the training data...') + all_acc, old_acc, new_acc = test(student, unlabelled_train_loader, epoch=epoch, save_name='Train ACC Unlabelled', args=args) + # args.logger.info('Testing on disjoint test set...') + # all_acc_test, old_acc_test, new_acc_test = test(student, test_loader, epoch=epoch, save_name='Test ACC', args=args) + + + args.logger.info('Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc, new_acc)) + # args.logger.info('Test Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc_test, old_acc_test, new_acc_test)) + + # Step schedule + exp_lr_scheduler.step() + + save_dict = { + 'model': student.state_dict(), + 'optimizer': optimizer.state_dict(), + 'epoch': epoch + 1, + } + + torch.save(save_dict, args.model_path) + args.logger.info("model saved to {}.".format(args.model_path)) + + # if old_acc_test > best_test_acc_lab: + # + # args.logger.info(f'Best ACC on old Classes on disjoint test set: {old_acc_test:.4f}...') + # args.logger.info('Best Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc, new_acc)) + # + # torch.save(save_dict, args.model_path[:-3] + f'_best.pt') + # args.logger.info("model saved to {}.".format(args.model_path[:-3] + f'_best.pt')) + # + # # inductive + # best_test_acc_lab = old_acc_test + # # transductive + # best_train_acc_lab = old_acc + # best_train_acc_ubl = new_acc + # best_train_acc_all = all_acc + # + # args.logger.info(f'Exp Name: {args.exp_name}') + # args.logger.info(f'Metrics with best model on test set: All: {best_train_acc_all:.4f} Old: {best_train_acc_lab:.4f} New: {best_train_acc_ubl:.4f}') + + +def test(model, test_loader, epoch, save_name, args): + + model.eval() + + preds, targets = [], [] + mask = np.array([]) + for batch_idx, (images, label, _) in enumerate(tqdm(test_loader)): + images = images.cuda(non_blocking=True) + with torch.no_grad(): + _, logits = model(images) + preds.append(logits.argmax(1).cpu().numpy()) + targets.append(label.cpu().numpy()) + mask = np.append(mask, np.array([True if x.item() in range(len(args.train_classes)) else False for x in label])) + + preds = np.concatenate(preds) + targets = np.concatenate(targets) + all_acc, old_acc, new_acc = log_accs_from_preds(y_true=targets, y_pred=preds, mask=mask, + T=epoch, eval_funcs=args.eval_funcs, save_name=save_name, + args=args) + + return all_acc, old_acc, new_acc + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description='cluster', formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--batch_size', default=128, type=int) + parser.add_argument('--num_workers', default=8, type=int) + parser.add_argument('--eval_funcs', nargs='+', help='Which eval functions to use', default=['v2', 'v2p']) + + parser.add_argument('--warmup_model_dir', type=str, default=None) + parser.add_argument('--dataset_name', type=str, default='scars', help='options: cifar10, cifar100, imagenet_100, cub, scars, fgvc_aricraft, herbarium_19') + parser.add_argument('--prop_train_labels', type=float, default=0.5) + parser.add_argument('--use_ssb_splits', action='store_true', default=True) + + parser.add_argument('--grad_from_block', type=int, default=11) + parser.add_argument('--lr', type=float, default=0.1) + parser.add_argument('--gamma', type=float, default=0.1) + parser.add_argument('--momentum', type=float, default=0.9) + parser.add_argument('--weight_decay', type=float, default=1e-4) + parser.add_argument('--epochs', default=200, type=int) + parser.add_argument('--exp_root', type=str, default=exp_root) + parser.add_argument('--transform', type=str, default='imagenet') + parser.add_argument('--sup_weight', type=float, default=0.35) + parser.add_argument('--n_views', default=2, type=int) + + parser.add_argument('--memax_weight', type=float, default=2) + parser.add_argument('--warmup_teacher_temp', default=0.07, type=float, help='Initial value for the teacher temperature.') + parser.add_argument('--teacher_temp', default=0.04, type=float, help='Final value (after linear warmup)of the teacher temperature.') + parser.add_argument('--warmup_teacher_temp_epochs', default=30, type=int, help='Number of warmup epochs for the teacher temperature.') + + parser.add_argument('--nn_per_image', type=int, default=4) + parser.add_argument('--stml_weight', type=float, default=0.1) + parser.add_argument('--stml_warmup_ep', type=int, default=50) + + parser.add_argument('--fp16', action='store_true', default=False) + parser.add_argument('--print_freq', default=10, type=int) + parser.add_argument('--exp_name', default=None, type=str) + + # ---------------------- + # INIT + # ---------------------- + args = parser.parse_args() + device = torch.device('cuda:0') + args = get_class_splits(args) + + args.num_labeled_classes = len(args.train_classes) + args.num_unlabeled_classes = len(args.unlabeled_classes) + + init_experiment(args, runner_name=['simgcd']) + args.logger.info(f'Using evaluation function {args.eval_funcs[0]} to print results') + + torch.backends.cudnn.benchmark = True + + # ---------------------- + # BASE MODEL + # ---------------------- + args.interpolation = 3 + args.crop_pct = 0.875 + + + backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14') + + if args.warmup_model_dir is not None: + args.logger.info(f'Loading weights from {args.warmup_model_dir}') + backbone.load_state_dict(torch.load(args.warmup_model_dir, map_location='cpu')) + + # NOTE: Hardcoded image size as we do not finetune the entire ViT model + args.image_size = 224 + args.feat_dim = 768 + args.num_mlp_layers = 3 + args.mlp_out_dim = args.num_labeled_classes + args.num_unlabeled_classes + + # ---------------------- + # HOW MUCH OF BASE MODEL TO FINETUNE + # ---------------------- + for m in backbone.parameters(): + m.requires_grad = False + + # Only finetune layers from block 'args.grad_from_block' onwards + for name, m in backbone.named_parameters(): + if 'block' in name: + block_num = int(name.split('.')[1]) + if block_num >= args.grad_from_block: + m.requires_grad = True + + + args.logger.info('model build') + + # -------------------- + # CONTRASTIVE TRANSFORM + # -------------------- + train_transform, test_transform = get_transform(args.transform, image_size=args.image_size, args=args) + train_transform = ContrastiveLearningViewGenerator(base_transform=train_transform, n_views=args.n_views) + # -------------------- + # DATASETS + # -------------------- + train_dataset, test_dataset, unlabelled_train_examples_test, datasets = get_datasets(args.dataset_name, + train_transform, + test_transform, + args) + + # -------------------- + # SAMPLER + # Sampler which balances labelled and unlabelled examples in each batch + # -------------------- + label_len = len(train_dataset.labelled_dataset) + unlabelled_len = len(train_dataset.unlabelled_dataset) + sample_weights = [1 if i < label_len else label_len / unlabelled_len for i in range(len(train_dataset))] + sample_weights = torch.DoubleTensor(sample_weights) + sampler = torch.utils.data.WeightedRandomSampler(sample_weights, num_samples=len(train_dataset)) + + # -------------------- + # DATALOADERS + # -------------------- + simgcd_train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False, + sampler=sampler, drop_last=True, pin_memory=True) + train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False, pin_memory=True) + test_loader_unlabelled = DataLoader(unlabelled_train_examples_test, num_workers=args.num_workers, + batch_size=256, shuffle=False, pin_memory=False) + # test_loader_labelled = DataLoader(test_dataset, num_workers=args.num_workers, + # batch_size=256, shuffle=False, pin_memory=False) + + # ---------------------- + # PROJECTION HEAD + # ---------------------- + projector = DINOHead(in_dim=args.feat_dim, out_dim=args.mlp_out_dim, nlayers=args.num_mlp_layers) + model = nn.Sequential(backbone, projector).to(device) + + # ---------------------- + # TRAIN + # ---------------------- + # train(model, train_loader, test_loader_labelled, test_loader_unlabelled, args) + train(model, simgcd_train_loader, train_loader, train_dataset, None, test_loader_unlabelled, args) diff --git a/train_mp.py b/train_mp.py new file mode 100644 index 0000000..c405ef1 --- /dev/null +++ b/train_mp.py @@ -0,0 +1,325 @@ +import argparse +import os + +import math +import numpy as np +import torch +import torch.nn as nn +import torch.distributed as dist +import torch.backends.cudnn as cudnn +from torch.utils.data import DataLoader +from tqdm import tqdm + +from data.augmentations import get_transform +from data.get_datasets import get_datasets, get_class_splits + +from util.general_utils import AverageMeter, init_experiment, DistributedWeightedSampler +from util.cluster_and_log_utils import log_accs_from_preds +from config import exp_root +from model import DINOHead, info_nce_logits, SupConLoss, DistillLoss, ContrastiveLearningViewGenerator, get_params_groups + + +def get_parser(): + parser = argparse.ArgumentParser(description='cluster', formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument('--batch_size', default=128, type=int) + parser.add_argument('--num_workers', default=8, type=int) + parser.add_argument('--eval_funcs', nargs='+', help='Which eval functions to use', default=['v2', 'v2b']) + + parser.add_argument('--warmup_model_dir', type=str, default=None) + parser.add_argument('--dataset_name', type=str, default='scars', help='options: cifar10, cifar100, imagenet_100, cub, scars, fgvc_aricraft, herbarium_19') + parser.add_argument('--prop_train_labels', type=float, default=0.5) + parser.add_argument('--use_ssb_splits', action='store_true', default=True) + + parser.add_argument('--grad_from_block', type=int, default=11) + parser.add_argument('--lr', type=float, default=0.1) + parser.add_argument('--gamma', type=float, default=0.1) + parser.add_argument('--momentum', type=float, default=0.9) + parser.add_argument('--weight_decay', type=float, default=1e-4) + parser.add_argument('--epochs', default=200, type=int) + parser.add_argument('--exp_root', type=str, default=exp_root) + parser.add_argument('--transform', type=str, default='imagenet') + parser.add_argument('--sup_weight', type=float, default=0.35) + parser.add_argument('--n_views', default=2, type=int) + + parser.add_argument('--memax_weight', type=float, default=2) + parser.add_argument('--warmup_teacher_temp', default=0.07, type=float, help='Initial value for the teacher temperature.') + parser.add_argument('--teacher_temp', default=0.04, type=float, help='Final value (after linear warmup)of the teacher temperature.') + parser.add_argument('--warmup_teacher_temp_epochs', default=30, type=int, help='Number of warmup epochs for the teacher temperature.') + + parser.add_argument('--fp16', action='store_true', default=False) + parser.add_argument('--print_freq', default=10, type=int) + parser.add_argument('--exp_name', default=None, type=str) + + # ---------------------- + # INIT + # ---------------------- + args = parser.parse_args() + args = get_class_splits(args) + + args.num_labeled_classes = len(args.train_classes) + args.num_unlabeled_classes = len(args.unlabeled_classes) + + if os.environ["LOCAL_RANK"] is not None: + args.local_rank = int(os.environ["LOCAL_RANK"]) + + return args + + +def main(args): + # ---------------------- + # BASE MODEL + # ---------------------- + args.interpolation = 3 + args.crop_pct = 0.875 + + backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16') + + if args.warmup_model_dir is not None: + if dist.get_rank() == 0: + args.logger.info(f'Loading weights from {args.warmup_model_dir}') + backbone.load_state_dict(torch.load(args.warmup_model_dir, map_location='cpu')) + + # NOTE: Hardcoded image size as we do not finetune the entire ViT model + args.image_size = 224 + args.feat_dim = 768 + args.num_mlp_layers = 3 + args.mlp_out_dim = args.num_labeled_classes + args.num_unlabeled_classes + + # ---------------------- + # HOW MUCH OF BASE MODEL TO FINETUNE + # ---------------------- + for m in backbone.parameters(): + m.requires_grad = False + + # Only finetune layers from block 'args.grad_from_block' onwards + for name, m in backbone.named_parameters(): + if 'block' in name: + block_num = int(name.split('.')[1]) + if block_num >= args.grad_from_block: + m.requires_grad = True + + if dist.get_rank() == 0: + args.logger.info('model build') + + # -------------------- + # CONTRASTIVE TRANSFORM + # -------------------- + train_transform, test_transform = get_transform(args.transform, image_size=args.image_size, args=args) + train_transform = ContrastiveLearningViewGenerator(base_transform=train_transform, n_views=args.n_views) + # -------------------- + # DATASETS + # -------------------- + train_dataset, test_dataset, unlabelled_train_examples_test, datasets = get_datasets(args.dataset_name, + train_transform, + test_transform, + args) + + # -------------------- + # SAMPLER + # Sampler which balances labelled and unlabelled examples in each batch + # -------------------- + label_len = len(train_dataset.labelled_dataset) + unlabelled_len = len(train_dataset.unlabelled_dataset) + sample_weights = [1 if i < label_len else label_len / unlabelled_len for i in range(len(train_dataset))] + sample_weights = torch.DoubleTensor(sample_weights) + train_sampler = DistributedWeightedSampler(train_dataset, sample_weights, num_samples=len(train_dataset)) + unlabelled_train_sampler = torch.utils.data.distributed.DistributedSampler(unlabelled_train_examples_test) + # test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) + # -------------------- + # DATALOADERS + # -------------------- + train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False, + sampler=train_sampler, drop_last=True, pin_memory=True) + unlabelled_train_loader = DataLoader(unlabelled_train_examples_test, num_workers=args.num_workers, batch_size=256, + shuffle=False, sampler=unlabelled_train_sampler, pin_memory=False) + # test_loader = DataLoader(test_dataset, num_workers=args.num_workers, batch_size=256, + # shuffle=False, sampler=test_sampler, pin_memory=False) + + # ---------------------- + # PROJECTION HEAD + # ---------------------- + projector = DINOHead(in_dim=args.feat_dim, out_dim=args.mlp_out_dim, nlayers=args.num_mlp_layers) + model = nn.Sequential(backbone, projector).cuda() + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) + + params_groups = get_params_groups(model) + optimizer = torch.optim.SGD( + params_groups, + lr=args.lr * (args.batch_size * dist.get_world_size() / 128), # linear scaling rule + momentum=args.momentum, + weight_decay=args.weight_decay + ) + + fp16_scaler = None + if args.fp16: + fp16_scaler = torch.cuda.amp.GradScaler() + + exp_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=args.epochs, + eta_min=args.lr * (args.batch_size * dist.get_world_size() / 128) * 1e-3, + ) + + cluster_criterion = DistillLoss( + args.warmup_teacher_temp_epochs, + args.epochs, + args.n_views, + args.warmup_teacher_temp, + args.teacher_temp, + ) + + # # inductive + # best_test_acc_lab = 0 + # # transductive + # best_train_acc_lab = 0 + # best_train_acc_ubl = 0 + # best_train_acc_all = 0 + + for epoch in range(args.epochs): + train_sampler.set_epoch(epoch) + train(model, train_loader, optimizer, fp16_scaler, exp_lr_scheduler, cluster_criterion, epoch, args) + + unlabelled_train_sampler.set_epoch(epoch) + # test_sampler.set_epoch(epoch) + if dist.get_rank() == 0: + args.logger.info('Testing on unlabelled examples in the training data...') + all_acc, old_acc, new_acc = test(model, unlabelled_train_loader, epoch=epoch, save_name='Train ACC Unlabelled', args=args) + # if dist.get_rank() == 0: + # args.logger.info('Testing on disjoint test set...') + # all_acc_test, old_acc_test, new_acc_test = test(model, test_loader, epoch=epoch, save_name='Test ACC', args=args) + + if dist.get_rank() == 0: + args.logger.info('Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc, new_acc)) + # args.logger.info('Test Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc_test, old_acc_test, new_acc_test)) + + save_dict = { + 'model': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'epoch': epoch + 1, + } + + torch.save(save_dict, args.model_path) + args.logger.info("model saved to {}.".format(args.model_path)) + + # if old_acc_test > best_test_acc_lab and dist.get_rank() == 0: + # args.logger.info(f'Best ACC on old Classes on disjoint test set: {old_acc_test:.4f}...') + # args.logger.info('Best Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc, new_acc)) + # + # torch.save(save_dict, args.model_path[:-3] + f'_best.pt') + # args.logger.info("model saved to {}.".format(args.model_path[:-3] + f'_best.pt')) + # + # # inductive + # best_test_acc_lab = old_acc_test + # # transductive + # best_train_acc_lab = old_acc + # best_train_acc_ubl = new_acc + # best_train_acc_all = all_acc + # + # if dist.get_rank() == 0: + # args.logger.info(f'Exp Name: {args.exp_name}') + # args.logger.info(f'Metrics with best model on test set: All: {best_train_acc_all:.4f} Old: {best_train_acc_lab:.4f} New: {best_train_acc_ubl:.4f}') + + +def train(student, train_loader, optimizer, scaler, scheduler, cluster_criterion, epoch, args): + loss_record = AverageMeter() + + student.train() + for batch_idx, batch in enumerate(train_loader): + images, class_labels, uq_idxs, mask_lab = batch + mask_lab = mask_lab[:, 0] + + class_labels, mask_lab = class_labels.cuda(non_blocking=True), mask_lab.cuda(non_blocking=True).bool() + images = torch.cat(images, dim=0).cuda(non_blocking=True) + + with torch.cuda.amp.autocast(scaler is not None): + student_proj, student_out = student(images) + teacher_out = student_out.detach() + + # clustering, sup + sup_logits = torch.cat([f[mask_lab] for f in (student_out / 0.1).chunk(2)], dim=0) + sup_labels = torch.cat([class_labels[mask_lab] for _ in range(2)], dim=0) + cls_loss = nn.CrossEntropyLoss()(sup_logits, sup_labels) + + # clustering, unsup + cluster_loss = cluster_criterion(student_out, teacher_out, epoch) + avg_probs = (student_out / 0.1).softmax(dim=1).mean(dim=0) + me_max_loss = - torch.sum(torch.log(avg_probs**(-avg_probs))) + math.log(float(len(avg_probs))) + cluster_loss += args.memax_weight * me_max_loss + + # represent learning, unsup + contrastive_logits, contrastive_labels = info_nce_logits(features=student_proj) + contrastive_loss = torch.nn.CrossEntropyLoss()(contrastive_logits, contrastive_labels) + + # representation learning, sup + student_proj = torch.cat([f[mask_lab].unsqueeze(1) for f in student_proj.chunk(2)], dim=1) + student_proj = torch.nn.functional.normalize(student_proj, dim=-1) + sup_con_labels = class_labels[mask_lab] + sup_con_loss = SupConLoss()(student_proj, labels=sup_con_labels) + + pstr = '' + pstr += f'cls_loss: {cls_loss.item():.4f} ' + pstr += f'cluster_loss: {cluster_loss.item():.4f} ' + pstr += f'sup_con_loss: {sup_con_loss.item():.4f} ' + pstr += f'contrastive_loss: {contrastive_loss.item():.4f} ' + + loss = 0 + loss += (1 - args.sup_weight) * cluster_loss + args.sup_weight * cls_loss + loss += (1 - args.sup_weight) * contrastive_loss + args.sup_weight * sup_con_loss + + # Train acc + loss_record.update(loss.item(), class_labels.size(0)) + optimizer.zero_grad() + if scaler is None: + loss.backward() + optimizer.step() + else: + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + + if batch_idx % args.print_freq == 0 and dist.get_rank() == 0: + args.logger.info('Epoch: [{}][{}/{}]\t loss {:.5f}\t {}' + .format(epoch, batch_idx, len(train_loader), loss.item(), pstr)) + # Step schedule + scheduler.step() + + if dist.get_rank() == 0: + args.logger.info('Train Epoch: {} Avg Loss: {:.4f} '.format(epoch, loss_record.avg)) + + +def test(model, test_loader, epoch, save_name, args): + + model.eval() + + preds, targets = [], [] + mask = np.array([]) + for batch_idx, (images, label, _) in enumerate(tqdm(test_loader)): + images = images.cuda(non_blocking=True) + with torch.no_grad(): + _, logits = model(images) + preds.append(logits.argmax(1).cpu().numpy()) + targets.append(label.cpu().numpy()) + mask = np.append(mask, np.array([True if x.item() in range(len(args.train_classes)) else False for x in label])) + + preds = np.concatenate(preds) + targets = np.concatenate(targets) + all_acc, old_acc, new_acc = log_accs_from_preds(y_true=targets, y_pred=preds, mask=mask, + T=epoch, eval_funcs=args.eval_funcs, save_name=save_name, + args=args) + + return all_acc, old_acc, new_acc + + +if __name__ == '__main__': + args = get_parser() + + torch.cuda.set_device(args.local_rank) + torch.distributed.init_process_group(backend='nccl', init_method='env://') + cudnn.benchmark = True + + if dist.get_rank() == 0: + init_experiment(args, runner_name=['simgcd']) + args.logger.info(f'Using evaluation function {args.eval_funcs[0]} to print results') + + main(args) \ No newline at end of file diff --git a/util/New_Kmeans.py b/util/New_Kmeans.py new file mode 100644 index 0000000..402b779 --- /dev/null +++ b/util/New_Kmeans.py @@ -0,0 +1,42 @@ +# import torch +# import torch.nn as nn +# import torch.nn.functional as F +# from torch.autograd import Variable +# import numpy as np +# import torch_clustering +# +# class Keans_Loss(nn.Module): +# def __init__(self, temperature = 0.5, n_cluster = 196): +# super(Keans_Loss,self).__init__() +# self.temperature = temperature +# self.n_cluster = n_cluster +# self.clustering_model = torch_clustering.PyTorchKMeans(init='k-means++', max_iter=300, tol=1e-4, +# n_clusters=self.num_cluster) +# self.psedo_labels = None +# +# def clustering(self, n_cluster, features): +# +# clustering_model = torch_clustering.PyTorchKMeans(init='k-means++', max_iter=300, tol=1e-4, +# n_clusters=self.num_cluster) +# psedo_labels = clustering_model.fit_predict(features) +# self.psedo_labels = psedo_labels +# cluster_centers = clustering_model.cluster_centers_ +# return psedo_labels, cluster_centers +# +# def compute_cluster_loss(self, +# +# +# +import torch + +x = torch.rand((6, 10)) # 假设有一个形状为 (6, 10) 的张量 +print(x) +print("==============") +proj,out = x.chunk(2) +print(proj) + +print("============") + +print(out) +print("==============") +print(out[0].argmax(1)) \ No newline at end of file diff --git a/util/cluster_and_log_utils.py b/util/cluster_and_log_utils.py new file mode 100644 index 0000000..25874e3 --- /dev/null +++ b/util/cluster_and_log_utils.py @@ -0,0 +1,184 @@ +import torch +import torch.distributed as dist +import numpy as np +from scipy.optimize import linear_sum_assignment as linear_assignment + + +def all_sum_item(item): + item = torch.tensor(item).cuda() + dist.all_reduce(item) + return item.item() + +def split_cluster_acc_v2(y_true, y_pred, mask): + """ + Calculate clustering accuracy. Require scikit-learn installed + First compute linear assignment on all data, then look at how good the accuracy is on subsets + + # Arguments + mask: Which instances come from old classes (True) and which ones come from new classes (False) + y: true labels, numpy.array with shape `(n_samples,)` + y_pred: predicted labels, numpy.array with shape `(n_samples,)` + + # Return + accuracy, in [0,1] + """ + y_true = y_true.astype(int) + + old_classes_gt = set(y_true[mask]) + new_classes_gt = set(y_true[~mask]) + + assert y_pred.size == y_true.size + D = max(y_pred.max(), y_true.max()) + 1 + w = np.zeros((D, D), dtype=int) + for i in range(y_pred.size): + w[y_pred[i], y_true[i]] += 1 + + ind = linear_assignment(w.max() - w) + ind = np.vstack(ind).T + + ind_map = {j: i for i, j in ind} + total_acc = sum([w[i, j] for i, j in ind]) + total_instances = y_pred.size + try: + if dist.get_world_size() > 0: + total_acc = all_sum_item(total_acc) + total_instances = all_sum_item(total_instances) + except: + pass + total_acc /= total_instances + + old_acc = 0 + total_old_instances = 0 + for i in old_classes_gt: + old_acc += w[ind_map[i], i] + total_old_instances += sum(w[:, i]) + + try: + if dist.get_world_size() > 0: + old_acc = all_sum_item(old_acc) + total_old_instances = all_sum_item(total_old_instances) + except: + pass + old_acc /= total_old_instances + + new_acc = 0 + total_new_instances = 0 + for i in new_classes_gt: + new_acc += w[ind_map[i], i] + total_new_instances += sum(w[:, i]) + + try: + if dist.get_world_size() > 0: + new_acc = all_sum_item(new_acc) + total_new_instances = all_sum_item(total_new_instances) + except: + pass + new_acc /= total_new_instances + + return total_acc, old_acc, new_acc + + +def split_cluster_acc_v2_balanced(y_true, y_pred, mask): + """ + Calculate clustering accuracy. Require scikit-learn installed + First compute linear assignment on all data, then look at how good the accuracy is on subsets + + # Arguments + mask: Which instances come from old classes (True) and which ones come from new classes (False) + y: true labels, numpy.array with shape `(n_samples,)` + y_pred: predicted labels, numpy.array with shape `(n_samples,)` + + # Return + accuracy, in [0,1] + """ + y_true = y_true.astype(int) + + old_classes_gt = set(y_true[mask]) + new_classes_gt = set(y_true[~mask]) + + assert y_pred.size == y_true.size + D = max(y_pred.max(), y_true.max()) + 1 + w = np.zeros((D, D), dtype=int) + for i in range(y_pred.size): + w[y_pred[i], y_true[i]] += 1 + + ind = linear_assignment(w.max() - w) + ind = np.vstack(ind).T + + ind_map = {j: i for i, j in ind} + + old_acc = np.zeros(len(old_classes_gt)) + total_old_instances = np.zeros(len(old_classes_gt)) + for idx, i in enumerate(old_classes_gt): + old_acc[idx] += w[ind_map[i], i] + total_old_instances[idx] += sum(w[:, i]) + + new_acc = np.zeros(len(new_classes_gt)) + total_new_instances = np.zeros(len(new_classes_gt)) + for idx, i in enumerate(new_classes_gt): + new_acc[idx] += w[ind_map[i], i] + total_new_instances[idx] += sum(w[:, i]) + + try: + if dist.get_world_size() > 0: + old_acc, new_acc = torch.from_numpy(old_acc).cuda(), torch.from_numpy(new_acc).cuda() + dist.all_reduce(old_acc), dist.all_reduce(new_acc) + dist.all_reduce(total_old_instances), dist.all_reduce(total_new_instances) + old_acc, new_acc = old_acc.cpu().numpy(), new_acc.cpu().numpy() + total_old_instances, total_new_instances = total_old_instances.cpu().numpy(), total_new_instances.cpu().numpy() + except: + pass + + total_acc = np.concatenate([old_acc, new_acc]) / np.concatenate([total_old_instances, total_new_instances]) + old_acc /= total_old_instances + new_acc /= total_new_instances + total_acc, old_acc, new_acc = total_acc.mean(), old_acc.mean(), new_acc.mean() + return total_acc, old_acc, new_acc + + +EVAL_FUNCS = { + 'v2': split_cluster_acc_v2, + 'v2b': split_cluster_acc_v2_balanced +} + +def log_accs_from_preds(y_true, y_pred, mask, eval_funcs, save_name, T=None, + print_output=True, args=None): + + """ + Given a list of evaluation functions to use (e.g ['v1', 'v2']) evaluate and log ACC results + + :param y_true: GT labels + :param y_pred: Predicted indices + :param mask: Which instances belong to Old and New classes + :param T: Epoch + :param eval_funcs: Which evaluation functions to use + :param save_name: What are we evaluating ACC on + :param writer: Tensorboard logger + :return: + """ + + mask = mask.astype(bool) + y_true = y_true.astype(int) + y_pred = y_pred.astype(int) + + for i, f_name in enumerate(eval_funcs): + + acc_f = EVAL_FUNCS[f_name] + all_acc, old_acc, new_acc = acc_f(y_true, y_pred, mask) + log_name = f'{save_name}_{f_name}' + + if i == 0: + to_return = (all_acc, old_acc, new_acc) + + if print_output: + print_str = f'Epoch {T}, {log_name}: All {all_acc:.4f} | Old {old_acc:.4f} | New {new_acc:.4f}' + try: + if dist.get_rank() == 0: + try: + args.logger.info(print_str) + except: + print(print_str) + except: + pass + + return to_return \ No newline at end of file diff --git a/util/general_utils.py b/util/general_utils.py new file mode 100644 index 0000000..26ccced --- /dev/null +++ b/util/general_utils.py @@ -0,0 +1,384 @@ +import os +import torch +import inspect + +from datetime import datetime +from torch.utils.data.sampler import Sampler +from loguru import logger + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def init_experiment(args, runner_name=None, exp_id=None): + # Get filepath of calling script + if runner_name is None: + runner_name = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))).split(".")[-2:] + + root_dir = os.path.join(args.exp_root, *runner_name) + + if not os.path.exists(root_dir): + os.makedirs(root_dir) + + # Either generate a unique experiment ID, or use one which is passed + if exp_id is None: + + if args.exp_name is None: + raise ValueError("Need to specify the experiment name") + # Unique identifier for experiment + now = '{}_({:02d}.{:02d}.{}_|_'.format(args.exp_name, datetime.now().day, datetime.now().month, datetime.now().year) + \ + datetime.now().strftime("%S.%f")[:-3] + ')' + + log_dir = os.path.join(root_dir, 'log', now) + while os.path.exists(log_dir): + now = '({:02d}.{:02d}.{}_|_'.format(datetime.now().day, datetime.now().month, datetime.now().year) + \ + datetime.now().strftime("%S.%f")[:-3] + ')' + + log_dir = os.path.join(root_dir, 'log', now) + + else: + + log_dir = os.path.join(root_dir, 'log', f'{exp_id}') + + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + + logger.add(os.path.join(log_dir, 'log.txt')) + args.logger = logger + args.log_dir = log_dir + + # Instantiate directory to save models to + model_root_dir = os.path.join(args.log_dir, 'checkpoints') + if not os.path.exists(model_root_dir): + os.mkdir(model_root_dir) + + args.model_dir = model_root_dir + args.model_path = os.path.join(args.model_dir, 'model.pt') + + print(f'Experiment saved to: {args.log_dir}') + + hparam_dict = {} + + for k, v in vars(args).items(): + if isinstance(v, (int, float, str, bool, torch.Tensor)): + hparam_dict[k] = v + + print(runner_name) + print(args) + + return args + + +class DistributedWeightedSampler(torch.utils.data.distributed.DistributedSampler): + + def __init__(self, dataset, weights, num_samples, num_replicas=None, rank=None, + replacement=True, generator=None): + super(DistributedWeightedSampler, self).__init__(dataset, num_replicas, rank) + if not isinstance(num_samples, int) or isinstance(num_samples, bool) or \ + num_samples <= 0: + raise ValueError("num_samples should be a positive integer " + "value, but got num_samples={}".format(num_samples)) + if not isinstance(replacement, bool): + raise ValueError("replacement should be a boolean value, but got " + "replacement={}".format(replacement)) + self.weights = torch.as_tensor(weights, dtype=torch.double) + self.num_samples = num_samples + self.replacement = replacement + self.generator = generator + self.weights = self.weights[self.rank::self.num_replicas] + self.num_samples = self.num_samples // self.num_replicas + + def __iter__(self): + rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator) + rand_tensor = self.rank + rand_tensor * self.num_replicas + yield from iter(rand_tensor.tolist()) + + def __len__(self): + return self.num_samples + + +import numpy as np +from tqdm import tqdm +import torch.nn.functional as F +class NNBatchSampler(Sampler): + """ + BatchSampler that ensures a fixed amount of images per class are sampled in the minibatch + """ + def __init__(self, data_source, model, seen_dataloader, batch_size, nn_per_image = 5, using_feat = True, is_norm = False): + self.batch_size = batch_size + self.nn_per_image = nn_per_image + self.using_feat = using_feat + self.is_norm = is_norm + self.num_samples = data_source.__len__() + self.nn_matrix, self.dist_matrix = self._build_nn_matrix(model, seen_dataloader) + + def __iter__(self): + for _ in range(len(self)): + yield self.sample_batch() + + def _predict_batchwise(self, model, seen_dataloader): + device = "cuda" + model_is_training = model.training + model.eval() + + ds = seen_dataloader.dataset + A = [[] for i in range(len(ds[0]))] + with torch.no_grad(): + # extract batches (A becomes list of samples) + for batch in tqdm(seen_dataloader): + for i, J in enumerate(batch): + # i = 0: sz_batch * images + # i = 1: sz_batch * labels + # i = 2: sz_batch * indices + # i = 3: sz_batch * mask_lab + if i == 0: + J = J[0] + # move images to device of model (approximate device) + # if self.using_feat: + # J, _ = model(J.cuda()) + # else: + J = model(J.cuda()) + + if self.is_norm: + J = F.normalize(J, p=2, dim=1) + + for j in J: + A[i].append(j) + + model.train() + model.train(model_is_training) # revert to previous training state + + return [torch.stack(A[i]) for i in range(len(A))] + + def _build_nn_matrix(self, model, seen_dataloader): + # calculate embeddings with model and get targets + X, T, _, _ = self._predict_batchwise(model, seen_dataloader) + + # get predictions by assigning nearest 8 neighbors with cosine + K = self.nn_per_image * 1 + nn_matrix = [] + dist_matrix = [] + xs = [] + + for x in X: + if len(xs)<5000: + xs.append(x) + else: + xs.append(x) + xs = torch.stack(xs,dim=0) + + dist_emb = xs.pow(2).sum(1) + (-2) * X.mm(xs.t()) + dist_emb = X.pow(2).sum(1) + dist_emb.t() + + ind = dist_emb.topk(K, largest = False)[1].long().cpu() + dist = dist_emb.topk(K, largest = False)[0] + nn_matrix.append(ind) + dist_matrix.append(dist.cpu()) + xs = [] + del ind + + # Last Loop + xs = torch.stack(xs,dim=0) + dist_emb = xs.pow(2).sum(1) + (-2) * X.mm(xs.t()) + dist_emb = X.pow(2).sum(1) + dist_emb.t() + ind = dist_emb.topk(K, largest = False)[1] + dist = dist_emb.topk(K, largest = False)[0] + nn_matrix.append(ind.long().cpu()) + dist_matrix.append(dist.cpu()) + nn_matrix = torch.cat(nn_matrix, dim=0) + dist_matrix = torch.cat(dist_matrix, dim=0) + + return nn_matrix, dist_matrix + + + def sample_batch(self): + num_image = self.batch_size // self.nn_per_image + sampled_queries = np.random.choice(self.num_samples, num_image, replace=False) + sampled_indices = self.nn_matrix[sampled_queries].view(-1).tolist() + + return sampled_indices + + def __len__(self): + return self.num_samples // self.batch_size + +import torch.nn as nn +class RC_STML(nn.Module): + def __init__(self, sigma, delta, view, disable_mu, topk): + super(RC_STML, self).__init__() + self.sigma = sigma + self.delta = delta + self.view = view + self.disable_mu = disable_mu + self.topk = topk + + def k_reciprocal_neigh(self, initial_rank, i, topk): + forward_k_neigh_index = initial_rank[i,:topk] + backward_k_neigh_index = initial_rank[forward_k_neigh_index,:topk] + fi = np.where(backward_k_neigh_index==i)[0] + return forward_k_neigh_index[fi] + + def forward(self, s_emb, t_emb, idx, v2=False): + if v2: + return self.forward_v2(t_emb, s_emb) + if self.disable_mu: + s_emb = F.normalize(s_emb) + t_emb = F.normalize(t_emb) + + N = len(s_emb) + S_dist = torch.cdist(s_emb, s_emb) + S_dist = S_dist / S_dist.mean(1, keepdim=True) + + with torch.no_grad(): + T_dist = torch.cdist(t_emb, t_emb) + W_P = torch.exp(-T_dist.pow(2) / self.sigma) + + batch_size = len(s_emb) // self.view + W_P_copy = W_P.clone() + W_P_copy[idx.unsqueeze(1) == idx.unsqueeze(1).t()] = 1 + + topk_index = torch.topk(W_P_copy, self.topk)[1] + topk_half_index = topk_index[:, :int(np.around(self.topk/2))] + + W_NN = torch.zeros_like(W_P).scatter_(1, topk_index, torch.ones_like(W_P)) + V = ((W_NN + W_NN.t())/2 == 1).float() + + W_C_tilda = torch.zeros_like(W_P) + for i in range(N): + indNonzero = torch.where(V[i, :]!=0)[0] + W_C_tilda[i, indNonzero] = (V[:,indNonzero].sum(1) / len(indNonzero))[indNonzero] + + W_C_hat = W_C_tilda[topk_half_index].mean(1) + W_C = (W_C_hat + W_C_hat.t())/2 + W = (W_P + W_C)/2 + + identity_matrix = torch.eye(N).cuda(non_blocking=True) + pos_weight = (W) * (1 - identity_matrix) + neg_weight = (1 - W) * (1 - identity_matrix) + + pull_losses = torch.relu(S_dist).pow(2) * pos_weight + push_losses = torch.relu(self.delta - S_dist).pow(2) * neg_weight + + loss = (pull_losses.sum() + push_losses.sum()) / (len(s_emb) * (len(s_emb)-1)) + + return loss + + + def forward_v2(self, probs, feats): + with torch.no_grad(): + pseudo_labels = probs.argmax(1).cuda() + one_hot = torch.zeros(probs.shape).cuda().scatter(1, pseudo_labels.unsqueeze(1), 1.0) + W_P = torch.mm(one_hot, one_hot.t()) + feats_dist = torch.cdist(feats, feats) + topk_index = torch.topk(feats_dist, self.topk)[1] + W_NN = torch.zeros_like(feats_dist).scatter_(1, topk_index, W_P) + + W = ((W_NN + W_NN.t())/2 == 0.5).float() + + N = len(probs) + identity_matrix = torch.eye(N).cuda(non_blocking=True) + pos_weight = (W) * (1 - identity_matrix) + neg_weight = (1 - W) * (1 - identity_matrix) + + pull_losses = torch.relu(feats_dist).pow(2) * pos_weight + push_losses = torch.relu(self.delta - feats_dist).pow(2) * neg_weight + + loss = (pull_losses.sum() + push_losses.sum()) / (len(probs) * (len(probs)-1)) + + return loss + +class KL_STML(nn.Module): + def __init__(self, disable_mu, temp=1): + super(KL_STML, self).__init__() + self.disable_mu = disable_mu + self.temp = temp + + def kl_div(self, A, B, T = 1): + log_q = F.log_softmax(A/T, dim=-1) + p = F.softmax(B/T, dim=-1) + kl_d = F.kl_div(log_q, p, reduction='sum') * T**2 / A.size(0) + return kl_d + + def forward(self, s_f, s_g): + if self.disable_mu: + s_f, s_g = F.normalize(s_f), F.normalize(s_g) + + N = len(s_f) + S_dist = torch.cdist(s_f, s_f) + S_dist = S_dist / S_dist.mean(1, keepdim=True) + + S_bg_dist = torch.cdist(s_g, s_g) + S_bg_dist = S_bg_dist / S_bg_dist.mean(1, keepdim=True) + + loss = self.kl_div(-S_dist, -S_bg_dist.detach(), T=1) + + return loss + +class STML_loss(nn.Module): + def __init__(self, sigma, delta, view, disable_mu, topk): + super(STML_loss, self).__init__() + self.sigma = sigma + self.delta = delta + self.view = view + self.disable_mu = disable_mu + self.topk = topk + self.RC_criterion = RC_STML(sigma, delta, view, disable_mu, topk) + self.KL_criterion = KL_STML(disable_mu, temp=1) + + def forward(self, s_f, s_g, t_g, idx): + # Relaxed contrastive loss for STML + loss_RC_f = self.RC_criterion(s_f, t_g, idx) + loss_RC_g = self.RC_criterion(s_g, t_g, idx) + loss_RC = (loss_RC_f + loss_RC_g)/2 + + # Self-Distillation for STML + loss_KL = self.KL_criterion(s_f, s_g) + + loss = loss_RC + loss_KL + + total_loss = dict(RC=loss_RC, KL=loss_KL, loss=loss) + + return total_loss + +class STML_loss_simgcd(nn.Module): + def __init__(self, disable_mu=0, topk=4, sigma=1, delta=1, view=2, v2=True): + super().__init__() + self.sigma = sigma + self.delta = delta + self.view = view + self.disable_mu = disable_mu + self.topk = topk + self.v2 = v2 + self.RC_criterion = RC_STML(sigma, delta, view, disable_mu, topk) + self.KL_criterion = KL_STML(disable_mu, temp=1) + + def forward(self, s_g, t_g, idx): + # Relaxed contrastive loss for STML + loss_RC = self.RC_criterion(s_g, t_g, idx, v2=self.v2) + + if not self.v2: + # Self-Distillation for STML + loss_KL = self.KL_criterion(s_g, t_g) + else: + loss_KL = 0.0 + loss = loss_RC + loss_KL + + # total_loss = dict(RC=loss_RC, KL=loss_KL, loss=loss) + + return loss