-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 9d462aa
Showing
81 changed files
with
5,718 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
*.npy | ||
__pycache__/ | ||
checkpoints/ | ||
data/ | ||
logs/ | ||
wandb/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
Copyright (c) 2023, NVIDIA Corporation. All rights reserved. | ||
|
||
Nvidia Source Code License-NC | ||
|
||
1. Definitions | ||
|
||
“Licensor” means any person or entity that distributes its Work. | ||
|
||
“Work” means (a) the original work of authorship made available under this license, which may include software, documentation, | ||
or other files, and (b) any additions to or derivative works thereof that are made available under this license. | ||
|
||
The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. | ||
copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that | ||
remain separable from, or merely link (or bind by name) to the interfaces of, the Work. | ||
|
||
Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing | ||
the applicability of this license to the Work, or (b) a copy of this license. | ||
|
||
2. License Grant | ||
|
||
2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, | ||
worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly | ||
display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. | ||
|
||
3. Limitations | ||
|
||
3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a | ||
complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, | ||
trademark, or attribution notices that are present in the Work. | ||
|
||
3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution | ||
of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 | ||
applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. | ||
Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply | ||
to the Work itself. | ||
|
||
3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. | ||
Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially. | ||
As used herein, “non-commercially” means for research or evaluation purposes only. | ||
|
||
3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim | ||
or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under | ||
this license from such Licensor (including the grant in Section 2.1) will terminate immediately. | ||
|
||
3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, | ||
except as necessary to reproduce the notices described in this license. | ||
|
||
3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) | ||
will terminate immediately. | ||
|
||
4. Disclaimer of Warranty. | ||
|
||
THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES | ||
OR CONDITIONS OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING | ||
ANY ACTIVITIES UNDER THIS LICENSE. | ||
|
||
5. Limitation of Liability. | ||
|
||
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, | ||
OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL | ||
DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, | ||
BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR | ||
HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Heterogeneous Continual Learning | ||
|
||
Official PyTorch implementation of [**Heterogeneous Continual Learning**](https://arxiv.org/abs/2306.08593). | ||
|
||
**Authors**: [Divyam Madaan](https://dmadaan.com/), [Hongxu Yin](https://hongxu-yin.github.i), [Wonmin Byeon](https://wonmin-byeon.github.i), [Pavlo Molchanov](https://research.nvidia.com/person/pavlo-molchano), | ||
|
||
For business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/) | ||
|
||
--- | ||
**TL;DR: First continual learning approach in which the architecture continuously evolves with the data.** | ||
-- | ||
## Abstract | ||
|
||
![concept figure](https://github.com/divyam3897/cvpr_hcl/files/13549399/concept_figure.pdf) | ||
|
||
We propose a novel framework and a solution to tackle | ||
the continual learning (CL) problem with changing network | ||
architectures. Most CL methods focus on adapting a single | ||
architecture to a new task/class by modifying its weights. | ||
However, with rapid progress in architecture design, the | ||
problem of adapting existing solutions to novel architectures | ||
becomes relevant. To address this limitation, we propose | ||
Heterogeneous Continual Learning (HCL), where a wide | ||
range of evolving network architectures emerge continually | ||
together with novel data/tasks. As a solution, we build on | ||
top of the distillation family of techniques and modify it | ||
to a new setting where a weaker model takes the role of a | ||
teacher; meanwhile, a new stronger architecture acts as a | ||
student. Furthermore, we consider a setup of limited access | ||
to previous data and propose Quick Deep Inversion (QDI) to | ||
recover prior task visual features to support knowledge trans- | ||
fer. QDI significantly reduces computational costs compared | ||
to previous solutions and improves overall performance. In | ||
summary, we propose a new setup for CL with a modified | ||
knowledge distillation paradigm and design a quick data | ||
inversion method to enhance distillation. Our evaluation | ||
of various benchmarks shows a significant improvement on | ||
accuracy in comparison to state-of-the-art methods over | ||
various networks architectures. | ||
|
||
__Contribution of this work__ | ||
|
||
- We propose a novel CL framework called Heteroge- | ||
neous Continual Learning (HCL) to learn a stream of | ||
different architectures on a sequence of tasks while | ||
transferring the knowledge from past representations. | ||
- We revisit knowledge distillation and propose Quick | ||
Deep Inversion (QDI), which inverts the previous task | ||
parameters while interpolating the current task exam- | ||
ples with minimal additional cost. | ||
- We benchmark existing state-of-the-art solutions in the | ||
new setting and outperform them with our proposed | ||
method across a diverse stream of architectures for both | ||
task-incremental and class-incremental CL. | ||
|
||
## Prerequisites | ||
|
||
``` | ||
$ pip install -r requirements.txt | ||
``` | ||
|
||
## Quick start | ||
|
||
### Training | ||
|
||
```python | ||
python main.py --data_dir ../data/ --log_dir ./logs/scl/ -c configs/cifar10/distil.yaml --ckpt_dir ./checkpoints/c10/scl/distil/ --hide_progress --cl_default --validation --hcl | ||
|
||
``` | ||
|
||
### Evaluation | ||
|
||
```python | ||
python linear_eval_alltasks.py --data_dir ../data/ --log_dir ./logs/scl/ -c configs/cifar10/distil.yaml --ckpt_dir ./checkpoints/c10/scl/distil/ --hide_progress --cl_default --hcl | ||
|
||
``` | ||
|
||
|
||
To change the dataset and method, use the configuration files from `./configs`. | ||
|
||
# Contributing | ||
|
||
We'd love to accept your contributions to this project. Please feel free to open an issue, or submit a pull request as necessary. If you have implementations of this repository in other ML frameworks, please reach out so we may highlight them here. | ||
|
||
## Licenses | ||
|
||
Copyright © 2023, NVIDIA Corporation. All rights reserved. | ||
|
||
This work is made available under the NVIDIA Source Code License-NC. Click [here](LICENSE) to view a copy of this license. | ||
|
||
|
||
## Acknowledgment | ||
|
||
The code is build upon [aimagelab/mammoth](https://github.com/aimagelab/mammoth), [divyam3897/UCL](https://github.com/divyam3897/UCL), [kuangliu/pytorch-cifar](https://github.com/kuangliu/pytorch-cifar/tree/master), [sutd-visual-computing-group/LS-KD-compatibility](https://github.com/sutd-visual-computing-group/LS-KD-compatibility), and [berniwal/swin-transformer-pytorch](https://github.com/berniwal/swin-transformer-pytorch). | ||
|
||
## Citation | ||
|
||
If you found the provided code useful, please cite our work. | ||
|
||
```bibtex | ||
@inproceedings{madaan2023heterogeneous, | ||
title={Heterogeneous Continual Learning}, | ||
author={Madaan, Divyam and Yin, Hongxu and Byeon, Wonmin and Kautz, Jan and Molchanov, Pavlo}, | ||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, | ||
year={2023} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import argparse | ||
import os | ||
import torch | ||
|
||
import numpy as np | ||
import torch | ||
import random | ||
|
||
import re | ||
import yaml | ||
|
||
import shutil | ||
import warnings | ||
|
||
from datetime import datetime | ||
|
||
|
||
class Namespace(object): | ||
def __init__(self, somedict): | ||
for key, value in somedict.items(): | ||
assert isinstance(key, str) and re.match("[A-Za-z_-]", key) | ||
if isinstance(value, dict): | ||
self.__dict__[key] = Namespace(value) | ||
else: | ||
self.__dict__[key] = value | ||
|
||
def __getattr__(self, attribute): | ||
|
||
raise AttributeError(f"Can not find {attribute} in namespace. Please write {attribute} in your config file(xxx.yaml)!") | ||
|
||
|
||
def set_deterministic(seed): | ||
# seed by default is None | ||
if seed is not None: | ||
print(f"Deterministic with seed = {seed}") | ||
random.seed(seed) | ||
np.random.seed(seed) | ||
torch.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
torch.backends.cudnn.deterministic = True | ||
torch.backends.cudnn.benchmark = False | ||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('-c', '--config-file', required=True, type=str, help="xxx.yaml") | ||
parser.add_argument('--debug', action='store_true') | ||
parser.add_argument('--debug_subset_size', type=int, default=8) | ||
parser.add_argument('--download', action='store_true', help="if can't find dataset, download from web") | ||
parser.add_argument('--data_dir', type=str, default=os.getenv('DATA')) | ||
parser.add_argument('--log_dir', type=str, default=os.getenv('LOG')) | ||
parser.add_argument('--ckpt_dir', type=str, default=os.getenv('CHECKPOINT')) | ||
parser.add_argument('--ckpt_dir_1', type=str, default=os.getenv('CHECKPOINT')) | ||
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') | ||
parser.add_argument('--eval_from', type=str, default=None) | ||
parser.add_argument('--hide_progress', action='store_true') | ||
parser.add_argument('--cl_default', action='store_true') | ||
parser.add_argument('--server', action='store_true') | ||
parser.add_argument('--hcl', action='store_true') | ||
parser.add_argument('--buffer_qdi', action='store_true') | ||
parser.add_argument('--validation', action='store_true', | ||
help='Test on the validation set') | ||
parser.add_argument('--ood_eval', action='store_true', | ||
help='Test on the OOD set') | ||
parser.add_argument('--alpha', type=float, default=0.3) | ||
args = parser.parse_args() | ||
|
||
|
||
with open(args.config_file, 'r') as f: | ||
for key, value in Namespace(yaml.load(f, Loader=yaml.FullLoader)).__dict__.items(): | ||
vars(args)[key] = value | ||
|
||
if args.debug: | ||
if args.train: | ||
args.train.batch_size = 2 | ||
args.train.num_epochs = 1 | ||
args.train.stop_at_epoch = 1 | ||
if args.eval: | ||
args.eval.batch_size = 2 | ||
args.eval.num_epochs = 1 # train only one epoch | ||
args.dataset.num_workers = 0 | ||
|
||
|
||
assert not None in [args.log_dir, args.data_dir, args.ckpt_dir, args.name] | ||
|
||
args.log_dir = os.path.join(args.log_dir, 'in-progress_'+datetime.now().strftime('%m%d%H%M%S_')+args.name) | ||
|
||
os.makedirs(args.log_dir, exist_ok=False) | ||
print(f'creating file {args.log_dir}') | ||
os.makedirs(args.ckpt_dir, exist_ok=True) | ||
|
||
shutil.copy2(args.config_file, args.log_dir) | ||
set_deterministic(args.seed) | ||
|
||
|
||
vars(args)['aug_kwargs'] = { | ||
'name':args.model.name, | ||
'image_size': args.dataset.image_size, | ||
'cl_default': args.cl_default | ||
} | ||
vars(args)['dataset_kwargs'] = { | ||
# 'name':args.model.name, | ||
# 'image_size': args.dataset.image_size, | ||
'dataset':args.dataset.name, | ||
'data_dir': args.data_dir, | ||
'download':args.download, | ||
'debug_subset_size': args.debug_subset_size if args.debug else None, | ||
# 'drop_last': True, | ||
# 'pin_memory': True, | ||
# 'num_workers': args.dataset.num_workers, | ||
} | ||
vars(args)['dataloader_kwargs'] = { | ||
'drop_last': True, | ||
'pin_memory': True, | ||
'num_workers': args.dataset.num_workers, | ||
} | ||
|
||
return args |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from .simsiam_aug import SimSiamTransform | ||
from .eval_aug import Transform_single | ||
|
||
|
||
def get_aug(name='simsiam', image_size=224, train=True, train_classifier=None, mean_std=None, **aug_kwargs): | ||
if train==True: | ||
augmentation = SimSiamTransform(image_size, mean_std=mean_std, **aug_kwargs) | ||
elif train==False: | ||
if train_classifier is None: | ||
raise Exception | ||
augmentation = Transform_single(image_size, train=train_classifier, mean_std=mean_std) | ||
else: | ||
raise Exception | ||
|
||
return augmentation | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from torchvision import transforms | ||
from PIL import Image | ||
|
||
|
||
class Transform_single(): | ||
def __init__(self, image_size, train, mean_std): | ||
if train == True: | ||
self.transform = transforms.Compose([ | ||
transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0), interpolation=Image.BICUBIC), | ||
# transforms.RandomCrop(image_size, padding=4), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize(*mean_std) | ||
]) | ||
else: | ||
self.transform = transforms.Compose([ | ||
# transforms.Resize(int(image_size*(8/7)), interpolation=Image.BICUBIC), # 224 -> 256 | ||
# transforms.CenterCrop(image_size), | ||
transforms.ToTensor(), | ||
transforms.Normalize(*mean_std) | ||
]) | ||
|
||
def __call__(self, x): | ||
return self.transform(x) |
Oops, something went wrong.