Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
divyam3897 committed Dec 4, 2023
0 parents commit 9d462aa
Show file tree
Hide file tree
Showing 81 changed files with 5,718 additions and 0 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
*.npy
__pycache__/
checkpoints/
data/
logs/
wandb/
63 changes: 63 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
Copyright (c) 2023, NVIDIA Corporation. All rights reserved.

Nvidia Source Code License-NC

1. Definitions

“Licensor” means any person or entity that distributes its Work.

“Work” means (a) the original work of authorship made available under this license, which may include software, documentation,
or other files, and (b) any additions to or derivative works thereof that are made available under this license.

The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S.
copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that
remain separable from, or merely link (or bind by name) to the interfaces of, the Work.

Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing
the applicability of this license to the Work, or (b) a copy of this license.

2. License Grant

2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual,
worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly
display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.

3. Limitations

3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a
complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent,
trademark, or attribution notices that are present in the Work.

3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution
of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3
applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms.
Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply
to the Work itself.

3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially.
Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially.
As used herein, “non-commercially” means for research or evaluation purposes only.

3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim
or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under
this license from such Licensor (including the grant in Section 2.1) will terminate immediately.

3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks,
except as necessary to reproduce the notices described in this license.

3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1)
will terminate immediately.

4. Disclaimer of Warranty.

THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES
OR CONDITIONS OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING
ANY ACTIVITIES UNDER THIS LICENSE.

5. Limitation of Liability.

EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT,
OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL
DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL,
BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR
HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
107 changes: 107 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Heterogeneous Continual Learning

Official PyTorch implementation of [**Heterogeneous Continual Learning**](https://arxiv.org/abs/2306.08593).

**Authors**: [Divyam Madaan](https://dmadaan.com/), [Hongxu Yin](https://hongxu-yin.github.i), [Wonmin Byeon](https://wonmin-byeon.github.i), [Pavlo Molchanov](https://research.nvidia.com/person/pavlo-molchano),

For business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/)

---
**TL;DR: First continual learning approach in which the architecture continuously evolves with the data.**
--
## Abstract

![concept figure](https://github.com/divyam3897/cvpr_hcl/files/13549399/concept_figure.pdf)

We propose a novel framework and a solution to tackle
the continual learning (CL) problem with changing network
architectures. Most CL methods focus on adapting a single
architecture to a new task/class by modifying its weights.
However, with rapid progress in architecture design, the
problem of adapting existing solutions to novel architectures
becomes relevant. To address this limitation, we propose
Heterogeneous Continual Learning (HCL), where a wide
range of evolving network architectures emerge continually
together with novel data/tasks. As a solution, we build on
top of the distillation family of techniques and modify it
to a new setting where a weaker model takes the role of a
teacher; meanwhile, a new stronger architecture acts as a
student. Furthermore, we consider a setup of limited access
to previous data and propose Quick Deep Inversion (QDI) to
recover prior task visual features to support knowledge trans-
fer. QDI significantly reduces computational costs compared
to previous solutions and improves overall performance. In
summary, we propose a new setup for CL with a modified
knowledge distillation paradigm and design a quick data
inversion method to enhance distillation. Our evaluation
of various benchmarks shows a significant improvement on
accuracy in comparison to state-of-the-art methods over
various networks architectures.

__Contribution of this work__

- We propose a novel CL framework called Heteroge-
neous Continual Learning (HCL) to learn a stream of
different architectures on a sequence of tasks while
transferring the knowledge from past representations.
- We revisit knowledge distillation and propose Quick
Deep Inversion (QDI), which inverts the previous task
parameters while interpolating the current task exam-
ples with minimal additional cost.
- We benchmark existing state-of-the-art solutions in the
new setting and outperform them with our proposed
method across a diverse stream of architectures for both
task-incremental and class-incremental CL.

## Prerequisites

```
$ pip install -r requirements.txt
```

## Quick start

### Training

```python
python main.py --data_dir ../data/ --log_dir ./logs/scl/ -c configs/cifar10/distil.yaml --ckpt_dir ./checkpoints/c10/scl/distil/ --hide_progress --cl_default --validation --hcl

```

### Evaluation

```python
python linear_eval_alltasks.py --data_dir ../data/ --log_dir ./logs/scl/ -c configs/cifar10/distil.yaml --ckpt_dir ./checkpoints/c10/scl/distil/ --hide_progress --cl_default --hcl

```


To change the dataset and method, use the configuration files from `./configs`.

# Contributing

We'd love to accept your contributions to this project. Please feel free to open an issue, or submit a pull request as necessary. If you have implementations of this repository in other ML frameworks, please reach out so we may highlight them here.

## Licenses

Copyright © 2023, NVIDIA Corporation. All rights reserved.

This work is made available under the NVIDIA Source Code License-NC. Click [here](LICENSE) to view a copy of this license.


## Acknowledgment

The code is build upon [aimagelab/mammoth](https://github.com/aimagelab/mammoth), [divyam3897/UCL](https://github.com/divyam3897/UCL), [kuangliu/pytorch-cifar](https://github.com/kuangliu/pytorch-cifar/tree/master), [sutd-visual-computing-group/LS-KD-compatibility](https://github.com/sutd-visual-computing-group/LS-KD-compatibility), and [berniwal/swin-transformer-pytorch](https://github.com/berniwal/swin-transformer-pytorch).

## Citation

If you found the provided code useful, please cite our work.

```bibtex
@inproceedings{madaan2023heterogeneous,
title={Heterogeneous Continual Learning},
author={Madaan, Divyam and Yin, Hongxu and Byeon, Wonmin and Kautz, Jan and Molchanov, Pavlo},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2023}
```
117 changes: 117 additions & 0 deletions arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import argparse
import os
import torch

import numpy as np
import torch
import random

import re
import yaml

import shutil
import warnings

from datetime import datetime


class Namespace(object):
def __init__(self, somedict):
for key, value in somedict.items():
assert isinstance(key, str) and re.match("[A-Za-z_-]", key)
if isinstance(value, dict):
self.__dict__[key] = Namespace(value)
else:
self.__dict__[key] = value

def __getattr__(self, attribute):

raise AttributeError(f"Can not find {attribute} in namespace. Please write {attribute} in your config file(xxx.yaml)!")


def set_deterministic(seed):
# seed by default is None
if seed is not None:
print(f"Deterministic with seed = {seed}")
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config-file', required=True, type=str, help="xxx.yaml")
parser.add_argument('--debug', action='store_true')
parser.add_argument('--debug_subset_size', type=int, default=8)
parser.add_argument('--download', action='store_true', help="if can't find dataset, download from web")
parser.add_argument('--data_dir', type=str, default=os.getenv('DATA'))
parser.add_argument('--log_dir', type=str, default=os.getenv('LOG'))
parser.add_argument('--ckpt_dir', type=str, default=os.getenv('CHECKPOINT'))
parser.add_argument('--ckpt_dir_1', type=str, default=os.getenv('CHECKPOINT'))
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--eval_from', type=str, default=None)
parser.add_argument('--hide_progress', action='store_true')
parser.add_argument('--cl_default', action='store_true')
parser.add_argument('--server', action='store_true')
parser.add_argument('--hcl', action='store_true')
parser.add_argument('--buffer_qdi', action='store_true')
parser.add_argument('--validation', action='store_true',
help='Test on the validation set')
parser.add_argument('--ood_eval', action='store_true',
help='Test on the OOD set')
parser.add_argument('--alpha', type=float, default=0.3)
args = parser.parse_args()


with open(args.config_file, 'r') as f:
for key, value in Namespace(yaml.load(f, Loader=yaml.FullLoader)).__dict__.items():
vars(args)[key] = value

if args.debug:
if args.train:
args.train.batch_size = 2
args.train.num_epochs = 1
args.train.stop_at_epoch = 1
if args.eval:
args.eval.batch_size = 2
args.eval.num_epochs = 1 # train only one epoch
args.dataset.num_workers = 0


assert not None in [args.log_dir, args.data_dir, args.ckpt_dir, args.name]

args.log_dir = os.path.join(args.log_dir, 'in-progress_'+datetime.now().strftime('%m%d%H%M%S_')+args.name)

os.makedirs(args.log_dir, exist_ok=False)
print(f'creating file {args.log_dir}')
os.makedirs(args.ckpt_dir, exist_ok=True)

shutil.copy2(args.config_file, args.log_dir)
set_deterministic(args.seed)


vars(args)['aug_kwargs'] = {
'name':args.model.name,
'image_size': args.dataset.image_size,
'cl_default': args.cl_default
}
vars(args)['dataset_kwargs'] = {
# 'name':args.model.name,
# 'image_size': args.dataset.image_size,
'dataset':args.dataset.name,
'data_dir': args.data_dir,
'download':args.download,
'debug_subset_size': args.debug_subset_size if args.debug else None,
# 'drop_last': True,
# 'pin_memory': True,
# 'num_workers': args.dataset.num_workers,
}
vars(args)['dataloader_kwargs'] = {
'drop_last': True,
'pin_memory': True,
'num_workers': args.dataset.num_workers,
}

return args
Binary file added assets/concept_figure.pdf
Binary file not shown.
23 changes: 23 additions & 0 deletions augmentations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from .simsiam_aug import SimSiamTransform
from .eval_aug import Transform_single


def get_aug(name='simsiam', image_size=224, train=True, train_classifier=None, mean_std=None, **aug_kwargs):
if train==True:
augmentation = SimSiamTransform(image_size, mean_std=mean_std, **aug_kwargs)
elif train==False:
if train_classifier is None:
raise Exception
augmentation = Transform_single(image_size, train=train_classifier, mean_std=mean_std)
else:
raise Exception

return augmentation








24 changes: 24 additions & 0 deletions augmentations/eval_aug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from torchvision import transforms
from PIL import Image


class Transform_single():
def __init__(self, image_size, train, mean_std):
if train == True:
self.transform = transforms.Compose([
transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0), interpolation=Image.BICUBIC),
# transforms.RandomCrop(image_size, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(*mean_std)
])
else:
self.transform = transforms.Compose([
# transforms.Resize(int(image_size*(8/7)), interpolation=Image.BICUBIC), # 224 -> 256
# transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(*mean_std)
])

def __call__(self, x):
return self.transform(x)
Loading

0 comments on commit 9d462aa

Please sign in to comment.