Skip to content

Commit

Permalink
upload code
Browse files Browse the repository at this point in the history
  • Loading branch information
Zechen Bai committed Aug 24, 2024
1 parent 9796f57 commit 1285541
Show file tree
Hide file tree
Showing 113 changed files with 21,735 additions and 9 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
**/.DS_Store
89 changes: 80 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,88 @@
## My Project
# Official PyTorch Implementation of Adaptive Slot Attention: Object Discovery with Dynamic Slot Number
[![ArXiv](https://img.shields.io/badge/ArXiv-2406.09196-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2406.09196)[![HomePage](https://img.shields.io/badge/HomePage-Visit-blue.svg?logo=homeadvisor&logoColor=f5f5f5)](https://kfan21.github.io/AdaSlot/)![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
> [**Adaptive Slot Attention: Object Discovery with Dynamic Slot Number**](https://arxiv.org/abs/2406.09196)<br>
> [Ke Fan](https://kfan21.github.io/), [Zechen Bai](https://www.baizechen.site/), [Tianjun Xiao](http://tianjunxiao.com/), [Tong He](https://hetong007.github.io/), [Max Horn](https://expectationmax.github.io/), [Yanwei Fu†](http://yanweifu.github.io/), [Francesco Locatello](https://www.francescolocatello.com/), [Zheng Zhang](https://scholar.google.com/citations?hl=zh-CN&user=k0KiE4wAAAAJ)
TODO: Fill this README out!

Be sure to:
This is the official implementation of the CVPR'24 paper [Adaptive Slot Attention: Object Discovery with Dynamic Slot Number]([CVPR 2024 Open Access Repository (thecvf.com)](https://openaccess.thecvf.com/content/CVPR2024/html/Fan_Adaptive_Slot_Attention_Object_Discovery_with_Dynamic_Slot_Number_CVPR_2024_paper.html)).

* Change the title in this README
* Edit your repository description on GitHub
## Introduction

## Security
![framework](framework.png)

See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information.
Object-centric learning (OCL) uses slots to extract object representations, enhancing flexibility and interpretability. Slot attention, a common OCL method, refines slot representations with attention mechanisms but requires predefined slot numbers, ignoring object variability. To address this, a novel complexity-aware object auto-encoder framework introduces adaptive slot attention (AdaSlot), dynamically determining the optimal slot count based on data content through a discrete slot sampling module. A masked slot decoder suppresses unselected slots during decoding. Extensive testing shows this framework matches or exceeds fixed-slot models, adapting slot numbers based on instance complexity and promising further research opportunities.

## License
## Development Setup

This project is licensed under the Apache-2.0 License.
Installing AdaSlot requires at least python3.8. Installation can be done using [poetry](https://python-poetry.org/docs/#installation). After installing `poetry`, check out the repo and setup a development environment:

```bash
# install python3.8
sudo apt update
sudo apt install software-properties-common
sudo add-apt-repository ppa:deadsnakes/ppa
sudo apt install python3.8

# install poetry with python3.8
curl -sSL https://install.python-poetry.org | python3.8 - --version 1.2.0
## add poetry to environment variable

# create virtual environment with poetry
cd $code_path
poetry install -E timm
```

This installs the `ocl` package and the cli scripts used for running experiments in a poetry managed virtual environment. Activate the poetry virtual environment `poetry shell` before running the experiments.

## Running experiments

Experiments are defined in the folder `configs/experiment` and can be run
by setting the experiment variable. For example, if we run OC-MOT on Cater dataset, we can follow:

```bash
poetry shell

python -m ocl.cli.train +experiment=projects/bridging/dinosaur/movi_e_feat_rec_vitb16.yaml
python -m ocl.cli.train +experiment=projects/bridging/dinosaur/movi_e_feat_rec_vitb16_adaslot.yaml +load_model_weight=PATH-TO-KMAX-SLOT-CHECKPOINT
python -m ocl.cli.eval +experiment=projects/bridging/dinosaur/movi_e_feat_rec_vitb16_adaslot_eval.yaml ++load_checkpoint=PATH-TO-ADASLOT-CHECKPOINT

python -m ocl.cli.train +experiment=projects/bridging/dinosaur/movi_c_feat_rec_vitb16.yaml
python -m ocl.cli.train +experiment=projects/bridging/dinosaur/movi_c_feat_rec_vitb16_adaslot.yaml +load_model_weight=PATH-TO-KMAX-SLOT-CHECKPOINT
python -m ocl.cli.eval +experiment=projects/bridging/dinosaur/movi_c_feat_rec_vitb16_adaslot_eval.yaml ++load_checkpoint=PATH-TO-ADASLOT-CHECKPOINT

python -m ocl.cli.train +experiment=projects/bridging/dinosaur/coco_feat_rec_dino_base16.yaml
python -m ocl.cli.train +experiment=projects/bridging/dinosaur/coco_feat_rec_dino_base16_adaslot.yaml +load_model_weight=PATH-TO-KMAX-SLOT-CHECKPOINT
python -m ocl.cli.eval +experiment=projects/bridging/dinosaur/coco_feat_rec_dino_base16_adaslot_eval.yaml ++load_checkpoint=PATH-TO-ADASLOT-CHECKPOINT

python -m ocl.cli.train +experiment=slot_attention/clevr10.yaml
python -m ocl.cli.train +experiment=slot_attention/clevr10_adaslot.yaml +load_model_weight=PATH-TO-KMAX-SLOT-CHECKPOINT
python -m ocl.cli.eval +experiment=slot_attention/clevr10_adaslot_eval.yaml ++load_checkpoint=PATH-TO-ADASLOT-CHECKPOINT
```

The result is saved in a timestamped subdirectory in `outputs/<experiment_name>`, i.e. `outputs/OC-MOT/cater/<date>_<time>` in the above case. The prefix path `outputs` can be configured using the `experiment.root_output_path` variable.

## Citation

Please cite our paper if you find this repo useful!

```bibtex
@inproceedings{fan2024adaptive,
title={Adaptive slot attention: Object discovery with dynamic slot number},
author={Fan, Ke and Bai, Zechen and Xiao, Tianjun and He, Tong and Horn, Max and Fu, Yanwei and Locatello, Francesco and Zhang, Zheng},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={23062--23071},
year={2024}
}
```

Related projects that this paper is developed upon:

```bibtex
@misc{oclf,
author = {Max Horn and Maximilian Seitzer and Andrii Zadaianchuk and Zixu Zhao and Dominik Zietlow and Florian Wenzel and Tianjun Xiao},
title = {Object Centric Learning Framework (version 0.1)},
year = {2023},
url = {https://github.com/amazon-science/object-centric-learning-framework},
}
```

1 change: 1 addition & 0 deletions configs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Hydra needs this file to recognize the config folder when using hydra.main from console scripts
10 changes: 10 additions & 0 deletions configs/dataset/clevr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Image dataset CLEVR based on https://github.com/deepmind/multi_object_datasets .
defaults:
- webdataset

train_shards: "/home/ubuntu/clevr_with_masks_new_splits/train/shard-{000000..000114}.tar"
train_size: 70000
val_shards: "/home/ubuntu/clevr_with_masks_new_splits/val/shard-{000000..000024}.tar"
val_size: 15000
test_shards: "/home/ubuntu/clevr_with_masks_new_splits/test/shard-{000000..000024}.tar"
test_size: 15000
18 changes: 18 additions & 0 deletions configs/dataset/clevr6.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# @package _global_
# Image dataset containing instances from CLEVR with at most 6 objects in each scene.
defaults:
- /dataset/clevr@dataset
- /plugins/[email protected]_clevr6_subset
- _self_

dataset:
# Values derived from running `bin/compute_dataset_size.py`
train_size: 26240
val_size: 5553
test_size: 5600

plugins:
01_clevr6_subset:
predicate: "${lambda_fn:'lambda visibility: visibility.sum() < 7'}"
fields:
- visibility
18 changes: 18 additions & 0 deletions configs/dataset/clevr6_old_splits.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# @package _global_
# Image dataset containing instances from CLEVR with at most 6 objects in each scene.
defaults:
- clevr_old_splits@dataset
- /plugins/[email protected]_clevr6_subset
- _self_

dataset:
# Values derived from running `bin/compute_dataset_size.py`
train_size: 29948
val_size: 3674
test_size: 3771

plugins:
01_clevr6_subset:
predicate: "${lambda_fn:'lambda visibility: visibility.sum() < 7'}"
fields:
- visibility
10 changes: 10 additions & 0 deletions configs/dataset/clevr_old_splits.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Image dataset CLEVR based on https://github.com/deepmind/multi_object_datasets .
defaults:
- webdataset

train_shards: ${s3_pipe:"s3://multi-object-webdatasets/clevr_with_masks/train/shard-{000000..000131}.tar"}
train_size: 80000
val_shards: ${s3_pipe:"s3://multi-object-webdatasets/clevr_with_masks/val/shard-{000000..000016}.tar"}
val_size: 10000
test_shards: ${s3_pipe:"s3://multi-object-webdatasets/clevr_with_masks/test/shard-{000000..000016}.tar"}
test_size: 10000
11 changes: 11 additions & 0 deletions configs/dataset/coco.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# The coco2017 dataset with instance, stuff and caption annotations.
defaults:
- webdataset

train_shards: "/home/ubuntu/coco2017/train/shard-{000000..000412}.tar"
train_size: 118287
val_shards: "/home/ubuntu/coco2017/val/shard-{000000..000017}.tar"
val_size: 5000
test_shards: "/home/ubuntu/coco2017/test/shard-{000000..000126}.tar"
test_size: 40670
use_autopadding: true
12 changes: 12 additions & 0 deletions configs/dataset/coco_nocrowd.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# The coco2017 dataset with instance, stuff and caption annotations.
# Validation dataset does not contain any crowd annotations.
defaults:
- webdataset

train_shards: ${dataset_prefix:"coco2017/train/shard-{000000..000412}.tar"}
train_size: 118287
val_shards: ${dataset_prefix:"coco2017/val_nocrowd/shard-{000000..000017}.tar"}
val_size: 5000
test_shards: ${dataset_prefix:"coco2017/test/shard-{000000..000126}.tar"}
test_size: 40670
use_autopadding: true
31 changes: 31 additions & 0 deletions configs/dataset/movi_c.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# @package _global_
defaults:
- webdataset@dataset
- _self_

dataset:
train_shards: "/home/ubuntu/movi_c/train/shard-{000000..000298}.tar"
train_size: 9737
val_shards: "/home/ubuntu/movi_c/val/shard-{000000..000007}.tar"
val_size: 250
test_shards: "/home/ubuntu/movi_c/val/shard-{000000..000007}.tar"
test_size: 250
use_autopadding: true

plugins:
00_1_rename_fields:
_target_: ocl.plugins.RenameFields
train_mapping:
video: image
evaluation_mapping:
video: image
segmentations: mask
00_2_adapt_mask_format:
_target_: ocl.plugins.SingleElementPreprocessing
training_transform: null
evaluation_transform:
_target_: ocl.preprocessing.IntegerToOneHotMask
output_axis: -4
max_instances: 10
ignore_typical_background: false
element_key: mask
25 changes: 25 additions & 0 deletions configs/dataset/movi_c_image.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# @package _global_
# Image dataset containing subsampled frames from MOVI_C dataset.
defaults:
- /dataset/movi_c
- /plugins/[email protected]_sample_frames
- _self_

dataset:
# Values derived from running `bin/compute_dataset_size.py`.
train_size: 87633
val_size: 6000
test_size: 6000

plugins:
02_sample_frames:
n_frames_per_video: 9
n_eval_frames_per_video: -1
training_fields:
- image
evaluation_fields:
- image
- mask
dim: 0
seed: 457834752
shuffle_buffer_size: 1000
30 changes: 30 additions & 0 deletions configs/dataset/movi_e.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# @package _global_
defaults:
- webdataset@dataset
- _self_

dataset:
train_shards: "/home/ubuntu/movi_e/train/shard-{000000..000679}.tar"
train_size: 9749
val_shards: "/home/ubuntu/movi_e/val/shard-{000000..000017}.tar"
val_size: 250
test_shards: "/home/ubuntu/movi_e/val/shard-{000000..000017}.tar"
test_size: 250
use_autopadding: true
plugins:
00_1_rename_fields:
_target_: ocl.plugins.RenameFields
train_mapping:
video: image
evaluation_mapping:
video: image
segmentations: mask
00_2_adapt_mask_format:
_target_: ocl.plugins.SingleElementPreprocessing
training_transform: null
evaluation_transform:
_target_: ocl.preprocessing.IntegerToOneHotMask
output_axis: -4
max_instances: 23
ignore_typical_background: false
element_key: mask
25 changes: 25 additions & 0 deletions configs/dataset/movi_e_image.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# @package _global_
# Image dataset containing subsampled frames from MOVI_E dataset.
defaults:
- /dataset/movi_e
- /plugins/[email protected]_sample_frames
- _self_

dataset:
# Values derived from running `bin/compute_dataset_size.py`.
train_size: 87741
val_size: 6000
test_size: 6000

plugins:
02_sample_frames:
n_frames_per_video: 9
n_eval_frames_per_video: -1
training_fields:
- image
evaluation_fields:
- image
- mask
dim: 0
seed: 457834752
shuffle_buffer_size: 1000
1 change: 1 addition & 0 deletions configs/dataset/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
lambda visibility, color, pixel_coords: True
11 changes: 11 additions & 0 deletions configs/dataset/voc2012.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# The PASCAL VOC 2012 dataset. Does not contain segmentation annotations.
defaults:
- webdataset

train_shards: ${dataset_prefix:"voc2012_detection/train/shard-{000000..000021}.tar"}
train_size: 5717
val_shards: ${dataset_prefix:"voc2012_detection/val/shard-{000000..000022}.tar"}
val_size: 5823
test_shards: ${dataset_prefix:"voc2012_detection/test/shard-{000000..000041}.tar"}
test_size: 10991
use_autopadding: true
11 changes: 11 additions & 0 deletions configs/dataset/voc2012_trainaug.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# The PASCAL VOC 2012 dataset in the trainaug variant with instance segmentation masks.
defaults:
- webdataset

train_shards: "/home/ubuntu/voc2012/trainaug/shard-{000000..000040}.tar"
train_size: 10582
val_shards: "/home/ubuntu/voc2012/val/shard-{000000..000011}.tar"
val_size: 1449
test_shards: null
test_size: null
use_autopadding: true
Loading

0 comments on commit 1285541

Please sign in to comment.