-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Zechen Bai
committed
Aug 24, 2024
1 parent
9796f57
commit 1285541
Showing
113 changed files
with
21,735 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
**/.DS_Store |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,88 @@ | ||
## My Project | ||
# Official PyTorch Implementation of Adaptive Slot Attention: Object Discovery with Dynamic Slot Number | ||
[![ArXiv](https://img.shields.io/badge/ArXiv-2406.09196-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2406.09196)[![HomePage](https://img.shields.io/badge/HomePage-Visit-blue.svg?logo=homeadvisor&logoColor=f5f5f5)](https://kfan21.github.io/AdaSlot/)![License](https://img.shields.io/badge/License-Apache%202.0-green.svg) | ||
> [**Adaptive Slot Attention: Object Discovery with Dynamic Slot Number**](https://arxiv.org/abs/2406.09196)<br> | ||
> [Ke Fan](https://kfan21.github.io/), [Zechen Bai](https://www.baizechen.site/), [Tianjun Xiao](http://tianjunxiao.com/), [Tong He](https://hetong007.github.io/), [Max Horn](https://expectationmax.github.io/), [Yanwei Fu†](http://yanweifu.github.io/), [Francesco Locatello](https://www.francescolocatello.com/), [Zheng Zhang](https://scholar.google.com/citations?hl=zh-CN&user=k0KiE4wAAAAJ) | ||
TODO: Fill this README out! | ||
|
||
Be sure to: | ||
This is the official implementation of the CVPR'24 paper [Adaptive Slot Attention: Object Discovery with Dynamic Slot Number]([CVPR 2024 Open Access Repository (thecvf.com)](https://openaccess.thecvf.com/content/CVPR2024/html/Fan_Adaptive_Slot_Attention_Object_Discovery_with_Dynamic_Slot_Number_CVPR_2024_paper.html)). | ||
|
||
* Change the title in this README | ||
* Edit your repository description on GitHub | ||
## Introduction | ||
|
||
## Security | ||
![framework](framework.png) | ||
|
||
See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. | ||
Object-centric learning (OCL) uses slots to extract object representations, enhancing flexibility and interpretability. Slot attention, a common OCL method, refines slot representations with attention mechanisms but requires predefined slot numbers, ignoring object variability. To address this, a novel complexity-aware object auto-encoder framework introduces adaptive slot attention (AdaSlot), dynamically determining the optimal slot count based on data content through a discrete slot sampling module. A masked slot decoder suppresses unselected slots during decoding. Extensive testing shows this framework matches or exceeds fixed-slot models, adapting slot numbers based on instance complexity and promising further research opportunities. | ||
|
||
## License | ||
## Development Setup | ||
|
||
This project is licensed under the Apache-2.0 License. | ||
Installing AdaSlot requires at least python3.8. Installation can be done using [poetry](https://python-poetry.org/docs/#installation). After installing `poetry`, check out the repo and setup a development environment: | ||
|
||
```bash | ||
# install python3.8 | ||
sudo apt update | ||
sudo apt install software-properties-common | ||
sudo add-apt-repository ppa:deadsnakes/ppa | ||
sudo apt install python3.8 | ||
|
||
# install poetry with python3.8 | ||
curl -sSL https://install.python-poetry.org | python3.8 - --version 1.2.0 | ||
## add poetry to environment variable | ||
|
||
# create virtual environment with poetry | ||
cd $code_path | ||
poetry install -E timm | ||
``` | ||
|
||
This installs the `ocl` package and the cli scripts used for running experiments in a poetry managed virtual environment. Activate the poetry virtual environment `poetry shell` before running the experiments. | ||
|
||
## Running experiments | ||
|
||
Experiments are defined in the folder `configs/experiment` and can be run | ||
by setting the experiment variable. For example, if we run OC-MOT on Cater dataset, we can follow: | ||
|
||
```bash | ||
poetry shell | ||
|
||
python -m ocl.cli.train +experiment=projects/bridging/dinosaur/movi_e_feat_rec_vitb16.yaml | ||
python -m ocl.cli.train +experiment=projects/bridging/dinosaur/movi_e_feat_rec_vitb16_adaslot.yaml +load_model_weight=PATH-TO-KMAX-SLOT-CHECKPOINT | ||
python -m ocl.cli.eval +experiment=projects/bridging/dinosaur/movi_e_feat_rec_vitb16_adaslot_eval.yaml ++load_checkpoint=PATH-TO-ADASLOT-CHECKPOINT | ||
|
||
python -m ocl.cli.train +experiment=projects/bridging/dinosaur/movi_c_feat_rec_vitb16.yaml | ||
python -m ocl.cli.train +experiment=projects/bridging/dinosaur/movi_c_feat_rec_vitb16_adaslot.yaml +load_model_weight=PATH-TO-KMAX-SLOT-CHECKPOINT | ||
python -m ocl.cli.eval +experiment=projects/bridging/dinosaur/movi_c_feat_rec_vitb16_adaslot_eval.yaml ++load_checkpoint=PATH-TO-ADASLOT-CHECKPOINT | ||
|
||
python -m ocl.cli.train +experiment=projects/bridging/dinosaur/coco_feat_rec_dino_base16.yaml | ||
python -m ocl.cli.train +experiment=projects/bridging/dinosaur/coco_feat_rec_dino_base16_adaslot.yaml +load_model_weight=PATH-TO-KMAX-SLOT-CHECKPOINT | ||
python -m ocl.cli.eval +experiment=projects/bridging/dinosaur/coco_feat_rec_dino_base16_adaslot_eval.yaml ++load_checkpoint=PATH-TO-ADASLOT-CHECKPOINT | ||
|
||
python -m ocl.cli.train +experiment=slot_attention/clevr10.yaml | ||
python -m ocl.cli.train +experiment=slot_attention/clevr10_adaslot.yaml +load_model_weight=PATH-TO-KMAX-SLOT-CHECKPOINT | ||
python -m ocl.cli.eval +experiment=slot_attention/clevr10_adaslot_eval.yaml ++load_checkpoint=PATH-TO-ADASLOT-CHECKPOINT | ||
``` | ||
|
||
The result is saved in a timestamped subdirectory in `outputs/<experiment_name>`, i.e. `outputs/OC-MOT/cater/<date>_<time>` in the above case. The prefix path `outputs` can be configured using the `experiment.root_output_path` variable. | ||
|
||
## Citation | ||
|
||
Please cite our paper if you find this repo useful! | ||
|
||
```bibtex | ||
@inproceedings{fan2024adaptive, | ||
title={Adaptive slot attention: Object discovery with dynamic slot number}, | ||
author={Fan, Ke and Bai, Zechen and Xiao, Tianjun and He, Tong and Horn, Max and Fu, Yanwei and Locatello, Francesco and Zhang, Zheng}, | ||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, | ||
pages={23062--23071}, | ||
year={2024} | ||
} | ||
``` | ||
|
||
Related projects that this paper is developed upon: | ||
|
||
```bibtex | ||
@misc{oclf, | ||
author = {Max Horn and Maximilian Seitzer and Andrii Zadaianchuk and Zixu Zhao and Dominik Zietlow and Florian Wenzel and Tianjun Xiao}, | ||
title = {Object Centric Learning Framework (version 0.1)}, | ||
year = {2023}, | ||
url = {https://github.com/amazon-science/object-centric-learning-framework}, | ||
} | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Hydra needs this file to recognize the config folder when using hydra.main from console scripts |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Image dataset CLEVR based on https://github.com/deepmind/multi_object_datasets . | ||
defaults: | ||
- webdataset | ||
|
||
train_shards: "/home/ubuntu/clevr_with_masks_new_splits/train/shard-{000000..000114}.tar" | ||
train_size: 70000 | ||
val_shards: "/home/ubuntu/clevr_with_masks_new_splits/val/shard-{000000..000024}.tar" | ||
val_size: 15000 | ||
test_shards: "/home/ubuntu/clevr_with_masks_new_splits/test/shard-{000000..000024}.tar" | ||
test_size: 15000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# @package _global_ | ||
# Image dataset containing instances from CLEVR with at most 6 objects in each scene. | ||
defaults: | ||
- /dataset/clevr@dataset | ||
- /plugins/[email protected]_clevr6_subset | ||
- _self_ | ||
|
||
dataset: | ||
# Values derived from running `bin/compute_dataset_size.py` | ||
train_size: 26240 | ||
val_size: 5553 | ||
test_size: 5600 | ||
|
||
plugins: | ||
01_clevr6_subset: | ||
predicate: "${lambda_fn:'lambda visibility: visibility.sum() < 7'}" | ||
fields: | ||
- visibility |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# @package _global_ | ||
# Image dataset containing instances from CLEVR with at most 6 objects in each scene. | ||
defaults: | ||
- clevr_old_splits@dataset | ||
- /plugins/[email protected]_clevr6_subset | ||
- _self_ | ||
|
||
dataset: | ||
# Values derived from running `bin/compute_dataset_size.py` | ||
train_size: 29948 | ||
val_size: 3674 | ||
test_size: 3771 | ||
|
||
plugins: | ||
01_clevr6_subset: | ||
predicate: "${lambda_fn:'lambda visibility: visibility.sum() < 7'}" | ||
fields: | ||
- visibility |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Image dataset CLEVR based on https://github.com/deepmind/multi_object_datasets . | ||
defaults: | ||
- webdataset | ||
|
||
train_shards: ${s3_pipe:"s3://multi-object-webdatasets/clevr_with_masks/train/shard-{000000..000131}.tar"} | ||
train_size: 80000 | ||
val_shards: ${s3_pipe:"s3://multi-object-webdatasets/clevr_with_masks/val/shard-{000000..000016}.tar"} | ||
val_size: 10000 | ||
test_shards: ${s3_pipe:"s3://multi-object-webdatasets/clevr_with_masks/test/shard-{000000..000016}.tar"} | ||
test_size: 10000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# The coco2017 dataset with instance, stuff and caption annotations. | ||
defaults: | ||
- webdataset | ||
|
||
train_shards: "/home/ubuntu/coco2017/train/shard-{000000..000412}.tar" | ||
train_size: 118287 | ||
val_shards: "/home/ubuntu/coco2017/val/shard-{000000..000017}.tar" | ||
val_size: 5000 | ||
test_shards: "/home/ubuntu/coco2017/test/shard-{000000..000126}.tar" | ||
test_size: 40670 | ||
use_autopadding: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# The coco2017 dataset with instance, stuff and caption annotations. | ||
# Validation dataset does not contain any crowd annotations. | ||
defaults: | ||
- webdataset | ||
|
||
train_shards: ${dataset_prefix:"coco2017/train/shard-{000000..000412}.tar"} | ||
train_size: 118287 | ||
val_shards: ${dataset_prefix:"coco2017/val_nocrowd/shard-{000000..000017}.tar"} | ||
val_size: 5000 | ||
test_shards: ${dataset_prefix:"coco2017/test/shard-{000000..000126}.tar"} | ||
test_size: 40670 | ||
use_autopadding: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# @package _global_ | ||
defaults: | ||
- webdataset@dataset | ||
- _self_ | ||
|
||
dataset: | ||
train_shards: "/home/ubuntu/movi_c/train/shard-{000000..000298}.tar" | ||
train_size: 9737 | ||
val_shards: "/home/ubuntu/movi_c/val/shard-{000000..000007}.tar" | ||
val_size: 250 | ||
test_shards: "/home/ubuntu/movi_c/val/shard-{000000..000007}.tar" | ||
test_size: 250 | ||
use_autopadding: true | ||
|
||
plugins: | ||
00_1_rename_fields: | ||
_target_: ocl.plugins.RenameFields | ||
train_mapping: | ||
video: image | ||
evaluation_mapping: | ||
video: image | ||
segmentations: mask | ||
00_2_adapt_mask_format: | ||
_target_: ocl.plugins.SingleElementPreprocessing | ||
training_transform: null | ||
evaluation_transform: | ||
_target_: ocl.preprocessing.IntegerToOneHotMask | ||
output_axis: -4 | ||
max_instances: 10 | ||
ignore_typical_background: false | ||
element_key: mask |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# @package _global_ | ||
# Image dataset containing subsampled frames from MOVI_C dataset. | ||
defaults: | ||
- /dataset/movi_c | ||
- /plugins/[email protected]_sample_frames | ||
- _self_ | ||
|
||
dataset: | ||
# Values derived from running `bin/compute_dataset_size.py`. | ||
train_size: 87633 | ||
val_size: 6000 | ||
test_size: 6000 | ||
|
||
plugins: | ||
02_sample_frames: | ||
n_frames_per_video: 9 | ||
n_eval_frames_per_video: -1 | ||
training_fields: | ||
- image | ||
evaluation_fields: | ||
- image | ||
- mask | ||
dim: 0 | ||
seed: 457834752 | ||
shuffle_buffer_size: 1000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# @package _global_ | ||
defaults: | ||
- webdataset@dataset | ||
- _self_ | ||
|
||
dataset: | ||
train_shards: "/home/ubuntu/movi_e/train/shard-{000000..000679}.tar" | ||
train_size: 9749 | ||
val_shards: "/home/ubuntu/movi_e/val/shard-{000000..000017}.tar" | ||
val_size: 250 | ||
test_shards: "/home/ubuntu/movi_e/val/shard-{000000..000017}.tar" | ||
test_size: 250 | ||
use_autopadding: true | ||
plugins: | ||
00_1_rename_fields: | ||
_target_: ocl.plugins.RenameFields | ||
train_mapping: | ||
video: image | ||
evaluation_mapping: | ||
video: image | ||
segmentations: mask | ||
00_2_adapt_mask_format: | ||
_target_: ocl.plugins.SingleElementPreprocessing | ||
training_transform: null | ||
evaluation_transform: | ||
_target_: ocl.preprocessing.IntegerToOneHotMask | ||
output_axis: -4 | ||
max_instances: 23 | ||
ignore_typical_background: false | ||
element_key: mask |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# @package _global_ | ||
# Image dataset containing subsampled frames from MOVI_E dataset. | ||
defaults: | ||
- /dataset/movi_e | ||
- /plugins/[email protected]_sample_frames | ||
- _self_ | ||
|
||
dataset: | ||
# Values derived from running `bin/compute_dataset_size.py`. | ||
train_size: 87741 | ||
val_size: 6000 | ||
test_size: 6000 | ||
|
||
plugins: | ||
02_sample_frames: | ||
n_frames_per_video: 9 | ||
n_eval_frames_per_video: -1 | ||
training_fields: | ||
- image | ||
evaluation_fields: | ||
- image | ||
- mask | ||
dim: 0 | ||
seed: 457834752 | ||
shuffle_buffer_size: 1000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
lambda visibility, color, pixel_coords: True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# The PASCAL VOC 2012 dataset. Does not contain segmentation annotations. | ||
defaults: | ||
- webdataset | ||
|
||
train_shards: ${dataset_prefix:"voc2012_detection/train/shard-{000000..000021}.tar"} | ||
train_size: 5717 | ||
val_shards: ${dataset_prefix:"voc2012_detection/val/shard-{000000..000022}.tar"} | ||
val_size: 5823 | ||
test_shards: ${dataset_prefix:"voc2012_detection/test/shard-{000000..000041}.tar"} | ||
test_size: 10991 | ||
use_autopadding: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# The PASCAL VOC 2012 dataset in the trainaug variant with instance segmentation masks. | ||
defaults: | ||
- webdataset | ||
|
||
train_shards: "/home/ubuntu/voc2012/trainaug/shard-{000000..000040}.tar" | ||
train_size: 10582 | ||
val_shards: "/home/ubuntu/voc2012/val/shard-{000000..000011}.tar" | ||
val_size: 1449 | ||
test_shards: null | ||
test_size: null | ||
use_autopadding: true |
Oops, something went wrong.